import javax.servlet.ServletRequest;
import javax.servlet.http.HttpServletRequest;
+import javax.servlet.http.HttpSession;
import org.cacert.gigi.Language;
import org.cacert.gigi.pages.Page;
public abstract class Form implements Outputable {
String csrf;
- public Form() {
+
+ public Form(HttpServletRequest hsr) {
csrf = RandomToken.generateToken(32);
+ HttpSession hs = hsr.getSession();
+ hs.setAttribute("form/" + getClass().getName() + "/" + csrf, this);
+
}
public abstract boolean submit(PrintWriter out, HttpServletRequest req);
+
@Override
- public final void output(PrintWriter out, Language l,
- Map<String, Object> vars) {
+ public final void output(PrintWriter out, Language l, Map<String, Object> vars) {
out.println("<form method='POST' autocomplete='off'>");
outputContent(out, l, vars);
- out.println("<input type='csrf' value='");
+ out.print("<input type='hidden' name='csrf' value='");
out.print(getCSRFToken());
out.println("'></form>");
}
- public abstract void outputContent(PrintWriter out, Language l,
- Map<String, Object> vars);
+ protected abstract void outputContent(PrintWriter out, Language l, Map<String, Object> vars);
protected void outputError(PrintWriter out, ServletRequest req, String text) {
out.print("<div>");
out.println("</div>");
}
- public String getCSRFToken() {
+ protected String getCSRFToken() {
return csrf;
}
+ protected void checkCSRF(HttpServletRequest req) {
+ if (!csrf.equals(req.getParameter("csrf"))) {
+ throw new CSRFError();
+ }
+ }
+
+ public static <T extends Form> T getForm(HttpServletRequest req, Class<T> target) {
+ String csrf = req.getParameter("csrf");
+ if (csrf == null) {
+ throw new CSRFError();
+ }
+ HttpSession hs = req.getSession();
+ if (hs == null) {
+ throw new CSRFError();
+ }
+ Form f = (Form) hs.getAttribute("form/" + target.getName() + "/" + csrf);
+ if (f == null) {
+ throw new CSRFError();
+ }
+ return (T) f;
+ }
+
+ public static class CSRFError extends Error {
+
+ }
}