]> WPIA git - gigi.git/blob - src/org/cacert/gigi/database/DatabaseConnection.java
fix: SQL change database call pattern
[gigi.git] / src / org / cacert / gigi / database / DatabaseConnection.java
1 package org.cacert.gigi.database;
2
3 import java.io.IOException;
4 import java.io.InputStream;
5 import java.io.PrintWriter;
6 import java.sql.Connection;
7 import java.sql.DriverManager;
8 import java.sql.PreparedStatement;
9 import java.sql.ResultSet;
10 import java.sql.SQLException;
11 import java.sql.Statement;
12 import java.util.HashMap;
13 import java.util.HashSet;
14 import java.util.Map.Entry;
15 import java.util.Properties;
16 import java.util.StringJoiner;
17 import java.util.regex.Matcher;
18 import java.util.regex.Pattern;
19
20 import org.cacert.gigi.database.SQLFileManager.ImportType;
21
22 public class DatabaseConnection {
23
24     public static final int CURRENT_SCHEMA_VERSION = 6;
25
26     public static final int CONNECTION_TIMEOUT = 24 * 60 * 60;
27
28     private Connection c;
29
30     private HashMap<String, PreparedStatement> statements = new HashMap<String, PreparedStatement>();
31
32     HashSet<PreparedStatement> underUse = new HashSet<>();
33
34     private static Properties credentials;
35
36     private Statement adHoc;
37
38     public DatabaseConnection() {
39         try {
40             Class.forName(credentials.getProperty("sql.driver"));
41         } catch (ClassNotFoundException e) {
42             e.printStackTrace();
43         }
44         tryConnect();
45
46     }
47
48     private void tryConnect() {
49         try {
50             c = DriverManager.getConnection(credentials.getProperty("sql.url") + "?socketTimeout=" + CONNECTION_TIMEOUT, credentials.getProperty("sql.user"), credentials.getProperty("sql.password"));
51             adHoc = c.createStatement();
52         } catch (SQLException e) {
53             e.printStackTrace();
54         }
55     }
56
57     protected synchronized PreparedStatement prepareInternal(String query) throws SQLException {
58         ensureOpen();
59         query = preprocessQuery(query);
60         PreparedStatement statement = statements.get(query);
61         if (statement != null) {
62             if (underUse.add(statement)) {
63                 return statement;
64             } else {
65                 throw new Error("Statement in Use");
66             }
67         }
68         statement = c.prepareStatement(query, query.startsWith("SELECT ") ? Statement.NO_GENERATED_KEYS : Statement.RETURN_GENERATED_KEYS);
69         statements.put(query, statement);
70         if (underUse.add(statement)) {
71             return statement;
72         } else {
73             throw new Error("Statement in Use");
74         }
75     }
76
77     protected synchronized PreparedStatement prepareInternalScrollable(String query) throws SQLException {
78         ensureOpen();
79         query = preprocessQuery(query);
80         PreparedStatement statement = statements.get("__SCROLLABLE__! " + query);
81         if (statement != null) {
82             if (underUse.add(statement)) {
83                 return statement;
84             } else {
85                 throw new Error("Statement in Use");
86             }
87         }
88         statement = c.prepareStatement(query, ResultSet.TYPE_SCROLL_INSENSITIVE, ResultSet.CONCUR_READ_ONLY);
89         statements.put("__SCROLLABLE__! " + query, statement);
90         if (underUse.add(statement)) {
91             return statement;
92         } else {
93             throw new Error("Statement in Use");
94         }
95     }
96
97     private long lastAction = System.currentTimeMillis();
98
99     private void ensureOpen() {
100         if (System.currentTimeMillis() - lastAction > CONNECTION_TIMEOUT * 1000L) {
101             try {
102                 ResultSet rs = adHoc.executeQuery("SELECT 1");
103                 rs.close();
104                 lastAction = System.currentTimeMillis();
105                 return;
106             } catch (SQLException e) {
107             }
108             statements.clear();
109             tryConnect();
110         }
111         lastAction = System.currentTimeMillis();
112     }
113
114     private static DatabaseConnection instance;
115
116     public static DatabaseConnection getInstance() {
117         if (instance == null) {
118             synchronized (DatabaseConnection.class) {
119                 if (instance == null) {
120                     instance = new DatabaseConnection();
121                 }
122             }
123         }
124         return instance;
125     }
126
127     public static boolean isInited() {
128         return credentials != null;
129     }
130
131     public static void init(Properties conf) {
132         if (credentials != null) {
133             throw new Error("Re-initiaizing is forbidden.");
134         }
135         credentials = conf;
136         int version = 0;
137         try (GigiPreparedStatement gigiPreparedStatement = new GigiPreparedStatement("SELECT version FROM \"schemeVersion\" ORDER BY version DESC LIMIT 1;")) {
138             GigiResultSet rs = gigiPreparedStatement.executeQuery();
139             if (rs.next()) {
140                 version = rs.getInt(1);
141             }
142         }
143         if (version == CURRENT_SCHEMA_VERSION) {
144             return; // Good to go
145         }
146         if (version > CURRENT_SCHEMA_VERSION) {
147             throw new Error("Invalid database version. Please fix this.");
148         }
149         upgrade(version);
150     }
151
152     public void beginTransaction() throws SQLException {
153         c.setAutoCommit(false);
154     }
155
156     private static void upgrade(int version) {
157         try {
158             Statement s = getInstance().c.createStatement();
159             try {
160                 while (version < CURRENT_SCHEMA_VERSION) {
161                     try (InputStream resourceAsStream = DatabaseConnection.class.getResourceAsStream("upgrade/from_" + version + ".sql")) {
162                         if (resourceAsStream == null) {
163                             throw new Error("Upgrade script from version " + version + " was not found.");
164                         }
165                         SQLFileManager.addFile(s, resourceAsStream, ImportType.PRODUCTION);
166                     }
167                     version++;
168                 }
169                 s.addBatch("UPDATE \"schemeVersion\" SET version='" + version + "'");
170                 System.out.println("UPGRADING Database to version " + version);
171                 s.executeBatch();
172                 System.out.println("done.");
173             } finally {
174                 s.close();
175             }
176         } catch (SQLException e) {
177             e.printStackTrace();
178         } catch (IOException e) {
179             e.printStackTrace();
180         }
181     }
182
183     public void commitTransaction() throws SQLException {
184         c.commit();
185         c.setAutoCommit(true);
186     }
187
188     public void quitTransaction() {
189         try {
190             if ( !c.getAutoCommit()) {
191                 c.rollback();
192                 c.setAutoCommit(true);
193             }
194         } catch (SQLException e) {
195             e.printStackTrace();
196         }
197     }
198
199     public static final String preprocessQuery(String originalQuery) {
200         originalQuery = originalQuery.replace('`', '"');
201         if (originalQuery.matches("^INSERT INTO [^ ]+ SET .*")) {
202             Pattern p = Pattern.compile("INSERT INTO ([^ ]+) SET (.*)");
203             Matcher m = p.matcher(originalQuery);
204             if (m.matches()) {
205                 String replacement = "INSERT INTO " + toIdentifier(m.group(1));
206                 String[] parts = m.group(2).split(",");
207                 StringJoiner columns = new StringJoiner(", ");
208                 StringJoiner values = new StringJoiner(", ");
209                 for (int i = 0; i < parts.length; i++) {
210                     String[] split = parts[i].split("=", 2);
211                     columns.add(toIdentifier(split[0]));
212                     values.add(split[1]);
213                 }
214                 replacement += "(" + columns.toString() + ") VALUES(" + values.toString() + ")";
215                 return replacement;
216             }
217         }
218
219         //
220         return originalQuery;
221     }
222
223     private static CharSequence toIdentifier(String ident) {
224         ident = ident.trim();
225         if ( !ident.startsWith("\"")) {
226             ident = "\"" + ident;
227         }
228         if ( !ident.endsWith("\"")) {
229             ident = ident + "\"";
230         }
231         return ident;
232     }
233
234     protected synchronized void returnStatement(PreparedStatement target) {
235         underUse.remove(target);
236     }
237
238     public void lockedStatements(PrintWriter writer) {
239         writer.println(underUse.size());
240         for (PreparedStatement ps : underUse) {
241             for (Entry<String, PreparedStatement> e : statements.entrySet()) {
242                 if (e.getValue() == ps) {
243                     writer.println("<br/>");
244                     writer.println(e.getKey());
245                 }
246             }
247         }
248     }
249 }