]> WPIA git - cassiopeia.git/blob - src/db/mysql.cpp
e8f1d4dd44a01cdc33358d53b19d63d6dc93e732
[cassiopeia.git] / src / db / mysql.cpp
1 #include "mysql.h"
2
3 #include <stdio.h>
4
5 #include <iostream>
6
7 #include <mysql/errmsg.h>
8 #include <log/logger.hpp>
9
10 //This static variable exists to handle initializing and finalizing the MySQL driver library
11 std::shared_ptr<int> MySQLJobProvider::lib_ref(
12     //Initializer: Store the return code as a pointer to an integer
13     new int( mysql_library_init( 0, NULL, NULL ) ),
14
15     //Finalizer: Check the pointer and free resources
16     []( int* ref ) {
17         if( !ref ) {
18             //The library is not initialized
19             return;
20         }
21
22         if( *ref ) {
23             //The library did return an error when initializing
24             delete ref;
25             return;
26         }
27
28         delete ref;
29
30         mysql_library_end();
31     } );
32
33 MySQLJobProvider::MySQLJobProvider( const std::string& server, const std::string& user, const std::string& password, const std::string& database ) {
34     if( !lib_ref || *lib_ref ) {
35         throw std::runtime_error("MySQL library not initialized!");
36     }
37
38     connect( server, user, password, database );
39 }
40
41 bool MySQLJobProvider::connect( const std::string& server, const std::string& user, const std::string& password, const std::string& database ) {
42     disconnect();
43     conn = _connect( server, user, password, database );
44
45     return !!conn;
46 }
47
48 std::shared_ptr<MYSQL> MySQLJobProvider::_connect( const std::string& server, const std::string& user, const std::string& password, const std::string& database ) {
49     MYSQL* tmp( mysql_init( NULL ) );
50
51     if( !tmp ) {
52         return nullptr;
53     }
54
55     tmp = mysql_real_connect( tmp, server.c_str(), user.c_str(), password.c_str(), database.c_str(), 3306, NULL, CLIENT_COMPRESS );
56
57     if( !tmp ) {
58         return nullptr;
59     }
60
61     auto l = lib_ref;
62     return std::shared_ptr<MYSQL>(
63         tmp,
64         [l]( MYSQL * c ) {
65             if( c ) {
66                 mysql_close( c );
67             }
68         } );
69 }
70
71 bool MySQLJobProvider::disconnect() {
72     if( !conn ) {
73         return false;
74     }
75
76     conn.reset();
77
78     return true;
79 }
80
81 std::pair< int, std::shared_ptr<MYSQL_RES> > MySQLJobProvider::query( const std::string& query ) {
82     if( !conn ) {
83         return std::make_pair( CR_SERVER_LOST, std::shared_ptr<MYSQL_RES>() );
84     }
85
86     int err = mysql_real_query( this->conn.get(), query.c_str(), query.size() );
87
88     if( err ) {
89         throw std::runtime_error(std::string( "MySQL error: " ) + mysql_error( this->conn.get() ));
90     }
91
92     auto c = conn;
93     std::shared_ptr<MYSQL_RES> res(
94         mysql_store_result( conn.get() ),
95         [c]( MYSQL_RES * r ) {
96             if( !r ) {
97                 return;
98             }
99
100             mysql_free_result( r );
101         } );
102
103     return std::make_pair( err, res );
104 }
105
106 std::shared_ptr<Job> MySQLJobProvider::fetchJob() {
107     std::string q = "SELECT id, targetId, task, executeFrom, executeTo, warning FROM jobs WHERE state='open' AND warning < 3";
108
109     int err = 0;
110     std::shared_ptr<MYSQL_RES> res;
111
112     std::tie( err, res ) = query( q );
113
114     if( err ) {
115         return nullptr;
116     }
117
118     MYSQL_ROW row = mysql_fetch_row( res.get() );
119
120     if( !row ) {
121         return nullptr;
122     }
123
124     auto job = std::make_shared<Job>();
125
126     unsigned long* l = mysql_fetch_lengths( res.get() );
127
128     if( !l ) {
129         return nullptr;
130     }
131
132     job->id = std::string( row[0], row[0] + l[0] );
133     job->target = std::string( row[1], row[1] + l[1] );
134     job->task = std::string( row[2], row[2] + l[2] );
135     job->from = std::string( row[3], row[3] + l[3] );
136     job->to = std::string( row[4], row[4] + l[4] );
137     job->warning = std::string( row[5], row[5] + l[5] );
138
139     logger::notef( "Got a job: (id=%s, target=%s, task=%s, from=%s, to=%s, warnings=%s)", job->id, job->target, job->task, job->from, job->to, job->warning );
140
141     return job;
142 }
143
144 std::string MySQLJobProvider::escape_string( const std::string& target ) {
145     if( !conn ) {
146         throw std::runtime_error("Not connected!");
147     }
148
149     std::string result;
150
151     result.resize( target.size() * 2 );
152
153     long unsigned int len = mysql_real_escape_string( conn.get(), const_cast<char*>( result.data() ), target.c_str(), target.size() );
154
155     result.resize( len );
156
157     return result;
158 }
159
160 void MySQLJobProvider::finishJob( std::shared_ptr<Job> job ) {
161     if( !conn ) {
162         throw std::runtime_error("Not connected!");
163     }
164
165     std::string q = "UPDATE jobs SET state='done' WHERE id='" + this->escape_string( job->id ) + "' LIMIT 1";
166
167     if( query( q ).first ) {
168         throw std::runtime_error("No database entry found.");
169     }
170 }
171
172 void MySQLJobProvider::failJob( std::shared_ptr<Job> job ) {
173     if( !conn ) {
174         throw std::runtime_error("Not connected!");
175     }
176
177     std::string q = "UPDATE jobs SET warning = warning + 1 WHERE id='" + this->escape_string( job->id ) + "' LIMIT 1";
178
179     if( query( q ).first ) {
180         throw std::runtime_error("No database entry found.");
181     }
182 }
183
184 std::shared_ptr<TBSCertificate> MySQLJobProvider::fetchTBSCert( std::shared_ptr<Job> job ) {
185     auto cert = std::make_shared<TBSCertificate>();
186     std::string q = "SELECT md, profile, csr_name, csr_type, keyname FROM certs INNER JOIN profiles ON profiles.id = certs.profile WHERE certs.id='" + this->escape_string( job->target ) + "'";
187
188     int err = 0;
189
190     std::shared_ptr<MYSQL_RES> res;
191
192     std::tie( err, res ) = query( q );
193
194     if( err ) {
195         return nullptr;
196     }
197
198     MYSQL_ROW row = mysql_fetch_row( res.get() );
199
200     if( !row ) {
201         return nullptr;
202     }
203
204     unsigned long* l = mysql_fetch_lengths( res.get() );
205
206     if( !l ) {
207         return nullptr;
208     }
209
210     std::string profileName = std::string( row[4], row[4] + l[4] );
211
212     cert->md = std::string( row[0], row[0] + l[0] );
213     std::string profileId = std::string( row[1], row[1] + l[1] );
214
215     while( profileId.size() < 4 ) {
216         profileId = "0" + profileId;
217     }
218
219     cert->profile = profileId + "-" + profileName;
220
221     cert->csr = std::string( row[2], row[2] + l[2] );
222     cert->csr_type = std::string( row[3], row[3] + l[3] );
223
224     cert->SANs = std::vector<std::shared_ptr<SAN>>();
225
226     q = "SELECT contents, type FROM subjectAlternativeNames WHERE certId='" + this->escape_string( job->target ) + "'";
227     std::tie( err, res ) = query( q );
228
229     if( err ) {
230         std::cout << mysql_error( this->conn.get() );
231         return nullptr;
232     }
233
234     std::cout << "Fetching SANs" << std::endl;
235
236     while( ( row = mysql_fetch_row( res.get() ) ) ) {
237         unsigned long* l = mysql_fetch_lengths( res.get() );
238
239         if( !l ) {
240             return nullptr;
241         }
242
243         auto nSAN = std::make_shared<SAN>();
244         nSAN->content = std::string( row[0], row[0] + l[0] );
245         nSAN->type = std::string( row[1], row[1] + l[1] );
246         cert->SANs.push_back( nSAN );
247     }
248
249     q = "SELECT name, value FROM certAvas WHERE certid='" + this->escape_string( job->target ) + "'";
250     std::tie( err, res ) = query( q );
251
252     if( err ) {
253         std::cout << mysql_error( this->conn.get() );
254         return nullptr;
255
256     }
257
258     while( ( row = mysql_fetch_row( res.get() ) ) ) {
259         unsigned long* l = mysql_fetch_lengths( res.get() );
260
261         if( !l ) {
262             return nullptr;
263         }
264
265         auto nAVA = std::make_shared<AVA>();
266         nAVA->name = std::string( row[0], row[0] + l[0] );
267         nAVA->value = std::string( row[1], row[1] + l[1] );
268         cert->AVAs.push_back( nAVA );
269     }
270
271     return cert;
272 }
273
274 void MySQLJobProvider::writeBack( std::shared_ptr<Job> job, std::shared_ptr<SignedCertificate> res ) {
275     if( !conn ) {
276         throw std::runtime_error("Error while writing back");
277     }
278
279     std::string id = "SELECT id FROM cacerts WHERE keyname='" + this->escape_string( res->ca_name ) + "'";
280
281     int err = 0;
282     std::shared_ptr<MYSQL_RES> resu;
283     std::tie( err, resu ) = query( id );
284
285     if( err ) {
286         throw std::runtime_error("Error while looking ca cert id");
287     }
288
289     MYSQL_ROW row = mysql_fetch_row( resu.get() );
290     unsigned long* l = mysql_fetch_lengths( resu.get() );
291
292     std::string read_id;
293
294     if( !row || !l ) {
295         throw std::runtime_error("Error while inserting new ca cert not found");
296     } else {
297         read_id = std::string( row[0], row[0] + l[0] );
298     }
299
300     std::string q = "UPDATE certs SET crt_name='" + this->escape_string( res->crt_name ) + "', serial='" + this->escape_string( res->serial ) + "', caId = '" + this->escape_string( read_id ) + "', created='" + this->escape_string( res->before ) + "', expire='" + this->escape_string( res->after ) + "'  WHERE id='" + this->escape_string( job->target ) + "' LIMIT 1";
301     // TODO write more thingies back
302
303     if( query( q ).first ) {
304         throw std::runtime_error("Error while writing back");
305     }
306 }
307
308 std::pair<std::string, std::string> MySQLJobProvider::getRevocationInfo( std::shared_ptr<Job> job ) {
309     std::string q = "SELECT certs.serial, cacerts.keyname FROM certs INNER JOIN cacerts ON certs.caId = cacerts.id WHERE certs.id = '" + this->escape_string( job->target ) + "' ";
310     int err = 0;
311     std::shared_ptr<MYSQL_RES> resu;
312     std::tie( err, resu ) = query( q );
313
314     if( err ) {
315         throw std::runtime_error("Error while looking ca cert id");
316     }
317
318     MYSQL_ROW row = mysql_fetch_row( resu.get() );
319     unsigned long* l = mysql_fetch_lengths( resu.get() );
320
321     if( !row || !l ) {
322         throw std::runtime_error("Error while inserting new ca cert");
323     }
324
325     return std::pair<std::string, std::string>( std::string( row[0], row[0] + l[0] ), std::string( row[1], row[1] + l[1] ) );
326 }
327
328 void MySQLJobProvider::writeBackRevocation( std::shared_ptr<Job> job, std::string date ) {
329     if( query( "UPDATE certs SET revoked = '" + this->escape_string( date ) + "' WHERE id = '" + this->escape_string( job->target ) + "'" ).first ) {
330         throw std::runtime_error("Error while writing back revocation");
331     }
332 }