001/*
002 * nimbus-jose-jwt
003 *
004 * Copyright 2012-2016, Connect2id Ltd and contributors.
005 *
006 * Licensed under the Apache License, Version 2.0 (the "License"); you may not use
007 * this file except in compliance with the License. You may obtain a copy of the
008 * License at
009 *
010 *    http://www.apache.org/licenses/LICENSE-2.0
011 *
012 * Unless required by applicable law or agreed to in writing, software distributed
013 * under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
014 * CONDITIONS OF ANY KIND, either express or implied. See the License for the
015 * specific language governing permissions and limitations under the License.
016 */
017
018package com.nimbusds.jose.crypto.impl;
019
020
021import java.io.ByteArrayOutputStream;
022import java.io.IOException;
023import java.security.MessageDigest;
024import java.security.NoSuchAlgorithmException;
025import java.security.Provider;
026import javax.crypto.SecretKey;
027import javax.crypto.spec.SecretKeySpec;
028
029import com.nimbusds.jose.JOSEException;
030import com.nimbusds.jose.jca.JCAAware;
031import com.nimbusds.jose.jca.JCAContext;
032import com.nimbusds.jose.util.Base64URL;
033import com.nimbusds.jose.util.ByteUtils;
034import com.nimbusds.jose.util.IntegerUtils;
035import com.nimbusds.jose.util.StandardCharset;
036import net.jcip.annotations.ThreadSafe;
037
038
039/**
040 * Concatenation Key Derivation Function (KDF). This class is thread-safe.
041 *
042 * <p>See NIST.800-56A.
043 *
044 * @author Vladimir Dzhuvinov
045 * @version 2017-06-01
046 */
047@ThreadSafe
048public class ConcatKDF implements JCAAware<JCAContext> {
049
050
051        /**
052         * The JCA name of the hash algorithm.
053         */
054        private final String jcaHashAlg;
055
056
057        /**
058         * The JCA context..
059         */
060        private final JCAContext jcaContext = new JCAContext();
061
062
063        /**
064         * Creates a new concatenation Key Derivation Function (KDF) with the
065         * specified hash algorithm.
066         *
067         * @param jcaHashAlg The JCA name of the hash algorithm. Must be
068         *                   supported and not {@code null}.
069         */
070        public ConcatKDF(final String jcaHashAlg) {
071
072                if (jcaHashAlg == null) {
073                        throw new IllegalArgumentException("The JCA hash algorithm must not be null");
074                }
075
076                this.jcaHashAlg = jcaHashAlg;
077        }
078
079
080        /**
081         * Returns the JCA name of the hash algorithm.
082         *
083         * @return The JCA name of the hash algorithm.
084         */
085        public String getHashAlgorithm() {
086
087                return jcaHashAlg;
088        }
089
090
091        @Override
092        public JCAContext getJCAContext() {
093
094                return jcaContext;
095        }
096
097
098        /**
099         * Derives a key from the specified inputs.
100         *
101         * @param sharedSecret  The shared secret. Must not be {@code null}.
102         * @param keyLengthBits The length of the key to derive, in bits.
103         * @param otherInfo     Other info, {@code null} if not specified.
104         *
105         * @return The derived key, with algorithm set to "AES".
106         *
107         * @throws JOSEException If the key derivation failed.
108         */
109        public SecretKey deriveKey(final SecretKey sharedSecret,
110                                   final int keyLengthBits,
111                                   final byte[] otherInfo)
112                throws JOSEException {
113
114                ByteArrayOutputStream baos = new ByteArrayOutputStream();
115
116                final MessageDigest md = getMessageDigest();
117
118                for (int i=1; i <= computeDigestCycles(ByteUtils.safeBitLength(md.getDigestLength()), keyLengthBits); i++) {
119
120                        byte[] counterBytes = IntegerUtils.toBytes(i);
121
122                        md.update(counterBytes);
123                        md.update(sharedSecret.getEncoded());
124
125                        if (otherInfo != null) {
126                                md.update(otherInfo);
127                        }
128
129                        try {
130                                baos.write(md.digest());
131                        } catch (IOException e) {
132                                throw new JOSEException("Couldn't write derived key: " + e.getMessage(), e);
133                        }
134                }
135
136                byte[] derivedKeyMaterial = baos.toByteArray();
137
138                final int keyLengthBytes = ByteUtils.byteLength(keyLengthBits);
139
140                if (derivedKeyMaterial.length == keyLengthBytes) {
141                        // Return immediately
142                        return new SecretKeySpec(derivedKeyMaterial, "AES");
143                }
144
145                return new SecretKeySpec(ByteUtils.subArray(derivedKeyMaterial, 0, keyLengthBytes), "AES");
146        }
147
148
149        /**
150         * Derives a key from the specified inputs.
151         *
152         * @param sharedSecret The shared secret. Must not be {@code null}.
153         * @param keyLength    The length of the key to derive, in bits.
154         * @param algID        The algorithm identifier, {@code null} if not
155         *                     specified.
156         * @param partyUInfo   The partyUInfo, {@code null} if not specified.
157         * @param partyVInfo   The partyVInfo {@code null} if not specified.
158         * @param suppPubInfo  The suppPubInfo, {@code null} if not specified.
159         * @param suppPrivInfo The suppPrivInfo, {@code null} if not specified.
160         *
161         *  @return The derived key, with algorithm set to "AES".
162         *
163         * @throws JOSEException If the key derivation failed.
164         */
165        public SecretKey deriveKey(final SecretKey sharedSecret,
166                                   final int keyLength,
167                                   final byte[] algID,
168                                   final byte[] partyUInfo,
169                                   final byte[] partyVInfo,
170                                   final byte[] suppPubInfo,
171                                   final byte[] suppPrivInfo)
172                throws JOSEException {
173
174                final byte[] otherInfo = composeOtherInfo(algID, partyUInfo, partyVInfo, suppPubInfo, suppPrivInfo);
175
176                return deriveKey(sharedSecret, keyLength, otherInfo);
177        }
178
179        /**
180         * Derives a key from the specified inputs.
181         *
182         * @param sharedSecret The shared secret. Must not be {@code null}.
183         * @param keyLength    The length of the key to derive, in bits.
184         * @param algID        The algorithm identifier, {@code null} if not
185         *                     specified.
186         * @param partyUInfo   The partyUInfo, {@code null} if not specified.
187         * @param partyVInfo   The partyVInfo {@code null} if not specified.
188         * @param suppPubInfo  The suppPubInfo, {@code null} if not specified.
189         * @param suppPrivInfo The suppPrivInfo, {@code null} if not specified.
190         *
191         * @return The derived key, with algorithm set to "AES".
192         *
193         * @throws JOSEException If the key derivation failed.
194         */
195        public SecretKey deriveKey(final SecretKey sharedSecret,
196                   final int keyLength,
197                   final byte[] algID,
198                   final byte[] partyUInfo,
199                   final byte[] partyVInfo,
200                   final byte[] suppPubInfo,
201                   final byte[] suppPrivInfo,
202                   final byte[] tag)
203                        throws JOSEException {
204
205                final byte[] otherInfo = composeOtherInfo(algID, partyUInfo, partyVInfo, suppPubInfo, suppPrivInfo, tag);
206
207                return deriveKey(sharedSecret, keyLength, otherInfo);
208        }
209
210        /**
211         * Composes the other info as {@code algID || partyUInfo || partyVInfo
212         * || suppPubInfo || suppPrivInfo}.
213         *
214         * @param algID        The algorithm identifier, {@code null} if not
215         *                     specified.
216         * @param partyUInfo   The partyUInfo, {@code null} if not specified.
217         * @param partyVInfo   The partyVInfo {@code null} if not specified.
218         * @param suppPubInfo  The suppPubInfo, {@code null} if not specified.
219         * @param suppPrivInfo The suppPrivInfo, {@code null} if not specified.
220         *
221         * @return The resulting other info.
222         */
223        public static byte[] composeOtherInfo(final byte[] algID,
224                                              final byte[] partyUInfo,
225                                              final byte[] partyVInfo,
226                                              final byte[] suppPubInfo,
227                                              final byte[] suppPrivInfo) {
228
229                return ByteUtils.concat(algID, partyUInfo, partyVInfo, suppPubInfo, suppPrivInfo);
230        }
231
232        /**
233         * Composes the other info as {@code algID || partyUInfo || partyVInfo
234         * || suppPubInfo || suppPrivInfo || tag}.
235         *
236         * @param algID        The algorithm identifier, {@code null} if not
237         *                     specified.
238         * @param partyUInfo   The partyUInfo, {@code null} if not specified.
239         * @param partyVInfo   The partyVInfo {@code null} if not specified.
240         * @param suppPubInfo  The suppPubInfo, {@code null} if not specified.
241         * @param suppPrivInfo The suppPrivInfo, {@code null} if not specified.
242         * @param tag          The cctag, {@code null} if not specified.
243         *
244         * @return The resulting other info.
245         */
246        public static byte[] composeOtherInfo(final byte[] algID,
247                                                  final byte[] partyUInfo,
248                                                  final byte[] partyVInfo,
249                                                  final byte[] suppPubInfo,
250                                                  final byte[] suppPrivInfo,
251                                                  final byte[] tag) {
252
253                return ByteUtils.concat(algID, partyUInfo, partyVInfo, suppPubInfo, suppPrivInfo, tag);
254        }
255
256
257        /**
258         * Returns a message digest instance for the configured
259         * {@link #jcaHashAlg hash algorithm}.
260         *
261         * @return The message digest instance.
262         *
263         * @throws JOSEException If the message digest algorithm is not
264         *                       supported by the underlying JCA provider.
265         */
266        private MessageDigest getMessageDigest()
267                throws JOSEException {
268
269                final Provider provider = getJCAContext().getProvider();
270
271                try {
272                        if (provider == null)
273                                return MessageDigest.getInstance(jcaHashAlg);
274                        else
275                                return MessageDigest.getInstance(jcaHashAlg, provider);
276                } catch (NoSuchAlgorithmException e) {
277                        throw new JOSEException("Couldn't get message digest for KDF: " + e.getMessage(), e);
278                }
279        }
280
281
282        /**
283         * Computes the required digest (hashing) cycles for the specified
284         * message digest length and derived key length.
285         *
286         * @param digestLengthBits The length of the message digest, in bits.
287         * @param keyLengthBits    The length of the derived key, in bits.
288         *
289         * @return The digest cycles.
290         */
291        public static int computeDigestCycles(final int digestLengthBits, final int keyLengthBits) {
292
293                // return the ceiling of keyLength / digestLength
294                
295                return (keyLengthBits + digestLengthBits - 1) / digestLengthBits;
296        }
297
298
299        /**
300         * Encodes no / empty data as an empty byte array.
301         *
302         * @return The encoded data.
303         */
304        public static byte[] encodeNoData() {
305
306                return new byte[0];
307        }
308
309
310        /**
311         * Encodes the specified integer data as a four byte array.
312         *
313         * @param data The integer data to encode.
314         *
315         * @return The encoded data.
316         */
317        public static byte[] encodeIntData(final int data) {
318
319                return IntegerUtils.toBytes(data);
320        }
321
322
323        /**
324         * Encodes the specified string data as {@code data.length || data}.
325         *
326         * @param data The string data, UTF-8 encoded. May be {@code null}.
327         *
328         * @return The encoded data.
329         */
330        public static byte[] encodeStringData(final String data) {
331
332                byte[] bytes = data != null ? data.getBytes(StandardCharset.UTF_8) : null;
333                return encodeDataWithLength(bytes);
334        }
335
336
337        /**
338         * Encodes the specified data as {@code data.length || data}.
339         *
340         * @param data The data to encode, may be {@code null}.
341         *
342         * @return The encoded data.
343         */
344        public static byte[] encodeDataWithLength(final byte[] data) {
345
346                byte[] bytes = data != null ? data : new byte[0];
347                byte[] length = IntegerUtils.toBytes(bytes.length);
348                return ByteUtils.concat(length, bytes);
349        }
350
351
352        /**
353         * Encodes the specified BASE64URL encoded data
354         * {@code data.length || data}.
355         *
356         * @param data The data to encode, may be {@code null}.
357         *
358         * @return The encoded data.
359         */
360        public static byte[] encodeDataWithLength(final Base64URL data) {
361
362                byte[] bytes = data != null ? data.decode() : null;
363                return encodeDataWithLength(bytes);
364        }
365}
366