1 package org.cacert.gigi.database;
3 import java.io.IOException;
4 import java.io.InputStream;
5 import java.io.PrintWriter;
6 import java.sql.Connection;
7 import java.sql.DriverManager;
8 import java.sql.PreparedStatement;
9 import java.sql.ResultSet;
10 import java.sql.SQLException;
11 import java.sql.Statement;
12 import java.util.HashMap;
13 import java.util.HashSet;
14 import java.util.Map.Entry;
15 import java.util.Properties;
16 import java.util.StringJoiner;
17 import java.util.regex.Matcher;
18 import java.util.regex.Pattern;
20 import org.cacert.gigi.database.SQLFileManager.ImportType;
22 public class DatabaseConnection {
24 public static final int MAX_CACHED_INSTANCES = 3;
26 private static class StatementDescriptor {
34 PreparedStatement target;
36 public StatementDescriptor(String query, boolean scrollable) {
38 this.scrollable = scrollable;
42 public synchronized void instanciate(Connection c) throws SQLException {
44 target = c.prepareStatement(query, ResultSet.TYPE_SCROLL_INSENSITIVE, ResultSet.CONCUR_READ_ONLY);
46 target = c.prepareStatement(query, query.startsWith("SELECT ") ? Statement.NO_GENERATED_KEYS : Statement.RETURN_GENERATED_KEYS);
51 public PreparedStatement getTarget() {
55 public synchronized void increase() {
57 throw new IllegalStateException();
63 public int hashCode() {
66 result = prime * result + instance;
67 result = prime * result + ((query == null) ? 0 : query.hashCode());
68 result = prime * result + (scrollable ? 1231 : 1237);
73 public boolean equals(Object obj) {
80 if (getClass() != obj.getClass()) {
83 StatementDescriptor other = (StatementDescriptor) obj;
84 if (instance != other.instance) {
88 if (other.query != null) {
91 } else if ( !query.equals(other.query)) {
94 if (scrollable != other.scrollable) {
102 public static final int CURRENT_SCHEMA_VERSION = 6;
104 public static final int CONNECTION_TIMEOUT = 24 * 60 * 60;
106 private Connection c;
108 private HashMap<StatementDescriptor, PreparedStatement> statements = new HashMap<StatementDescriptor, PreparedStatement>();
110 HashSet<PreparedStatement> underUse = new HashSet<>();
112 private static Properties credentials;
114 private Statement adHoc;
116 public DatabaseConnection() {
118 Class.forName(credentials.getProperty("sql.driver"));
119 } catch (ClassNotFoundException e) {
126 private void tryConnect() {
128 c = DriverManager.getConnection(credentials.getProperty("sql.url") + "?socketTimeout=" + CONNECTION_TIMEOUT, credentials.getProperty("sql.user"), credentials.getProperty("sql.password"));
129 adHoc = c.createStatement();
130 } catch (SQLException e) {
135 protected synchronized PreparedStatement prepareInternal(String query) throws SQLException {
136 return prepareInternal(query, false);
139 protected synchronized PreparedStatement prepareInternal(String query, boolean scrollable) throws SQLException {
142 query = preprocessQuery(query);
143 StatementDescriptor searchHead = new StatementDescriptor(query, scrollable);
144 PreparedStatement statement = null;
145 while (statement == null) {
146 statement = statements.get(searchHead);
147 if (statement == null) {
148 searchHead.instanciate(c);
149 statement = searchHead.getTarget();
150 if (searchHead.instance >= MAX_CACHED_INSTANCES) {
153 underUse.add(statement);
154 statements.put(searchHead, statement);
155 } else if (underUse.contains(statement)) {
156 searchHead.increase();
159 underUse.add(statement);
165 protected synchronized PreparedStatement prepareInternalScrollable(String query) throws SQLException {
166 return prepareInternal(query, true);
169 private long lastAction = System.currentTimeMillis();
171 private void ensureOpen() {
172 if (System.currentTimeMillis() - lastAction > CONNECTION_TIMEOUT * 1000L) {
174 ResultSet rs = adHoc.executeQuery("SELECT 1");
176 lastAction = System.currentTimeMillis();
178 } catch (SQLException e) {
183 lastAction = System.currentTimeMillis();
186 private static DatabaseConnection instance;
188 public static DatabaseConnection getInstance() {
189 if (instance == null) {
190 synchronized (DatabaseConnection.class) {
191 if (instance == null) {
192 instance = new DatabaseConnection();
199 public static boolean isInited() {
200 return credentials != null;
203 public static void init(Properties conf) {
204 if (credentials != null) {
205 throw new Error("Re-initiaizing is forbidden.");
209 try (GigiPreparedStatement gigiPreparedStatement = new GigiPreparedStatement("SELECT version FROM \"schemeVersion\" ORDER BY version DESC LIMIT 1;")) {
210 GigiResultSet rs = gigiPreparedStatement.executeQuery();
212 version = rs.getInt(1);
215 if (version == CURRENT_SCHEMA_VERSION) {
216 return; // Good to go
218 if (version > CURRENT_SCHEMA_VERSION) {
219 throw new Error("Invalid database version. Please fix this.");
224 public void beginTransaction() throws SQLException {
225 c.setAutoCommit(false);
228 private static void upgrade(int version) {
230 Statement s = getInstance().c.createStatement();
232 while (version < CURRENT_SCHEMA_VERSION) {
233 try (InputStream resourceAsStream = DatabaseConnection.class.getResourceAsStream("upgrade/from_" + version + ".sql")) {
234 if (resourceAsStream == null) {
235 throw new Error("Upgrade script from version " + version + " was not found.");
237 SQLFileManager.addFile(s, resourceAsStream, ImportType.PRODUCTION);
241 s.addBatch("UPDATE \"schemeVersion\" SET version='" + version + "'");
242 System.out.println("UPGRADING Database to version " + version);
244 System.out.println("done.");
248 } catch (SQLException e) {
250 } catch (IOException e) {
255 public void commitTransaction() throws SQLException {
257 c.setAutoCommit(true);
260 public void quitTransaction() {
262 if ( !c.getAutoCommit()) {
264 c.setAutoCommit(true);
266 } catch (SQLException e) {
271 public static final String preprocessQuery(String originalQuery) {
272 originalQuery = originalQuery.replace('`', '"');
273 if (originalQuery.matches("^INSERT INTO [^ ]+ SET .*")) {
274 Pattern p = Pattern.compile("INSERT INTO ([^ ]+) SET (.*)");
275 Matcher m = p.matcher(originalQuery);
277 String replacement = "INSERT INTO " + toIdentifier(m.group(1));
278 String[] parts = m.group(2).split(",");
279 StringJoiner columns = new StringJoiner(", ");
280 StringJoiner values = new StringJoiner(", ");
281 for (int i = 0; i < parts.length; i++) {
282 String[] split = parts[i].split("=", 2);
283 columns.add(toIdentifier(split[0]));
284 values.add(split[1]);
286 replacement += "(" + columns.toString() + ") VALUES(" + values.toString() + ")";
292 return originalQuery;
295 private static CharSequence toIdentifier(String ident) {
296 ident = ident.trim();
297 if ( !ident.startsWith("\"")) {
298 ident = "\"" + ident;
300 if ( !ident.endsWith("\"")) {
301 ident = ident + "\"";
306 protected synchronized void returnStatement(PreparedStatement target) throws SQLException {
307 if ( !underUse.remove(target)) {
312 public synchronized int getNumberOfLockedStatements() {
313 return underUse.size();
316 public void lockedStatements(PrintWriter writer) {
317 writer.println(underUse.size());
318 for (PreparedStatement ps : underUse) {
319 for (Entry<StatementDescriptor, PreparedStatement> e : statements.entrySet()) {
320 if (e.getValue() == ps) {
321 writer.println("<br/>");
322 writer.println(e.getKey().instance + ":");
324 writer.println(e.getKey().query);