001/* 002 * nimbus-jose-jwt 003 * 004 * Copyright 2012-2016, Connect2id Ltd and contributors. 005 * 006 * Licensed under the Apache License, Version 2.0 (the "License"); you may not use 007 * this file except in compliance with the License. You may obtain a copy of the 008 * License at 009 * 010 * http://www.apache.org/licenses/LICENSE-2.0 011 * 012 * Unless required by applicable law or agreed to in writing, software distributed 013 * under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR 014 * CONDITIONS OF ANY KIND, either express or implied. See the License for the 015 * specific language governing permissions and limitations under the License. 016 */ 017 018package com.nimbusds.jose.crypto.impl; 019 020 021import java.io.ByteArrayOutputStream; 022import java.io.IOException; 023import java.security.MessageDigest; 024import java.security.NoSuchAlgorithmException; 025import java.security.Provider; 026import javax.crypto.SecretKey; 027import javax.crypto.spec.SecretKeySpec; 028 029import com.nimbusds.jose.JOSEException; 030import com.nimbusds.jose.jca.JCAAware; 031import com.nimbusds.jose.jca.JCAContext; 032import com.nimbusds.jose.util.Base64URL; 033import com.nimbusds.jose.util.ByteUtils; 034import com.nimbusds.jose.util.IntegerUtils; 035import com.nimbusds.jose.util.StandardCharset; 036import net.jcip.annotations.ThreadSafe; 037 038 039/** 040 * Concatenation Key Derivation Function (KDF). This class is thread-safe. 041 * 042 * <p>See NIST.800-56A. 043 * 044 * @author Vladimir Dzhuvinov 045 * @version 2017-06-01 046 */ 047@ThreadSafe 048public class ConcatKDF implements JCAAware<JCAContext> { 049 050 051 /** 052 * The JCA name of the hash algorithm. 053 */ 054 private final String jcaHashAlg; 055 056 057 /** 058 * The JCA context.. 059 */ 060 private final JCAContext jcaContext = new JCAContext(); 061 062 063 /** 064 * Creates a new concatenation Key Derivation Function (KDF) with the 065 * specified hash algorithm. 066 * 067 * @param jcaHashAlg The JCA name of the hash algorithm. Must be 068 * supported and not {@code null}. 069 */ 070 public ConcatKDF(final String jcaHashAlg) { 071 072 if (jcaHashAlg == null) { 073 throw new IllegalArgumentException("The JCA hash algorithm must not be null"); 074 } 075 076 this.jcaHashAlg = jcaHashAlg; 077 } 078 079 080 /** 081 * Returns the JCA name of the hash algorithm. 082 * 083 * @return The JCA name of the hash algorithm. 084 */ 085 public String getHashAlgorithm() { 086 087 return jcaHashAlg; 088 } 089 090 091 @Override 092 public JCAContext getJCAContext() { 093 094 return jcaContext; 095 } 096 097 098 /** 099 * Derives a key from the specified inputs. 100 * 101 * @param sharedSecret The shared secret. Must not be {@code null}. 102 * @param keyLengthBits The length of the key to derive, in bits. 103 * @param otherInfo Other info, {@code null} if not specified. 104 * 105 * @return The derived key, with algorithm set to "AES". 106 * 107 * @throws JOSEException If the key derivation failed. 108 */ 109 public SecretKey deriveKey(final SecretKey sharedSecret, 110 final int keyLengthBits, 111 final byte[] otherInfo) 112 throws JOSEException { 113 114 ByteArrayOutputStream baos = new ByteArrayOutputStream(); 115 116 final MessageDigest md = getMessageDigest(); 117 118 for (int i=1; i <= computeDigestCycles(ByteUtils.safeBitLength(md.getDigestLength()), keyLengthBits); i++) { 119 120 byte[] counterBytes = IntegerUtils.toBytes(i); 121 122 md.update(counterBytes); 123 md.update(sharedSecret.getEncoded()); 124 125 if (otherInfo != null) { 126 md.update(otherInfo); 127 } 128 129 try { 130 baos.write(md.digest()); 131 } catch (IOException e) { 132 throw new JOSEException("Couldn't write derived key: " + e.getMessage(), e); 133 } 134 } 135 136 byte[] derivedKeyMaterial = baos.toByteArray(); 137 138 final int keyLengthBytes = ByteUtils.byteLength(keyLengthBits); 139 140 if (derivedKeyMaterial.length == keyLengthBytes) { 141 // Return immediately 142 return new SecretKeySpec(derivedKeyMaterial, "AES"); 143 } 144 145 return new SecretKeySpec(ByteUtils.subArray(derivedKeyMaterial, 0, keyLengthBytes), "AES"); 146 } 147 148 149 /** 150 * Derives a key from the specified inputs. 151 * 152 * @param sharedSecret The shared secret. Must not be {@code null}. 153 * @param keyLength The length of the key to derive, in bits. 154 * @param algID The algorithm identifier, {@code null} if not 155 * specified. 156 * @param partyUInfo The partyUInfo, {@code null} if not specified. 157 * @param partyVInfo The partyVInfo {@code null} if not specified. 158 * @param suppPubInfo The suppPubInfo, {@code null} if not specified. 159 * @param suppPrivInfo The suppPrivInfo, {@code null} if not specified. 160 * 161 * @return The derived key, with algorithm set to "AES". 162 * 163 * @throws JOSEException If the key derivation failed. 164 */ 165 public SecretKey deriveKey(final SecretKey sharedSecret, 166 final int keyLength, 167 final byte[] algID, 168 final byte[] partyUInfo, 169 final byte[] partyVInfo, 170 final byte[] suppPubInfo, 171 final byte[] suppPrivInfo) 172 throws JOSEException { 173 174 final byte[] otherInfo = composeOtherInfo(algID, partyUInfo, partyVInfo, suppPubInfo, suppPrivInfo); 175 176 return deriveKey(sharedSecret, keyLength, otherInfo); 177 } 178 179 /** 180 * Derives a key from the specified inputs. 181 * 182 * @param sharedSecret The shared secret. Must not be {@code null}. 183 * @param keyLength The length of the key to derive, in bits. 184 * @param algID The algorithm identifier, {@code null} if not 185 * specified. 186 * @param partyUInfo The partyUInfo, {@code null} if not specified. 187 * @param partyVInfo The partyVInfo {@code null} if not specified. 188 * @param suppPubInfo The suppPubInfo, {@code null} if not specified. 189 * @param suppPrivInfo The suppPrivInfo, {@code null} if not specified. 190 * @param tag The cctag, {@code null} if not specified. 191 * 192 * @return The derived key, with algorithm set to "AES". 193 * 194 * @throws JOSEException If the key derivation failed. 195 */ 196 public SecretKey deriveKey(final SecretKey sharedSecret, 197 final int keyLength, 198 final byte[] algID, 199 final byte[] partyUInfo, 200 final byte[] partyVInfo, 201 final byte[] suppPubInfo, 202 final byte[] suppPrivInfo, 203 final byte[] tag) 204 throws JOSEException { 205 206 final byte[] otherInfo = composeOtherInfo(algID, partyUInfo, partyVInfo, suppPubInfo, suppPrivInfo, tag); 207 208 return deriveKey(sharedSecret, keyLength, otherInfo); 209 } 210 211 /** 212 * Composes the other info as {@code algID || partyUInfo || partyVInfo 213 * || suppPubInfo || suppPrivInfo}. 214 * 215 * @param algID The algorithm identifier, {@code null} if not 216 * specified. 217 * @param partyUInfo The partyUInfo, {@code null} if not specified. 218 * @param partyVInfo The partyVInfo {@code null} if not specified. 219 * @param suppPubInfo The suppPubInfo, {@code null} if not specified. 220 * @param suppPrivInfo The suppPrivInfo, {@code null} if not specified. 221 * 222 * @return The resulting other info. 223 */ 224 public static byte[] composeOtherInfo(final byte[] algID, 225 final byte[] partyUInfo, 226 final byte[] partyVInfo, 227 final byte[] suppPubInfo, 228 final byte[] suppPrivInfo) { 229 230 return ByteUtils.concat(algID, partyUInfo, partyVInfo, suppPubInfo, suppPrivInfo); 231 } 232 233 /** 234 * Composes the other info as {@code algID || partyUInfo || partyVInfo 235 * || suppPubInfo || suppPrivInfo || tag}. 236 * 237 * @param algID The algorithm identifier, {@code null} if not 238 * specified. 239 * @param partyUInfo The partyUInfo, {@code null} if not specified. 240 * @param partyVInfo The partyVInfo {@code null} if not specified. 241 * @param suppPubInfo The suppPubInfo, {@code null} if not specified. 242 * @param suppPrivInfo The suppPrivInfo, {@code null} if not specified. 243 * @param tag The cctag, {@code null} if not specified. 244 * 245 * @return The resulting other info. 246 */ 247 public static byte[] composeOtherInfo(final byte[] algID, 248 final byte[] partyUInfo, 249 final byte[] partyVInfo, 250 final byte[] suppPubInfo, 251 final byte[] suppPrivInfo, 252 final byte[] tag) { 253 254 return ByteUtils.concat(algID, partyUInfo, partyVInfo, suppPubInfo, suppPrivInfo, tag); 255 } 256 257 258 /** 259 * Returns a message digest instance for the configured 260 * {@link #jcaHashAlg hash algorithm}. 261 * 262 * @return The message digest instance. 263 * 264 * @throws JOSEException If the message digest algorithm is not 265 * supported by the underlying JCA provider. 266 */ 267 private MessageDigest getMessageDigest() 268 throws JOSEException { 269 270 final Provider provider = getJCAContext().getProvider(); 271 272 try { 273 if (provider == null) 274 return MessageDigest.getInstance(jcaHashAlg); 275 else 276 return MessageDigest.getInstance(jcaHashAlg, provider); 277 } catch (NoSuchAlgorithmException e) { 278 throw new JOSEException("Couldn't get message digest for KDF: " + e.getMessage(), e); 279 } 280 } 281 282 283 /** 284 * Computes the required digest (hashing) cycles for the specified 285 * message digest length and derived key length. 286 * 287 * @param digestLengthBits The length of the message digest, in bits. 288 * @param keyLengthBits The length of the derived key, in bits. 289 * 290 * @return The digest cycles. 291 */ 292 public static int computeDigestCycles(final int digestLengthBits, final int keyLengthBits) { 293 294 // return the ceiling of keyLength / digestLength 295 296 return (keyLengthBits + digestLengthBits - 1) / digestLengthBits; 297 } 298 299 300 /** 301 * Encodes no / empty data as an empty byte array. 302 * 303 * @return The encoded data. 304 */ 305 public static byte[] encodeNoData() { 306 307 return new byte[0]; 308 } 309 310 311 /** 312 * Encodes the specified integer data as a four byte array. 313 * 314 * @param data The integer data to encode. 315 * 316 * @return The encoded data. 317 */ 318 public static byte[] encodeIntData(final int data) { 319 320 return IntegerUtils.toBytes(data); 321 } 322 323 324 /** 325 * Encodes the specified string data as {@code data.length || data}. 326 * 327 * @param data The string data, UTF-8 encoded. May be {@code null}. 328 * 329 * @return The encoded data. 330 */ 331 public static byte[] encodeStringData(final String data) { 332 333 byte[] bytes = data != null ? data.getBytes(StandardCharset.UTF_8) : null; 334 return encodeDataWithLength(bytes); 335 } 336 337 338 /** 339 * Encodes the specified data as {@code data.length || data}. 340 * 341 * @param data The data to encode, may be {@code null}. 342 * 343 * @return The encoded data. 344 */ 345 public static byte[] encodeDataWithLength(final byte[] data) { 346 347 byte[] bytes = data != null ? data : new byte[0]; 348 byte[] length = IntegerUtils.toBytes(bytes.length); 349 return ByteUtils.concat(length, bytes); 350 } 351 352 353 /** 354 * Encodes the specified BASE64URL encoded data 355 * {@code data.length || data}. 356 * 357 * @param data The data to encode, may be {@code null}. 358 * 359 * @return The encoded data. 360 */ 361 public static byte[] encodeDataWithLength(final Base64URL data) { 362 363 byte[] bytes = data != null ? data.decode() : null; 364 return encodeDataWithLength(bytes); 365 } 366} 367