]> WPIA git - gigi.git/blob - src/org/cacert/gigi/database/DatabaseConnection.java
add: internal api for password reset (with assurance)
[gigi.git] / src / org / cacert / gigi / database / DatabaseConnection.java
1 package org.cacert.gigi.database;
2
3 import java.io.IOException;
4 import java.io.InputStream;
5 import java.sql.Connection;
6 import java.sql.DriverManager;
7 import java.sql.ResultSet;
8 import java.sql.SQLException;
9 import java.sql.Statement;
10 import java.util.HashMap;
11 import java.util.Properties;
12 import java.util.StringJoiner;
13 import java.util.regex.Matcher;
14 import java.util.regex.Pattern;
15
16 import org.cacert.gigi.database.SQLFileManager.ImportType;
17
18 public class DatabaseConnection {
19
20     public static final int CURRENT_SCHEMA_VERSION = 6;
21
22     public static final int CONNECTION_TIMEOUT = 24 * 60 * 60;
23
24     private Connection c;
25
26     private HashMap<String, GigiPreparedStatement> statements = new HashMap<String, GigiPreparedStatement>();
27
28     private static Properties credentials;
29
30     private Statement adHoc;
31
32     public DatabaseConnection() {
33         try {
34             Class.forName(credentials.getProperty("sql.driver"));
35         } catch (ClassNotFoundException e) {
36             e.printStackTrace();
37         }
38         tryConnect();
39
40     }
41
42     private void tryConnect() {
43         try {
44             c = DriverManager.getConnection(credentials.getProperty("sql.url") + "?socketTimeout=" + CONNECTION_TIMEOUT, credentials.getProperty("sql.user"), credentials.getProperty("sql.password"));
45             adHoc = c.createStatement();
46         } catch (SQLException e) {
47             e.printStackTrace();
48         }
49     }
50
51     public GigiPreparedStatement prepare(String query) {
52         ensureOpen();
53         query = preprocessQuery(query);
54         GigiPreparedStatement statement = statements.get(query);
55         if (statement == null) {
56             try {
57                 statement = new GigiPreparedStatement(c.prepareStatement(query, query.startsWith("SELECT ") ? Statement.NO_GENERATED_KEYS : Statement.RETURN_GENERATED_KEYS));
58             } catch (SQLException e) {
59                 throw new Error(e);
60             }
61             statements.put(query, statement);
62         }
63         return statement;
64     }
65
66     public GigiPreparedStatement prepareScrollable(String query) {
67         ensureOpen();
68         query = preprocessQuery(query);
69         GigiPreparedStatement statement = statements.get(query);
70         if (statement == null) {
71             try {
72                 statement = new GigiPreparedStatement(c.prepareStatement(query, ResultSet.TYPE_SCROLL_INSENSITIVE, ResultSet.CONCUR_READ_ONLY));
73             } catch (SQLException e) {
74                 throw new Error(e);
75             }
76             statements.put(query, statement);
77         }
78         return statement;
79     }
80
81     private long lastAction = System.currentTimeMillis();
82
83     private void ensureOpen() {
84         if (System.currentTimeMillis() - lastAction > CONNECTION_TIMEOUT * 1000L) {
85             try {
86                 ResultSet rs = adHoc.executeQuery("SELECT 1");
87                 rs.close();
88                 lastAction = System.currentTimeMillis();
89                 return;
90             } catch (SQLException e) {
91             }
92             statements.clear();
93             tryConnect();
94         }
95         lastAction = System.currentTimeMillis();
96     }
97
98     private static ThreadLocal<DatabaseConnection> instances = new ThreadLocal<DatabaseConnection>() {
99
100         @Override
101         protected DatabaseConnection initialValue() {
102             return new DatabaseConnection();
103         }
104     };
105
106     public static DatabaseConnection getInstance() {
107         return instances.get();
108     }
109
110     public static boolean isInited() {
111         return credentials != null;
112     }
113
114     public static void init(Properties conf) {
115         if (credentials != null) {
116             throw new Error("Re-initiaizing is forbidden.");
117         }
118         credentials = conf;
119         GigiResultSet rs = getInstance().prepare("SELECT version FROM \"schemeVersion\" ORDER BY version DESC LIMIT 1;").executeQuery();
120         int version = 0;
121         if (rs.next()) {
122             version = rs.getInt(1);
123         }
124         if (version == CURRENT_SCHEMA_VERSION) {
125             return; // Good to go
126         }
127         if (version > CURRENT_SCHEMA_VERSION) {
128             throw new Error("Invalid database version. Please fix this.");
129         }
130         upgrade(version);
131     }
132
133     public void beginTransaction() throws SQLException {
134         c.setAutoCommit(false);
135     }
136
137     private static void upgrade(int version) {
138         try {
139             Statement s = getInstance().c.createStatement();
140             try {
141                 while (version < CURRENT_SCHEMA_VERSION) {
142                     try (InputStream resourceAsStream = DatabaseConnection.class.getResourceAsStream("upgrade/from_" + version + ".sql")) {
143                         if (resourceAsStream == null) {
144                             throw new Error("Upgrade script from version " + version + " was not found.");
145                         }
146                         SQLFileManager.addFile(s, resourceAsStream, ImportType.PRODUCTION);
147                     }
148                     version++;
149                 }
150                 s.addBatch("UPDATE \"schemeVersion\" SET version='" + version + "'");
151                 System.out.println("UPGRADING Database to version " + version);
152                 s.executeBatch();
153                 System.out.println("done.");
154             } finally {
155                 s.close();
156             }
157         } catch (SQLException e) {
158             e.printStackTrace();
159         } catch (IOException e) {
160             e.printStackTrace();
161         }
162     }
163
164     public void commitTransaction() throws SQLException {
165         c.commit();
166         c.setAutoCommit(true);
167     }
168
169     public void quitTransaction() {
170         try {
171             if ( !c.getAutoCommit()) {
172                 c.rollback();
173                 c.setAutoCommit(true);
174             }
175         } catch (SQLException e) {
176             e.printStackTrace();
177         }
178     }
179
180     public static final String preprocessQuery(String originalQuery) {
181         originalQuery = originalQuery.replace('`', '"');
182         if (originalQuery.matches("^INSERT INTO [^ ]+ SET .*")) {
183             Pattern p = Pattern.compile("INSERT INTO ([^ ]+) SET (.*)");
184             Matcher m = p.matcher(originalQuery);
185             if (m.matches()) {
186                 String replacement = "INSERT INTO " + toIdentifier(m.group(1));
187                 String[] parts = m.group(2).split(",");
188                 StringJoiner columns = new StringJoiner(", ");
189                 StringJoiner values = new StringJoiner(", ");
190                 for (int i = 0; i < parts.length; i++) {
191                     String[] split = parts[i].split("=", 2);
192                     columns.add(toIdentifier(split[0]));
193                     values.add(split[1]);
194                 }
195                 replacement += "(" + columns.toString() + ") VALUES(" + values.toString() + ")";
196                 return replacement;
197             }
198         }
199
200         //
201         return originalQuery;
202     }
203
204     private static CharSequence toIdentifier(String ident) {
205         ident = ident.trim();
206         if ( !ident.startsWith("\"")) {
207             ident = "\"" + ident;
208         }
209         if ( !ident.endsWith("\"")) {
210             ident = ident + "\"";
211         }
212         return ident;
213     }
214 }