b5e78eaaffa902764e0a29f938d3744848511b67
[gigi.git] / src / club / wpia / gigi / database / DatabaseConnection.java
1 package club.wpia.gigi.database;
2
3 import java.io.File;
4 import java.io.FileInputStream;
5 import java.io.IOException;
6 import java.io.InputStream;
7 import java.io.InputStreamReader;
8 import java.io.PrintWriter;
9 import java.io.Reader;
10 import java.sql.Connection;
11 import java.sql.DriverManager;
12 import java.sql.PreparedStatement;
13 import java.sql.ResultSet;
14 import java.sql.SQLException;
15 import java.sql.Statement;
16 import java.util.HashMap;
17 import java.util.HashSet;
18 import java.util.Map.Entry;
19 import java.util.Properties;
20 import java.util.StringJoiner;
21 import java.util.concurrent.LinkedBlockingDeque;
22 import java.util.regex.Matcher;
23 import java.util.regex.Pattern;
24
25 import club.wpia.gigi.database.SQLFileManager.ImportType;
26 import club.wpia.gigi.dbObjects.Certificate;
27 import club.wpia.gigi.dbObjects.Certificate.AttachmentType;
28
29 public class DatabaseConnection {
30
31     public static final class Upgrade32 {
32
33         public static void execute() throws IOException {
34             // "csr_name" varchar(255) NOT NULL DEFAULT '',
35             // "csr_type" "csrType" NOT NULL,
36             // "crt_name" varchar(255) NOT NULL DEFAULT '',
37             try (GigiPreparedStatement ps = new GigiPreparedStatement("SELECT `id`, `csr_name`, `crt_name` FROM `certs`")) {
38                 GigiResultSet rs = ps.executeQuery();
39                 while (rs.next()) {
40                     // Load CSR
41                     load(rs.getInt(1), rs.getString(2), Certificate.AttachmentType.CSR);
42                     load(rs.getInt(1), rs.getString(3), Certificate.AttachmentType.CRT);
43                 }
44             }
45         }
46
47         private static void load(int id, String file, AttachmentType type) throws IOException {
48             if ("".equals(file) && type == AttachmentType.CRT) {
49                 // this is ok, certificates might be in DRAFT state
50                 return;
51             }
52             File f = new File(file);
53             System.out.println("Upgrade 32: loading " + f);
54             if (f.exists()) {
55                 StringBuilder sb = new StringBuilder();
56                 try (FileInputStream fis = new FileInputStream(f); Reader r = new InputStreamReader(fis, "UTF-8")) {
57                     int len;
58                     char[] buf = new char[4096];
59                     while ((len = r.read(buf)) > 0) {
60                         sb.append(buf, 0, len);
61                     }
62                 }
63                 String csrS = sb.toString();
64                 try (GigiPreparedStatement ps1 = new GigiPreparedStatement("INSERT INTO `certificateAttachment` SET `certid`=?, `type`=?::`certificateAttachmentType`, `content`=?")) {
65                     ps1.setInt(1, id);
66                     ps1.setEnum(2, type);
67                     ps1.setString(3, csrS);
68                     ps1.execute();
69                 }
70                 f.delete();
71             } else {
72                 try (GigiPreparedStatement ps1 = new GigiPreparedStatement("SELECT 1 FROM `certificateAttachment` WHERE `certid`=? AND `type`=?::`certificateAttachmentType`")) {
73                     ps1.setInt(1, id);
74                     ps1.setEnum(2, type);
75                     GigiResultSet rs1 = ps1.executeQuery();
76                     if ( !rs1.next()) {
77                         throw new Error("file " + f + " not found, and attachment is missing as well.");
78                     }
79                 }
80             }
81         }
82     }
83
84     public static class Link implements AutoCloseable {
85
86         private DatabaseConnection target;
87
88         protected Link(DatabaseConnection target) {
89             this.target = target;
90         }
91
92         @Override
93         public void close() {
94             synchronized (DatabaseConnection.class) {
95                 Link i = instances.get(Thread.currentThread());
96                 if (i != this) {
97                     throw new Error();
98                 }
99                 instances.remove(Thread.currentThread());
100                 pool.add(target);
101             }
102         }
103
104     }
105
106     public static final int MAX_CACHED_INSTANCES = 3;
107
108     private static class StatementDescriptor {
109
110         String query;
111
112         boolean scrollable;
113
114         int instance;
115
116         PreparedStatement target;
117
118         public StatementDescriptor(String query, boolean scrollable) {
119             this.query = query;
120             this.scrollable = scrollable;
121             this.instance = 0;
122         }
123
124         public synchronized void instanciate(Connection c) throws SQLException {
125             if (scrollable) {
126                 target = c.prepareStatement(query, ResultSet.TYPE_SCROLL_INSENSITIVE, ResultSet.CONCUR_READ_ONLY);
127             } else {
128                 target = c.prepareStatement(query, query.startsWith("SELECT ") ? Statement.NO_GENERATED_KEYS : Statement.RETURN_GENERATED_KEYS);
129             }
130
131         }
132
133         public synchronized PreparedStatement getTarget() {
134             return target;
135         }
136
137         public synchronized void increase() {
138             if (target != null) {
139                 throw new IllegalStateException();
140             }
141             instance++;
142         }
143
144         @Override
145         public int hashCode() {
146             final int prime = 31;
147             int result = 1;
148             result = prime * result + instance;
149             result = prime * result + ((query == null) ? 0 : query.hashCode());
150             result = prime * result + (scrollable ? 1231 : 1237);
151             return result;
152         }
153
154         @Override
155         public boolean equals(Object obj) {
156             if (this == obj) {
157                 return true;
158             }
159             if (obj == null) {
160                 return false;
161             }
162             if (getClass() != obj.getClass()) {
163                 return false;
164             }
165             StatementDescriptor other = (StatementDescriptor) obj;
166             if (instance != other.instance) {
167                 return false;
168             }
169             if (query == null) {
170                 if (other.query != null) {
171                     return false;
172                 }
173             } else if ( !query.equals(other.query)) {
174                 return false;
175             }
176             if (scrollable != other.scrollable) {
177                 return false;
178             }
179             return true;
180         }
181
182     }
183
184     public static final int CURRENT_SCHEMA_VERSION = 37;
185
186     public static final int CONNECTION_TIMEOUT = 24 * 60 * 60;
187
188     private Connection c;
189
190     private HashMap<StatementDescriptor, PreparedStatement> statements = new HashMap<StatementDescriptor, PreparedStatement>();
191
192     private HashSet<PreparedStatement> underUse = new HashSet<>();
193
194     private static Properties credentials;
195
196     private Statement adHoc;
197
198     public DatabaseConnection() {
199         try {
200             Class.forName(credentials.getProperty("sql.driver"));
201         } catch (ClassNotFoundException e) {
202             e.printStackTrace();
203         }
204         tryConnect();
205
206     }
207
208     private void tryConnect() {
209         try {
210             c = DriverManager.getConnection(credentials.getProperty("sql.url") + "?socketTimeout=" + CONNECTION_TIMEOUT, credentials.getProperty("sql.user"), credentials.getProperty("sql.password"));
211             adHoc = c.createStatement();
212         } catch (SQLException e) {
213             e.printStackTrace();
214         }
215     }
216
217     protected synchronized PreparedStatement prepareInternal(String query) throws SQLException {
218         return prepareInternal(query, false);
219     }
220
221     protected synchronized PreparedStatement prepareInternal(String query, boolean scrollable) throws SQLException {
222
223         ensureOpen();
224         query = preprocessQuery(query);
225         StatementDescriptor searchHead = new StatementDescriptor(query, scrollable);
226         PreparedStatement statement = null;
227         while (statement == null) {
228             statement = statements.get(searchHead);
229             if (statement == null) {
230                 searchHead.instanciate(c);
231                 statement = searchHead.getTarget();
232                 if (searchHead.instance >= MAX_CACHED_INSTANCES) {
233                     return statement;
234                 }
235                 underUse.add(statement);
236                 statements.put(searchHead, statement);
237             } else if (underUse.contains(statement)) {
238                 searchHead.increase();
239                 statement = null;
240             } else {
241                 underUse.add(statement);
242             }
243         }
244         return statement;
245     }
246
247     protected synchronized PreparedStatement prepareInternalScrollable(String query) throws SQLException {
248         return prepareInternal(query, true);
249     }
250
251     private long lastAction = System.currentTimeMillis();
252
253     private void ensureOpen() {
254         if (System.currentTimeMillis() - lastAction > CONNECTION_TIMEOUT * 1000L) {
255             try {
256                 ResultSet rs = adHoc.executeQuery("SELECT 1");
257                 rs.close();
258                 lastAction = System.currentTimeMillis();
259                 return;
260             } catch (SQLException e) {
261             }
262             statements.clear();
263             tryConnect();
264         }
265         lastAction = System.currentTimeMillis();
266     }
267
268     private static HashMap<Thread, Link> instances = new HashMap<>();
269
270     private static LinkedBlockingDeque<DatabaseConnection> pool = new LinkedBlockingDeque<>();
271
272     private static int connCount = 0;
273
274     public static synchronized DatabaseConnection getInstance() {
275         Link l = instances.get(Thread.currentThread());
276         if (l == null) {
277             throw new Error("No database connection allocated");
278         }
279         return l.target;
280     }
281
282     public static synchronized boolean hasInstance() {
283         Link l = instances.get(Thread.currentThread());
284         return l != null;
285     }
286
287     public static boolean isInited() {
288         return credentials != null;
289     }
290
291     public static void init(Properties conf) {
292         if (credentials != null) {
293             throw new Error("Re-initiaizing is forbidden.");
294         }
295         credentials = conf;
296         try (Link i = newLink(false)) {
297             try (GigiPreparedStatement empty = new GigiPreparedStatement("SELECT * from information_schema.tables WHERE table_schema='public' AND table_name='schemeVersion'")) {
298                 if ( !empty.executeQuery().next()) {
299                     try (InputStream resourceAsStream = DatabaseConnection.class.getResourceAsStream("tableStructure.sql")) {
300                         if (resourceAsStream == null) {
301                             throw new Error("DB-Install-Script not found.");
302                         }
303                         try (Statement s = getInstance().c.createStatement()) {
304                             SQLFileManager.addFile(s, resourceAsStream, ImportType.PRODUCTION);
305                             s.executeBatch();
306                         }
307                     }
308                     return;
309                 }
310             } catch (IOException e) {
311                 throw new Error(e);
312             } catch (SQLException e) {
313                 throw new Error(e);
314             }
315             int version = 0;
316             try (GigiPreparedStatement gigiPreparedStatement = new GigiPreparedStatement("SELECT version FROM \"schemeVersion\" ORDER BY version DESC LIMIT 1;")) {
317                 GigiResultSet rs = gigiPreparedStatement.executeQuery();
318                 if (rs.next()) {
319                     version = rs.getInt(1);
320                 }
321             }
322             if (version == CURRENT_SCHEMA_VERSION) {
323                 return; // Good to go
324             }
325             if (version > CURRENT_SCHEMA_VERSION) {
326                 throw new Error("Invalid database version. Please fix this.");
327             }
328             upgrade(version);
329         } catch (InterruptedException e) {
330             throw new Error(e);
331         }
332     }
333
334     private static void upgrade(int version) {
335         try {
336             try (Statement s = getInstance().c.createStatement()) {
337                 while (version < CURRENT_SCHEMA_VERSION) {
338                     addUpgradeScript(Integer.toString(version), s);
339                     version++;
340                     System.out.println("UPGRADING Database to version " + version);
341                     s.addBatch("UPDATE \"schemeVersion\" SET version='" + version + "'");
342                     s.executeBatch();
343                 }
344                 System.out.println("done.");
345             }
346         } catch (SQLException e) {
347             e.printStackTrace();
348         } catch (IOException e) {
349             e.printStackTrace();
350         }
351     }
352
353     private static void addUpgradeScript(String version, Statement s) throws Error, IOException, SQLException {
354         if (version.equals("31")) {
355             Upgrade32.execute();
356         } else {
357             try (InputStream resourceAsStream = DatabaseConnection.class.getResourceAsStream("upgrade/from_" + version + ".sql")) {
358                 if (resourceAsStream == null) {
359                     throw new Error("Upgrade script from version " + version + " was not found.");
360                 }
361                 SQLFileManager.addFile(s, resourceAsStream, ImportType.PRODUCTION);
362             }
363         }
364     }
365
366     public static final String preprocessQuery(String originalQuery) {
367         originalQuery = originalQuery.replace('`', '"');
368         if (originalQuery.matches("^INSERT INTO [^ ]+ SET .*")) {
369             Pattern p = Pattern.compile("INSERT INTO ([^ ]+) SET (.*)");
370             Matcher m = p.matcher(originalQuery);
371             if (m.matches()) {
372                 String replacement = "INSERT INTO " + toIdentifier(m.group(1));
373                 String[] parts = m.group(2).split(",");
374                 StringJoiner columns = new StringJoiner(", ");
375                 StringJoiner values = new StringJoiner(", ");
376                 for (int i = 0; i < parts.length; i++) {
377                     String[] split = parts[i].split("=", 2);
378                     columns.add(toIdentifier(split[0]));
379                     values.add(split[1]);
380                 }
381                 replacement += "(" + columns.toString() + ") VALUES(" + values.toString() + ")";
382                 return replacement;
383             }
384         }
385
386         //
387         return originalQuery;
388     }
389
390     private static CharSequence toIdentifier(String ident) {
391         ident = ident.trim();
392         if ( !ident.startsWith("\"")) {
393             ident = "\"" + ident;
394         }
395         if ( !ident.endsWith("\"")) {
396             ident = ident + "\"";
397         }
398         return ident;
399     }
400
401     protected synchronized void returnStatement(PreparedStatement target) throws SQLException {
402         if ( !underUse.remove(target)) {
403             target.close();
404         }
405     }
406
407     public synchronized int getNumberOfLockedStatements() {
408         return underUse.size();
409     }
410
411     public synchronized void lockedStatements(PrintWriter writer) {
412         writer.println(underUse.size());
413         for (PreparedStatement ps : underUse) {
414             for (Entry<StatementDescriptor, PreparedStatement> e : statements.entrySet()) {
415                 if (e.getValue() == ps) {
416                     writer.println("<br/>");
417                     writer.println(e.getKey().instance + ":");
418
419                     writer.println(e.getKey().query);
420                 }
421             }
422         }
423     }
424
425     public static Link newLink(boolean readOnly) throws InterruptedException {
426         synchronized (DatabaseConnection.class) {
427
428             if (instances.get(Thread.currentThread()) != null) {
429                 throw new Error("There is already a connection allocated for this thread.");
430             }
431             if (pool.isEmpty() && connCount < 5) {
432                 pool.addLast(new DatabaseConnection());
433                 connCount++;
434             }
435         }
436         DatabaseConnection conn = pool.takeFirst();
437         synchronized (DatabaseConnection.class) {
438             try {
439                 conn.c.setReadOnly(readOnly);
440             } catch (SQLException e) {
441                 throw new Error(e);
442             }
443             Link l = new Link(conn);
444             instances.put(Thread.currentThread(), l);
445             return l;
446         }
447
448     }
449 }