]> WPIA git - cassiopeia.git/blobdiff - src/db/mysql.cpp
chg: replace default std::shared_ptr<…>() with explicit nullptr for empty return...
[cassiopeia.git] / src / db / mysql.cpp
index 973e9d0fb68f06ba13f8d57f29e07485ba632c88..db3a7b91402fba5da76cf3bebbd1eaa3b066a9d9 100644 (file)
@@ -36,19 +36,8 @@ MySQLJobProvider::MySQLJobProvider( const std::string& server, const std::string
     connect( server, user, password, database );
 }
 
-MySQLJobProvider::~MySQLJobProvider() {
-    disconnect();
-}
-
 bool MySQLJobProvider::connect( const std::string& server, const std::string& user, const std::string& password, const std::string& database ) {
-    if( conn ) {
-        if( !disconnect() ) {
-            return false;
-        }
-
-        conn.reset();
-    }
-
+    disconnect();
     conn = _connect( server, user, password, database );
 
     return !!conn;
@@ -58,13 +47,13 @@ std::shared_ptr<MYSQL> MySQLJobProvider::_connect( const std::string& server, co
     MYSQL* tmp( mysql_init( NULL ) );
 
     if( !tmp ) {
-        return std::shared_ptr<MYSQL>();
+        return nullptr;
     }
 
     tmp = mysql_real_connect( tmp, server.c_str(), user.c_str(), password.c_str(), database.c_str(), 3306, NULL, CLIENT_COMPRESS );
 
     if( !tmp ) {
-        return std::shared_ptr<MYSQL>();
+        return nullptr;
     }
 
     auto l = lib_ref;
@@ -121,7 +110,7 @@ std::shared_ptr<Job> MySQLJobProvider::fetchJob() {
     std::tie( err, res ) = query( q );
 
     if( err ) {
-        return std::shared_ptr<Job>();
+        return nullptr;
     }
 
     unsigned int num = mysql_num_fields( res.get() );
@@ -129,15 +118,15 @@ std::shared_ptr<Job> MySQLJobProvider::fetchJob() {
     MYSQL_ROW row = mysql_fetch_row( res.get() );
 
     if( !row ) {
-        return std::shared_ptr<Job>();
+        return nullptr;
     }
 
-    std::shared_ptr<Job> job( new Job() );
+    auto job = std::make_shared<Job>();
 
     unsigned long* l = mysql_fetch_lengths( res.get() );
 
     if( !l ) {
-        return std::shared_ptr<Job>();
+        return nullptr;
     }
 
     job->id = std::string( row[0], row[0] + l[0] );
@@ -198,7 +187,7 @@ void MySQLJobProvider::failJob( std::shared_ptr<Job> job ) {
 }
 
 std::shared_ptr<TBSCertificate> MySQLJobProvider::fetchTBSCert( std::shared_ptr<Job> job ) {
-    std::shared_ptr<TBSCertificate> cert = std::shared_ptr<TBSCertificate>( new TBSCertificate() );
+    auto cert = std::make_shared<TBSCertificate>();
     std::string q = "SELECT md, profile, csr_name, csr_type, keyname FROM certs INNER JOIN profiles ON profiles.id = certs.profile WHERE certs.id='" + this->escape_string( job->target ) + "'";
 
     int err = 0;
@@ -208,19 +197,19 @@ std::shared_ptr<TBSCertificate> MySQLJobProvider::fetchTBSCert( std::shared_ptr<
     std::tie( err, res ) = query( q );
 
     if( err ) {
-        return std::shared_ptr<TBSCertificate>();
+        return nullptr;
     }
 
     MYSQL_ROW row = mysql_fetch_row( res.get() );
 
     if( !row ) {
-        return std::shared_ptr<TBSCertificate>();
+        return nullptr;
     }
 
     unsigned long* l = mysql_fetch_lengths( res.get() );
 
     if( !l ) {
-        return std::shared_ptr<TBSCertificate>();
+        return nullptr;
     }
 
     std::string profileName = std::string( row[4], row[4] + l[4] );
@@ -244,7 +233,7 @@ std::shared_ptr<TBSCertificate> MySQLJobProvider::fetchTBSCert( std::shared_ptr<
 
     if( err ) {
         std::cout << mysql_error( this->conn.get() );
-        return std::shared_ptr<TBSCertificate>();
+        return nullptr;
     }
 
     std::cout << "Fetching SANs" << std::endl;
@@ -253,7 +242,7 @@ std::shared_ptr<TBSCertificate> MySQLJobProvider::fetchTBSCert( std::shared_ptr<
         unsigned long* l = mysql_fetch_lengths( res.get() );
 
         if( !l ) {
-            return std::shared_ptr<TBSCertificate>();
+            return nullptr;
         }
 
         std::shared_ptr<SAN> nSAN = std::shared_ptr<SAN>( new SAN() );
@@ -267,7 +256,7 @@ std::shared_ptr<TBSCertificate> MySQLJobProvider::fetchTBSCert( std::shared_ptr<
 
     if( err ) {
         std::cout << mysql_error( this->conn.get() );
-        return std::shared_ptr<TBSCertificate>();
+        return nullptr;
 
     }
 
@@ -275,7 +264,7 @@ std::shared_ptr<TBSCertificate> MySQLJobProvider::fetchTBSCert( std::shared_ptr<
         unsigned long* l = mysql_fetch_lengths( res.get() );
 
         if( !l ) {
-            return std::shared_ptr<TBSCertificate>();
+            return nullptr;
         }
 
         std::shared_ptr<AVA> nAVA = std::shared_ptr<AVA>( new AVA() );
@@ -292,11 +281,63 @@ void MySQLJobProvider::writeBack( std::shared_ptr<Job> job, std::shared_ptr<Sign
         throw "Error while writing back";
     }
 
-    std::string q = "UPDATE certs SET crt_name='" + this->escape_string( res->crt_name ) + "', serial='" + this->escape_string( res->serial ) + "', created=NOW() WHERE id='" + this->escape_string( job->target ) + "' LIMIT 1";
+    std::string id = "SELECT id FROM cacerts WHERE keyname='" + this->escape_string( res->ca_name ) + "'";
+
+    int err = 0;
+    std::shared_ptr<MYSQL_RES> resu;
+    std::tie( err, resu ) = query( id );
+
+    if( err ) {
+        throw "Error while looking ca cert id";
+    }
+
+    MYSQL_ROW row = mysql_fetch_row( resu.get() );
+    unsigned long* l = mysql_fetch_lengths( resu.get() );
+
+    std::string read_id;
+
+    if( !row || !l ) {
+        if( query( "INSERT INTO cacerts SET keyname= '" + this->escape_string( res->ca_name ) + "', subroot = 0" ).first ) {
+            throw "Error while inserting new ca cert";
+        }
 
+        my_ulonglong insert_id = mysql_insert_id( conn.get() );
+
+        read_id = std::to_string( insert_id );
+    } else {
+        read_id = std::string( row[0], row[0] + l[0] );
+    }
+
+    std::string q = "UPDATE certs SET crt_name='" + this->escape_string( res->crt_name ) + "', serial='" + this->escape_string( res->serial ) + "', caId = '" + this->escape_string( read_id ) + "', created='" + this->escape_string( res->before ) + "', expire='" + this->escape_string( res->after ) + "'  WHERE id='" + this->escape_string( job->target ) + "' LIMIT 1";
     // TODO write more thingies back
 
     if( query( q ).first ) {
         throw "Error while writing back";
     }
 }
+
+std::pair<std::string, std::string> MySQLJobProvider::getRevocationInfo( std::shared_ptr<Job> job ) {
+    std::string q = "SELECT certs.serial, cacerts.keyname FROM certs INNER JOIN cacerts ON certs.caId = cacerts.id WHERE certs.id = '" + this->escape_string( job->target ) + "' ";
+    int err = 0;
+    std::shared_ptr<MYSQL_RES> resu;
+    std::tie( err, resu ) = query( q );
+
+    if( err ) {
+        throw "Error while looking ca cert id";
+    }
+
+    MYSQL_ROW row = mysql_fetch_row( resu.get() );
+    unsigned long* l = mysql_fetch_lengths( resu.get() );
+
+    if( !row || !l ) {
+        throw "Error while inserting new ca cert";
+    }
+
+    return std::pair<std::string, std::string>( std::string( row[0], row[0] + l[0] ), std::string( row[1], row[1] + l[1] ) );
+}
+
+void MySQLJobProvider::writeBackRevocation( std::shared_ptr<Job> job, std::string date ) {
+    if( query( "UPDATE certs SET revoked = '" + this->escape_string( date ) + "' WHERE id = '" + this->escape_string( job->target ) + "'" ).first ) {
+        throw "Error while writing back revocation";
+    }
+}