]> WPIA git - cassiopeia.git/blob - src/db/mysql.cpp
cln: Move code around, cleanup structure
[cassiopeia.git] / src / db / mysql.cpp
1 #include "mysql.h"
2
3 #include <stdio.h>
4
5 #include <iostream>
6
7 #include <mysql/errmsg.h>
8
9 //This static variable exists to handle initializing and finalizing the MySQL driver library
10 std::shared_ptr<int> MySQLJobProvider::lib_ref(
11     //Initializer: Store the return code as a pointer to an integer
12     new int( mysql_library_init( 0, NULL, NULL ) ),
13     //Finalizer: Check the pointer and free resources
14     []( int* ref ) {
15         if( !ref ) {
16             //The library is not initialized
17             return;
18         }
19
20         if( *ref ) {
21             //The library did return an error when initializing
22             delete ref;
23             return;
24         }
25
26         delete ref;
27
28         mysql_library_end();
29     } );
30
31 MySQLJobProvider::MySQLJobProvider( const std::string& server, const std::string& user, const std::string& password, const std::string& database ) {
32     if( !lib_ref || *lib_ref ) {
33         throw "MySQL library not initialized!";
34     }
35
36     connect( server, user, password, database );
37 }
38
39 MySQLJobProvider::~MySQLJobProvider() {
40     disconnect();
41 }
42
43 bool MySQLJobProvider::connect( const std::string& server, const std::string& user, const std::string& password, const std::string& database ) {
44     if( conn ) {
45         if( !disconnect() ) {
46             return false;
47         }
48
49         conn.reset();
50     }
51
52     conn = _connect( server, user, password, database );
53
54     return !!conn;
55 }
56
57 std::shared_ptr<MYSQL> MySQLJobProvider::_connect( const std::string& server, const std::string& user, const std::string& password, const std::string& database ) {
58     MYSQL* tmp( mysql_init( NULL ) );
59
60     if( !tmp ) {
61         return std::shared_ptr<MYSQL>();
62     }
63
64     tmp = mysql_real_connect( tmp, server.c_str(), user.c_str(), password.c_str(), database.c_str(), 3306, NULL, CLIENT_COMPRESS );
65
66     if( !tmp ) {
67         return std::shared_ptr<MYSQL>();
68     }
69
70     auto l = lib_ref;
71     return std::shared_ptr<MYSQL>(
72         tmp,
73         [l]( MYSQL * c ) {
74             if( c ) {
75                 mysql_close( c );
76             }
77         } );
78 }
79
80 bool MySQLJobProvider::disconnect() {
81     if( !conn ) {
82         return false;
83     }
84
85     conn.reset();
86
87     return true;
88 }
89
90 std::pair< int, std::shared_ptr<MYSQL_RES> > MySQLJobProvider::query( const std::string& query ) {
91     if( !conn ) {
92         return std::make_pair( CR_SERVER_LOST, std::shared_ptr<MYSQL_RES>() );
93     }
94
95     int err = mysql_real_query( this->conn.get(), query.c_str(), query.size() );
96
97     if( err ) {
98         throw std::string( "MySQL error: " ) + mysql_error( this->conn.get() );
99     }
100
101     auto c = conn;
102     std::shared_ptr<MYSQL_RES> res(
103         mysql_store_result( conn.get() ),
104         [c]( MYSQL_RES * r ) {
105             if( !r ) {
106                 return;
107             }
108
109             mysql_free_result( r );
110         } );
111
112     return std::make_pair( err, res );
113 }
114
115 std::shared_ptr<Job> MySQLJobProvider::fetchJob() {
116     std::string q = "SELECT id, targetId, task, executeFrom, executeTo, warning FROM jobs WHERE state='open' AND warning < 3";
117
118     int err = 0;
119     std::shared_ptr<MYSQL_RES> res;
120
121     std::tie( err, res ) = query( q );
122
123     if( err ) {
124         return std::shared_ptr<Job>();
125     }
126
127     unsigned int num = mysql_num_fields( res.get() );
128
129     MYSQL_ROW row = mysql_fetch_row( res.get() );
130
131     if( !row ) {
132         return std::shared_ptr<Job>();
133     }
134
135     std::shared_ptr<Job> job( new Job() );
136
137     unsigned long* l = mysql_fetch_lengths( res.get() );
138
139     if( !l ) {
140         return std::shared_ptr<Job>();
141     }
142
143     job->id = std::string( row[0], row[0] + l[0] );
144     job->target = std::string( row[1], row[1] + l[1] );
145     job->task = std::string( row[2], row[2] + l[2] );
146     job->from = std::string( row[3], row[3] + l[3] );
147     job->to = std::string( row[4], row[4] + l[4] );
148     job->warning = std::string( row[5], row[5] + l[5] );
149
150     for( unsigned int i = 0; i < num; i++ ) {
151         printf( "[%.*s] ", ( int ) l[i], row[i] ? row[i] : "NULL" );
152     }
153
154     printf( "\n" );
155
156     return job;
157 }
158
159 std::string MySQLJobProvider::escape_string( const std::string& target ) {
160     if( !conn ) {
161         throw "Not connected!";
162     }
163
164     std::string result;
165
166     result.resize( target.size() * 2 );
167
168     long unsigned int len = mysql_real_escape_string( conn.get(), const_cast<char*>( result.data() ), target.c_str(), target.size() );
169
170     result.resize( len );
171
172     return result;
173 }
174
175 void MySQLJobProvider::finishJob( std::shared_ptr<Job> job ) {
176     if( !conn ) {
177         throw "Not connected!";
178     }
179
180     std::string q = "UPDATE jobs SET state='done' WHERE id='" + this->escape_string( job->id ) + "' LIMIT 1";
181
182     if( query( q ).first ) {
183         throw "No database entry found.";
184     }
185
186 }
187
188 void MySQLJobProvider::failJob( std::shared_ptr<Job> job ) {
189     if( !conn ) {
190         throw "Not connected!";
191     }
192
193     std::string q = "UPDATE jobs SET warning = warning + 1 WHERE id='" + this->escape_string( job->id ) + "' LIMIT 1";
194
195     if( query( q ).first ) {
196         throw "No database entry found.";
197     }
198 }
199
200 std::shared_ptr<TBSCertificate> MySQLJobProvider::fetchTBSCert( std::shared_ptr<Job> job ) {
201     std::shared_ptr<TBSCertificate> cert = std::shared_ptr<TBSCertificate>( new TBSCertificate() );
202     std::string q = "SELECT md, profile, csr_name, csr_type, keyname FROM certs INNER JOIN profiles ON profiles.id = certs.profile WHERE certs.id='" + this->escape_string( job->target ) + "'";
203
204     int err = 0;
205
206     std::shared_ptr<MYSQL_RES> res;
207
208     std::tie( err, res ) = query( q );
209
210     if( err ) {
211         return std::shared_ptr<TBSCertificate>();
212     }
213
214     MYSQL_ROW row = mysql_fetch_row( res.get() );
215
216     if( !row ) {
217         return std::shared_ptr<TBSCertificate>();
218     }
219
220     unsigned long* l = mysql_fetch_lengths( res.get() );
221
222     if( !l ) {
223         return std::shared_ptr<TBSCertificate>();
224     }
225
226     std::string profileName = std::string( row[4], row[4] + l[4] );
227
228     cert->md = std::string( row[0], row[0] + l[0] );
229     std::string profileId = std::string( row[1], row[1] + l[1] );
230
231     while( profileId.size() < 4 ) {
232         profileId = "0" + profileId;
233     }
234
235     cert->profile = profileId + "-" + profileName;
236
237     cert->csr = std::string( row[2], row[2] + l[2] );
238     cert->csr_type = std::string( row[3], row[3] + l[3] );
239
240     cert->SANs = std::vector<std::shared_ptr<SAN>>();
241
242     q = "SELECT contents, type FROM subjectAlternativeNames WHERE certId='" + this->escape_string( job->target ) + "'";
243     std::tie( err, res ) = query( q );
244
245     if( err ) {
246         std::cout << mysql_error( this->conn.get() );
247         return std::shared_ptr<TBSCertificate>();
248     }
249
250     std::cout << "Fetching SANs" << std::endl;
251
252     while( ( row = mysql_fetch_row( res.get() ) ) ) {
253         unsigned long* l = mysql_fetch_lengths( res.get() );
254
255         if( !l ) {
256             return std::shared_ptr<TBSCertificate>();
257         }
258
259         std::shared_ptr<SAN> nSAN = std::shared_ptr<SAN>( new SAN() );
260         nSAN->content = std::string( row[0], row[0] + l[0] );
261         nSAN->type = std::string( row[1], row[1] + l[1] );
262         cert->SANs.push_back( nSAN );
263     }
264
265     q = "SELECT name, value FROM certAvas WHERE certid='" + this->escape_string( job->target ) + "'";
266     std::tie( err, res ) = query( q );
267
268     if( err ) {
269         std::cout << mysql_error( this->conn.get() );
270         return std::shared_ptr<TBSCertificate>();
271
272     }
273
274     while( ( row = mysql_fetch_row( res.get() ) ) ) {
275         unsigned long* l = mysql_fetch_lengths( res.get() );
276
277         if( !l ) {
278             return std::shared_ptr<TBSCertificate>();
279         }
280
281         std::shared_ptr<AVA> nAVA = std::shared_ptr<AVA>( new AVA() );
282         nAVA->name = std::string( row[0], row[0] + l[0] );
283         nAVA->value = std::string( row[1], row[1] + l[1] );
284         cert->AVAs.push_back( nAVA );
285     }
286
287     return cert;
288 }
289
290 void MySQLJobProvider::writeBack( std::shared_ptr<Job> job, std::shared_ptr<SignedCertificate> res ) {
291     if( !conn ) {
292         throw "Error while writing back";
293     }
294
295     std::string q = "UPDATE certs SET crt_name='" + this->escape_string( res->crt_name ) + "', serial='" + this->escape_string( res->serial ) + "', created=NOW() WHERE id='" + this->escape_string( job->target ) + "' LIMIT 1";
296
297     // TODO write more thingies back
298
299     if( query( q ).first ) {
300         throw "Error while writing back";
301     }
302 }