001/*
002 * nimbus-jose-jwt
003 *
004 * Copyright 2012-2016, Connect2id Ltd and contributors.
005 *
006 * Licensed under the Apache License, Version 2.0 (the "License"); you may not use
007 * this file except in compliance with the License. You may obtain a copy of the
008 * License at
009 *
010 *    http://www.apache.org/licenses/LICENSE-2.0
011 *
012 * Unless required by applicable law or agreed to in writing, software distributed
013 * under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
014 * CONDITIONS OF ANY KIND, either express or implied. See the License for the
015 * specific language governing permissions and limitations under the License.
016 */
017
018package com.nimbusds.jose.crypto.impl;
019
020
021import java.io.ByteArrayOutputStream;
022import java.io.IOException;
023import javax.crypto.Mac;
024import javax.crypto.SecretKey;
025import javax.crypto.spec.SecretKeySpec;
026
027import com.nimbusds.jose.JOSEException;
028import com.nimbusds.jose.JWEAlgorithm;
029import com.nimbusds.jose.util.ByteUtils;
030import com.nimbusds.jose.util.IntegerUtils;
031import com.nimbusds.jose.util.StandardCharset;
032
033
034/**
035 * Password-Based Key Derivation Function 2 (PBKDF2) utilities. Provides static
036 * methods to generate Key Encryption Keys (KEK) from passwords. Adopted from
037 * jose4j by Brian Campbell.
038 *
039 * @author Brian Campbell
040 * @author Yavor Vassilev
041 * @author Vladimir Dzhuvinov
042 * @version 2021-07-03
043 */
044public class PBKDF2 {
045        
046        
047        /**
048         * The minimum salt length (8 bytes).
049         */
050        public static final int MIN_SALT_LENGTH = 8;
051
052
053        /**
054         * Zero byte array of length one.
055         */
056        static final byte[] ZERO_BYTE = { 0 };// value of (long) Math.pow(2, 32) - 1;
057        
058        
059        /**
060         * Value of {@code (long) Math.pow(2, 32) - 1;}
061         */
062        static final long MAX_DERIVED_KEY_LENGTH = 4294967295L;
063        
064        
065        /**
066         * Formats the specified cryptographic salt for use in PBKDF2.
067         *
068         * <pre>
069         * UTF8(JWE-alg) || 0x00 || Salt Input
070         * </pre>
071         *
072         * @param alg  The JWE algorithm. Must not be {@code null}.
073         * @param salt The cryptographic salt. Must be at least 8 bytes long.
074         *
075         * @return The formatted salt for use in PBKDF2.
076         *
077         * @throws JOSEException If formatting failed.
078         */
079        public static byte[] formatSalt(final JWEAlgorithm alg, final byte[] salt)
080                throws JOSEException {
081
082                byte[] algBytes = alg.toString().getBytes(StandardCharset.UTF_8);
083                
084                if (salt == null) {
085                        throw new JOSEException("The salt must not be null");
086                }
087                
088                if (salt.length < MIN_SALT_LENGTH) {
089                        throw new JOSEException("The salt must be at least " + MIN_SALT_LENGTH + " bytes long");
090                }
091
092                ByteArrayOutputStream out = new ByteArrayOutputStream();
093                try {
094                        out.write(algBytes);
095                        out.write(ZERO_BYTE);
096                        out.write(salt);
097                } catch (IOException e) {
098                        throw new JOSEException(e.getMessage(), e);
099                }
100
101                return out.toByteArray();
102        }
103
104
105        /**
106         * Derives a PBKDF2 key from the specified password and parameters.
107         *
108         * @param password       The password. Must not be {@code null}.
109         * @param formattedSalt  The formatted cryptographic salt. Must not be
110         *                       {@code null}.
111         * @param iterationCount The iteration count. Must be a positive
112         *                       integer.
113         * @param prfParams      The Pseudo-Random Function (PRF) parameters.
114         *                       Must not be {@code null}.
115         *
116         * @return The derived secret key (with "AES" algorithm).
117         *
118         * @throws JOSEException If the key derivation failed.
119         */
120        public static SecretKey deriveKey(final byte[] password,
121                                          final byte[] formattedSalt,
122                                          final int iterationCount,
123                                          final PRFParams prfParams)
124                throws JOSEException {
125                
126                if (formattedSalt == null) {
127                        throw new JOSEException("The formatted salt must not be null");
128                }
129                
130                if (iterationCount < 1) {
131                        throw new JOSEException("The iteration count must be greater than 0");
132                }
133
134                SecretKey macKey = new SecretKeySpec(password, prfParams.getMACAlgorithm());
135
136                Mac prf = HMAC.getInitMac(macKey, prfParams.getMacProvider());
137
138                int hLen = prf.getMacLength();
139
140                //  1. If dkLen > (2^32 - 1) * hLen, output "derived key too long" and
141                //     stop.
142                if (prfParams.getDerivedKeyByteLength() > MAX_DERIVED_KEY_LENGTH) {
143                        throw new JOSEException("Derived key too long: " + prfParams.getDerivedKeyByteLength());
144                }
145
146                //  2. Let l be the number of hLen-octet blocks in the derived key,
147                //     rounding up, and let r be the number of octets in the last
148                //     block:
149                //
150                //               l = CEIL (dkLen / hLen) ,
151                //               r = dkLen - (l - 1) * hLen .
152                //
153                //     Here, CEIL (x) is the "ceiling" function, i.e. the smallest
154                //     integer greater than, or equal to, x.
155                int l = (int) Math.ceil((double) prfParams.getDerivedKeyByteLength() / (double) hLen);
156                int r = prfParams.getDerivedKeyByteLength() - (l - 1) * hLen;
157
158                //  3. For each block of the derived key apply the function F defined
159                //     below to the password P, the salt S, the iteration count c, and
160                //     the block index to compute the block:
161                //
162                //               T_1 = F (P, S, c, 1) ,
163                //               T_2 = F (P, S, c, 2) ,
164                //               ...
165                //               T_l = F (P, S, c, l) ,
166                //
167                //     where the function F is defined as the exclusive-or sum of the
168                //     first c iterates of the underlying pseudorandom function PRF
169                //     applied to the password P and the concatenation of the salt S
170                //     and the block index i:
171                //
172                //               F (P, S, c, i) = U_1 \xor U_2 \xor ... \xor U_c
173                //
174                //     where
175                //
176                //               U_1 = PRF (P, S || INT (i)) ,
177                //               U_2 = PRF (P, U_1) ,
178                //               ...
179                //               U_c = PRF (P, U_{c-1}) .
180                //
181                //     Here, INT (i) is a four-octet encoding of the integer i, most
182                //     significant octet first.
183
184                //  4. Concatenate the blocks and extract the first dkLen octets to
185                //     produce a derived key DK:
186                //
187                //               DK = T_1 || T_2 ||  ...  || T_l<0..r-1>
188                //
189                ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream();
190                for (int i = 0; i < l; i++) {
191                        byte[] block = extractBlock(formattedSalt, iterationCount, i + 1, prf);
192                        if (i == (l - 1)) {
193                                block = ByteUtils.subArray(block, 0, r);
194                        }
195                        byteArrayOutputStream.write(block, 0, block.length);
196                }
197
198                //  5. Output the derived key DK.
199                return new SecretKeySpec(byteArrayOutputStream.toByteArray(), "AES");
200        }
201
202
203        /**
204         * Block extraction iteration.
205         *
206         * @param formattedSalt  The formatted salt. Must not be {@code null}.
207         * @param iterationCount The iteration count. Must be a positive
208         *                       integer.
209         * @param blockIndex     The block index.
210         * @param prf            The pseudo-random function (HMAC). Must not be
211         *                       {@code null.
212         *
213         * @return The block.
214         *
215         * @throws JOSEException If the block extraction failed.
216         */
217        static byte[] extractBlock(final byte[] formattedSalt, final int iterationCount, final int blockIndex, final Mac prf)
218                throws JOSEException {
219                
220                if (formattedSalt == null) {
221                        throw new JOSEException("The formatted salt must not be null");
222                }
223                
224                if (iterationCount < 1) {
225                        throw new JOSEException("The iteration count must be greater than 0");
226                }
227
228                byte[] currentU;
229                byte[] lastU = null;
230                byte[] xorU = null;
231
232                for (int i = 1; i <= iterationCount; i++)
233                {
234                        byte[] inputBytes;
235                        if (i == 1)
236                        {
237                                inputBytes = ByteUtils.concat(formattedSalt, IntegerUtils.toBytes(blockIndex));
238                                currentU = prf.doFinal(inputBytes);
239                                xorU = currentU;
240                        }
241                        else
242                        {
243                                currentU = prf.doFinal(lastU);
244                                for (int j = 0; j < currentU.length; j++)
245                                {
246                                        xorU[j] = (byte) (currentU[j] ^ xorU[j]);
247                                }
248                        }
249
250                        lastU = currentU;
251                }
252                return xorU;
253        }
254
255
256        /**
257         * Prevents public instantiation.
258         */
259        private PBKDF2() {
260
261        }
262}