From b2028692f1f0a3fa3d78f0dc8b81d7fdc14f6865 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Felix=20D=C3=B6rre?= Date: Fri, 4 Aug 2017 18:31:11 +0200 Subject: [PATCH] add: code to statically verify SQL call patterns Change-Id: Ib5c0e7a76d9a14f318087f092091bdf2cfa3c174 --- build.xml | 7 + .../gigi/localisation/SQLTestingVisitor.java | 406 ++++++++++++++++++ .../TranslationCollectingVisitor.java | 8 +- .../localisation/TranslationCollector.java | 55 ++- 4 files changed, 462 insertions(+), 14 deletions(-) create mode 100644 util-testing/club/wpia/gigi/localisation/SQLTestingVisitor.java diff --git a/build.xml b/build.xml index 6d2587eb..1f123bce 100644 --- a/build.xml +++ b/build.xml @@ -133,6 +133,13 @@ + + + + + + + diff --git a/util-testing/club/wpia/gigi/localisation/SQLTestingVisitor.java b/util-testing/club/wpia/gigi/localisation/SQLTestingVisitor.java new file mode 100644 index 00000000..808aee6a --- /dev/null +++ b/util-testing/club/wpia/gigi/localisation/SQLTestingVisitor.java @@ -0,0 +1,406 @@ +package club.wpia.gigi.localisation; + +import java.sql.Date; +import java.sql.ParameterMetaData; +import java.sql.SQLException; +import java.sql.Timestamp; +import java.sql.Types; +import java.util.Arrays; +import java.util.Deque; +import java.util.LinkedList; +import java.util.List; +import java.util.concurrent.LinkedBlockingDeque; + +import org.eclipse.jdt.internal.compiler.ASTVisitor; +import org.eclipse.jdt.internal.compiler.CompilationResult; +import org.eclipse.jdt.internal.compiler.ast.AllocationExpression; +import org.eclipse.jdt.internal.compiler.ast.BinaryExpression; +import org.eclipse.jdt.internal.compiler.ast.CompilationUnitDeclaration; +import org.eclipse.jdt.internal.compiler.ast.ConditionalExpression; +import org.eclipse.jdt.internal.compiler.ast.Expression; +import org.eclipse.jdt.internal.compiler.ast.ExtendedStringLiteral; +import org.eclipse.jdt.internal.compiler.ast.IntLiteral; +import org.eclipse.jdt.internal.compiler.ast.MessageSend; +import org.eclipse.jdt.internal.compiler.ast.SingleNameReference; +import org.eclipse.jdt.internal.compiler.ast.StringLiteral; +import org.eclipse.jdt.internal.compiler.ast.TryStatement; +import org.eclipse.jdt.internal.compiler.ast.TypeDeclaration; +import org.eclipse.jdt.internal.compiler.impl.Constant; +import org.eclipse.jdt.internal.compiler.lookup.BlockScope; +import org.eclipse.jdt.internal.compiler.lookup.ClassScope; +import org.eclipse.jdt.internal.compiler.lookup.CompilationUnitScope; +import org.eclipse.jdt.internal.compiler.lookup.MethodBinding; +import org.eclipse.jdt.internal.compiler.lookup.SourceTypeBinding; +import org.eclipse.jdt.internal.compiler.lookup.TypeBinding; + +import club.wpia.gigi.database.DBEnum; +import club.wpia.gigi.database.DatabaseConnection; +import club.wpia.gigi.database.GigiPreparedStatement; + +public class SQLTestingVisitor extends ASTVisitor { + + private CompilationUnitDeclaration pu; + + private TranslationCollector tc; + + private enum Type { + TIMESTAMP(Types.TIMESTAMP), ENUM(Types.VARCHAR), STRING(Types.VARCHAR), BOOLEAN(Types.BOOLEAN), DATE(Types.DATE), INTEGER(Types.INTEGER), OTHER(0); + + private final int sqltype; + + private Type(int sqltype) { + this.sqltype = sqltype; + + } + + public void set(GigiPreparedStatement ps, int index) { + if (this == TIMESTAMP) { + ps.setTimestamp(index, new Timestamp(System.currentTimeMillis())); + } else if (this == STRING) { + ps.setString(index, "y"); + } else if (this == DATE) { + ps.setDate(index, new Date(System.currentTimeMillis())); + } else if (this == Type.BOOLEAN) { + ps.setBoolean(index, false); + } else if (this == OTHER || this == INTEGER) { + ps.setInt(index, 1000); + } else { + throw new Error(); + } + } + + public boolean isOfSQLType(int i) { + if (i == sqltype) { + return true; + } + if (i == Types.BIT && this == BOOLEAN) { + return true; + } + return false; + } + } + + private class TypeInstantiation { + + Type type; + + String enumValue; + + public TypeInstantiation(Type type) { + this.type = type; + } + + public TypeInstantiation(Type type, String enumValue) { + this.enumValue = enumValue; + this.type = type; + } + + public void set(GigiPreparedStatement ps, int index) { + if (type == Type.ENUM) { + ps.setString(index, enumValue); + } else { + type.set(ps, index); + } + } + + @Override + public String toString() { + return type.toString() + (enumValue != null ? enumValue : ""); + } + + @Override + public int hashCode() { + final int prime = 31; + int result = 1; + result = prime * result + getOuterType().hashCode(); + result = prime * result + ((enumValue == null) ? 0 : enumValue.hashCode()); + result = prime * result + ((type == null) ? 0 : type.hashCode()); + return result; + } + + @Override + public boolean equals(Object obj) { + if (this == obj) { + return true; + } + if (obj == null) { + return false; + } + if (getClass() != obj.getClass()) { + return false; + } + TypeInstantiation other = (TypeInstantiation) obj; + if ( !getOuterType().equals(other.getOuterType())) { + return false; + } + if (enumValue == null) { + if (other.enumValue != null) { + return false; + } + } else if ( !enumValue.equals(other.enumValue)) { + return false; + } + if (type != other.type) { + return false; + } + return true; + } + + private SQLTestingVisitor getOuterType() { + return SQLTestingVisitor.this; + } + + } + + public class SQLOccurrence { + + private List query; + + private TryStatement target; + + private CompilationResult source; + + public TypeInstantiation[] types = new TypeInstantiation[10]; + + private int sourceStart; + + public SQLOccurrence(TryStatement target) { + this.target = target; + } + + public void setQuery(List query, CompilationResult compilationResult, int sourceStart) { + this.query = query; + this.source = compilationResult; + this.sourceStart = sourceStart; + } + + public TryStatement getTarget() { + return target; + } + + public List getQuery() { + return query; + } + + public int getSourceStart() { + return sourceStart; + } + + public boolean isQuery() { + return query != null; + } + + public String getPosition() { + int pos = source.lineSeparatorPositions.length + 1; + for (int i = 0; i < source.lineSeparatorPositions.length; i++) { + if (source.lineSeparatorPositions[i] > sourceStart) { + pos = i + 1; + break; + } + } + return new String(source.getFileName()) + ":" + pos; + } + + private void check(String stmt) { + tc.countStatement(); + try (DatabaseConnection.Link l = DatabaseConnection.newLink(true)) { + try (GigiPreparedStatement ps = new GigiPreparedStatement(stmt)) { + ParameterMetaData dt = ps.getParameterMetaData(); + int count = dt.getParameterCount(); + for (int i = 1; i <= types.length; i++) { + if (i > count) { + if (types[i - 1] != null) { + errMsg(stmt, "too many params"); + return; + } + continue; + } + int tp = dt.getParameterType(i); + TypeInstantiation t = types[i - 1]; + if (t == null) { + errMsg(stmt, "arg " + i + " not set"); + return; + } + if ( !t.type.isOfSQLType(tp)) { + errMsg(stmt, "type mismatch. From parameter setting code: " + t + ", in SQL statement: " + tp); + return; + } + } + } catch (SQLException e) { + throw new Error(e); + } + } catch (InterruptedException e) { + e.printStackTrace(); + } + } + + private void errMsg(String stmt, String errMsg) { + System.err.println(getPosition()); + System.err.println("Problem with statement: " + stmt); + System.err.println(Arrays.toString(types)); + System.err.println(errMsg); + tc.hadError(); + } + + public void check() { + for (String q : getQuery()) { + check(q); + } + + } + } + + public SQLTestingVisitor(CompilationUnitDeclaration pu, TranslationCollector tc) { + this.pu = pu; + this.tc = tc; + } + + Deque ts = new LinkedBlockingDeque<>(); + + @Override + public boolean visit(TypeDeclaration typeDeclaration, CompilationUnitScope scope) { + return true; + } + + @Override + public boolean visit(TypeDeclaration memberTypeDeclaration, ClassScope scope) { + return true; + } + + @Override + public boolean visit(TryStatement tryStatement, BlockScope scope) { + ts.push(new SQLOccurrence(tryStatement)); + return true; + } + + @Override + public void endVisit(TryStatement tryStatement, BlockScope scope) { + SQLOccurrence occ = ts.pop(); + if (occ.isQuery()) { + occ.check(); + } + if (occ.getTarget() != tryStatement) { + throw new Error(); + } + } + + @Override + public boolean visit(AllocationExpression ae, BlockScope scope) { + MethodBinding mb = ae.binding; + if (new String(mb.declaringClass.qualifiedPackageName()).equals("club.wpia.gigi.database") && new String(mb.declaringClass.qualifiedSourceName()).equals("GigiPreparedStatement")) { + String sig = new String(mb.readableName()); + if (sig.equals("GigiPreparedStatement(String)") || sig.equals("GigiPreparedStatement(String, boolean)")) { + List l = getQueries(ae.arguments[0], scope); + if (l.size() == 0) { + return false; + } + LinkedList qs = new LinkedList<>(); + for (String q : l) { + qs.add(DatabaseConnection.preprocessQuery(q)); + } + ts.peek().setQuery(qs, scope.compilationUnitScope().referenceContext.compilationResult, ae.sourceStart); + } else { + throw new Error(sig); + } + } + return true; + } + + private List getQueries(Expression q, BlockScope scope) { + SourceTypeBinding typ = scope.enclosingSourceType(); + String fullType = new String(typ.qualifiedPackageName()) + "." + new String(typ.qualifiedSourceName()); + if (fullType.equals("club.wpia.gigi.database.IntegrityVerifier")) { + return Arrays.asList(); + } + if (q instanceof StringLiteral) { + String s = new String(((StringLiteral) q).source()); + return Arrays.asList(s); + } else if (q instanceof ExtendedStringLiteral) { + throw new Error(); + } else if (q instanceof BinaryExpression) { + Expression l = ((BinaryExpression) q).left; + Expression r = ((BinaryExpression) q).right; + if ( !((BinaryExpression) q).operatorToString().equals("+")) { + throw new Error(((BinaryExpression) q).operatorToString()); + } + List left = getQueries(l, scope); + List right = getQueries(r, scope); + LinkedList res = new LinkedList<>(); + for (String leftS : left) { + for (String rightS : right) { + res.add(leftS + rightS); + } + } + return res; + } else if (q instanceof ConditionalExpression) { + Expression t = ((ConditionalExpression) q).valueIfTrue; + Expression f = ((ConditionalExpression) q).valueIfFalse; + List res = new LinkedList<>(); + res.addAll(getQueries(t, scope)); + res.addAll(getQueries(f, scope)); + return res; + } else if (q instanceof SingleNameReference) { + SingleNameReference ref = (SingleNameReference) q; + Constant c = ref.constant; + if (c.equals(Constant.NotAConstant)) { + throw new Error(q.toString()); + } + return Arrays.asList(c.stringValue()); + } else { + System.err.println(q.getClass() + ";" + q.toString()); + throw new Error(q.toString()); + } + } + + @Override + public boolean visit(MessageSend messageSend, BlockScope scope) { + Expression r = messageSend.receiver; + String rec = new String(r.resolvedType.readableName()); + if (rec.equals("club.wpia.gigi.database.GigiPreparedStatement")) { + String selector = new String(messageSend.selector); + if (selector.startsWith("set")) { + SQLOccurrence peek = ts.peek(); + if (peek == null) { + throw new Error("setting parameter at bad location"); + } + IntLiteral i = (IntLiteral) messageSend.arguments[0]; + int val = i.constant.intValue(); + TypeInstantiation typeInstantiation = getTypeInstantiation(messageSend, selector); + if (peek.types[val - 1] != null && !peek.types[val - 1].equals(typeInstantiation)) { + throw new Error("multiple different typeInstantiations"); + } + peek.types[val - 1] = typeInstantiation; + } + } + return true; + } + + private TypeInstantiation getTypeInstantiation(MessageSend messageSend, String selector) throws Error { + switch (selector) { + case "setTimestamp": + return new TypeInstantiation(Type.TIMESTAMP); + case "setEnum": + TypeBinding rn = messageSend.arguments[1].resolvedType; + String sn = new String(rn.qualifiedSourceName()); + String enumClass = new String(rn.readableName()); + enumClass = enumClass.substring(0, enumClass.length() - sn.length()) + sn.replace('.', '$'); + String dbn; + try { + dbn = ((DBEnum) ((Object[]) Class.forName(enumClass).getMethod("values").invoke(null))[0]).getDBName(); + } catch (ReflectiveOperationException e) { + throw new Error(e); + } + return new TypeInstantiation(Type.ENUM, dbn); + case "setString": + return new TypeInstantiation(Type.STRING); + case "setDate": + return new TypeInstantiation(Type.DATE); + case "setInt": + return new TypeInstantiation(Type.INTEGER); + case "setBoolean": + return new TypeInstantiation(Type.BOOLEAN); + default: + return new TypeInstantiation(Type.OTHER); + } + } +} diff --git a/util-testing/club/wpia/gigi/localisation/TranslationCollectingVisitor.java b/util-testing/club/wpia/gigi/localisation/TranslationCollectingVisitor.java index ab3719a6..5dbfcaac 100644 --- a/util-testing/club/wpia/gigi/localisation/TranslationCollectingVisitor.java +++ b/util-testing/club/wpia/gigi/localisation/TranslationCollectingVisitor.java @@ -34,12 +34,6 @@ public final class TranslationCollectingVisitor extends ASTVisitor { Stack anonymousConstructorCall = new Stack<>(); - private boolean hadErrors = false; - - public boolean hadErrors() { - return hadErrors; - } - public TranslationCollectingVisitor(CompilationUnitDeclaration unit, TaintSource[] target, TranslationCollector c) { this.unit = unit; ts = target; @@ -194,7 +188,7 @@ public final class TranslationCollectingVisitor extends ASTVisitor { System.err.println("Cannot Handle: " + e + " in " + (call == null ? "constructor" : call.sourceStart) + " => " + caller); System.err.println(e.getClass()); System.err.println("To ignore: " + (b == null ? "don't know" : b.toConfLine())); - hadErrors = true; + translationCollector.hadError(); } private void testEnum(Expression e, MethodBinding binding) { diff --git a/util-testing/club/wpia/gigi/localisation/TranslationCollector.java b/util-testing/club/wpia/gigi/localisation/TranslationCollector.java index 44df294e..239c32fb 100644 --- a/util-testing/club/wpia/gigi/localisation/TranslationCollector.java +++ b/util-testing/club/wpia/gigi/localisation/TranslationCollector.java @@ -13,11 +13,11 @@ import java.util.Collection; import java.util.HashMap; import java.util.LinkedList; import java.util.List; +import java.util.Properties; import java.util.TreeSet; -import club.wpia.gigi.output.template.Template; - import org.eclipse.jdt.core.compiler.CategorizedProblem; +import org.eclipse.jdt.internal.compiler.ASTVisitor; import org.eclipse.jdt.internal.compiler.CompilationResult; import org.eclipse.jdt.internal.compiler.ast.CompilationUnitDeclaration; import org.eclipse.jdt.internal.compiler.ast.TypeDeclaration; @@ -36,6 +36,9 @@ import org.eclipse.jdt.internal.compiler.lookup.PackageBinding; import org.eclipse.jdt.internal.compiler.parser.Parser; import org.eclipse.jdt.internal.compiler.problem.ProblemReporter; +import club.wpia.gigi.database.DatabaseConnection; +import club.wpia.gigi.output.template.Template; + public class TranslationCollector { static class TranslationEntry implements Comparable { @@ -82,6 +85,8 @@ public class TranslationCollector { private boolean hadErrors = false; + private int statements = 0; + public TranslationCollector(File base, File conf) { this.base = base; taint = new LinkedList<>(); @@ -90,15 +95,37 @@ public class TranslationCollector { } } + interface ASTVisitorFactory { + + public ASTVisitor createVisitor(CompilationUnitDeclaration parsedUnit); + } + public void run(File out) throws IOException { scanTemplates(); - scanCode(taint); + scanCode(new ASTVisitorFactory() { + + @Override + public ASTVisitor createVisitor(CompilationUnitDeclaration parsedUnit) { + return new TranslationCollectingVisitor(parsedUnit, taint.toArray(new TaintSource[taint.size()]), TranslationCollector.this); + } + }); System.err.println("Total Translatable Strings: " + translations.size()); TreeSet trs = new TreeSet<>(translations.values()); writePOFile(out, trs); } + public void runSQLValidation() throws IOException { + scanCode(new ASTVisitorFactory() { + + @Override + public ASTVisitor createVisitor(CompilationUnitDeclaration parsedUnit) { + return new SQLTestingVisitor(parsedUnit, TranslationCollector.this); + } + }); + System.out.println("Validated: " + statements + " SQL statements."); + } + public void add(String text, String line) { if (text.contains("\r") || text.contains("\n")) { throw new Error("Malformed translation in " + line); @@ -111,7 +138,7 @@ public class TranslationCollector { i.add(line); } - private void scanCode(LinkedList taint) throws Error { + private void scanCode(ASTVisitorFactory visitor) throws Error { PrintWriter out = new PrintWriter(System.err); Main m = new Main(out, out, false, null, null); File[] fs = recurse(new File(new File(new File(base, "src"), "club"), "wpia"), new LinkedList(), ".java").toArray(new File[0]); @@ -197,11 +224,10 @@ public class TranslationCollector { System.err.println("No types"); } else { - TranslationCollectingVisitor v = new TranslationCollectingVisitor(parsedUnit, taint.toArray(new TaintSource[taint.size()]), this); + ASTVisitor v = visitor.createVisitor(parsedUnit); for (TypeDeclaration td : parsedUnit.types) { td.traverse(v, td.scope); } - hadErrors |= v.hadErrors(); } parsedUnits[i] = parsedUnit; } @@ -226,8 +252,15 @@ public class TranslationCollector { private LinkedList taint; public static void main(String[] args) throws IOException { + Properties pp = new Properties(); + pp.load(new InputStreamReader(new FileInputStream("config/test.properties"), "UTF-8")); + DatabaseConnection.init(pp); TranslationCollector tc = new TranslationCollector(new File(args[1]), new File(args[0])); - tc.run(new File(args[2])); + if (args[2].equals("SQLValidation")) { + tc.runSQLValidation(); + } else { + tc.run(new File(args[2])); + } if (tc.hadErrors) { System.exit(1); } else { @@ -262,4 +295,12 @@ public class TranslationCollector { } return toAdd; } + + public void hadError() { + hadErrors = true; + } + + public void countStatement() { + statements++; + } } -- 2.39.2