001package com.nimbusds.srp6;
002
003
004import java.math.BigInteger;
005import java.security.MessageDigest;
006import java.security.SecureRandom;
007
008
009/**
010 * Secure Remote Password (SRP-6a) routines for computing the various protocol 
011 * variables and messages.
012 *
013 * <p>The routines comply with RFC 5054 (SRP for TLS), with the following 
014 * exceptions:
015 *
016 * <ul>
017 *     <li>The computation of the password key 'x' is modified to omit the user 
018 *         identity 'I' in order to allow for server-side user identity renaming
019 *         as well as authentication with multiple alternate identities. 
020 *     <li>The evidence messages 'M1' and 'M2' are computed according to Tom 
021 *         Wu's paper "SRP-6: Improvements and refinements to the Secure Remote 
022 *         Password protocol", table 5, from 2002.
023 * </ul>
024 *
025 * <p>This class contains portions of code from Bouncy Castle's SRP6 
026 * implementation.
027 *
028 * @author Vladimir Dzhuvinov
029 */
030public class SRP6Routines {
031
032        
033        /**
034         * Computes the SRP-6 multiplier k = H(N | PAD(g))
035         *
036         * <p>Specification: RFC 5054.
037         *
038         * @param digest The hash function 'H'. Must not be {@code null}.
039         * @param N      The prime parameter 'N'. Must not be {@code null}.
040         * @param g      The generator parameter 'g'. Must not be {@code null}.
041         *
042         * @return The resulting multiplier 'k'.
043         */
044        public static BigInteger computeK(final MessageDigest digest, 
045                                          final BigInteger N, 
046                                          final BigInteger g) {
047        
048                return hashPaddedPair(digest, N, N, g);
049        }
050        
051        
052        /**
053         * Generates a random salt 's'.
054         *
055         * @param numBytes The number of bytes the salt 's' must have.
056         *
057         * @return The salt 's' as a byte array.
058         */
059        public static byte[] generateRandomSalt(final int numBytes) {
060        
061                SecureRandom random = new SecureRandom();
062                
063                byte[] salt = new byte[numBytes];
064                
065                random.nextBytes(salt);
066                
067                return salt;
068        }
069        
070        
071        /**
072         * Computes x = H(s | H(P))
073         *
074         * <p>Note that this method differs from the RFC 5054 recommendation 
075         * which includes the user identity 'I', i.e. x = H(s | H(I | ":" | P))
076         *
077         * @param digest   The hash function 'H'. Must not be {@code null}.
078         * @param salt     The salt 's'. Must not be {@code null}.
079         * @param password The user password 'P'. Must not be {@code null}.
080         *
081         * @return The resulting 'x' value.
082         */
083        public static BigInteger computeX(final MessageDigest digest,
084                                          final byte[] salt,
085                                          final byte[] password) {         
086                                            
087                byte[] output = digest.digest(password);
088
089                digest.update(salt);
090                digest.update(output);
091                
092                return BigIntegerUtils.bigIntegerFromBytes(digest.digest());
093        }
094        
095        
096        /**
097         * Computes a verifier v = g^x (mod N)
098         *
099         * <p>Specification: RFC 5054.
100         *
101         * @param N The prime parameter 'N'. Must not be {@code null}.
102         * @param g The generator parameter 'g'. Must not be {@code null}.
103         * @param x The password key 'x', see {@link #computeX}. Must not be 
104         *          {@code null}.
105         *
106         * @return The resulting verifier 'v'.
107         */
108        public static BigInteger computeVerifier(final BigInteger N,
109                                                 final BigInteger g,
110                                                 final BigInteger x) {
111        
112                return g.modPow(x, N);
113        }                  
114        
115        
116        /**
117         * Generates a random SRP-6a client or server private value ('a' or 
118         * 'b') which is 256 bits long.
119         *
120         * <p>Specification: RFC 5054.
121         *
122         * @param N      The prime parameter 'N'. Must not be {@code null}.
123         * @param random Source of randomness. Must not be {@code null}.
124         *
125         * @return The resulting client or server private value ('a' or 'b').
126         */
127        public static BigInteger generatePrivateValue(final BigInteger N,
128                                                      final SecureRandom random) {
129         
130                final int minBits = Math.min(256, N.bitLength() / 2);
131                
132                BigInteger min = BigInteger.ONE.shiftLeft(minBits - 1);
133                BigInteger max = N.subtract(BigInteger.ONE);
134                
135                return createRandomBigIntegerInRange(min, max, random);               
136        }
137        
138        
139        /**
140         * Computes the public client value A = g^a (mod N)
141         *
142         * <p>Specification: RFC 5054.
143         *
144         * @param N The prime parameter 'N'. Must not be {@code null}.
145         * @param g The generator parameter 'g'. Must not be {@code null}.
146         * @param a The private client value 'a'. Must not be {@code null}.
147         *
148         * @return The public client value 'A'.
149         */
150        public static BigInteger computePublicClientValue(final BigInteger N,
151                                                          final BigInteger g,
152                                                          final BigInteger a) {
153                                                            
154                return g.modPow(a, N);
155        }
156        
157        
158        
159        /**
160         * Computes the public server value B = k * v + g^b (mod N)
161         *
162         * <p>Specification: RFC 5054.
163         *
164         * @param N The prime parameter 'N'. Must not be {@code null}.
165         * @param g The generator parameter 'g'. Must not be {@code null}.
166         * @param k The SRP-6a multiplier 'k'. Must not be {@code null}.
167         * @param v The password verifier 'v'. Must not be {@code null}.
168         * @param b The private server value 'b'. Must not be {@code null}.
169         *
170         * @return The public server value 'B'.
171         */
172        public static BigInteger computePublicServerValue(final BigInteger N,
173                                                          final BigInteger g,
174                                                          final BigInteger k,
175                                                          final BigInteger v,
176                                                          final BigInteger b) {
177        
178                // Original from Bouncy Castle, modified:
179                // return k.multiply(v).add(g.modPow(b, N));
180                
181                // Below from http://srp.stanford.edu/demo/demo.html
182                return g.modPow(b, N).add(v.multiply(k)).mod(N);
183        }
184        
185        
186        /**
187         * Validates an SRP6 client or server public value ('A' or 'B').
188         *
189         * <p>Specification: RFC 5054.
190         *
191         * @param N     The prime parameter 'N'. Must not be {@code null}.
192         * @param value The public value ('A' or 'B') to validate.
193         *
194         * @return {@code true} on successful validation, else {@code false}.
195         */
196        public static boolean isValidPublicValue(final BigInteger N,
197                                                 final BigInteger value) {
198                
199                // check that value % N != 0
200                return !value.mod(N).equals(BigInteger.ZERO);
201        }
202        
203        
204        /**
205         * Computes the random scrambling parameter u = H(PAD(A) | PAD(B))
206         *
207         * <p>Specification: RFC 5054.
208         *
209         * @param digest The hash function 'H'. Must not be {@code null}.
210         * @param N      The prime parameter 'N'. Must not be {@code null}.
211         * @param A      The public client value 'A'. Must not be {@code null}.
212         * @param B      The public server value 'B'. Must not be {@code null}.
213         *
214         * @return The resulting 'u' value.
215         */
216        public static BigInteger computeU(final MessageDigest digest, 
217                                          final BigInteger N, 
218                                          final BigInteger A,
219                                          final BigInteger B) {
220                                           
221                                        
222                return hashPaddedPair(digest, N, A, B);
223        }
224        
225        
226        /**
227         * Computes the session key S = (B - k * g^x) ^ (a + u * x) (mod N)
228         * from client-side parameters.
229         * 
230         * <p>Specification: RFC 5054
231         *
232         * @param N The prime parameter 'N'. Must not be {@code null}.
233         * @param g The generator parameter 'g'. Must not be {@code null}.
234         * @param k The SRP-6a multiplier 'k'. Must not be {@code null}.
235         * @param x The 'x' value, see {@link #computeX}. Must not be 
236         *          {@code null}.
237         * @param u The random scrambling parameter 'u'. Must not be 
238         *          {@code null}.
239         * @param a The private client value 'a'. Must not be {@code null}.
240         * @param B The public server value 'B'. Must note be {@code null}.
241         *
242         * @return The resulting session key 'S'.
243         */
244        public static BigInteger computeSessionKey(final BigInteger N,
245                                                   final BigInteger g,
246                                                   final BigInteger k,
247                                                   final BigInteger x,
248                                                   final BigInteger u,
249                                                   final BigInteger a,
250                                                   final BigInteger B) {
251                
252                final BigInteger exp = u.multiply(x).add(a);
253                final BigInteger tmp = g.modPow(x, N).multiply(k);
254                return B.subtract(tmp).modPow(exp, N);
255        }
256        
257        
258        /**
259         * Computes the session key S = (A * v^u) ^ b (mod N) from server-side
260         * parameters.
261         *
262         * <p>Specification: RFC 5054
263         *
264         * @param N The prime parameter 'N'. Must not be {@code null}.
265         * @param v The password verifier 'v'. Must not be {@code null}.
266         * @param u The random scrambling parameter 'u'. Must not be 
267         *          {@code null}.
268         * @param A The public client value 'A'. Must not be {@code null}.
269         * @param b The private server value 'b'. Must not be {@code null}.
270         *
271         * @return The resulting session key 'S'.
272         */
273        public static BigInteger computeSessionKey(final BigInteger N,
274                                                   final BigInteger v,
275                                                   final BigInteger u,
276                                                   final BigInteger A,
277                                                   final BigInteger b) {
278        
279                return v.modPow(u, N).multiply(A).modPow(b, N);
280        }
281        
282        
283        /**
284         * Computes the client evidence message M1 = H(A | B | S)
285         *
286         * <p>Specification: Tom Wu's paper "SRP-6: Improvements and 
287         * refinements to the Secure Remote Password protocol", table 5, from 
288         * 2002.
289         *
290         * @param digest The hash function 'H'. Must not be {@code null}.
291         * @param A      The public client value 'A'. Must not be {@code null}.
292         * @param B      The public server value 'B'. Must note be {@code null}.
293         * @param S      The session key 'S'. Must not be {@code null}.
294         *
295         * @return The resulting client evidence message 'M1'.
296         */
297        public static BigInteger computeClientEvidence(final MessageDigest digest,
298                                                       final BigInteger A,
299                                                       final BigInteger B,
300                                                       final BigInteger S) {
301                
302                digest.update(BigIntegerUtils.bigIntegerToBytes(A));
303                digest.update(BigIntegerUtils.bigIntegerToBytes(B));
304                digest.update(BigIntegerUtils.bigIntegerToBytes(S));
305
306                return BigIntegerUtils.bigIntegerFromBytes(digest.digest());
307        }
308        
309        
310        /**
311         * Computes the server evidence message M2 = H(A | M1 | S)
312         *
313         * <p>Specification: Tom Wu's paper "SRP-6: Improvements and 
314         * refinements to the Secure Remote Password protocol", table 5, from 
315         * 2002.
316         *
317         * @param digest The hash function 'H'. Must not be {@code null}.
318         * @param A      The public client value 'A'. Must not be {@code null}.
319         * @param M1     The client evidence message 'M1'. Must not be 
320         *               {@code null}.
321         * @param S      The session key 'S'. Must not be {@code null}.
322         *
323         * @return The resulting server evidence message 'M2'.
324         */
325        protected static BigInteger computeServerEvidence(final MessageDigest digest,
326                                                          final BigInteger A,
327                                                          final BigInteger M1,
328                                                          final BigInteger S) {
329        
330                digest.update(BigIntegerUtils.bigIntegerToBytes(A));
331                digest.update(BigIntegerUtils.bigIntegerToBytes(M1));
332                digest.update(BigIntegerUtils.bigIntegerToBytes(S));
333                
334                return BigIntegerUtils.bigIntegerFromBytes(digest.digest());
335        }
336        
337        
338        /**
339         * Hashes two padded values 'n1' and 'n2' where the total length is
340         * determined by the size of N.
341         *
342         * <p>H(PAD(n1) | PAD(n2))
343         *
344         * @param digest The hash function 'H'. Must not be {@code null}.
345         * @param N      Its size determines the pad length. Must not be 
346         *               {@code null}.
347         * @param n1     The first value to pad and hash.
348         * @param n2     The second value to pad and hash.
349         *
350         * @return The resulting hashed padded pair.
351         */
352        protected static BigInteger hashPaddedPair(final MessageDigest digest,
353                                                   final BigInteger N,
354                                                   final BigInteger n1,
355                                                   final BigInteger n2) {
356                                                   
357                final int padLength = (N.bitLength() + 7) / 8;
358                
359                byte[] n1_bytes = getPadded(n1, padLength);
360
361                byte[] n2_bytes = getPadded(n2, padLength);
362
363                digest.update(n1_bytes);
364                digest.update(n2_bytes);
365                
366                byte[] output = digest.digest();
367                
368                return BigIntegerUtils.bigIntegerFromBytes(output);
369        }
370        
371        
372        /**
373         * Pads a big integer with leading zeros up to the specified length.
374         *
375         * @param n      The big integer to pad. Must not be {@code null}.
376         * @param length The required length of the padded big integer as a
377         *               byte array.
378         *
379         * @return The padded big integer as a byte array.
380         */
381        protected static byte[] getPadded(final BigInteger n, final int length) {
382
383                byte[] bs = BigIntegerUtils.bigIntegerToBytes(n);
384                
385                if (bs.length < length) {
386                
387                        byte[] tmp = new byte[length];
388                        System.arraycopy(bs, 0, tmp, length - bs.length, bs.length);
389                        bs = tmp;
390                }
391                
392                return bs;
393        }
394
395
396        /**
397         * Returns a random big integer in the specified range [min, max].
398         *
399         * @param min    The least value that may be generated. Must not be
400         *               {@code null}.
401         * @param max    The greatest value that may be generated. Must not be
402         *               {@code null}.
403         * @param random Source of randomness. Must not be {@code null}.
404         *
405         * @return A random big integer in the range [min, max].
406         */
407        protected static BigInteger createRandomBigIntegerInRange(final BigInteger min, 
408                                                                  final BigInteger max,
409                                                                  final SecureRandom random) {
410        
411                final int cmp = min.compareTo(max);
412                
413                if (cmp >= 0) {
414                        
415                        if (cmp > 0)
416                                throw new IllegalArgumentException("'min' may not be greater than 'max'");
417                
418                        return min;
419                }
420
421                if (min.bitLength() > max.bitLength() / 2)
422                        return createRandomBigIntegerInRange(BigInteger.ZERO, max.subtract(min), random).add(min);
423                
424                final int MAX_ITERATIONS = 1000;
425                
426                for (int i = 0; i < MAX_ITERATIONS; ++i) {
427                
428                        BigInteger x = new BigInteger(max.bitLength(), random);
429                        
430                        if (x.compareTo(min) >= 0 && x.compareTo(max) <= 0)
431                                return x;
432                }
433
434                // fall back to a faster (restricted) method
435                return new BigInteger(max.subtract(min).bitLength() - 1, random).add(min);
436        }
437
438        private SRP6Routines() {
439                // empty
440        }
441}