add: code to statically verify SQL call patterns
authorFelix Dörre <felix@dogcraft.de>
Fri, 4 Aug 2017 16:31:11 +0000 (18:31 +0200)
committerFelix Dörre <felix@dogcraft.de>
Sun, 6 Aug 2017 22:43:27 +0000 (00:43 +0200)
Change-Id: Ib5c0e7a76d9a14f318087f092091bdf2cfa3c174

build.xml
util-testing/club/wpia/gigi/localisation/SQLTestingVisitor.java [new file with mode: 0644]
util-testing/club/wpia/gigi/localisation/TranslationCollectingVisitor.java
util-testing/club/wpia/gigi/localisation/TranslationCollector.java

index 6d2587e..1f123bc 100644 (file)
--- a/build.xml
+++ b/build.xml
                        <classpath refid="gigi.test.classpath" />
                        <classpath refid="gigi.test.classpath.jdt" />
                </java>
+               <java classname="club.wpia.gigi.localisation.TranslationCollector" failonerror="true">
+                       <arg value="util-testing/club/wpia/gigi/localisation/conf.txt"/>
+                       <arg value="."/>
+                       <arg value="SQLValidation"/>
+                       <classpath refid="gigi.test.classpath" />
+                       <classpath refid="gigi.test.classpath.jdt" />
+               </java>
        </target>
 
        <target name="native">
