Guard division by zero in SCrypt verification.
[gigi.git] / lib / scrypt / com / lambdaworks / crypto / SCrypt.java
1 // Copyright (C) 2011 - Will Glozer. All rights reserved.
2
3 package com.lambdaworks.crypto;
4
5 import static java.lang.Integer.*;
6 import static java.lang.System.*;
7
8 import java.security.GeneralSecurityException;
9
10 import javax.crypto.Mac;
11 import javax.crypto.spec.SecretKeySpec;
12
13 /**
14  * An implementation of the <a
15  * href="http://www.tarsnap.com/scrypt/scrypt.pdf"/>scrypt</a> key derivation
16  * function. This class will attempt to load a native library containing the
17  * optimized C implementation from <a
18  * href="http://www.tarsnap.com/scrypt.html">http
19  * ://www.tarsnap.com/scrypt.html<a> and fall back to the pure Java version if
20  * that fails.
21  *
22  * @author Will Glozer
23  */
24 public class SCrypt {
25
26     private static final boolean native_library_loaded;
27
28     static {
29         // do not load native library
30         native_library_loaded = false;
31     }
32
33     /**
34      * Implementation of the <a
35      * href="http://www.tarsnap.com/scrypt/scrypt.pdf"/>scrypt KDF</a>. Calls
36      * the native implementation {@link #scryptN} when the native library was
37      * successfully loaded, otherwise calls {@link #scryptJ}.
38      *
39      * @param passwd
40      *            Password.
41      * @param salt
42      *            Salt.
43      * @param N
44      *            CPU cost parameter.
45      * @param r
46      *            Memory cost parameter.
47      * @param p
48      *            Parallelization parameter.
49      * @param dkLen
50      *            Intended length of the derived key.
51      * @return The derived key.
52      * @throws GeneralSecurityException
53      *             when HMAC_SHA256 is not available.
54      */
55     public static byte[] scrypt(byte[] passwd, byte[] salt, int N, int r, int p, int dkLen) throws GeneralSecurityException {
56         return native_library_loaded ? scryptN(passwd, salt, N, r, p, dkLen) : scryptJ(passwd, salt, N, r, p, dkLen);
57     }
58
59     /**
60      * Native C implementation of the <a
61      * href="http://www.tarsnap.com/scrypt/scrypt.pdf"/>scrypt KDF</a> using the
62      * code from <a
63      * href="http://www.tarsnap.com/scrypt.html">http://www.tarsnap.
64      * com/scrypt.html<a>.
65      *
66      * @param passwd
67      *            Password.
68      * @param salt
69      *            Salt.
70      * @param N
71      *            CPU cost parameter.
72      * @param r
73      *            Memory cost parameter.
74      * @param p
75      *            Parallelization parameter.
76      * @param dkLen
77      *            Intended length of the derived key.
78      * @return The derived key.
79      */
80     public static native byte[] scryptN(byte[] passwd, byte[] salt, int N, int r, int p, int dkLen);
81
82     /**
83      * Pure Java implementation of the <a
84      * href="http://www.tarsnap.com/scrypt/scrypt.pdf"/>scrypt KDF</a>.
85      *
86      * @param passwd
87      *            Password.
88      * @param salt
89      *            Salt.
90      * @param N
91      *            CPU cost parameter.
92      * @param r
93      *            Memory cost parameter.
94      * @param p
95      *            Parallelization parameter.
96      * @param dkLen
97      *            Intended length of the derived key.
98      * @return The derived key.
99      * @throws GeneralSecurityException
100      *             when HMAC_SHA256 is not available.
101      */
102     public static byte[] scryptJ(byte[] passwd, byte[] salt, int N, int r, int p, int dkLen) throws GeneralSecurityException {
103         if (N < 2 || (N & (N - 1)) != 0) {
104             throw new IllegalArgumentException("N must be a power of 2 greater than 1");
105         }
106         if (r <= 0) {
107             throw new IllegalArgumentException("Parameter r zero or negative");
108         }
109         if (p <= 0) {
110             throw new IllegalArgumentException("Parameter p zero or negative");
111         }
112
113         if (N > MAX_VALUE / 128 / r) {
114             throw new IllegalArgumentException("Parameter N is too large");
115         }
116         if (r > MAX_VALUE / 128 / p) {
117             throw new IllegalArgumentException("Parameter r is too large");
118         }
119
120         Mac mac = Mac.getInstance("HmacSHA256");
121         mac.init(new SecretKeySpec(passwd, "HmacSHA256"));
122
123         byte[] DK = new byte[dkLen];
124
125         byte[] B = new byte[128 * r * p];
126         byte[] XY = new byte[256 * r];
127         byte[] V = new byte[128 * r * N];
128         int i;
129
130         PBKDF.pbkdf2(mac, salt, 1, B, p * 128 * r);
131
132         for (i = 0; i < p; i++) {
133             smix(B, i * 128 * r, r, N, V, XY);
134         }
135
136         PBKDF.pbkdf2(mac, B, 1, DK, dkLen);
137
138         return DK;
139     }
140
141     public static void smix(byte[] B, int Bi, int r, int N, byte[] V, byte[] XY) {
142         int Xi = 0;
143         int Yi = 128 * r;
144         int i;
145
146         arraycopy(B, Bi, XY, Xi, 128 * r);
147
148         for (i = 0; i < N; i++) {
149             arraycopy(XY, Xi, V, i * (128 * r), 128 * r);
150             blockmix_salsa8(XY, Xi, Yi, r);
151         }
152
153         for (i = 0; i < N; i++) {
154             int j = integerify(XY, Xi, r) & (N - 1);
155             blockxor(V, j * (128 * r), XY, Xi, 128 * r);
156             blockmix_salsa8(XY, Xi, Yi, r);
157         }
158
159         arraycopy(XY, Xi, B, Bi, 128 * r);
160     }
161
162     public static void blockmix_salsa8(byte[] BY, int Bi, int Yi, int r) {
163         byte[] X = new byte[64];
164         int i;
165
166         arraycopy(BY, Bi + (2 * r - 1) * 64, X, 0, 64);
167
168         for (i = 0; i < 2 * r; i++) {
169             blockxor(BY, i * 64, X, 0, 64);
170             salsa20_8(X);
171             arraycopy(X, 0, BY, Yi + (i * 64), 64);
172         }
173
174         for (i = 0; i < r; i++) {
175             arraycopy(BY, Yi + (i * 2) * 64, BY, Bi + (i * 64), 64);
176         }
177
178         for (i = 0; i < r; i++) {
179             arraycopy(BY, Yi + (i * 2 + 1) * 64, BY, Bi + (i + r) * 64, 64);
180         }
181     }
182
183     public static int R(int a, int b) {
184         return (a << b) | (a >>> (32 - b));
185     }
186
187     public static void salsa20_8(byte[] B) {
188         int[] B32 = new int[16];
189         int[] x = new int[16];
190         int i;
191
192         for (i = 0; i < 16; i++) {
193             B32[i] = (B[i * 4 + 0] & 0xff) << 0;
194             B32[i] |= (B[i * 4 + 1] & 0xff) << 8;
195             B32[i] |= (B[i * 4 + 2] & 0xff) << 16;
196             B32[i] |= (B[i * 4 + 3] & 0xff) << 24;
197         }
198
199         arraycopy(B32, 0, x, 0, 16);
200
201         for (i = 8; i > 0; i -= 2) {
202             x[4] ^= R(x[0] + x[12], 7);
203             x[8] ^= R(x[4] + x[0], 9);
204             x[12] ^= R(x[8] + x[4], 13);
205             x[0] ^= R(x[12] + x[8], 18);
206             x[9] ^= R(x[5] + x[1], 7);
207             x[13] ^= R(x[9] + x[5], 9);
208             x[1] ^= R(x[13] + x[9], 13);
209             x[5] ^= R(x[1] + x[13], 18);
210             x[14] ^= R(x[10] + x[6], 7);
211             x[2] ^= R(x[14] + x[10], 9);
212             x[6] ^= R(x[2] + x[14], 13);
213             x[10] ^= R(x[6] + x[2], 18);
214             x[3] ^= R(x[15] + x[11], 7);
215             x[7] ^= R(x[3] + x[15], 9);
216             x[11] ^= R(x[7] + x[3], 13);
217             x[15] ^= R(x[11] + x[7], 18);
218             x[1] ^= R(x[0] + x[3], 7);
219             x[2] ^= R(x[1] + x[0], 9);
220             x[3] ^= R(x[2] + x[1], 13);
221             x[0] ^= R(x[3] + x[2], 18);
222             x[6] ^= R(x[5] + x[4], 7);
223             x[7] ^= R(x[6] + x[5], 9);
224             x[4] ^= R(x[7] + x[6], 13);
225             x[5] ^= R(x[4] + x[7], 18);
226             x[11] ^= R(x[10] + x[9], 7);
227             x[8] ^= R(x[11] + x[10], 9);
228             x[9] ^= R(x[8] + x[11], 13);
229             x[10] ^= R(x[9] + x[8], 18);
230             x[12] ^= R(x[15] + x[14], 7);
231             x[13] ^= R(x[12] + x[15], 9);
232             x[14] ^= R(x[13] + x[12], 13);
233             x[15] ^= R(x[14] + x[13], 18);
234         }
235
236         for (i = 0; i < 16; ++i) {
237             B32[i] = x[i] + B32[i];
238         }
239
240         for (i = 0; i < 16; i++) {
241             B[i * 4 + 0] = (byte) (B32[i] >> 0 & 0xff);
242             B[i * 4 + 1] = (byte) (B32[i] >> 8 & 0xff);
243             B[i * 4 + 2] = (byte) (B32[i] >> 16 & 0xff);
244             B[i * 4 + 3] = (byte) (B32[i] >> 24 & 0xff);
245         }
246     }
247
248     public static void blockxor(byte[] S, int Si, byte[] D, int Di, int len) {
249         for (int i = 0; i < len; i++) {
250             D[Di + i] ^= S[Si + i];
251         }
252     }
253
254     public static int integerify(byte[] B, int Bi, int r) {
255         int n;
256
257         Bi += (2 * r - 1) * 64;
258
259         n = (B[Bi + 0] & 0xff) << 0;
260         n |= (B[Bi + 1] & 0xff) << 8;
261         n |= (B[Bi + 2] & 0xff) << 16;
262         n |= (B[Bi + 3] & 0xff) << 24;
263
264         return n;
265     }
266 }