]> WPIA git - gigi.git/blob - src/org/cacert/gigi/database/DatabaseConnection.java
add: allow un-cached SQL queries
[gigi.git] / src / org / cacert / gigi / database / DatabaseConnection.java
1 package org.cacert.gigi.database;
2
3 import java.io.IOException;
4 import java.io.InputStream;
5 import java.io.PrintWriter;
6 import java.sql.Connection;
7 import java.sql.DriverManager;
8 import java.sql.PreparedStatement;
9 import java.sql.ResultSet;
10 import java.sql.SQLException;
11 import java.sql.Statement;
12 import java.util.HashMap;
13 import java.util.HashSet;
14 import java.util.Map.Entry;
15 import java.util.Properties;
16 import java.util.StringJoiner;
17 import java.util.regex.Matcher;
18 import java.util.regex.Pattern;
19
20 import org.cacert.gigi.database.SQLFileManager.ImportType;
21
22 public class DatabaseConnection {
23
24     public static final int MAX_CACHED_INSTANCES = 3;
25
26     private static class StatementDescriptor {
27
28         String query;
29
30         boolean scrollable;
31
32         int instance;
33
34         PreparedStatement target;
35
36         public StatementDescriptor(String query, boolean scrollable) {
37             this.query = query;
38             this.scrollable = scrollable;
39             this.instance = 0;
40         }
41
42         public synchronized void instanciate(Connection c) throws SQLException {
43             if (scrollable) {
44                 target = c.prepareStatement(query, ResultSet.TYPE_SCROLL_INSENSITIVE, ResultSet.CONCUR_READ_ONLY);
45             } else {
46                 target = c.prepareStatement(query, query.startsWith("SELECT ") ? Statement.NO_GENERATED_KEYS : Statement.RETURN_GENERATED_KEYS);
47             }
48
49         }
50
51         public PreparedStatement getTarget() {
52             return target;
53         }
54
55         public synchronized void increase() {
56             if (target != null) {
57                 throw new IllegalStateException();
58             }
59             instance++;
60         }
61
62         @Override
63         public int hashCode() {
64             final int prime = 31;
65             int result = 1;
66             result = prime * result + instance;
67             result = prime * result + ((query == null) ? 0 : query.hashCode());
68             result = prime * result + (scrollable ? 1231 : 1237);
69             return result;
70         }
71
72         @Override
73         public boolean equals(Object obj) {
74             if (this == obj) {
75                 return true;
76             }
77             if (obj == null) {
78                 return false;
79             }
80             if (getClass() != obj.getClass()) {
81                 return false;
82             }
83             StatementDescriptor other = (StatementDescriptor) obj;
84             if (instance != other.instance) {
85                 return false;
86             }
87             if (query == null) {
88                 if (other.query != null) {
89                     return false;
90                 }
91             } else if ( !query.equals(other.query)) {
92                 return false;
93             }
94             if (scrollable != other.scrollable) {
95                 return false;
96             }
97             return true;
98         }
99
100     }
101
102     public static final int CURRENT_SCHEMA_VERSION = 6;
103
104     public static final int CONNECTION_TIMEOUT = 24 * 60 * 60;
105
106     private Connection c;
107
108     private HashMap<StatementDescriptor, PreparedStatement> statements = new HashMap<StatementDescriptor, PreparedStatement>();
109
110     HashSet<PreparedStatement> underUse = new HashSet<>();
111
112     private static Properties credentials;
113
114     private Statement adHoc;
115
116     public DatabaseConnection() {
117         try {
118             Class.forName(credentials.getProperty("sql.driver"));
119         } catch (ClassNotFoundException e) {
120             e.printStackTrace();
121         }
122         tryConnect();
123
124     }
125
126     private void tryConnect() {
127         try {
128             c = DriverManager.getConnection(credentials.getProperty("sql.url") + "?socketTimeout=" + CONNECTION_TIMEOUT, credentials.getProperty("sql.user"), credentials.getProperty("sql.password"));
129             adHoc = c.createStatement();
130         } catch (SQLException e) {
131             e.printStackTrace();
132         }
133     }
134
135     protected synchronized PreparedStatement prepareInternal(String query) throws SQLException {
136         return prepareInternal(query, false);
137     }
138
139     protected synchronized PreparedStatement prepareInternal(String query, boolean scrollable) throws SQLException {
140
141         ensureOpen();
142         query = preprocessQuery(query);
143         StatementDescriptor searchHead = new StatementDescriptor(query, scrollable);
144         PreparedStatement statement = null;
145         while (statement == null) {
146             statement = statements.get(searchHead);
147             if (statement == null) {
148                 searchHead.instanciate(c);
149                 statement = searchHead.getTarget();
150                 if (searchHead.instance >= MAX_CACHED_INSTANCES) {
151                     return statement;
152                 }
153                 underUse.add(statement);
154                 statements.put(searchHead, statement);
155             } else if (underUse.contains(statement)) {
156                 searchHead.increase();
157                 statement = null;
158             } else {
159                 underUse.add(statement);
160             }
161         }
162         return statement;
163     }
164
165     protected synchronized PreparedStatement prepareInternalScrollable(String query) throws SQLException {
166         return prepareInternal(query, true);
167     }
168
169     private long lastAction = System.currentTimeMillis();
170
171     private void ensureOpen() {
172         if (System.currentTimeMillis() - lastAction > CONNECTION_TIMEOUT * 1000L) {
173             try {
174                 ResultSet rs = adHoc.executeQuery("SELECT 1");
175                 rs.close();
176                 lastAction = System.currentTimeMillis();
177                 return;
178             } catch (SQLException e) {
179             }
180             statements.clear();
181             tryConnect();
182         }
183         lastAction = System.currentTimeMillis();
184     }
185
186     private static DatabaseConnection instance;
187
188     public static DatabaseConnection getInstance() {
189         if (instance == null) {
190             synchronized (DatabaseConnection.class) {
191                 if (instance == null) {
192                     instance = new DatabaseConnection();
193                 }
194             }
195         }
196         return instance;
197     }
198
199     public static boolean isInited() {
200         return credentials != null;
201     }
202
203     public static void init(Properties conf) {
204         if (credentials != null) {
205             throw new Error("Re-initiaizing is forbidden.");
206         }
207         credentials = conf;
208         int version = 0;
209         try (GigiPreparedStatement gigiPreparedStatement = new GigiPreparedStatement("SELECT version FROM \"schemeVersion\" ORDER BY version DESC LIMIT 1;")) {
210             GigiResultSet rs = gigiPreparedStatement.executeQuery();
211             if (rs.next()) {
212                 version = rs.getInt(1);
213             }
214         }
215         if (version == CURRENT_SCHEMA_VERSION) {
216             return; // Good to go
217         }
218         if (version > CURRENT_SCHEMA_VERSION) {
219             throw new Error("Invalid database version. Please fix this.");
220         }
221         upgrade(version);
222     }
223
224     public void beginTransaction() throws SQLException {
225         c.setAutoCommit(false);
226     }
227
228     private static void upgrade(int version) {
229         try {
230             Statement s = getInstance().c.createStatement();
231             try {
232                 while (version < CURRENT_SCHEMA_VERSION) {
233                     try (InputStream resourceAsStream = DatabaseConnection.class.getResourceAsStream("upgrade/from_" + version + ".sql")) {
234                         if (resourceAsStream == null) {
235                             throw new Error("Upgrade script from version " + version + " was not found.");
236                         }
237                         SQLFileManager.addFile(s, resourceAsStream, ImportType.PRODUCTION);
238                     }
239                     version++;
240                 }
241                 s.addBatch("UPDATE \"schemeVersion\" SET version='" + version + "'");
242                 System.out.println("UPGRADING Database to version " + version);
243                 s.executeBatch();
244                 System.out.println("done.");
245             } finally {
246                 s.close();
247             }
248         } catch (SQLException e) {
249             e.printStackTrace();
250         } catch (IOException e) {
251             e.printStackTrace();
252         }
253     }
254
255     public void commitTransaction() throws SQLException {
256         c.commit();
257         c.setAutoCommit(true);
258     }
259
260     public void quitTransaction() {
261         try {
262             if ( !c.getAutoCommit()) {
263                 c.rollback();
264                 c.setAutoCommit(true);
265             }
266         } catch (SQLException e) {
267             e.printStackTrace();
268         }
269     }
270
271     public static final String preprocessQuery(String originalQuery) {
272         originalQuery = originalQuery.replace('`', '"');
273         if (originalQuery.matches("^INSERT INTO [^ ]+ SET .*")) {
274             Pattern p = Pattern.compile("INSERT INTO ([^ ]+) SET (.*)");
275             Matcher m = p.matcher(originalQuery);
276             if (m.matches()) {
277                 String replacement = "INSERT INTO " + toIdentifier(m.group(1));
278                 String[] parts = m.group(2).split(",");
279                 StringJoiner columns = new StringJoiner(", ");
280                 StringJoiner values = new StringJoiner(", ");
281                 for (int i = 0; i < parts.length; i++) {
282                     String[] split = parts[i].split("=", 2);
283                     columns.add(toIdentifier(split[0]));
284                     values.add(split[1]);
285                 }
286                 replacement += "(" + columns.toString() + ") VALUES(" + values.toString() + ")";
287                 return replacement;
288             }
289         }
290
291         //
292         return originalQuery;
293     }
294
295     private static CharSequence toIdentifier(String ident) {
296         ident = ident.trim();
297         if ( !ident.startsWith("\"")) {
298             ident = "\"" + ident;
299         }
300         if ( !ident.endsWith("\"")) {
301             ident = ident + "\"";
302         }
303         return ident;
304     }
305
306     protected synchronized void returnStatement(PreparedStatement target) throws SQLException {
307         if ( !underUse.remove(target)) {
308             target.close();
309         }
310     }
311
312     public synchronized int getNumberOfLockedStatements() {
313         return underUse.size();
314     }
315
316     public void lockedStatements(PrintWriter writer) {
317         writer.println(underUse.size());
318         for (PreparedStatement ps : underUse) {
319             for (Entry<StatementDescriptor, PreparedStatement> e : statements.entrySet()) {
320                 if (e.getValue() == ps) {
321                     writer.println("<br/>");
322                     writer.println(e.getKey().instance + ":");
323
324                     writer.println(e.getKey().query);
325                 }
326             }
327         }
328     }
329 }