001/*
002 * nimbus-jose-jwt
003 *
004 * Copyright 2012-2016, Connect2id Ltd and contributors.
005 *
006 * Licensed under the Apache License, Version 2.0 (the "License"); you may not use
007 * this file except in compliance with the License. You may obtain a copy of the
008 * License at
009 *
010 *    http://www.apache.org/licenses/LICENSE-2.0
011 *
012 * Unless required by applicable law or agreed to in writing, software distributed
013 * under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
014 * CONDITIONS OF ANY KIND, either express or implied. See the License for the
015 * specific language governing permissions and limitations under the License.
016 */
017
018package com.nimbusds.jose.crypto.impl;
019
020
021import java.io.ByteArrayOutputStream;
022import java.io.IOException;
023import javax.crypto.Mac;
024import javax.crypto.SecretKey;
025import javax.crypto.spec.SecretKeySpec;
026
027import com.nimbusds.jose.JOSEException;
028import com.nimbusds.jose.JWEAlgorithm;
029import com.nimbusds.jose.util.ByteUtils;
030import com.nimbusds.jose.util.IntegerUtils;
031import com.nimbusds.jose.util.StandardCharset;
032
033
034/**
035 * Password-Based Key Derivation Function 2 (PBKDF2) utilities. Provides static
036 * methods to generate Key Encryption Keys (KEK) from passwords. Adopted from
037 * jose4j by Brian Campbell.
038 *
039 * @author Brian Campbell
040 * @author Yavor Vassilev
041 * @version 2016-07-26
042 */
043public class PBKDF2 {
044
045
046        /**
047         * Zero byte array of length one.
048         */
049        public static final byte[] ZERO_BYTE = { 0 };
050
051
052        /**
053         * Formats the specified cryptographic salt for use in PBKDF2.
054         *
055         * <pre>
056         * UTF8(JWE-alg) || 0x00 || Salt Input
057         * </pre>
058         *
059         * @param alg  The JWE algorithm. Must not be {@code null}.
060         * @param salt The cryptographic salt. Must not be empty or null.
061         *
062         * @return The formatted salt for use in PBKDF2.
063         */
064        public static byte[] formatSalt(final JWEAlgorithm alg, final byte[] salt)
065                throws JOSEException {
066
067                byte[] algBytes = alg.toString().getBytes(StandardCharset.UTF_8);
068
069                ByteArrayOutputStream out = new ByteArrayOutputStream();
070
071                try {
072                        out.write(algBytes);
073                        out.write(ZERO_BYTE);
074                        out.write(salt);
075
076                } catch (IOException e) {
077
078                        throw new JOSEException(e.getMessage(), e);
079                }
080
081                return out.toByteArray();
082        }
083
084
085        /**
086         * Derives a PBKDF2 key from the specified password and parameters.
087         *
088         * @param password       The password. Must not be {@code null}.
089         * @param formattedSalt  The formatted cryptographic salt. Must not be
090         *                       {@code null}.
091         * @param iterationCount The iteration count. Must be positive.
092         * @param prfParams      The Pseudo-Random Function (PRF) parameters.
093         *                       Must not be {@code null}.
094         *
095         * @return The derived secret key (with "AES" algorithm).
096         *
097         * @throws JOSEException If the key derivation failed.
098         */
099        public static SecretKey deriveKey(final byte[] password,
100                                          final byte[] formattedSalt,
101                                          final int iterationCount,
102                                          final PRFParams prfParams)
103                throws JOSEException {
104
105                SecretKey macKey = new SecretKeySpec(password, prfParams.getMACAlgorithm());
106
107                Mac prf = HMAC.getInitMac(macKey, prfParams.getMacProvider());
108
109                int hLen = prf.getMacLength();
110
111                //  1. If dkLen > (2^32 - 1) * hLen, output "derived key too long" and
112                //     stop.
113                long maxDerivedKeyLength = 4294967295L; // value of (long) Math.pow(2, 32) - 1;
114                if (prfParams.getDerivedKeyByteLength() > maxDerivedKeyLength) {
115                        throw new JOSEException("derived key too long " + prfParams.getDerivedKeyByteLength());
116                }
117
118                //  2. Let l be the number of hLen-octet blocks in the derived key,
119                //     rounding up, and let r be the number of octets in the last
120                //     block:
121                //
122                //               l = CEIL (dkLen / hLen) ,
123                //               r = dkLen - (l - 1) * hLen .
124                //
125                //     Here, CEIL (x) is the "ceiling" function, i.e. the smallest
126                //     integer greater than, or equal to, x.
127                int l = (int) Math.ceil((double) prfParams.getDerivedKeyByteLength() / (double) hLen);
128                int r = prfParams.getDerivedKeyByteLength() - (l - 1) * hLen;
129
130                //  3. For each block of the derived key apply the function F defined
131                //     below to the password P, the salt S, the iteration count c, and
132                //     the block index to compute the block:
133                //
134                //               T_1 = F (P, S, c, 1) ,
135                //               T_2 = F (P, S, c, 2) ,
136                //               ...
137                //               T_l = F (P, S, c, l) ,
138                //
139                //     where the function F is defined as the exclusive-or sum of the
140                //     first c iterates of the underlying pseudorandom function PRF
141                //     applied to the password P and the concatenation of the salt S
142                //     and the block index i:
143                //
144                //               F (P, S, c, i) = U_1 \xor U_2 \xor ... \xor U_c
145                //
146                //     where
147                //
148                //               U_1 = PRF (P, S || INT (i)) ,
149                //               U_2 = PRF (P, U_1) ,
150                //               ...
151                //               U_c = PRF (P, U_{c-1}) .
152                //
153                //     Here, INT (i) is a four-octet encoding of the integer i, most
154                //     significant octet first.
155
156                //  4. Concatenate the blocks and extract the first dkLen octets to
157                //     produce a derived key DK:
158                //
159                //               DK = T_1 || T_2 ||  ...  || T_l<0..r-1>
160                //
161                ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream();
162                for (int i = 0; i < l; i++) {
163                        byte[] block = extractBlock(formattedSalt, iterationCount, i + 1, prf);
164                        if (i == (l - 1)) {
165                                block = ByteUtils.subArray(block, 0, r);
166                        }
167                        byteArrayOutputStream.write(block, 0, block.length);
168                }
169
170                //  5. Output the derived key DK.
171                return new SecretKeySpec(byteArrayOutputStream.toByteArray(), "AES");
172        }
173
174
175        /**
176         * Block extraction iteration.
177         *
178         * @param salt           The cryptographic salt. Must not be
179         *                       {@code null}.
180         * @param iterationCount The iteration count.
181         * @param blockIndex     The block index.
182         * @param prf            The pseudo-random function (HMAC). Must not be
183         *                       {@code null.
184         *
185         * @return The block.
186         */
187        private static byte[] extractBlock(byte[] salt, int iterationCount, int blockIndex, Mac prf) {
188
189                byte[] currentU;
190                byte[] lastU = null;
191                byte[] xorU = null;
192
193                for (int i = 1; i <= iterationCount; i++)
194                {
195                        byte[] inputBytes;
196                        if (i == 1)
197                        {
198                                inputBytes = ByteUtils.concat(salt, IntegerUtils.toBytes(blockIndex));
199                                currentU = prf.doFinal(inputBytes);
200                                xorU = currentU;
201                        }
202                        else
203                        {
204                                currentU = prf.doFinal(lastU);
205                                for (int j = 0; j < currentU.length; j++)
206                                {
207                                        xorU[j] = (byte) (currentU[j] ^ xorU[j]);
208                                }
209                        }
210
211                        lastU = currentU;
212                }
213                return xorU;
214        }
215
216
217        /**
218         * Prevents public instantiation.
219         */
220        private PBKDF2() {
221
222        }
223}