]> WPIA git - cassiopeia.git/blob - src/mysql.cpp
add: Read the CSR (naive approach)
[cassiopeia.git] / src / mysql.cpp
1 #include "mysql.h"
2
3 #include <stdio.h>
4
5 #include <mysql/errmsg.h>
6
7 //This static variable exists to handle initializing and finalizing the MySQL driver library
8 std::shared_ptr<int> MySQLJobProvider::lib_ref(
9     //Initializer: Store the return code as a pointer to an integer
10     new int( mysql_library_init( 0, NULL, NULL ) ),
11     //Finalizer: Check the pointer and free resources
12     []( int* ref ) {
13         if( !ref ) {
14             //The library is not initialized
15             return;
16         }
17
18         if( *ref ) {
19             //The library did return an error when initializing
20             delete ref;
21             return;
22         }
23
24         delete ref;
25
26         mysql_library_end();
27     } );
28
29 MySQLJobProvider::MySQLJobProvider( const std::string& server, const std::string& user, const std::string& password, const std::string& database ) {
30     if( !lib_ref || *lib_ref ) {
31         throw "MySQL library not initialized!";
32     }
33
34     connect( server, user, password, database );
35 }
36
37 MySQLJobProvider::~MySQLJobProvider() {
38     disconnect();
39 }
40
41 bool MySQLJobProvider::connect( const std::string& server, const std::string& user, const std::string& password, const std::string& database ) {
42     if( conn ) {
43         if( !disconnect() ) {
44             return false;
45         }
46
47         conn.reset();
48     }
49
50     conn = _connect( server, user, password, database );
51
52     return !!conn;
53 }
54
55 std::shared_ptr<MYSQL> MySQLJobProvider::_connect( const std::string& server, const std::string& user, const std::string& password, const std::string& database ) {
56     MYSQL* tmp( mysql_init( NULL ) );
57
58     if( !tmp ) {
59         return std::shared_ptr<MYSQL>();
60     }
61
62     tmp = mysql_real_connect( tmp, server.c_str(), user.c_str(), password.c_str(), database.c_str(), 3306, NULL, CLIENT_COMPRESS );
63
64     if( !tmp ) {
65         return std::shared_ptr<MYSQL>();
66     }
67
68     auto l = lib_ref;
69     return std::shared_ptr<MYSQL>(
70         tmp,
71         [l]( MYSQL * c ) {
72             if( c ) {
73                 mysql_close( c );
74             }
75         } );
76 }
77
78 bool MySQLJobProvider::disconnect() {
79     if( !conn ) {
80         return false;
81     }
82
83     conn.reset();
84
85     return true;
86 }
87
88 std::pair< int, std::shared_ptr<MYSQL_RES> > MySQLJobProvider::query( const std::string& query ) {
89     if( !conn ) {
90         return std::make_pair( CR_SERVER_LOST, std::shared_ptr<MYSQL_RES>() );
91     }
92
93     int err = mysql_real_query( this->conn.get(), query.c_str(), query.size() );
94
95     if( err ) {
96         return std::make_pair( err, std::shared_ptr<MYSQL_RES>() );
97     }
98
99     auto c = conn;
100     std::shared_ptr<MYSQL_RES> res(
101         mysql_store_result( conn.get() ),
102         [c]( MYSQL_RES * r ) {
103             if( !r ) {
104                 return;
105             }
106
107             mysql_free_result( r );
108         } );
109
110     return std::make_pair( err, res );
111 }
112
113 std::shared_ptr<Job> MySQLJobProvider::fetchJob() {
114     std::string q = "SELECT id, targetId, task, executeFrom, executeTo FROM jobs WHERE state='open'";
115
116     int err = 0;
117     std::shared_ptr<MYSQL_RES> res;
118
119     std::tie( err, res ) = query( q );
120
121     if( err ) {
122         return std::shared_ptr<Job>();
123     }
124
125     unsigned int num = mysql_num_fields( res.get() );
126
127     MYSQL_ROW row = mysql_fetch_row( res.get() );
128
129     if( !row ) {
130         return std::shared_ptr<Job>();
131     }
132
133     std::shared_ptr<Job> job( new Job() );
134
135     unsigned long* l = mysql_fetch_lengths( res.get() );
136
137     if( !l ) {
138         return std::shared_ptr<Job>();
139     }
140
141     job->id = std::string( row[0], row[0] + l[0] );
142     job->target = std::string( row[1], row[1] + l[1] );
143     job->task = std::string( row[2], row[2] + l[2] );
144     job->from = std::string( row[3], row[3] + l[3] );
145     job->to = std::string( row[4], row[4] + l[4] );
146
147     for( unsigned int i = 0; i < num; i++ ) {
148         printf( "[%.*s] ", ( int ) l[i], row[i] ? row[i] : "NULL" );
149     }
150
151     printf( "\n" );
152
153     return job;
154 }
155
156 std::string MySQLJobProvider::escape_string( const std::string& target ) {
157     if( !conn ) {
158         throw "Not connected!";
159     }
160
161     std::string result;
162
163     result.resize( target.size() * 2 );
164
165     long unsigned int len = mysql_real_escape_string( conn.get(), const_cast<char*>( result.data() ), target.c_str(), target.size() );
166
167     result.resize( len );
168
169     return result;
170 }
171
172 bool MySQLJobProvider::finishJob( std::shared_ptr<Job> job ) {
173     if( !conn ) {
174         return false;
175     }
176
177     std::string q = "UPDATE jobs SET state='done' WHERE id='" + this->escape_string( job->id ) + "' LIMIT 1";
178
179     if( query( q ).first ) {
180         return false;
181     }
182
183     return true;
184 }
185
186 std::shared_ptr<TBSCertificate> MySQLJobProvider::fetchTBSCert( std::shared_ptr<Job> job ) {
187     std::shared_ptr<TBSCertificate> cert = std::shared_ptr<TBSCertificate>( new TBSCertificate() );
188     std::string q = "SELECT CN, subject, md, profile, csr_name, csr_type FROM certs WHERE id='" + this->escape_string( job->id ) + "'";
189
190     int err = 0;
191     std::shared_ptr<MYSQL_RES> res;
192
193     std::tie( err, res ) = query( q );
194
195     if( err ) {
196         return std::shared_ptr<TBSCertificate>();
197     }
198
199     MYSQL_ROW row = mysql_fetch_row( res.get() );
200
201     if( !row ) {
202         return std::shared_ptr<TBSCertificate>();
203     }
204
205     unsigned long* l = mysql_fetch_lengths( res.get() );
206
207     if( !l ) {
208         return std::shared_ptr<TBSCertificate>();
209     }
210
211     cert->CN = std::string( row[0], row[0] + l[0] );
212     cert->subj = std::string( row[1], row[1] + l[1] );
213     cert->md = std::string( row[2], row[2] + l[2] );
214     cert->profile = std::string( row[3], row[3] + l[3] );
215     cert->csr = std::string( row[4], row[4] + l[4] );
216     cert->csr_type = std::string( row[5], row[5] + l[5] );
217
218     return cert;
219 }