X-Git-Url: https://code.wpia.club/?a=blobdiff_plain;f=src%2Forg%2Fcacert%2Fgigi%2Fdatabase%2FDatabaseConnection.java;h=67364d0a06fe95ee02760d5c270e8121a6c25e9b;hb=8552a0eaf52ad7a05cb045ee9423cfc8a45c336a;hp=389a82cfe84bb43fbe1b0a1b8ac39448389778e0;hpb=943d8e7ed0ea5a9d56e7e694a3cbd849c52bad16;p=gigi.git diff --git a/src/org/cacert/gigi/database/DatabaseConnection.java b/src/org/cacert/gigi/database/DatabaseConnection.java index 389a82cf..67364d0a 100644 --- a/src/org/cacert/gigi/database/DatabaseConnection.java +++ b/src/org/cacert/gigi/database/DatabaseConnection.java @@ -1,25 +1,33 @@ package org.cacert.gigi.database; +import java.io.IOException; +import java.io.InputStream; import java.sql.Connection; import java.sql.DriverManager; -import java.sql.PreparedStatement; import java.sql.ResultSet; import java.sql.SQLException; +import java.sql.Statement; import java.util.HashMap; import java.util.Properties; -import java.sql.Statement; +import java.util.StringJoiner; +import java.util.regex.Matcher; +import java.util.regex.Pattern; + +import org.cacert.gigi.database.SQLFileManager.ImportType; public class DatabaseConnection { + public static final int CURRENT_SCHEMA_VERSION = 5; + public static final int CONNECTION_TIMEOUT = 24 * 60 * 60; - Connection c; + private Connection c; - HashMap statements = new HashMap(); + private HashMap statements = new HashMap(); private static Properties credentials; - Statement adHoc; + private Statement adHoc; public DatabaseConnection() { try { @@ -33,28 +41,44 @@ public class DatabaseConnection { private void tryConnect() { try { - c = DriverManager.getConnection(credentials.getProperty("sql.url") + "?zeroDateTimeBehavior=convertToNull", credentials.getProperty("sql.user"), credentials.getProperty("sql.password")); - PreparedStatement ps = c.prepareStatement("SET SESSION wait_timeout=?;"); - ps.setInt(1, CONNECTION_TIMEOUT); - ps.execute(); - ps.close(); + c = DriverManager.getConnection(credentials.getProperty("sql.url") + "?socketTimeout=" + CONNECTION_TIMEOUT, credentials.getProperty("sql.user"), credentials.getProperty("sql.password")); adHoc = c.createStatement(); } catch (SQLException e) { e.printStackTrace(); } } - public PreparedStatement prepare(String query) throws SQLException { + public GigiPreparedStatement prepare(String query) { + ensureOpen(); + query = preprocessQuery(query); + GigiPreparedStatement statement = statements.get(query); + if (statement == null) { + try { + statement = new GigiPreparedStatement(c.prepareStatement(query, query.startsWith("SELECT ") ? Statement.NO_GENERATED_KEYS : Statement.RETURN_GENERATED_KEYS)); + } catch (SQLException e) { + throw new Error(e); + } + statements.put(query, statement); + } + return statement; + } + + public GigiPreparedStatement prepareScrollable(String query) { ensureOpen(); - PreparedStatement statement = statements.get(query); + query = preprocessQuery(query); + GigiPreparedStatement statement = statements.get(query); if (statement == null) { - statement = c.prepareStatement(query, Statement.RETURN_GENERATED_KEYS); + try { + statement = new GigiPreparedStatement(c.prepareStatement(query, ResultSet.TYPE_SCROLL_INSENSITIVE, ResultSet.CONCUR_READ_ONLY)); + } catch (SQLException e) { + throw new Error(e); + } statements.put(query, statement); } return statement; } - long lastAction = System.currentTimeMillis(); + private long lastAction = System.currentTimeMillis(); private void ensureOpen() { if (System.currentTimeMillis() - lastAction > CONNECTION_TIMEOUT * 1000L) { @@ -71,15 +95,7 @@ public class DatabaseConnection { lastAction = System.currentTimeMillis(); } - public static int lastInsertId(PreparedStatement query) throws SQLException { - ResultSet rs = query.getGeneratedKeys(); - rs.next(); - int id = rs.getInt(1); - rs.close(); - return id; - } - - static ThreadLocal instances = new ThreadLocal() { + private static ThreadLocal instances = new ThreadLocal() { @Override protected DatabaseConnection initialValue() { @@ -100,12 +116,51 @@ public class DatabaseConnection { throw new Error("Re-initiaizing is forbidden."); } credentials = conf; + GigiResultSet rs = getInstance().prepare("SELECT version FROM \"schemeVersion\" ORDER BY version DESC LIMIT 1;").executeQuery(); + int version = 0; + if (rs.next()) { + version = rs.getInt(1); + } + if (version == CURRENT_SCHEMA_VERSION) { + return; // Good to go + } + if (version > CURRENT_SCHEMA_VERSION) { + throw new Error("Invalid database version. Please fix this."); + } + upgrade(version); } public void beginTransaction() throws SQLException { c.setAutoCommit(false); } + private static void upgrade(int version) { + try { + Statement s = getInstance().c.createStatement(); + try { + while (version < CURRENT_SCHEMA_VERSION) { + try (InputStream resourceAsStream = DatabaseConnection.class.getResourceAsStream("upgrade/from_" + version + ".sql")) { + if (resourceAsStream == null) { + throw new Error("Upgrade script from version " + version + " was not found."); + } + SQLFileManager.addFile(s, resourceAsStream, ImportType.PRODUCTION); + } + version++; + } + s.addBatch("UPDATE \"schemeVersion\" SET version='" + version + "'"); + System.out.println("UPGRADING Database to version " + version); + s.executeBatch(); + System.out.println("done."); + } finally { + s.close(); + } + } catch (SQLException e) { + e.printStackTrace(); + } catch (IOException e) { + e.printStackTrace(); + } + } + public void commitTransaction() throws SQLException { c.commit(); c.setAutoCommit(true); @@ -121,4 +176,39 @@ public class DatabaseConnection { e.printStackTrace(); } } + + public static final String preprocessQuery(String originalQuery) { + originalQuery = originalQuery.replace('`', '"'); + if (originalQuery.matches("^INSERT INTO [^ ]+ SET .*")) { + Pattern p = Pattern.compile("INSERT INTO ([^ ]+) SET (.*)"); + Matcher m = p.matcher(originalQuery); + if (m.matches()) { + String replacement = "INSERT INTO " + toIdentifier(m.group(1)); + String[] parts = m.group(2).split(","); + StringJoiner columns = new StringJoiner(", "); + StringJoiner values = new StringJoiner(", "); + for (int i = 0; i < parts.length; i++) { + String[] split = parts[i].split("=", 2); + columns.add(toIdentifier(split[0])); + values.add(split[1]); + } + replacement += "(" + columns.toString() + ") VALUES(" + values.toString() + ")"; + return replacement; + } + } + + // + return originalQuery; + } + + private static CharSequence toIdentifier(String ident) { + ident = ident.trim(); + if ( !ident.startsWith("\"")) { + ident = "\"" + ident; + } + if ( !ident.endsWith("\"")) { + ident = ident + "\""; + } + return ident; + } }