]> WPIA git - gigi.git/blobdiff - src/org/cacert/gigi/output/Form.java
Build out certificate issuing.
[gigi.git] / src / org / cacert / gigi / output / Form.java
index 69fb22878e3fa6a76b8944e8c0d586c62717d870..2ffb873171216efa91148d7798ab3b69e7029638 100644 (file)
@@ -5,30 +5,35 @@ import java.util.Map;
 
 import javax.servlet.ServletRequest;
 import javax.servlet.http.HttpServletRequest;
+import javax.servlet.http.HttpSession;
 
 import org.cacert.gigi.Language;
 import org.cacert.gigi.pages.Page;
 import org.cacert.gigi.util.RandomToken;
 
 public abstract class Form implements Outputable {
+       public static final String CSRF_FIELD = "csrf";
        String csrf;
-       public Form() {
+
+       public Form(HttpServletRequest hsr) {
                csrf = RandomToken.generateToken(32);
+               HttpSession hs = hsr.getSession();
+               hs.setAttribute("form/" + getClass().getName() + "/" + csrf, this);
+
        }
 
        public abstract boolean submit(PrintWriter out, HttpServletRequest req);
+
        @Override
-       public final void output(PrintWriter out, Language l,
-                       Map<String, Object> vars) {
+       public final void output(PrintWriter out, Language l, Map<String, Object> vars) {
                out.println("<form method='POST' autocomplete='off'>");
                outputContent(out, l, vars);
-               out.print("<input type='csrf' value='");
+               out.print("<input type='hidden' name='" + CSRF_FIELD + "' value='");
                out.print(getCSRFToken());
                out.println("'></form>");
        }
 
-       protected abstract void outputContent(PrintWriter out, Language l,
-                       Map<String, Object> vars);
+       protected abstract void outputContent(PrintWriter out, Language l, Map<String, Object> vars);
 
        protected void outputError(PrintWriter out, ServletRequest req, String text) {
                out.print("<div>");
@@ -39,13 +44,30 @@ public abstract class Form implements Outputable {
        protected String getCSRFToken() {
                return csrf;
        }
+
        protected void checkCSRF(HttpServletRequest req) {
-               if (!csrf.equals(req.getParameter("csrf"))) {
+               if (!csrf.equals(req.getParameter(CSRF_FIELD))) {
+                       throw new CSRFError();
+               }
+       }
+
+       public static <T extends Form> T getForm(HttpServletRequest req, Class<T> target) {
+               String csrf = req.getParameter(CSRF_FIELD);
+               if (csrf == null) {
+                       throw new CSRFError();
+               }
+               HttpSession hs = req.getSession();
+               if (hs == null) {
+                       throw new CSRFError();
+               }
+               Form f = (Form) hs.getAttribute("form/" + target.getName() + "/" + csrf);
+               if (f == null) {
                        throw new CSRFError();
                }
+               return (T) f;
        }
 
-       public class CSRFError extends Error {
+       public static class CSRFError extends Error {
 
        }
 }