]> WPIA git - gigi.git/blob - src/org/cacert/gigi/database/DatabaseConnection.java
fix: transactions this way are buggy, we need a better way... will come.
[gigi.git] / src / org / cacert / gigi / database / DatabaseConnection.java
1 package org.cacert.gigi.database;
2
3 import java.io.IOException;
4 import java.io.InputStream;
5 import java.io.PrintWriter;
6 import java.sql.Connection;
7 import java.sql.DriverManager;
8 import java.sql.PreparedStatement;
9 import java.sql.ResultSet;
10 import java.sql.SQLException;
11 import java.sql.Statement;
12 import java.util.HashMap;
13 import java.util.HashSet;
14 import java.util.Map.Entry;
15 import java.util.Properties;
16 import java.util.StringJoiner;
17 import java.util.regex.Matcher;
18 import java.util.regex.Pattern;
19
20 import org.cacert.gigi.database.SQLFileManager.ImportType;
21
22 public class DatabaseConnection {
23
24     public static final int MAX_CACHED_INSTANCES = 3;
25
26     private static class StatementDescriptor {
27
28         String query;
29
30         boolean scrollable;
31
32         int instance;
33
34         PreparedStatement target;
35
36         public StatementDescriptor(String query, boolean scrollable) {
37             this.query = query;
38             this.scrollable = scrollable;
39             this.instance = 0;
40         }
41
42         public synchronized void instanciate(Connection c) throws SQLException {
43             if (scrollable) {
44                 target = c.prepareStatement(query, ResultSet.TYPE_SCROLL_INSENSITIVE, ResultSet.CONCUR_READ_ONLY);
45             } else {
46                 target = c.prepareStatement(query, query.startsWith("SELECT ") ? Statement.NO_GENERATED_KEYS : Statement.RETURN_GENERATED_KEYS);
47             }
48
49         }
50
51         public synchronized PreparedStatement getTarget() {
52             return target;
53         }
54
55         public synchronized void increase() {
56             if (target != null) {
57                 throw new IllegalStateException();
58             }
59             instance++;
60         }
61
62         @Override
63         public int hashCode() {
64             final int prime = 31;
65             int result = 1;
66             result = prime * result + instance;
67             result = prime * result + ((query == null) ? 0 : query.hashCode());
68             result = prime * result + (scrollable ? 1231 : 1237);
69             return result;
70         }
71
72         @Override
73         public boolean equals(Object obj) {
74             if (this == obj) {
75                 return true;
76             }
77             if (obj == null) {
78                 return false;
79             }
80             if (getClass() != obj.getClass()) {
81                 return false;
82             }
83             StatementDescriptor other = (StatementDescriptor) obj;
84             if (instance != other.instance) {
85                 return false;
86             }
87             if (query == null) {
88                 if (other.query != null) {
89                     return false;
90                 }
91             } else if ( !query.equals(other.query)) {
92                 return false;
93             }
94             if (scrollable != other.scrollable) {
95                 return false;
96             }
97             return true;
98         }
99
100     }
101
102     public static final int CURRENT_SCHEMA_VERSION = 9;
103
104     public static final int CONNECTION_TIMEOUT = 24 * 60 * 60;
105
106     private Connection c;
107
108     private HashMap<StatementDescriptor, PreparedStatement> statements = new HashMap<StatementDescriptor, PreparedStatement>();
109
110     HashSet<PreparedStatement> underUse = new HashSet<>();
111
112     private static Properties credentials;
113
114     private Statement adHoc;
115
116     public DatabaseConnection() {
117         try {
118             Class.forName(credentials.getProperty("sql.driver"));
119         } catch (ClassNotFoundException e) {
120             e.printStackTrace();
121         }
122         tryConnect();
123
124     }
125
126     private void tryConnect() {
127         try {
128             c = DriverManager.getConnection(credentials.getProperty("sql.url") + "?socketTimeout=" + CONNECTION_TIMEOUT, credentials.getProperty("sql.user"), credentials.getProperty("sql.password"));
129             adHoc = c.createStatement();
130         } catch (SQLException e) {
131             e.printStackTrace();
132         }
133     }
134
135     protected synchronized PreparedStatement prepareInternal(String query) throws SQLException {
136         return prepareInternal(query, false);
137     }
138
139     protected synchronized PreparedStatement prepareInternal(String query, boolean scrollable) throws SQLException {
140
141         ensureOpen();
142         query = preprocessQuery(query);
143         StatementDescriptor searchHead = new StatementDescriptor(query, scrollable);
144         PreparedStatement statement = null;
145         while (statement == null) {
146             statement = statements.get(searchHead);
147             if (statement == null) {
148                 searchHead.instanciate(c);
149                 statement = searchHead.getTarget();
150                 if (searchHead.instance >= MAX_CACHED_INSTANCES) {
151                     return statement;
152                 }
153                 underUse.add(statement);
154                 statements.put(searchHead, statement);
155             } else if (underUse.contains(statement)) {
156                 searchHead.increase();
157                 statement = null;
158             } else {
159                 underUse.add(statement);
160             }
161         }
162         return statement;
163     }
164
165     protected synchronized PreparedStatement prepareInternalScrollable(String query) throws SQLException {
166         return prepareInternal(query, true);
167     }
168
169     private long lastAction = System.currentTimeMillis();
170
171     private void ensureOpen() {
172         if (System.currentTimeMillis() - lastAction > CONNECTION_TIMEOUT * 1000L) {
173             try {
174                 ResultSet rs = adHoc.executeQuery("SELECT 1");
175                 rs.close();
176                 lastAction = System.currentTimeMillis();
177                 return;
178             } catch (SQLException e) {
179             }
180             statements.clear();
181             tryConnect();
182         }
183         lastAction = System.currentTimeMillis();
184     }
185
186     private static volatile DatabaseConnection instance;
187
188     public static synchronized DatabaseConnection getInstance() {
189         if (instance == null) {
190             instance = new DatabaseConnection();
191         }
192         return instance;
193     }
194
195     public static boolean isInited() {
196         return credentials != null;
197     }
198
199     public static void init(Properties conf) {
200         if (credentials != null) {
201             throw new Error("Re-initiaizing is forbidden.");
202         }
203         credentials = conf;
204         int version = 0;
205         try (GigiPreparedStatement gigiPreparedStatement = new GigiPreparedStatement("SELECT version FROM \"schemeVersion\" ORDER BY version DESC LIMIT 1;")) {
206             GigiResultSet rs = gigiPreparedStatement.executeQuery();
207             if (rs.next()) {
208                 version = rs.getInt(1);
209             }
210         }
211         if (version == CURRENT_SCHEMA_VERSION) {
212             return; // Good to go
213         }
214         if (version > CURRENT_SCHEMA_VERSION) {
215             throw new Error("Invalid database version. Please fix this.");
216         }
217         upgrade(version);
218     }
219
220     private static void upgrade(int version) {
221         try {
222             Statement s = getInstance().c.createStatement();
223             try {
224                 while (version < CURRENT_SCHEMA_VERSION) {
225                     try (InputStream resourceAsStream = DatabaseConnection.class.getResourceAsStream("upgrade/from_" + version + ".sql")) {
226                         if (resourceAsStream == null) {
227                             throw new Error("Upgrade script from version " + version + " was not found.");
228                         }
229                         SQLFileManager.addFile(s, resourceAsStream, ImportType.PRODUCTION);
230                     }
231                     version++;
232                 }
233                 s.addBatch("UPDATE \"schemeVersion\" SET version='" + version + "'");
234                 System.out.println("UPGRADING Database to version " + version);
235                 s.executeBatch();
236                 System.out.println("done.");
237             } finally {
238                 s.close();
239             }
240         } catch (SQLException e) {
241             e.printStackTrace();
242         } catch (IOException e) {
243             e.printStackTrace();
244         }
245     }
246
247     public static final String preprocessQuery(String originalQuery) {
248         originalQuery = originalQuery.replace('`', '"');
249         if (originalQuery.matches("^INSERT INTO [^ ]+ SET .*")) {
250             Pattern p = Pattern.compile("INSERT INTO ([^ ]+) SET (.*)");
251             Matcher m = p.matcher(originalQuery);
252             if (m.matches()) {
253                 String replacement = "INSERT INTO " + toIdentifier(m.group(1));
254                 String[] parts = m.group(2).split(",");
255                 StringJoiner columns = new StringJoiner(", ");
256                 StringJoiner values = new StringJoiner(", ");
257                 for (int i = 0; i < parts.length; i++) {
258                     String[] split = parts[i].split("=", 2);
259                     columns.add(toIdentifier(split[0]));
260                     values.add(split[1]);
261                 }
262                 replacement += "(" + columns.toString() + ") VALUES(" + values.toString() + ")";
263                 return replacement;
264             }
265         }
266
267         //
268         return originalQuery;
269     }
270
271     private static CharSequence toIdentifier(String ident) {
272         ident = ident.trim();
273         if ( !ident.startsWith("\"")) {
274             ident = "\"" + ident;
275         }
276         if ( !ident.endsWith("\"")) {
277             ident = ident + "\"";
278         }
279         return ident;
280     }
281
282     protected synchronized void returnStatement(PreparedStatement target) throws SQLException {
283         if ( !underUse.remove(target)) {
284             target.close();
285         }
286     }
287
288     public synchronized int getNumberOfLockedStatements() {
289         return underUse.size();
290     }
291
292     public synchronized void lockedStatements(PrintWriter writer) {
293         writer.println(underUse.size());
294         for (PreparedStatement ps : underUse) {
295             for (Entry<StatementDescriptor, PreparedStatement> e : statements.entrySet()) {
296                 if (e.getValue() == ps) {
297                     writer.println("<br/>");
298                     writer.println(e.getKey().instance + ":");
299
300                     writer.println(e.getKey().query);
301                 }
302             }
303         }
304     }
305 }