001/*
002 * nimbus-jose-jwt
003 *
004 * Copyright 2012-2016, Connect2id Ltd and contributors.
005 *
006 * Licensed under the Apache License, Version 2.0 (the "License"); you may not use
007 * this file except in compliance with the License. You may obtain a copy of the
008 * License at
009 *
010 *    http://www.apache.org/licenses/LICENSE-2.0
011 *
012 * Unless required by applicable law or agreed to in writing, software distributed
013 * under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
014 * CONDITIONS OF ANY KIND, either express or implied. See the License for the
015 * specific language governing permissions and limitations under the License.
016 */
017
018package com.nimbusds.jose.crypto.impl;
019
020
021import java.io.ByteArrayOutputStream;
022import java.io.IOException;
023import java.security.MessageDigest;
024import java.security.NoSuchAlgorithmException;
025import java.security.Provider;
026import javax.crypto.SecretKey;
027import javax.crypto.spec.SecretKeySpec;
028
029import com.nimbusds.jose.JOSEException;
030import com.nimbusds.jose.jca.JCAAware;
031import com.nimbusds.jose.jca.JCAContext;
032import com.nimbusds.jose.util.Base64URL;
033import com.nimbusds.jose.util.ByteUtils;
034import com.nimbusds.jose.util.IntegerUtils;
035import com.nimbusds.jose.util.StandardCharset;
036import net.jcip.annotations.ThreadSafe;
037
038
039/**
040 * Concatenation Key Derivation Function (KDF). This class is thread-safe.
041 *
042 * <p>See NIST.800-56A.
043 *
044 * @author Vladimir Dzhuvinov
045 * @version 2017-06-01
046 */
047@ThreadSafe
048public class ConcatKDF implements JCAAware<JCAContext> {
049
050
051        /**
052         * The JCA name of the hash algorithm.
053         */
054        private final String jcaHashAlg;
055
056
057        /**
058         * The JCA context..
059         */
060        private final JCAContext jcaContext = new JCAContext();
061
062
063        /**
064         * Creates a new concatenation Key Derivation Function (KDF) with the
065         * specified hash algorithm.
066         *
067         * @param jcaHashAlg The JCA name of the hash algorithm. Must be
068         *                   supported and not {@code null}.
069         */
070        public ConcatKDF(final String jcaHashAlg) {
071
072                if (jcaHashAlg == null) {
073                        throw new IllegalArgumentException("The JCA hash algorithm must not be null");
074                }
075
076                this.jcaHashAlg = jcaHashAlg;
077        }
078
079
080        /**
081         * Returns the JCA name of the hash algorithm.
082         *
083         * @return The JCA name of the hash algorithm.
084         */
085        public String getHashAlgorithm() {
086
087                return jcaHashAlg;
088        }
089
090
091        @Override
092        public JCAContext getJCAContext() {
093
094                return jcaContext;
095        }
096
097
098        /**
099         * Derives a key from the specified inputs.
100         *
101         * @param sharedSecret  The shared secret. Must not be {@code null}.
102         * @param keyLengthBits The length of the key to derive, in bits.
103         * @param otherInfo     Other info, {@code null} if not specified.
104         *
105         * @return The derived key, with algorithm set to "AES".
106         *
107         * @throws JOSEException If the key derivation failed.
108         */
109        public SecretKey deriveKey(final SecretKey sharedSecret,
110                                   final int keyLengthBits,
111                                   final byte[] otherInfo)
112                throws JOSEException {
113
114                ByteArrayOutputStream baos = new ByteArrayOutputStream();
115
116                final MessageDigest md = getMessageDigest();
117
118                for (int i=1; i <= computeDigestCycles(ByteUtils.safeBitLength(md.getDigestLength()), keyLengthBits); i++) {
119
120                        byte[] counterBytes = IntegerUtils.toBytes(i);
121
122                        md.update(counterBytes);
123                        md.update(sharedSecret.getEncoded());
124
125                        if (otherInfo != null) {
126                                md.update(otherInfo);
127                        }
128
129                        try {
130                                baos.write(md.digest());
131                        } catch (IOException e) {
132                                throw new JOSEException("Couldn't write derived key: " + e.getMessage(), e);
133                        }
134                }
135
136                byte[] derivedKeyMaterial = baos.toByteArray();
137
138                final int keyLengthBytes = ByteUtils.byteLength(keyLengthBits);
139
140                if (derivedKeyMaterial.length == keyLengthBytes) {
141                        // Return immediately
142                        return new SecretKeySpec(derivedKeyMaterial, "AES");
143                }
144
145                return new SecretKeySpec(ByteUtils.subArray(derivedKeyMaterial, 0, keyLengthBytes), "AES");
146        }
147
148
149        /**
150         * Derives a key from the specified inputs.
151         *
152         * @param sharedSecret The shared secret. Must not be {@code null}.
153         * @param keyLength    The length of the key to derive, in bits.
154         * @param algID        The algorithm identifier, {@code null} if not
155         *                     specified.
156         * @param partyUInfo   The partyUInfo, {@code null} if not specified.
157         * @param partyVInfo   The partyVInfo {@code null} if not specified.
158         * @param suppPubInfo  The suppPubInfo, {@code null} if not specified.
159         * @param suppPrivInfo The suppPrivInfo, {@code null} if not specified.
160         *
161         *  @return The derived key, with algorithm set to "AES".
162         *
163         * @throws JOSEException If the key derivation failed.
164         */
165        public SecretKey deriveKey(final SecretKey sharedSecret,
166                                   final int keyLength,
167                                   final byte[] algID,
168                                   final byte[] partyUInfo,
169                                   final byte[] partyVInfo,
170                                   final byte[] suppPubInfo,
171                                   final byte[] suppPrivInfo)
172                throws JOSEException {
173
174                final byte[] otherInfo = composeOtherInfo(algID, partyUInfo, partyVInfo, suppPubInfo, suppPrivInfo);
175
176                return deriveKey(sharedSecret, keyLength, otherInfo);
177        }
178
179        /**
180         * Derives a key from the specified inputs.
181         *
182         * @param sharedSecret The shared secret. Must not be {@code null}.
183         * @param keyLength    The length of the key to derive, in bits.
184         * @param algID        The algorithm identifier, {@code null} if not
185         *                     specified.
186         * @param partyUInfo   The partyUInfo, {@code null} if not specified.
187         * @param partyVInfo   The partyVInfo {@code null} if not specified.
188         * @param suppPubInfo  The suppPubInfo, {@code null} if not specified.
189         * @param suppPrivInfo The suppPrivInfo, {@code null} if not specified.
190         * @param tag          The cctag, {@code null} if not specified.
191         *
192         * @return The derived key, with algorithm set to "AES".
193         *
194         * @throws JOSEException If the key derivation failed.
195         */
196        public SecretKey deriveKey(final SecretKey sharedSecret,
197                                   final int keyLength,
198                                   final byte[] algID,
199                                   final byte[] partyUInfo,
200                                   final byte[] partyVInfo,
201                                   final byte[] suppPubInfo,
202                                   final byte[] suppPrivInfo,
203                                   final byte[] tag)
204                        throws JOSEException {
205
206                final byte[] otherInfo = composeOtherInfo(algID, partyUInfo, partyVInfo, suppPubInfo, suppPrivInfo, tag);
207
208                return deriveKey(sharedSecret, keyLength, otherInfo);
209        }
210
211        /**
212         * Composes the other info as {@code algID || partyUInfo || partyVInfo
213         * || suppPubInfo || suppPrivInfo}.
214         *
215         * @param algID        The algorithm identifier, {@code null} if not
216         *                     specified.
217         * @param partyUInfo   The partyUInfo, {@code null} if not specified.
218         * @param partyVInfo   The partyVInfo {@code null} if not specified.
219         * @param suppPubInfo  The suppPubInfo, {@code null} if not specified.
220         * @param suppPrivInfo The suppPrivInfo, {@code null} if not specified.
221         *
222         * @return The resulting other info.
223         */
224        public static byte[] composeOtherInfo(final byte[] algID,
225                                              final byte[] partyUInfo,
226                                              final byte[] partyVInfo,
227                                              final byte[] suppPubInfo,
228                                              final byte[] suppPrivInfo) {
229
230                return ByteUtils.concat(algID, partyUInfo, partyVInfo, suppPubInfo, suppPrivInfo);
231        }
232
233        /**
234         * Composes the other info as {@code algID || partyUInfo || partyVInfo
235         * || suppPubInfo || suppPrivInfo || tag}.
236         *
237         * @param algID        The algorithm identifier, {@code null} if not
238         *                     specified.
239         * @param partyUInfo   The partyUInfo, {@code null} if not specified.
240         * @param partyVInfo   The partyVInfo {@code null} if not specified.
241         * @param suppPubInfo  The suppPubInfo, {@code null} if not specified.
242         * @param suppPrivInfo The suppPrivInfo, {@code null} if not specified.
243         * @param tag          The cctag, {@code null} if not specified.
244         *
245         * @return The resulting other info.
246         */
247        public static byte[] composeOtherInfo(final byte[] algID,
248                                                  final byte[] partyUInfo,
249                                                  final byte[] partyVInfo,
250                                                  final byte[] suppPubInfo,
251                                                  final byte[] suppPrivInfo,
252                                                  final byte[] tag) {
253
254                return ByteUtils.concat(algID, partyUInfo, partyVInfo, suppPubInfo, suppPrivInfo, tag);
255        }
256
257
258        /**
259         * Returns a message digest instance for the configured
260         * {@link #jcaHashAlg hash algorithm}.
261         *
262         * @return The message digest instance.
263         *
264         * @throws JOSEException If the message digest algorithm is not
265         *                       supported by the underlying JCA provider.
266         */
267        private MessageDigest getMessageDigest()
268                throws JOSEException {
269
270                final Provider provider = getJCAContext().getProvider();
271
272                try {
273                        if (provider == null)
274                                return MessageDigest.getInstance(jcaHashAlg);
275                        else
276                                return MessageDigest.getInstance(jcaHashAlg, provider);
277                } catch (NoSuchAlgorithmException e) {
278                        throw new JOSEException("Couldn't get message digest for KDF: " + e.getMessage(), e);
279                }
280        }
281
282
283        /**
284         * Computes the required digest (hashing) cycles for the specified
285         * message digest length and derived key length.
286         *
287         * @param digestLengthBits The length of the message digest, in bits.
288         * @param keyLengthBits    The length of the derived key, in bits.
289         *
290         * @return The digest cycles.
291         */
292        public static int computeDigestCycles(final int digestLengthBits, final int keyLengthBits) {
293
294                // return the ceiling of keyLength / digestLength
295                
296                return (keyLengthBits + digestLengthBits - 1) / digestLengthBits;
297        }
298
299
300        /**
301         * Encodes no / empty data as an empty byte array.
302         *
303         * @return The encoded data.
304         */
305        public static byte[] encodeNoData() {
306
307                return new byte[0];
308        }
309
310
311        /**
312         * Encodes the specified integer data as a four byte array.
313         *
314         * @param data The integer data to encode.
315         *
316         * @return The encoded data.
317         */
318        public static byte[] encodeIntData(final int data) {
319
320                return IntegerUtils.toBytes(data);
321        }
322
323
324        /**
325         * Encodes the specified string data as {@code data.length || data}.
326         *
327         * @param data The string data, UTF-8 encoded. May be {@code null}.
328         *
329         * @return The encoded data.
330         */
331        public static byte[] encodeStringData(final String data) {
332
333                byte[] bytes = data != null ? data.getBytes(StandardCharset.UTF_8) : null;
334                return encodeDataWithLength(bytes);
335        }
336
337
338        /**
339         * Encodes the specified data as {@code data.length || data}.
340         *
341         * @param data The data to encode, may be {@code null}.
342         *
343         * @return The encoded data.
344         */
345        public static byte[] encodeDataWithLength(final byte[] data) {
346
347                byte[] bytes = data != null ? data : new byte[0];
348                byte[] length = IntegerUtils.toBytes(bytes.length);
349                return ByteUtils.concat(length, bytes);
350        }
351
352
353        /**
354         * Encodes the specified BASE64URL encoded data
355         * {@code data.length || data}.
356         *
357         * @param data The data to encode, may be {@code null}.
358         *
359         * @return The encoded data.
360         */
361        public static byte[] encodeDataWithLength(final Base64URL data) {
362
363                byte[] bytes = data != null ? data.decode() : null;
364                return encodeDataWithLength(bytes);
365        }
366}
367