diff --git a/util-testing/club/wpia/gigi/localisation/SQLTestingVisitor.java b/util-testing/club/wpia/gigi/localisation/SQLTestingVisitor.java
new file mode 100644 (file)
index 0000000..808aee6
--- /dev/null
@@ -0,0 +1,406 @@
+package club.wpia.gigi.localisation;
+
+import java.sql.Date;
+import java.sql.ParameterMetaData;
+import java.sql.SQLException;
+import java.sql.Timestamp;
+import java.sql.Types;
+import java.util.Arrays;
+import java.util.Deque;
+import java.util.LinkedList;
+import java.util.List;
+import java.util.concurrent.LinkedBlockingDeque;
+
+import org.eclipse.jdt.internal.compiler.ASTVisitor;
+import org.eclipse.jdt.internal.compiler.CompilationResult;
+import org.eclipse.jdt.internal.compiler.ast.AllocationExpression;
+import org.eclipse.jdt.internal.compiler.ast.BinaryExpression;
+import org.eclipse.jdt.internal.compiler.ast.CompilationUnitDeclaration;
+import org.eclipse.jdt.internal.compiler.ast.ConditionalExpression;
+import org.eclipse.jdt.internal.compiler.ast.Expression;
+import org.eclipse.jdt.internal.compiler.ast.ExtendedStringLiteral;
+import org.eclipse.jdt.internal.compiler.ast.IntLiteral;
+import org.eclipse.jdt.internal.compiler.ast.MessageSend;
+import org.eclipse.jdt.internal.compiler.ast.SingleNameReference;
+import org.eclipse.jdt.internal.compiler.ast.StringLiteral;
+import org.eclipse.jdt.internal.compiler.ast.TryStatement;
+import org.eclipse.jdt.internal.compiler.ast.TypeDeclaration;
+import org.eclipse.jdt.internal.compiler.impl.Constant;
+import org.eclipse.jdt.internal.compiler.lookup.BlockScope;
+import org.eclipse.jdt.internal.compiler.lookup.ClassScope;
+import org.eclipse.jdt.internal.compiler.lookup.CompilationUnitScope;
+import org.eclipse.jdt.internal.compiler.lookup.MethodBinding;
+import org.eclipse.jdt.internal.compiler.lookup.SourceTypeBinding;
+import org.eclipse.jdt.internal.compiler.lookup.TypeBinding;
+
+import club.wpia.gigi.database.DBEnum;
+import club.wpia.gigi.database.DatabaseConnection;
+import club.wpia.gigi.database.GigiPreparedStatement;
+
+public class SQLTestingVisitor extends ASTVisitor {
+
+    private CompilationUnitDeclaration pu;
+
+    private TranslationCollector tc;
+
+    private enum Type {
+        TIMESTAMP(Types.TIMESTAMP), ENUM(Types.VARCHAR), STRING(Types.VARCHAR), BOOLEAN(Types.BOOLEAN), DATE(Types.DATE), INTEGER(Types.INTEGER), OTHER(0);
+
+        private final int sqltype;
+
+        private Type(int sqltype) {
+            this.sqltype = sqltype;
+
+        }
+
+        public void set(GigiPreparedStatement ps, int index) {
+            if (this == TIMESTAMP) {
+                ps.setTimestamp(index, new Timestamp(System.currentTimeMillis()));
+            } else if (this == STRING) {
+                ps.setString(index, "y");
+            } else if (this == DATE) {
+                ps.setDate(index, new Date(System.currentTimeMillis()));
+            } else if (this == Type.BOOLEAN) {
+                ps.setBoolean(index, false);
+            } else if (this == OTHER || this == INTEGER) {
+                ps.setInt(index, 1000);
+            } else {
+                throw new Error();
+            }
+        }
+
+        public boolean isOfSQLType(int i) {
+            if (i == sqltype) {
+                return true;
+            }
+            if (i == Types.BIT && this == BOOLEAN) {
+                return true;
+            }
+            return false;
+        }
+    }
+
+    private class TypeInstantiation {
+
+        Type type;
+
+        String enumValue;
+
+        public TypeInstantiation(Type type) {
+            this.type = type;
+        }
+
+        public TypeInstantiation(Type type, String enumValue) {
+            this.enumValue = enumValue;
+            this.type = type;
+        }
+
+        public void set(GigiPreparedStatement ps, int index) {
+            if (type == Type.ENUM) {
+                ps.setString(index, enumValue);
+            } else {
+                type.set(ps, index);
+            }
+        }
+
+        @Override
+        public String toString() {
+            return type.toString() + (enumValue != null ? enumValue : "");
+        }
+
+        @Override
+        public int hashCode() {
+            final int prime = 31;
+            int result = 1;
+            result = prime * result + getOuterType().hashCode();
+            result = prime * result + ((enumValue == null) ? 0 : enumValue.hashCode());
+            result = prime * result + ((type == null) ? 0 : type.hashCode());
+            return result;
+        }
+
+        @Override
+        public boolean equals(Object obj) {
+            if (this == obj) {
+                return true;
+            }
+            if (obj == null) {
+                return false;
+            }
+            if (getClass() != obj.getClass()) {
+                return false;
+            }
+            TypeInstantiation other = (TypeInstantiation) obj;
+            if ( !getOuterType().equals(other.getOuterType())) {
+                return false;
+            }
+            if (enumValue == null) {
+                if (other.enumValue != null) {
+                    return false;
+                }
+            } else if ( !enumValue.equals(other.enumValue)) {
+                return false;
+            }
+            if (type != other.type) {
+                return false;
+            }
+            return true;
+        }
+
+        private SQLTestingVisitor getOuterType() {
+            return SQLTestingVisitor.this;
+        }
+
+    }
+
+    public class SQLOccurrence {
+
+        private List<String> query;
+
+        private TryStatement target;
+
+        private CompilationResult source;
+
+        public TypeInstantiation[] types = new TypeInstantiation[10];
+
+        private int sourceStart;
+
+        public SQLOccurrence(TryStatement target) {
+            this.target = target;
+        }
+
+        public void setQuery(List<String> query, CompilationResult compilationResult, int sourceStart) {
+            this.query = query;
+            this.source = compilationResult;
+            this.sourceStart = sourceStart;
+        }
+
+        public TryStatement getTarget() {
+            return target;
+        }
+
+        public List<String> getQuery() {
+            return query;
+        }
+
+        public int getSourceStart() {
+            return sourceStart;
+        }
+
+        public boolean isQuery() {
+            return query != null;
+        }
+
+        public String getPosition() {
+            int pos = source.lineSeparatorPositions.length + 1;
+            for (int i = 0; i < source.lineSeparatorPositions.length; i++) {
+                if (source.lineSeparatorPositions[i] > sourceStart) {
+                    pos = i + 1;
+                    break;
+                }
+            }
+            return new String(source.getFileName()) + ":" + pos;
+        }
+
+        private void check(String stmt) {
+            tc.countStatement();
+            try (DatabaseConnection.Link l = DatabaseConnection.newLink(true)) {
+                try (GigiPreparedStatement ps = new GigiPreparedStatement(stmt)) {
+                    ParameterMetaData dt = ps.getParameterMetaData();
+                    int count = dt.getParameterCount();
+                    for (int i = 1; i <= types.length; i++) {
+                        if (i > count) {
+                            if (types[i - 1] != null) {
+                                errMsg(stmt, "too many params");
+                                return;
+                            }
+                            continue;
+                        }
+                        int tp = dt.getParameterType(i);
+                        TypeInstantiation t = types[i - 1];
+                        if (t == null) {
+                            errMsg(stmt, "arg " + i + " not set");
+                            return;
+                        }
+                        if ( !t.type.isOfSQLType(tp)) {
+                            errMsg(stmt, "type mismatch. From parameter setting code: " + t + ", in SQL statement: " + tp);
+                            return;
+                        }
+                    }
+                } catch (SQLException e) {
+                    throw new Error(e);
+                }
+            } catch (InterruptedException e) {
+                e.printStackTrace();
+            }
+        }
+
+        private void errMsg(String stmt, String errMsg) {
+            System.err.println(getPosition());
+            System.err.println("Problem with statement: " + stmt);
+            System.err.println(Arrays.toString(types));
+            System.err.println(errMsg);
+            tc.hadError();
+        }
+
+        public void check() {
+            for (String q : getQuery()) {
+                check(q);
+            }
+
+        }
+    }
+
+    public SQLTestingVisitor(CompilationUnitDeclaration pu, TranslationCollector tc) {
+        this.pu = pu;
+        this.tc = tc;
+    }
+
+    Deque<SQLOccurrence> ts = new LinkedBlockingDeque<>();
+
+    @Override
+    public boolean visit(TypeDeclaration typeDeclaration, CompilationUnitScope scope) {
+        return true;
+    }
+
+    @Override
+    public boolean visit(TypeDeclaration memberTypeDeclaration, ClassScope scope) {
+        return true;
+    }
+
+    @Override
+    public boolean visit(TryStatement tryStatement, BlockScope scope) {
+        ts.push(new SQLOccurrence(tryStatement));
+        return true;
+    }
+
+    @Override
+    public void endVisit(TryStatement tryStatement, BlockScope scope) {
+        SQLOccurrence occ = ts.pop();
+        if (occ.isQuery()) {
+            occ.check();
+        }
+        if (occ.getTarget() != tryStatement) {
+            throw new Error();
+        }
+    }
+
+    @Override
+    public boolean visit(AllocationExpression ae, BlockScope scope) {
+        MethodBinding mb = ae.binding;
+        if (new String(mb.declaringClass.qualifiedPackageName()).equals("club.wpia.gigi.database") && new String(mb.declaringClass.qualifiedSourceName()).equals("GigiPreparedStatement")) {
+            String sig = new String(mb.readableName());
+            if (sig.equals("GigiPreparedStatement(String)") || sig.equals("GigiPreparedStatement(String, boolean)")) {
+                List<String> l = getQueries(ae.arguments[0], scope);
+                if (l.size() == 0) {
+                    return false;
+                }
+                LinkedList<String> qs = new LinkedList<>();
+                for (String q : l) {
+                    qs.add(DatabaseConnection.preprocessQuery(q));
+                }
+                ts.peek().setQuery(qs, scope.compilationUnitScope().referenceContext.compilationResult, ae.sourceStart);
+            } else {
+                throw new Error(sig);
+            }
+        }
+        return true;
+    }
+
+    private List<String> getQueries(Expression q, BlockScope scope) {
+        SourceTypeBinding typ = scope.enclosingSourceType();
+        String fullType = new String(typ.qualifiedPackageName()) + "." + new String(typ.qualifiedSourceName());
+        if (fullType.equals("club.wpia.gigi.database.IntegrityVerifier")) {
+            return Arrays.asList();
+        }
+        if (q instanceof StringLiteral) {
+            String s = new String(((StringLiteral) q).source());
+            return Arrays.asList(s);
+        } else if (q instanceof ExtendedStringLiteral) {
+            throw new Error();
+        } else if (q instanceof BinaryExpression) {
+            Expression l = ((BinaryExpression) q).left;
+            Expression r = ((BinaryExpression) q).right;
+            if ( !((BinaryExpression) q).operatorToString().equals("+")) {
+                throw new Error(((BinaryExpression) q).operatorToString());
+            }
+            List<String> left = getQueries(l, scope);
+            List<String> right = getQueries(r, scope);
+            LinkedList<String> res = new LinkedList<>();
+            for (String leftS : left) {
+                for (String rightS : right) {
+                    res.add(leftS + rightS);
+                }
+            }
+            return res;
+        } else if (q instanceof ConditionalExpression) {
+            Expression t = ((ConditionalExpression) q).valueIfTrue;
+            Expression f = ((ConditionalExpression) q).valueIfFalse;
+            List<String> res = new LinkedList<>();
+            res.addAll(getQueries(t, scope));
+            res.addAll(getQueries(f, scope));
+            return res;
+        } else if (q instanceof SingleNameReference) {
+            SingleNameReference ref = (SingleNameReference) q;
+            Constant c = ref.constant;
+            if (c.equals(Constant.NotAConstant)) {
+                throw new Error(q.toString());
+            }
+            return Arrays.asList(c.stringValue());
+        } else {
+            System.err.println(q.getClass() + ";" + q.toString());
+            throw new Error(q.toString());
+        }
+    }
+
+    @Override
+    public boolean visit(MessageSend messageSend, BlockScope scope) {
+        Expression r = messageSend.receiver;
+        String rec = new String(r.resolvedType.readableName());
+        if (rec.equals("club.wpia.gigi.database.GigiPreparedStatement")) {
+            String selector = new String(messageSend.selector);
+            if (selector.startsWith("set")) {
+                SQLOccurrence peek = ts.peek();
+                if (peek == null) {
+                    throw new Error("setting parameter at bad location");
+                }
+                IntLiteral i = (IntLiteral) messageSend.arguments[0];
+                int val = i.constant.intValue();
+                TypeInstantiation typeInstantiation = getTypeInstantiation(messageSend, selector);
+                if (peek.types[val - 1] != null && !peek.types[val - 1].equals(typeInstantiation)) {
+                    throw new Error("multiple different typeInstantiations");
+                }
+                peek.types[val - 1] = typeInstantiation;
+            }
+        }
+        return true;
+    }
+
+    private TypeInstantiation getTypeInstantiation(MessageSend messageSend, String selector) throws Error {
+        switch (selector) {
+        case "setTimestamp":
+            return new TypeInstantiation(Type.TIMESTAMP);
+        case "setEnum":
+            TypeBinding rn = messageSend.arguments[1].resolvedType;
+            String sn = new String(rn.qualifiedSourceName());
+            String enumClass = new String(rn.readableName());
+            enumClass = enumClass.substring(0, enumClass.length() - sn.length()) + sn.replace('.', '$');
+            String dbn;
+            try {
+                dbn = ((DBEnum) ((Object[]) Class.forName(enumClass).getMethod("values").invoke(null))[0]).getDBName();
+            } catch (ReflectiveOperationException e) {
+                throw new Error(e);
+            }
+            return new TypeInstantiation(Type.ENUM, dbn);
+        case "setString":
+            return new TypeInstantiation(Type.STRING);
+        case "setDate":
+            return new TypeInstantiation(Type.DATE);
+        case "setInt":
+            return new TypeInstantiation(Type.INTEGER);
+        case "setBoolean":
+            return new TypeInstantiation(Type.BOOLEAN);
+        default:
+            return new TypeInstantiation(Type.OTHER);
+        }
+    }
+}
index ab3719a..5dbfcaa 100644 (file)
@@ -34,12 +34,6 @@ public final class TranslationCollectingVisitor extends ASTVisitor {
 
     Stack<QualifiedAllocationExpression> anonymousConstructorCall = new Stack<>();
 
-    private boolean hadErrors = false;
-
-    public boolean hadErrors() {
-        return hadErrors;
-    }
-
     public TranslationCollectingVisitor(CompilationUnitDeclaration unit, TaintSource[] target, TranslationCollector c) {
         this.unit = unit;
         ts = target;
@@ -194,7 +188,7 @@ public final class TranslationCollectingVisitor extends ASTVisitor {
         System.err.println("Cannot Handle: " + e + " in " + (call == null ? "constructor" : call.sourceStart) + " => " + caller);
         System.err.println(e.getClass());
         System.err.println("To ignore: " + (b == null ? "don't know" : b.toConfLine()));
-        hadErrors = true;
+        translationCollector.hadError();
     }
 
     private void testEnum(Expression e, MethodBinding binding) {
index 44df294..239c32f 100644 (file)
@@ -13,11 +13,11 @@ import java.util.Collection;
 import java.util.HashMap;
 import java.util.LinkedList;
 import java.util.List;
+import java.util.Properties;
 import java.util.TreeSet;
 
-import club.wpia.gigi.output.template.Template;
-
 import org.eclipse.jdt.core.compiler.CategorizedProblem;
+import org.eclipse.jdt.internal.compiler.ASTVisitor;
 import org.eclipse.jdt.internal.compiler.CompilationResult;
 import org.eclipse.jdt.internal.compiler.ast.CompilationUnitDeclaration;
 import org.eclipse.jdt.internal.compiler.ast.TypeDeclaration;
@@ -36,6 +36,9 @@ import org.eclipse.jdt.internal.compiler.lookup.PackageBinding;
 import org.eclipse.jdt.internal.compiler.parser.Parser;
 import org.eclipse.jdt.internal.compiler.problem.ProblemReporter;
 
+import club.wpia.gigi.database.DatabaseConnection;
+import club.wpia.gigi.output.template.Template;
+
 public class TranslationCollector {
 
     static class TranslationEntry implements Comparable<TranslationEntry> {
@@ -82,6 +85,8 @@ public class TranslationCollector {
 
     private boolean hadErrors = false;
 
+    private int statements = 0;
+
     public TranslationCollector(File base, File conf) {
         this.base = base;
         taint = new LinkedList<>();
@@ -90,15 +95,37 @@ public class TranslationCollector {
         }
     }
 
+    interface ASTVisitorFactory {
+
+        public ASTVisitor createVisitor(CompilationUnitDeclaration parsedUnit);
+    }
+
     public void run(File out) throws IOException {
         scanTemplates();
-        scanCode(taint);
+        scanCode(new ASTVisitorFactory() {
+
+            @Override
+            public ASTVisitor createVisitor(CompilationUnitDeclaration parsedUnit) {
+                return new TranslationCollectingVisitor(parsedUnit, taint.toArray(new TaintSource[taint.size()]), TranslationCollector.this);
+            }
+        });
 
         System.err.println("Total Translatable Strings: " + translations.size());
         TreeSet<TranslationEntry> trs = new TreeSet<>(translations.values());
         writePOFile(out, trs);
     }
 
+    public void runSQLValidation() throws IOException {
+        scanCode(new ASTVisitorFactory() {
+
+            @Override
+            public ASTVisitor createVisitor(CompilationUnitDeclaration parsedUnit) {
+                return new SQLTestingVisitor(parsedUnit, TranslationCollector.this);
+            }
+        });
+        System.out.println("Validated: " + statements + " SQL statements.");
+    }
+
     public void add(String text, String line) {
         if (text.contains("\r") || text.contains("\n")) {
             throw new Error("Malformed translation in " + line);
@@ -111,7 +138,7 @@ public class TranslationCollector {
         i.add(line);
     }
 
-    private void scanCode(LinkedList<TaintSource> taint) throws Error {
+    private void scanCode(ASTVisitorFactory visitor) throws Error {
         PrintWriter out = new PrintWriter(System.err);
         Main m = new Main(out, out, false, null, null);
         File[] fs = recurse(new File(new File(new File(base, "src"), "club"), "wpia"), new LinkedList<File>(), ".java").toArray(new File[0]);
@@ -197,11 +224,10 @@ public class TranslationCollector {
                 System.err.println("No types");
 
             } else {
-                TranslationCollectingVisitor v = new TranslationCollectingVisitor(parsedUnit, taint.toArray(new TaintSource[taint.size()]), this);
+                ASTVisitor v = visitor.createVisitor(parsedUnit);
                 for (TypeDeclaration td : parsedUnit.types) {
                     td.traverse(v, td.scope);
                 }
-                hadErrors |= v.hadErrors();
             }
             parsedUnits[i] = parsedUnit;
         }
@@ -226,8 +252,15 @@ public class TranslationCollector {
     private LinkedList<TaintSource> taint;
 
     public static void main(String[] args) throws IOException {
+        Properties pp = new Properties();
+        pp.load(new InputStreamReader(new FileInputStream("config/test.properties"), "UTF-8"));
+        DatabaseConnection.init(pp);
         TranslationCollector tc = new TranslationCollector(new File(args[1]), new File(args[0]));
-        tc.run(new File(args[2]));
+        if (args[2].equals("SQLValidation")) {
+            tc.runSQLValidation();
+        } else {
+            tc.run(new File(args[2]));
+        }
         if (tc.hadErrors) {
             System.exit(1);
         } else {
@@ -262,4 +295,12 @@ public class TranslationCollector {
         }
         return toAdd;
     }
+
+    public void hadError() {
+        hadErrors = true;
+    }
+
+    public void countStatement() {
+        statements++;
+    }
 }