]> WPIA git - gigi.git/blobdiff - src/org/cacert/gigi/database/DatabaseConnection.java
fix: SQL change database call pattern
[gigi.git] / src / org / cacert / gigi / database / DatabaseConnection.java
index 21d701cb4bc4048589fcd845040e5660a405912d..bccae86faebb950f1ea82da64aca0e90db8d58b2 100644 (file)
@@ -1,5 +1,8 @@
 package org.cacert.gigi.database;
 
+import java.io.IOException;
+import java.io.InputStream;
+import java.io.PrintWriter;
 import java.sql.Connection;
 import java.sql.DriverManager;
 import java.sql.PreparedStatement;
@@ -7,16 +10,27 @@ import java.sql.ResultSet;
 import java.sql.SQLException;
 import java.sql.Statement;
 import java.util.HashMap;
+import java.util.HashSet;
+import java.util.Map.Entry;
 import java.util.Properties;
+import java.util.StringJoiner;
+import java.util.regex.Matcher;
+import java.util.regex.Pattern;
+
+import org.cacert.gigi.database.SQLFileManager.ImportType;
 
 public class DatabaseConnection {
 
+    public static final int CURRENT_SCHEMA_VERSION = 6;
+
     public static final int CONNECTION_TIMEOUT = 24 * 60 * 60;
 
     private Connection c;
 
     private HashMap<String, PreparedStatement> statements = new HashMap<String, PreparedStatement>();
 
+    HashSet<PreparedStatement> underUse = new HashSet<>();
+
     private static Properties credentials;
 
     private Statement adHoc;
@@ -33,25 +47,51 @@ public class DatabaseConnection {
 
     private void tryConnect() {
         try {
-            c = DriverManager.getConnection(credentials.getProperty("sql.url") + "?zeroDateTimeBehavior=convertToNull", credentials.getProperty("sql.user"), credentials.getProperty("sql.password"));
-            PreparedStatement ps = c.prepareStatement("SET SESSION wait_timeout=?;");
-            ps.setInt(1, CONNECTION_TIMEOUT);
-            ps.execute();
-            ps.close();
+            c = DriverManager.getConnection(credentials.getProperty("sql.url") + "?socketTimeout=" + CONNECTION_TIMEOUT, credentials.getProperty("sql.user"), credentials.getProperty("sql.password"));
             adHoc = c.createStatement();
         } catch (SQLException e) {
             e.printStackTrace();
         }
     }
 
-    public PreparedStatement prepare(String query) throws SQLException {
+    protected synchronized PreparedStatement prepareInternal(String query) throws SQLException {
         ensureOpen();
+        query = preprocessQuery(query);
         PreparedStatement statement = statements.get(query);
-        if (statement == null) {
-            statement = c.prepareStatement(query, Statement.RETURN_GENERATED_KEYS);
-            statements.put(query, statement);
+        if (statement != null) {
+            if (underUse.add(statement)) {
+                return statement;
+            } else {
+                throw new Error("Statement in Use");
+            }
+        }
+        statement = c.prepareStatement(query, query.startsWith("SELECT ") ? Statement.NO_GENERATED_KEYS : Statement.RETURN_GENERATED_KEYS);
+        statements.put(query, statement);
+        if (underUse.add(statement)) {
+            return statement;
+        } else {
+            throw new Error("Statement in Use");
+        }
+    }
+
+    protected synchronized PreparedStatement prepareInternalScrollable(String query) throws SQLException {
+        ensureOpen();
+        query = preprocessQuery(query);
+        PreparedStatement statement = statements.get("__SCROLLABLE__! " + query);
+        if (statement != null) {
+            if (underUse.add(statement)) {
+                return statement;
+            } else {
+                throw new Error("Statement in Use");
+            }
+        }
+        statement = c.prepareStatement(query, ResultSet.TYPE_SCROLL_INSENSITIVE, ResultSet.CONCUR_READ_ONLY);
+        statements.put("__SCROLLABLE__! " + query, statement);
+        if (underUse.add(statement)) {
+            return statement;
+        } else {
+            throw new Error("Statement in Use");
         }
-        return statement;
     }
 
     private long lastAction = System.currentTimeMillis();
@@ -71,24 +111,17 @@ public class DatabaseConnection {
         lastAction = System.currentTimeMillis();
     }
 
-    public static int lastInsertId(PreparedStatement query) throws SQLException {
-        ResultSet rs = query.getGeneratedKeys();
-        rs.next();
-        int id = rs.getInt(1);
-        rs.close();
-        return id;
-    }
-
-    private static ThreadLocal<DatabaseConnection> instances = new ThreadLocal<DatabaseConnection>() {
-
-        @Override
-        protected DatabaseConnection initialValue() {
-            return new DatabaseConnection();
-        }
-    };
+    private static DatabaseConnection instance;
 
     public static DatabaseConnection getInstance() {
-        return instances.get();
+        if (instance == null) {
+            synchronized (DatabaseConnection.class) {
+                if (instance == null) {
+                    instance = new DatabaseConnection();
+                }
+            }
+        }
+        return instance;
     }
 
     public static boolean isInited() {
@@ -100,12 +133,53 @@ public class DatabaseConnection {
             throw new Error("Re-initiaizing is forbidden.");
         }
         credentials = conf;
+        int version = 0;
+        try (GigiPreparedStatement gigiPreparedStatement = new GigiPreparedStatement("SELECT version FROM \"schemeVersion\" ORDER BY version DESC LIMIT 1;")) {
+            GigiResultSet rs = gigiPreparedStatement.executeQuery();
+            if (rs.next()) {
+                version = rs.getInt(1);
+            }
+        }
+        if (version == CURRENT_SCHEMA_VERSION) {
+            return; // Good to go
+        }
+        if (version > CURRENT_SCHEMA_VERSION) {
+            throw new Error("Invalid database version. Please fix this.");
+        }
+        upgrade(version);
     }
 
     public void beginTransaction() throws SQLException {
         c.setAutoCommit(false);
     }
 
+    private static void upgrade(int version) {
+        try {
+            Statement s = getInstance().c.createStatement();
+            try {
+                while (version < CURRENT_SCHEMA_VERSION) {
+                    try (InputStream resourceAsStream = DatabaseConnection.class.getResourceAsStream("upgrade/from_" + version + ".sql")) {
+                        if (resourceAsStream == null) {
+                            throw new Error("Upgrade script from version " + version + " was not found.");
+                        }
+                        SQLFileManager.addFile(s, resourceAsStream, ImportType.PRODUCTION);
+                    }
+                    version++;
+                }
+                s.addBatch("UPDATE \"schemeVersion\" SET version='" + version + "'");
+                System.out.println("UPGRADING Database to version " + version);
+                s.executeBatch();
+                System.out.println("done.");
+            } finally {
+                s.close();
+            }
+        } catch (SQLException e) {
+            e.printStackTrace();
+        } catch (IOException e) {
+            e.printStackTrace();
+        }
+    }
+
     public void commitTransaction() throws SQLException {
         c.commit();
         c.setAutoCommit(true);
@@ -121,4 +195,55 @@ public class DatabaseConnection {
             e.printStackTrace();
         }
     }
+
+    public static final String preprocessQuery(String originalQuery) {
+        originalQuery = originalQuery.replace('`', '"');
+        if (originalQuery.matches("^INSERT INTO [^ ]+ SET .*")) {
+            Pattern p = Pattern.compile("INSERT INTO ([^ ]+) SET (.*)");
+            Matcher m = p.matcher(originalQuery);
+            if (m.matches()) {
+                String replacement = "INSERT INTO " + toIdentifier(m.group(1));
+                String[] parts = m.group(2).split(",");
+                StringJoiner columns = new StringJoiner(", ");
+                StringJoiner values = new StringJoiner(", ");
+                for (int i = 0; i < parts.length; i++) {
+                    String[] split = parts[i].split("=", 2);
+                    columns.add(toIdentifier(split[0]));
+                    values.add(split[1]);
+                }
+                replacement += "(" + columns.toString() + ") VALUES(" + values.toString() + ")";
+                return replacement;
+            }
+        }
+
+        //
+        return originalQuery;
+    }
+
+    private static CharSequence toIdentifier(String ident) {
+        ident = ident.trim();
+        if ( !ident.startsWith("\"")) {
+            ident = "\"" + ident;
+        }
+        if ( !ident.endsWith("\"")) {
+            ident = ident + "\"";
+        }
+        return ident;
+    }
+
+    protected synchronized void returnStatement(PreparedStatement target) {
+        underUse.remove(target);
+    }
+
+    public void lockedStatements(PrintWriter writer) {
+        writer.println(underUse.size());
+        for (PreparedStatement ps : underUse) {
+            for (Entry<String, PreparedStatement> e : statements.entrySet()) {
+                if (e.getValue() == ps) {
+                    writer.println("<br/>");
+                    writer.println(e.getKey());
+                }
+            }
+        }
+    }
 }