upd: store different types of revocation
[gigi.git] / src / club / wpia / gigi / database / DatabaseConnection.java
1 package club.wpia.gigi.database;
2
3 import java.io.IOException;
4 import java.io.InputStream;
5 import java.io.PrintWriter;
6 import java.sql.Connection;
7 import java.sql.DriverManager;
8 import java.sql.PreparedStatement;
9 import java.sql.ResultSet;
10 import java.sql.SQLException;
11 import java.sql.Statement;
12 import java.util.HashMap;
13 import java.util.HashSet;
14 import java.util.Map.Entry;
15 import java.util.Properties;
16 import java.util.StringJoiner;
17 import java.util.concurrent.LinkedBlockingDeque;
18 import java.util.regex.Matcher;
19 import java.util.regex.Pattern;
20
21 import club.wpia.gigi.database.SQLFileManager.ImportType;
22
23 public class DatabaseConnection {
24
25     public static class Link implements AutoCloseable {
26
27         private DatabaseConnection target;
28
29         protected Link(DatabaseConnection target) {
30             this.target = target;
31         }
32
33         @Override
34         public void close() {
35             synchronized (DatabaseConnection.class) {
36                 Link i = instances.get(Thread.currentThread());
37                 if (i != this) {
38                     throw new Error();
39                 }
40                 instances.remove(Thread.currentThread());
41                 pool.add(target);
42             }
43         }
44
45     }
46
47     public static final int MAX_CACHED_INSTANCES = 3;
48
49     private static class StatementDescriptor {
50
51         String query;
52
53         boolean scrollable;
54
55         int instance;
56
57         PreparedStatement target;
58
59         public StatementDescriptor(String query, boolean scrollable) {
60             this.query = query;
61             this.scrollable = scrollable;
62             this.instance = 0;
63         }
64
65         public synchronized void instanciate(Connection c) throws SQLException {
66             if (scrollable) {
67                 target = c.prepareStatement(query, ResultSet.TYPE_SCROLL_INSENSITIVE, ResultSet.CONCUR_READ_ONLY);
68             } else {
69                 target = c.prepareStatement(query, query.startsWith("SELECT ") ? Statement.NO_GENERATED_KEYS : Statement.RETURN_GENERATED_KEYS);
70             }
71
72         }
73
74         public synchronized PreparedStatement getTarget() {
75             return target;
76         }
77
78         public synchronized void increase() {
79             if (target != null) {
80                 throw new IllegalStateException();
81             }
82             instance++;
83         }
84
85         @Override
86         public int hashCode() {
87             final int prime = 31;
88             int result = 1;
89             result = prime * result + instance;
90             result = prime * result + ((query == null) ? 0 : query.hashCode());
91             result = prime * result + (scrollable ? 1231 : 1237);
92             return result;
93         }
94
95         @Override
96         public boolean equals(Object obj) {
97             if (this == obj) {
98                 return true;
99             }
100             if (obj == null) {
101                 return false;
102             }
103             if (getClass() != obj.getClass()) {
104                 return false;
105             }
106             StatementDescriptor other = (StatementDescriptor) obj;
107             if (instance != other.instance) {
108                 return false;
109             }
110             if (query == null) {
111                 if (other.query != null) {
112                     return false;
113                 }
114             } else if ( !query.equals(other.query)) {
115                 return false;
116             }
117             if (scrollable != other.scrollable) {
118                 return false;
119             }
120             return true;
121         }
122
123     }
124
125     public static final int CURRENT_SCHEMA_VERSION = 29;
126
127     public static final int CONNECTION_TIMEOUT = 24 * 60 * 60;
128
129     private Connection c;
130
131     private HashMap<StatementDescriptor, PreparedStatement> statements = new HashMap<StatementDescriptor, PreparedStatement>();
132
133     private HashSet<PreparedStatement> underUse = new HashSet<>();
134
135     private static Properties credentials;
136
137     private Statement adHoc;
138
139     public DatabaseConnection() {
140         try {
141             Class.forName(credentials.getProperty("sql.driver"));
142         } catch (ClassNotFoundException e) {
143             e.printStackTrace();
144         }
145         tryConnect();
146
147     }
148
149     private void tryConnect() {
150         try {
151             c = DriverManager.getConnection(credentials.getProperty("sql.url") + "?socketTimeout=" + CONNECTION_TIMEOUT, credentials.getProperty("sql.user"), credentials.getProperty("sql.password"));
152             adHoc = c.createStatement();
153         } catch (SQLException e) {
154             e.printStackTrace();
155         }
156     }
157
158     protected synchronized PreparedStatement prepareInternal(String query) throws SQLException {
159         return prepareInternal(query, false);
160     }
161
162     protected synchronized PreparedStatement prepareInternal(String query, boolean scrollable) throws SQLException {
163
164         ensureOpen();
165         query = preprocessQuery(query);
166         StatementDescriptor searchHead = new StatementDescriptor(query, scrollable);
167         PreparedStatement statement = null;
168         while (statement == null) {
169             statement = statements.get(searchHead);
170             if (statement == null) {
171                 searchHead.instanciate(c);
172                 statement = searchHead.getTarget();
173                 if (searchHead.instance >= MAX_CACHED_INSTANCES) {
174                     return statement;
175                 }
176                 underUse.add(statement);
177                 statements.put(searchHead, statement);
178             } else if (underUse.contains(statement)) {
179                 searchHead.increase();
180                 statement = null;
181             } else {
182                 underUse.add(statement);
183             }
184         }
185         return statement;
186     }
187
188     protected synchronized PreparedStatement prepareInternalScrollable(String query) throws SQLException {
189         return prepareInternal(query, true);
190     }
191
192     private long lastAction = System.currentTimeMillis();
193
194     private void ensureOpen() {
195         if (System.currentTimeMillis() - lastAction > CONNECTION_TIMEOUT * 1000L) {
196             try {
197                 ResultSet rs = adHoc.executeQuery("SELECT 1");
198                 rs.close();
199                 lastAction = System.currentTimeMillis();
200                 return;
201             } catch (SQLException e) {
202             }
203             statements.clear();
204             tryConnect();
205         }
206         lastAction = System.currentTimeMillis();
207     }
208
209     private static HashMap<Thread, Link> instances = new HashMap<>();
210
211     private static LinkedBlockingDeque<DatabaseConnection> pool = new LinkedBlockingDeque<>();
212
213     private static int connCount = 0;
214
215     public static synchronized DatabaseConnection getInstance() {
216         Link l = instances.get(Thread.currentThread());
217         if (l == null) {
218             throw new Error("No database connection allocated");
219         }
220         return l.target;
221     }
222
223     public static synchronized boolean hasInstance() {
224         Link l = instances.get(Thread.currentThread());
225         return l != null;
226     }
227
228     public static boolean isInited() {
229         return credentials != null;
230     }
231
232     public static void init(Properties conf) {
233         if (credentials != null) {
234             throw new Error("Re-initiaizing is forbidden.");
235         }
236         credentials = conf;
237         try (Link i = newLink(false)) {
238             try (GigiPreparedStatement empty = new GigiPreparedStatement("SELECT * from information_schema.tables WHERE table_schema='public' AND table_name='schemeVersion'")) {
239                 if ( !empty.executeQuery().next()) {
240                     try (InputStream resourceAsStream = DatabaseConnection.class.getResourceAsStream("tableStructure.sql")) {
241                         if (resourceAsStream == null) {
242                             throw new Error("DB-Install-Script not found.");
243                         }
244                         try (Statement s = getInstance().c.createStatement()) {
245                             SQLFileManager.addFile(s, resourceAsStream, ImportType.PRODUCTION);
246                             s.executeBatch();
247                         }
248                     }
249                     return;
250                 }
251             } catch (IOException e) {
252                 throw new Error(e);
253             } catch (SQLException e) {
254                 throw new Error(e);
255             }
256             int version = 0;
257             try (GigiPreparedStatement gigiPreparedStatement = new GigiPreparedStatement("SELECT version FROM \"schemeVersion\" ORDER BY version DESC LIMIT 1;")) {
258                 GigiResultSet rs = gigiPreparedStatement.executeQuery();
259                 if (rs.next()) {
260                     version = rs.getInt(1);
261                 }
262             }
263             if (version == CURRENT_SCHEMA_VERSION) {
264                 return; // Good to go
265             }
266             if (version > CURRENT_SCHEMA_VERSION) {
267                 throw new Error("Invalid database version. Please fix this.");
268             }
269             upgrade(version);
270         } catch (InterruptedException e) {
271             throw new Error(e);
272         }
273     }
274
275     private static void upgrade(int version) {
276         try {
277             try (Statement s = getInstance().c.createStatement()) {
278                 while (version < CURRENT_SCHEMA_VERSION) {
279                     addUpgradeScript(Integer.toString(version), s);
280                     version++;
281                 }
282                 s.addBatch("UPDATE \"schemeVersion\" SET version='" + version + "'");
283                 System.out.println("UPGRADING Database to version " + version);
284                 s.executeBatch();
285                 System.out.println("done.");
286             }
287         } catch (SQLException e) {
288             e.printStackTrace();
289         } catch (IOException e) {
290             e.printStackTrace();
291         }
292     }
293
294     private static void addUpgradeScript(String version, Statement s) throws Error, IOException, SQLException {
295         try (InputStream resourceAsStream = DatabaseConnection.class.getResourceAsStream("upgrade/from_" + version + ".sql")) {
296             if (resourceAsStream == null) {
297                 throw new Error("Upgrade script from version " + version + " was not found.");
298             }
299             SQLFileManager.addFile(s, resourceAsStream, ImportType.PRODUCTION);
300         }
301     }
302
303     public static final String preprocessQuery(String originalQuery) {
304         originalQuery = originalQuery.replace('`', '"');
305         if (originalQuery.matches("^INSERT INTO [^ ]+ SET .*")) {
306             Pattern p = Pattern.compile("INSERT INTO ([^ ]+) SET (.*)");
307             Matcher m = p.matcher(originalQuery);
308             if (m.matches()) {
309                 String replacement = "INSERT INTO " + toIdentifier(m.group(1));
310                 String[] parts = m.group(2).split(",");
311                 StringJoiner columns = new StringJoiner(", ");
312                 StringJoiner values = new StringJoiner(", ");
313                 for (int i = 0; i < parts.length; i++) {
314                     String[] split = parts[i].split("=", 2);
315                     columns.add(toIdentifier(split[0]));
316                     values.add(split[1]);
317                 }
318                 replacement += "(" + columns.toString() + ") VALUES(" + values.toString() + ")";
319                 return replacement;
320             }
321         }
322
323         //
324         return originalQuery;
325     }
326
327     private static CharSequence toIdentifier(String ident) {
328         ident = ident.trim();
329         if ( !ident.startsWith("\"")) {
330             ident = "\"" + ident;
331         }
332         if ( !ident.endsWith("\"")) {
333             ident = ident + "\"";
334         }
335         return ident;
336     }
337
338     protected synchronized void returnStatement(PreparedStatement target) throws SQLException {
339         if ( !underUse.remove(target)) {
340             target.close();
341         }
342     }
343
344     public synchronized int getNumberOfLockedStatements() {
345         return underUse.size();
346     }
347
348     public synchronized void lockedStatements(PrintWriter writer) {
349         writer.println(underUse.size());
350         for (PreparedStatement ps : underUse) {
351             for (Entry<StatementDescriptor, PreparedStatement> e : statements.entrySet()) {
352                 if (e.getValue() == ps) {
353                     writer.println("<br/>");
354                     writer.println(e.getKey().instance + ":");
355
356                     writer.println(e.getKey().query);
357                 }
358             }
359         }
360     }
361
362     public static Link newLink(boolean readOnly) throws InterruptedException {
363         synchronized (DatabaseConnection.class) {
364
365             if (instances.get(Thread.currentThread()) != null) {
366                 throw new Error("There is already a connection allocated for this thread.");
367             }
368             if (pool.isEmpty() && connCount < 5) {
369                 pool.addLast(new DatabaseConnection());
370                 connCount++;
371             }
372         }
373         DatabaseConnection conn = pool.takeFirst();
374         synchronized (DatabaseConnection.class) {
375             try {
376                 conn.c.setReadOnly(readOnly);
377             } catch (SQLException e) {
378                 throw new Error(e);
379             }
380             Link l = new Link(conn);
381             instances.put(Thread.currentThread(), l);
382             return l;
383         }
384
385     }
386 }