]> WPIA git - cassiopeia.git/blob - src/db/mysql.cpp
chg: remove useless destructor and simplify connect-method
[cassiopeia.git] / src / db / mysql.cpp
1 #include "mysql.h"
2
3 #include <stdio.h>
4
5 #include <iostream>
6
7 #include <mysql/errmsg.h>
8
9 //This static variable exists to handle initializing and finalizing the MySQL driver library
10 std::shared_ptr<int> MySQLJobProvider::lib_ref(
11     //Initializer: Store the return code as a pointer to an integer
12     new int( mysql_library_init( 0, NULL, NULL ) ),
13     //Finalizer: Check the pointer and free resources
14     []( int* ref ) {
15         if( !ref ) {
16             //The library is not initialized
17             return;
18         }
19
20         if( *ref ) {
21             //The library did return an error when initializing
22             delete ref;
23             return;
24         }
25
26         delete ref;
27
28         mysql_library_end();
29     } );
30
31 MySQLJobProvider::MySQLJobProvider( const std::string& server, const std::string& user, const std::string& password, const std::string& database ) {
32     if( !lib_ref || *lib_ref ) {
33         throw "MySQL library not initialized!";
34     }
35
36     connect( server, user, password, database );
37 }
38
39 bool MySQLJobProvider::connect( const std::string& server, const std::string& user, const std::string& password, const std::string& database ) {
40     disconnect();
41     conn = _connect( server, user, password, database );
42
43     return !!conn;
44 }
45
46 std::shared_ptr<MYSQL> MySQLJobProvider::_connect( const std::string& server, const std::string& user, const std::string& password, const std::string& database ) {
47     MYSQL* tmp( mysql_init( NULL ) );
48
49     if( !tmp ) {
50         return std::shared_ptr<MYSQL>();
51     }
52
53     tmp = mysql_real_connect( tmp, server.c_str(), user.c_str(), password.c_str(), database.c_str(), 3306, NULL, CLIENT_COMPRESS );
54
55     if( !tmp ) {
56         return std::shared_ptr<MYSQL>();
57     }
58
59     auto l = lib_ref;
60     return std::shared_ptr<MYSQL>(
61         tmp,
62         [l]( MYSQL * c ) {
63             if( c ) {
64                 mysql_close( c );
65             }
66         } );
67 }
68
69 bool MySQLJobProvider::disconnect() {
70     if( !conn ) {
71         return false;
72     }
73
74     conn.reset();
75
76     return true;
77 }
78
79 std::pair< int, std::shared_ptr<MYSQL_RES> > MySQLJobProvider::query( const std::string& query ) {
80     if( !conn ) {
81         return std::make_pair( CR_SERVER_LOST, std::shared_ptr<MYSQL_RES>() );
82     }
83
84     int err = mysql_real_query( this->conn.get(), query.c_str(), query.size() );
85
86     if( err ) {
87         throw std::string( "MySQL error: " ) + mysql_error( this->conn.get() );
88     }
89
90     auto c = conn;
91     std::shared_ptr<MYSQL_RES> res(
92         mysql_store_result( conn.get() ),
93         [c]( MYSQL_RES * r ) {
94             if( !r ) {
95                 return;
96             }
97
98             mysql_free_result( r );
99         } );
100
101     return std::make_pair( err, res );
102 }
103
104 std::shared_ptr<Job> MySQLJobProvider::fetchJob() {
105     std::string q = "SELECT id, targetId, task, executeFrom, executeTo, warning FROM jobs WHERE state='open' AND warning < 3";
106
107     int err = 0;
108     std::shared_ptr<MYSQL_RES> res;
109
110     std::tie( err, res ) = query( q );
111
112     if( err ) {
113         return std::shared_ptr<Job>();
114     }
115
116     unsigned int num = mysql_num_fields( res.get() );
117
118     MYSQL_ROW row = mysql_fetch_row( res.get() );
119
120     if( !row ) {
121         return std::shared_ptr<Job>();
122     }
123
124     std::shared_ptr<Job> job( new Job() );
125
126     unsigned long* l = mysql_fetch_lengths( res.get() );
127
128     if( !l ) {
129         return std::shared_ptr<Job>();
130     }
131
132     job->id = std::string( row[0], row[0] + l[0] );
133     job->target = std::string( row[1], row[1] + l[1] );
134     job->task = std::string( row[2], row[2] + l[2] );
135     job->from = std::string( row[3], row[3] + l[3] );
136     job->to = std::string( row[4], row[4] + l[4] );
137     job->warning = std::string( row[5], row[5] + l[5] );
138
139     for( unsigned int i = 0; i < num; i++ ) {
140         printf( "[%.*s] ", ( int ) l[i], row[i] ? row[i] : "NULL" );
141     }
142
143     printf( "\n" );
144
145     return job;
146 }
147
148 std::string MySQLJobProvider::escape_string( const std::string& target ) {
149     if( !conn ) {
150         throw "Not connected!";
151     }
152
153     std::string result;
154
155     result.resize( target.size() * 2 );
156
157     long unsigned int len = mysql_real_escape_string( conn.get(), const_cast<char*>( result.data() ), target.c_str(), target.size() );
158
159     result.resize( len );
160
161     return result;
162 }
163
164 void MySQLJobProvider::finishJob( std::shared_ptr<Job> job ) {
165     if( !conn ) {
166         throw "Not connected!";
167     }
168
169     std::string q = "UPDATE jobs SET state='done' WHERE id='" + this->escape_string( job->id ) + "' LIMIT 1";
170
171     if( query( q ).first ) {
172         throw "No database entry found.";
173     }
174
175 }
176
177 void MySQLJobProvider::failJob( std::shared_ptr<Job> job ) {
178     if( !conn ) {
179         throw "Not connected!";
180     }
181
182     std::string q = "UPDATE jobs SET warning = warning + 1 WHERE id='" + this->escape_string( job->id ) + "' LIMIT 1";
183
184     if( query( q ).first ) {
185         throw "No database entry found.";
186     }
187 }
188
189 std::shared_ptr<TBSCertificate> MySQLJobProvider::fetchTBSCert( std::shared_ptr<Job> job ) {
190     std::shared_ptr<TBSCertificate> cert = std::shared_ptr<TBSCertificate>( new TBSCertificate() );
191     std::string q = "SELECT md, profile, csr_name, csr_type, keyname FROM certs INNER JOIN profiles ON profiles.id = certs.profile WHERE certs.id='" + this->escape_string( job->target ) + "'";
192
193     int err = 0;
194
195     std::shared_ptr<MYSQL_RES> res;
196
197     std::tie( err, res ) = query( q );
198
199     if( err ) {
200         return std::shared_ptr<TBSCertificate>();
201     }
202
203     MYSQL_ROW row = mysql_fetch_row( res.get() );
204
205     if( !row ) {
206         return std::shared_ptr<TBSCertificate>();
207     }
208
209     unsigned long* l = mysql_fetch_lengths( res.get() );
210
211     if( !l ) {
212         return std::shared_ptr<TBSCertificate>();
213     }
214
215     std::string profileName = std::string( row[4], row[4] + l[4] );
216
217     cert->md = std::string( row[0], row[0] + l[0] );
218     std::string profileId = std::string( row[1], row[1] + l[1] );
219
220     while( profileId.size() < 4 ) {
221         profileId = "0" + profileId;
222     }
223
224     cert->profile = profileId + "-" + profileName;
225
226     cert->csr = std::string( row[2], row[2] + l[2] );
227     cert->csr_type = std::string( row[3], row[3] + l[3] );
228
229     cert->SANs = std::vector<std::shared_ptr<SAN>>();
230
231     q = "SELECT contents, type FROM subjectAlternativeNames WHERE certId='" + this->escape_string( job->target ) + "'";
232     std::tie( err, res ) = query( q );
233
234     if( err ) {
235         std::cout << mysql_error( this->conn.get() );
236         return std::shared_ptr<TBSCertificate>();
237     }
238
239     std::cout << "Fetching SANs" << std::endl;
240
241     while( ( row = mysql_fetch_row( res.get() ) ) ) {
242         unsigned long* l = mysql_fetch_lengths( res.get() );
243
244         if( !l ) {
245             return std::shared_ptr<TBSCertificate>();
246         }
247
248         std::shared_ptr<SAN> nSAN = std::shared_ptr<SAN>( new SAN() );
249         nSAN->content = std::string( row[0], row[0] + l[0] );
250         nSAN->type = std::string( row[1], row[1] + l[1] );
251         cert->SANs.push_back( nSAN );
252     }
253
254     q = "SELECT name, value FROM certAvas WHERE certid='" + this->escape_string( job->target ) + "'";
255     std::tie( err, res ) = query( q );
256
257     if( err ) {
258         std::cout << mysql_error( this->conn.get() );
259         return std::shared_ptr<TBSCertificate>();
260
261     }
262
263     while( ( row = mysql_fetch_row( res.get() ) ) ) {
264         unsigned long* l = mysql_fetch_lengths( res.get() );
265
266         if( !l ) {
267             return std::shared_ptr<TBSCertificate>();
268         }
269
270         std::shared_ptr<AVA> nAVA = std::shared_ptr<AVA>( new AVA() );
271         nAVA->name = std::string( row[0], row[0] + l[0] );
272         nAVA->value = std::string( row[1], row[1] + l[1] );
273         cert->AVAs.push_back( nAVA );
274     }
275
276     return cert;
277 }
278
279 void MySQLJobProvider::writeBack( std::shared_ptr<Job> job, std::shared_ptr<SignedCertificate> res ) {
280     if( !conn ) {
281         throw "Error while writing back";
282     }
283
284     std::string id = "SELECT id FROM cacerts WHERE keyname='" + this->escape_string( res->ca_name ) + "'";
285
286     int err = 0;
287     std::shared_ptr<MYSQL_RES> resu;
288     std::tie( err, resu ) = query( id );
289
290     if( err ) {
291         throw "Error while looking ca cert id";
292     }
293
294     MYSQL_ROW row = mysql_fetch_row( resu.get() );
295     unsigned long* l = mysql_fetch_lengths( resu.get() );
296
297     std::string read_id;
298
299     if( !row || !l ) {
300         if( query( "INSERT INTO cacerts SET keyname= '" + this->escape_string( res->ca_name ) + "', subroot = 0" ).first ) {
301             throw "Error while inserting new ca cert";
302         }
303
304         my_ulonglong insert_id = mysql_insert_id( conn.get() );
305
306         read_id = std::to_string( insert_id );
307     } else {
308         read_id = std::string( row[0], row[0] + l[0] );
309     }
310
311     std::string q = "UPDATE certs SET crt_name='" + this->escape_string( res->crt_name ) + "', serial='" + this->escape_string( res->serial ) + "', caId = '" + this->escape_string( read_id ) + "', created='" + this->escape_string( res->before ) + "', expire='" + this->escape_string( res->after ) + "'  WHERE id='" + this->escape_string( job->target ) + "' LIMIT 1";
312     // TODO write more thingies back
313
314     if( query( q ).first ) {
315         throw "Error while writing back";
316     }
317 }
318
319 std::pair<std::string, std::string> MySQLJobProvider::getRevocationInfo( std::shared_ptr<Job> job ) {
320     std::string q = "SELECT certs.serial, cacerts.keyname FROM certs INNER JOIN cacerts ON certs.caId = cacerts.id WHERE certs.id = '" + this->escape_string( job->target ) + "' ";
321     int err = 0;
322     std::shared_ptr<MYSQL_RES> resu;
323     std::tie( err, resu ) = query( q );
324
325     if( err ) {
326         throw "Error while looking ca cert id";
327     }
328
329     MYSQL_ROW row = mysql_fetch_row( resu.get() );
330     unsigned long* l = mysql_fetch_lengths( resu.get() );
331
332     if( !row || !l ) {
333         throw "Error while inserting new ca cert";
334     }
335
336     return std::pair<std::string, std::string>( std::string( row[0], row[0] + l[0] ), std::string( row[1], row[1] + l[1] ) );
337 }
338
339 void MySQLJobProvider::writeBackRevocation( std::shared_ptr<Job> job, std::string date ) {
340     if( query( "UPDATE certs SET revoked = '" + this->escape_string( date ) + "' WHERE id = '" + this->escape_string( job->target ) + "'" ).first ) {
341         throw "Error while writing back revocation";
342     }
343 }