]> WPIA git - gigi.git/blob - src/club/wpia/gigi/database/SQLFileManager.java
upd: rename package name and all references to it
[gigi.git] / src / club / wpia / gigi / database / SQLFileManager.java
1 package club.wpia.gigi.database;
2
3 import java.io.ByteArrayOutputStream;
4 import java.io.IOException;
5 import java.io.InputStream;
6 import java.sql.SQLException;
7 import java.sql.Statement;
8 import java.util.regex.Matcher;
9 import java.util.regex.Pattern;
10
11 public class SQLFileManager {
12
13     public static enum ImportType {
14         /**
15          * Execute Script as-is
16          */
17         PRODUCTION,
18         /**
19          * Execute Script, but change Engine=InnoDB to Engine=Memory
20          */
21         TEST,
22         /**
23          * Execute INSERT statements as-is, and TRUNCATE instead of DROPPING
24          */
25         TRUNCATE,
26         /**
27          * Execute Script as-is if db version is >= specified version in
28          * optional header
29          */
30         SAMPLE_DATA,
31     }
32
33     public static void addFile(Statement stmt, InputStream f, ImportType type) throws IOException, SQLException {
34         String sql = readFile(f);
35         if (type == ImportType.SAMPLE_DATA) {
36             String fl = sql.split("\n")[0];
37             if (fl.matches("--Version: ([0-9]+)")) {
38                 int v0 = Integer.parseInt(fl.substring(11));
39                 if (DatabaseConnection.CURRENT_SCHEMA_VERSION < v0) {
40                     System.out.println("skipping sample data (data has version " + v0 + ", db has version " + DatabaseConnection.CURRENT_SCHEMA_VERSION + ")");
41                     return;
42                 }
43             }
44         }
45         sql = sql.replaceAll("--[^\n]*\n", "\n");
46         sql = sql.replaceAll("#[^\n]*\n", "\n");
47         String[] stmts = sql.split(";");
48         Pattern p = Pattern.compile("\\s*DROP TABLE IF EXISTS \"([^\"]+)\"");
49         for (String string : stmts) {
50             Matcher m = p.matcher(string);
51             string = string.trim();
52             if (string.equals("")) {
53                 continue;
54             }
55             if ((string.contains("profiles") || string.contains("cacerts") || string.contains("cats_type") || string.contains("countryIsoCode")) && type == ImportType.TRUNCATE) {
56                 continue;
57             }
58             string = DatabaseConnection.preprocessQuery(string);
59             if (m.matches() && type == ImportType.TRUNCATE) {
60                 String sql2 = "DELETE FROM \"" + m.group(1) + "\"";
61                 stmt.addBatch(sql2);
62                 continue;
63             }
64             if (type == ImportType.PRODUCTION || type == ImportType.SAMPLE_DATA || string.startsWith("INSERT")) {
65                 stmt.addBatch(string);
66             } else if (type == ImportType.TEST) {
67                 stmt.addBatch(string.replace("ENGINE=InnoDB", "ENGINE=Memory"));
68             }
69         }
70     }
71
72     private static String readFile(InputStream f) throws IOException {
73         ByteArrayOutputStream baos = new ByteArrayOutputStream();
74         int len;
75         byte[] buf = new byte[4096];
76         while ((len = f.read(buf)) > 0) {
77             baos.write(buf, 0, len);
78         }
79         return new String(baos.toByteArray(), "UTF-8");
80     }
81 }