001/*
002 * nimbus-jose-jwt
003 *
004 * Copyright 2012-2016, Connect2id Ltd and contributors.
005 *
006 * Licensed under the Apache License, Version 2.0 (the "License"); you may not use
007 * this file except in compliance with the License. You may obtain a copy of the
008 * License at
009 *
010 *    http://www.apache.org/licenses/LICENSE-2.0
011 *
012 * Unless required by applicable law or agreed to in writing, software distributed
013 * under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
014 * CONDITIONS OF ANY KIND, either express or implied. See the License for the
015 * specific language governing permissions and limitations under the License.
016 */
017
018package com.nimbusds.jose.crypto.impl;
019
020
021import java.io.ByteArrayOutputStream;
022import java.io.IOException;
023import java.security.MessageDigest;
024import java.security.NoSuchAlgorithmException;
025import java.security.Provider;
026import javax.crypto.SecretKey;
027import javax.crypto.spec.SecretKeySpec;
028
029import com.nimbusds.jose.JOSEException;
030import com.nimbusds.jose.jca.JCAAware;
031import com.nimbusds.jose.jca.JCAContext;
032import com.nimbusds.jose.util.Base64URL;
033import com.nimbusds.jose.util.ByteUtils;
034import com.nimbusds.jose.util.IntegerUtils;
035import com.nimbusds.jose.util.StandardCharset;
036import net.jcip.annotations.ThreadSafe;
037
038
039/**
040 * Concatenation Key Derivation Function (KDF). This class is thread-safe.
041 *
042 * <p>See NIST.800-56A.
043 *
044 * @author Vladimir Dzhuvinov
045 * @version 2017-06-01
046 */
047@ThreadSafe
048public class ConcatKDF implements JCAAware<JCAContext> {
049
050
051        /**
052         * The JCA name of the hash algorithm.
053         */
054        private final String jcaHashAlg;
055
056
057        /**
058         * The JCA context..
059         */
060        private final JCAContext jcaContext = new JCAContext();
061
062
063        /**
064         * Creates a new concatenation Key Derivation Function (KDF) with the
065         * specified hash algorithm.
066         *
067         * @param jcaHashAlg The JCA name of the hash algorithm. Must be
068         *                   supported and not {@code null}.
069         */
070        public ConcatKDF(final String jcaHashAlg) {
071
072                if (jcaHashAlg == null) {
073                        throw new IllegalArgumentException("The JCA hash algorithm must not be null");
074                }
075
076                this.jcaHashAlg = jcaHashAlg;
077        }
078
079
080        /**
081         * Returns the JCA name of the hash algorithm.
082         *
083         * @return The JCA name of the hash algorithm.
084         */
085        public String getHashAlgorithm() {
086
087                return jcaHashAlg;
088        }
089
090
091        @Override
092        public JCAContext getJCAContext() {
093
094                return jcaContext;
095        }
096
097
098        /**
099         * Derives a key from the specified inputs.
100         *
101         * @param sharedSecret  The shared secret. Must not be {@code null}.
102         * @param keyLengthBits The length of the key to derive, in bits.
103         * @param otherInfo     Other info, {@code null} if not specified.
104         *
105         * @return The derived key, with algorithm set to "AES".
106         *
107         * @throws JOSEException If the key derivation failed.
108         */
109        public SecretKey deriveKey(final SecretKey sharedSecret,
110                                   final int keyLengthBits,
111                                   final byte[] otherInfo)
112                throws JOSEException {
113
114                ByteArrayOutputStream baos = new ByteArrayOutputStream();
115
116                final MessageDigest md = getMessageDigest();
117
118                for (int i=1; i <= computeDigestCycles(ByteUtils.safeBitLength(md.getDigestLength()), keyLengthBits); i++) {
119
120                        byte[] counterBytes = IntegerUtils.toBytes(i);
121
122                        md.update(counterBytes);
123                        md.update(sharedSecret.getEncoded());
124
125                        if (otherInfo != null) {
126                                md.update(otherInfo);
127                        }
128
129                        try {
130                                baos.write(md.digest());
131                        } catch (IOException e) {
132                                throw new JOSEException("Couldn't write derived key: " + e.getMessage(), e);
133                        }
134                }
135
136                byte[] derivedKeyMaterial = baos.toByteArray();
137
138                final int keyLengthBytes = ByteUtils.byteLength(keyLengthBits);
139
140                if (derivedKeyMaterial.length == keyLengthBytes) {
141                        // Return immediately
142                        return new SecretKeySpec(derivedKeyMaterial, "AES");
143                }
144
145                return new SecretKeySpec(ByteUtils.subArray(derivedKeyMaterial, 0, keyLengthBytes), "AES");
146        }
147
148
149        /**
150         * Derives a key from the specified inputs.
151         *
152         * @param sharedSecret The shared secret. Must not be {@code null}.
153         * @param keyLength    The length of the key to derive, in bits.
154         * @param algID        The algorithm identifier, {@code null} if not
155         *                     specified.
156         * @param partyUInfo   The partyUInfo, {@code null} if not specified.
157         * @param partyVInfo   The partyVInfo {@code null} if not specified.
158         * @param suppPubInfo  The suppPubInfo, {@code null} if not specified.
159         * @param suppPrivInfo The suppPrivInfo, {@code null} if not specified.
160         *
161         *  @return The derived key, with algorithm set to "AES".
162         *
163         * @throws JOSEException If the key derivation failed.
164         */
165        public SecretKey deriveKey(final SecretKey sharedSecret,
166                                   final int keyLength,
167                                   final byte[] algID,
168                                   final byte[] partyUInfo,
169                                   final byte[] partyVInfo,
170                                   final byte[] suppPubInfo,
171                                   final byte[] suppPrivInfo)
172                throws JOSEException {
173
174                final byte[] otherInfo = composeOtherInfo(algID, partyUInfo, partyVInfo, suppPubInfo, suppPrivInfo);
175
176                return deriveKey(sharedSecret, keyLength, otherInfo);
177        }
178
179
180        /**
181         * Composes the other info as {@code algID || partyUInfo || partyVInfo
182         * || suppPubInfo || suppPrivInfo}.
183         *
184         * @param algID        The algorithm identifier, {@code null} if not
185         *                     specified.
186         * @param partyUInfo   The partyUInfo, {@code null} if not specified.
187         * @param partyVInfo   The partyVInfo {@code null} if not specified.
188         * @param suppPubInfo  The suppPubInfo, {@code null} if not specified.
189         * @param suppPrivInfo The suppPrivInfo, {@code null} if not specified.
190         *
191         * @return The resulting other info.
192         */
193        public static byte[] composeOtherInfo(final byte[] algID,
194                                              final byte[] partyUInfo,
195                                              final byte[] partyVInfo,
196                                              final byte[] suppPubInfo,
197                                              final byte[] suppPrivInfo) {
198
199                return ByteUtils.concat(algID, partyUInfo, partyVInfo, suppPubInfo, suppPrivInfo);
200        }
201
202
203        /**
204         * Returns a message digest instance for the configured
205         * {@link #jcaHashAlg hash algorithm}.
206         *
207         * @return The message digest instance.
208         *
209         * @throws JOSEException If the message digest algorithm is not
210         *                       supported by the underlying JCA provider.
211         */
212        private MessageDigest getMessageDigest()
213                throws JOSEException {
214
215                final Provider provider = getJCAContext().getProvider();
216
217                try {
218                        if (provider == null)
219                                return MessageDigest.getInstance(jcaHashAlg);
220                        else
221                                return MessageDigest.getInstance(jcaHashAlg, provider);
222                } catch (NoSuchAlgorithmException e) {
223                        throw new JOSEException("Couldn't get message digest for KDF: " + e.getMessage(), e);
224                }
225        }
226
227
228        /**
229         * Computes the required digest (hashing) cycles for the specified
230         * message digest length and derived key length.
231         *
232         * @param digestLengthBits The length of the message digest, in bits.
233         * @param keyLengthBits    The length of the derived key, in bits.
234         *
235         * @return The digest cycles.
236         */
237        public static int computeDigestCycles(final int digestLengthBits, final int keyLengthBits) {
238
239                // return the ceiling of keyLength / digestLength
240                
241                return (keyLengthBits + digestLengthBits - 1) / digestLengthBits;
242        }
243
244
245        /**
246         * Encodes no / empty data as an empty byte array.
247         *
248         * @return The encoded data.
249         */
250        public static byte[] encodeNoData() {
251
252                return new byte[0];
253        }
254
255
256        /**
257         * Encodes the specified integer data as a four byte array.
258         *
259         * @param data The integer data to encode.
260         *
261         * @return The encoded data.
262         */
263        public static byte[] encodeIntData(final int data) {
264
265                return IntegerUtils.toBytes(data);
266        }
267
268
269        /**
270         * Encodes the specified string data as {@code data.length || data}.
271         *
272         * @param data The string data, UTF-8 encoded. May be {@code null}.
273         *
274         * @return The encoded data.
275         */
276        public static byte[] encodeStringData(final String data) {
277
278                byte[] bytes = data != null ? data.getBytes(StandardCharset.UTF_8) : null;
279                return encodeDataWithLength(bytes);
280        }
281
282
283        /**
284         * Encodes the specified data as {@code data.length || data}.
285         *
286         * @param data The data to encode, may be {@code null}.
287         *
288         * @return The encoded data.
289         */
290        public static byte[] encodeDataWithLength(final byte[] data) {
291
292                byte[] bytes = data != null ? data : new byte[0];
293                byte[] length = IntegerUtils.toBytes(bytes.length);
294                return ByteUtils.concat(length, bytes);
295        }
296
297
298        /**
299         * Encodes the specified BASE64URL encoded data
300         * {@code data.length || data}.
301         *
302         * @param data The data to encode, may be {@code null}.
303         *
304         * @return The encoded data.
305         */
306        public static byte[] encodeDataWithLength(final Base64URL data) {
307
308                byte[] bytes = data != null ? data.decode() : null;
309                return encodeDataWithLength(bytes);
310        }
311}
312