package org.cacert.gigi.database;
-import java.io.FileInputStream;
import java.io.IOException;
+import java.io.InputStream;
+import java.io.PrintWriter;
import java.sql.Connection;
import java.sql.DriverManager;
import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.sql.SQLException;
+import java.sql.Statement;
import java.util.HashMap;
+import java.util.HashSet;
+import java.util.Map.Entry;
import java.util.Properties;
+import java.util.StringJoiner;
+import java.util.regex.Matcher;
+import java.util.regex.Pattern;
+
+import org.cacert.gigi.database.SQLFileManager.ImportType;
public class DatabaseConnection {
- Connection c;
- HashMap<String, PreparedStatement> statements = new HashMap<String, PreparedStatement>();
- static Properties credentials = new Properties();
- static {
- try {
- credentials.load(new FileInputStream("config/sql.properties"));
- } catch (IOException e) {
- e.printStackTrace();
- }
- }
- public DatabaseConnection() {
- try {
- Class.forName(credentials.getProperty("driver"));
- } catch (ClassNotFoundException e) {
- e.printStackTrace();
- }
- try {
- c = DriverManager.getConnection(credentials.getProperty("url")
- + "?zeroDateTimeBehavior=convertToNull",
- credentials.getProperty("user"),
- credentials.getProperty("password"));
- } catch (SQLException e) {
- e.printStackTrace();
- }
-
- }
- public PreparedStatement prepare(String query) throws SQLException {
- PreparedStatement statement = statements.get(query);
- if (statement == null) {
- statement = c.prepareStatement(query);
- statements.put(query, statement);
- }
- return statement;
- }
-
- public static int lastInsertId(PreparedStatement query) throws SQLException {
- ResultSet rs = query.getGeneratedKeys();
- rs.next();
- int id = rs.getInt(1);
- rs.close();
- return id;
- }
- static ThreadLocal<DatabaseConnection> instances = new ThreadLocal<DatabaseConnection>() {
- @Override
- protected DatabaseConnection initialValue() {
- return new DatabaseConnection();
- }
- };
- public static DatabaseConnection getInstance() {
- return instances.get();
- }
+
+ public static final int MAX_CACHED_INSTANCES = 3;
+
+ private static class StatementDescriptor {
+
+ String query;
+
+ boolean scrollable;
+
+ int instance;
+
+ PreparedStatement target;
+
+ public StatementDescriptor(String query, boolean scrollable) {
+ this.query = query;
+ this.scrollable = scrollable;
+ this.instance = 0;
+ }
+
+ public synchronized void instanciate(Connection c) throws SQLException {
+ if (scrollable) {
+ target = c.prepareStatement(query, ResultSet.TYPE_SCROLL_INSENSITIVE, ResultSet.CONCUR_READ_ONLY);
+ } else {
+ target = c.prepareStatement(query, query.startsWith("SELECT ") ? Statement.NO_GENERATED_KEYS : Statement.RETURN_GENERATED_KEYS);
+ }
+
+ }
+
+ public synchronized PreparedStatement getTarget() {
+ return target;
+ }
+
+ public synchronized void increase() {
+ if (target != null) {
+ throw new IllegalStateException();
+ }
+ instance++;
+ }
+
+ @Override
+ public int hashCode() {
+ final int prime = 31;
+ int result = 1;
+ result = prime * result + instance;
+ result = prime * result + ((query == null) ? 0 : query.hashCode());
+ result = prime * result + (scrollable ? 1231 : 1237);
+ return result;
+ }
+
+ @Override
+ public boolean equals(Object obj) {
+ if (this == obj) {
+ return true;
+ }
+ if (obj == null) {
+ return false;
+ }
+ if (getClass() != obj.getClass()) {
+ return false;
+ }
+ StatementDescriptor other = (StatementDescriptor) obj;
+ if (instance != other.instance) {
+ return false;
+ }
+ if (query == null) {
+ if (other.query != null) {
+ return false;
+ }
+ } else if ( !query.equals(other.query)) {
+ return false;
+ }
+ if (scrollable != other.scrollable) {
+ return false;
+ }
+ return true;
+ }
+
+ }
+
+ public static final int CURRENT_SCHEMA_VERSION = 9;
+
+ public static final int CONNECTION_TIMEOUT = 24 * 60 * 60;
+
+ private Connection c;
+
+ private HashMap<StatementDescriptor, PreparedStatement> statements = new HashMap<StatementDescriptor, PreparedStatement>();
+
+ HashSet<PreparedStatement> underUse = new HashSet<>();
+
+ private static Properties credentials;
+
+ private Statement adHoc;
+
+ public DatabaseConnection() {
+ try {
+ Class.forName(credentials.getProperty("sql.driver"));
+ } catch (ClassNotFoundException e) {
+ e.printStackTrace();
+ }
+ tryConnect();
+
+ }
+
+ private void tryConnect() {
+ try {
+ c = DriverManager.getConnection(credentials.getProperty("sql.url") + "?socketTimeout=" + CONNECTION_TIMEOUT, credentials.getProperty("sql.user"), credentials.getProperty("sql.password"));
+ adHoc = c.createStatement();
+ } catch (SQLException e) {
+ e.printStackTrace();
+ }
+ }
+
+ protected synchronized PreparedStatement prepareInternal(String query) throws SQLException {
+ return prepareInternal(query, false);
+ }
+
+ protected synchronized PreparedStatement prepareInternal(String query, boolean scrollable) throws SQLException {
+
+ ensureOpen();
+ query = preprocessQuery(query);
+ StatementDescriptor searchHead = new StatementDescriptor(query, scrollable);
+ PreparedStatement statement = null;
+ while (statement == null) {
+ statement = statements.get(searchHead);
+ if (statement == null) {
+ searchHead.instanciate(c);
+ statement = searchHead.getTarget();
+ if (searchHead.instance >= MAX_CACHED_INSTANCES) {
+ return statement;
+ }
+ underUse.add(statement);
+ statements.put(searchHead, statement);
+ } else if (underUse.contains(statement)) {
+ searchHead.increase();
+ statement = null;
+ } else {
+ underUse.add(statement);
+ }
+ }
+ return statement;
+ }
+
+ protected synchronized PreparedStatement prepareInternalScrollable(String query) throws SQLException {
+ return prepareInternal(query, true);
+ }
+
+ private long lastAction = System.currentTimeMillis();
+
+ private void ensureOpen() {
+ if (System.currentTimeMillis() - lastAction > CONNECTION_TIMEOUT * 1000L) {
+ try {
+ ResultSet rs = adHoc.executeQuery("SELECT 1");
+ rs.close();
+ lastAction = System.currentTimeMillis();
+ return;
+ } catch (SQLException e) {
+ }
+ statements.clear();
+ tryConnect();
+ }
+ lastAction = System.currentTimeMillis();
+ }
+
+ private static volatile DatabaseConnection instance;
+
+ public static synchronized DatabaseConnection getInstance() {
+ if (instance == null) {
+ instance = new DatabaseConnection();
+ }
+ return instance;
+ }
+
+ public static boolean isInited() {
+ return credentials != null;
+ }
+
+ public static void init(Properties conf) {
+ if (credentials != null) {
+ throw new Error("Re-initiaizing is forbidden.");
+ }
+ credentials = conf;
+ int version = 0;
+ try (GigiPreparedStatement gigiPreparedStatement = new GigiPreparedStatement("SELECT version FROM \"schemeVersion\" ORDER BY version DESC LIMIT 1;")) {
+ GigiResultSet rs = gigiPreparedStatement.executeQuery();
+ if (rs.next()) {
+ version = rs.getInt(1);
+ }
+ }
+ if (version == CURRENT_SCHEMA_VERSION) {
+ return; // Good to go
+ }
+ if (version > CURRENT_SCHEMA_VERSION) {
+ throw new Error("Invalid database version. Please fix this.");
+ }
+ upgrade(version);
+ }
+
+ public void beginTransaction() throws SQLException {
+ c.setAutoCommit(false);
+ }
+
+ private static void upgrade(int version) {
+ try {
+ Statement s = getInstance().c.createStatement();
+ try {
+ while (version < CURRENT_SCHEMA_VERSION) {
+ try (InputStream resourceAsStream = DatabaseConnection.class.getResourceAsStream("upgrade/from_" + version + ".sql")) {
+ if (resourceAsStream == null) {
+ throw new Error("Upgrade script from version " + version + " was not found.");
+ }
+ SQLFileManager.addFile(s, resourceAsStream, ImportType.PRODUCTION);
+ }
+ version++;
+ }
+ s.addBatch("UPDATE \"schemeVersion\" SET version='" + version + "'");
+ System.out.println("UPGRADING Database to version " + version);
+ s.executeBatch();
+ System.out.println("done.");
+ } finally {
+ s.close();
+ }
+ } catch (SQLException e) {
+ e.printStackTrace();
+ } catch (IOException e) {
+ e.printStackTrace();
+ }
+ }
+
+ public void commitTransaction() throws SQLException {
+ c.commit();
+ c.setAutoCommit(true);
+ }
+
+ public void quitTransaction() {
+ try {
+ if ( !c.getAutoCommit()) {
+ c.rollback();
+ c.setAutoCommit(true);
+ }
+ } catch (SQLException e) {
+ e.printStackTrace();
+ }
+ }
+
+ public static final String preprocessQuery(String originalQuery) {
+ originalQuery = originalQuery.replace('`', '"');
+ if (originalQuery.matches("^INSERT INTO [^ ]+ SET .*")) {
+ Pattern p = Pattern.compile("INSERT INTO ([^ ]+) SET (.*)");
+ Matcher m = p.matcher(originalQuery);
+ if (m.matches()) {
+ String replacement = "INSERT INTO " + toIdentifier(m.group(1));
+ String[] parts = m.group(2).split(",");
+ StringJoiner columns = new StringJoiner(", ");
+ StringJoiner values = new StringJoiner(", ");
+ for (int i = 0; i < parts.length; i++) {
+ String[] split = parts[i].split("=", 2);
+ columns.add(toIdentifier(split[0]));
+ values.add(split[1]);
+ }
+ replacement += "(" + columns.toString() + ") VALUES(" + values.toString() + ")";
+ return replacement;
+ }
+ }
+
+ //
+ return originalQuery;
+ }
+
+ private static CharSequence toIdentifier(String ident) {
+ ident = ident.trim();
+ if ( !ident.startsWith("\"")) {
+ ident = "\"" + ident;
+ }
+ if ( !ident.endsWith("\"")) {
+ ident = ident + "\"";
+ }
+ return ident;
+ }
+
+ protected synchronized void returnStatement(PreparedStatement target) throws SQLException {
+ if ( !underUse.remove(target)) {
+ target.close();
+ }
+ }
+
+ public synchronized int getNumberOfLockedStatements() {
+ return underUse.size();
+ }
+
+ public synchronized void lockedStatements(PrintWriter writer) {
+ writer.println(underUse.size());
+ for (PreparedStatement ps : underUse) {
+ for (Entry<StatementDescriptor, PreparedStatement> e : statements.entrySet()) {
+ if (e.getValue() == ps) {
+ writer.println("<br/>");
+ writer.println(e.getKey().instance + ":");
+
+ writer.println(e.getKey().query);
+ }
+ }
+ }
+ }
}