001/*
002 * nimbus-jose-jwt
003 *
004 * Copyright 2012-2021, Connect2id Ltd and contributors.
005 *
006 * Licensed under the Apache License, Version 2.0 (the "License"); you may not use
007 * this file except in compliance with the License. You may obtain a copy of the
008 * License at
009 *
010 *    http://www.apache.org/licenses/LICENSE-2.0
011 *
012 * Unless required by applicable law or agreed to in writing, software distributed
013 * under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
014 * CONDITIONS OF ANY KIND, either express or implied. See the License for the
015 * specific language governing permissions and limitations under the License.
016 */
017
018package com.nimbusds.jose.crypto.impl;
019
020
021import com.nimbusds.jose.*;
022import com.nimbusds.jose.crypto.utils.ECChecks;
023import com.nimbusds.jose.jwk.Curve;
024import com.nimbusds.jose.jwk.OctetKeyPair;
025import com.nimbusds.jose.util.Base64URL;
026import com.nimbusds.jose.util.ByteUtils;
027
028import javax.crypto.SecretKey;
029import javax.crypto.spec.SecretKeySpec;
030import java.nio.charset.StandardCharsets;
031import java.security.Provider;
032import java.security.interfaces.ECPrivateKey;
033import java.security.interfaces.ECPublicKey;
034import java.util.Objects;
035
036
037/**
038 * Elliptic Curve Diffie-Hellman One-Pass Unified Model (ECDH-1PU) key
039 * agreement functions and utilities.
040 *
041 * @see <a href="https://datatracker.ietf.org/doc/html/draft-madden-jose-ecdh-1pu-04">Public
042 * Key Authenticated Encryption for JOSE: ECDH-1PU</a>
043 *
044 * @author Alexander Martynov
045 * @version 2021-08-03
046 */
047public class ECDH1PU {
048
049    /**
050     * Resolves the ECDH algorithm mode.
051     *
052     * @param alg The JWE algorithm. Must be supported and not {@code null}.
053     *
054     * @return The algorithm mode.
055     *
056     * @throws JOSEException If the JWE algorithm is not supported.
057     */
058    public static ECDH.AlgorithmMode resolveAlgorithmMode(final JWEAlgorithm alg)
059        throws JOSEException {
060
061        Objects.requireNonNull(alg, "The parameter \"alg\" must not be null");
062
063        if (alg.equals(JWEAlgorithm.ECDH_1PU)) {
064
065            return ECDH.AlgorithmMode.DIRECT;
066        }
067
068        if (alg.equals(JWEAlgorithm.ECDH_1PU_A128KW) ||
069                alg.equals(JWEAlgorithm.ECDH_1PU_A192KW) ||
070                alg.equals(JWEAlgorithm.ECDH_1PU_A256KW)
071        ) {
072
073            return ECDH.AlgorithmMode.KW;
074        }
075
076        throw new JOSEException(AlgorithmSupportMessage.unsupportedJWEAlgorithm(
077                alg,
078                ECDHCryptoProvider.SUPPORTED_ALGORITHMS));
079    }
080
081
082    /**
083     * Returns the bit length of the shared key (derived via concat KDF)
084     * for the specified JWE ECDH algorithm.
085     *
086     * @param alg The JWE ECDH algorithm. Must be supported and not
087     *            {@code null}.
088     * @param enc The encryption method. Must be supported and not
089     *            {@code null}.
090     *
091     * @return The bit length of the shared key.
092     *
093     * @throws JOSEException If the JWE algorithm or encryption method is
094     *                       not supported.
095     */
096    public static int sharedKeyLength(final JWEAlgorithm alg, final EncryptionMethod enc)
097        throws JOSEException {
098
099        Objects.requireNonNull(alg, "The parameter \"alg\" must not be null");
100        Objects.requireNonNull(enc, "The parameter \"enc\" must not be null");
101
102        if (alg.equals(JWEAlgorithm.ECDH_1PU)) {
103
104            int length = enc.cekBitLength();
105
106            if (length == 0) {
107                throw new JOSEException("Unsupported JWE encryption method " + enc);
108            }
109
110            return length;
111        }
112
113        if (alg.equals(JWEAlgorithm.ECDH_1PU_A128KW)) {
114            return 128;
115        }
116
117        if (alg.equals(JWEAlgorithm.ECDH_1PU_A192KW)) {
118            return  192;
119        }
120
121        if (alg.equals(JWEAlgorithm.ECDH_1PU_A256KW)) {
122            return  256;
123        }
124
125        throw new JOSEException(AlgorithmSupportMessage.unsupportedJWEAlgorithm(
126                alg, ECDHCryptoProvider.SUPPORTED_ALGORITHMS));
127    }
128
129    /**
130     * Derives a shared key (via concat KDF).
131     *
132     * The method should only be called in the
133     * {@link ECDH.AlgorithmMode#DIRECT} mode.
134     *
135     * The method derives the Content Encryption Key (CEK) for the "enc"
136     * algorithm, in the {@link ECDH.AlgorithmMode#DIRECT} mode.
137     *
138     * The method does not take the auth tag because the auth tag will be
139     * generated using a CEK derived as an output of this method.
140     *
141     * @param header    The JWE header. Its algorithm and encryption method
142     *                  must be supported. Must not be {@code null}.
143     * @param Z         The derived shared secret ('Z'). Must not be
144     *                  {@code null}.
145     * @param concatKDF The concat KDF. Must be initialised and not
146     *                  {@code null}.
147     *
148     * @return The derived shared key.
149     *
150     * @throws JOSEException If derivation of the shared key failed.
151     */
152    public static SecretKey deriveSharedKey(final JWEHeader header,
153                                            final SecretKey Z,
154                                            final ConcatKDF concatKDF)
155            throws JOSEException {
156
157        Objects.requireNonNull(header, "The parameter \"header\" must not be null");
158        Objects.requireNonNull(Z, "The parameter \"Z\" must not be null");
159        Objects.requireNonNull(concatKDF, "The parameter \"concatKDF\" must not be null");
160
161        final int sharedKeyLength = sharedKeyLength(header.getAlgorithm(), header.getEncryptionMethod());
162
163        // Set the alg ID for the concat KDF
164        ECDH.AlgorithmMode algMode = resolveAlgorithmMode(header.getAlgorithm());
165
166        final String algID;
167
168        if (algMode == ECDH.AlgorithmMode.DIRECT) {
169            // algID = enc
170            algID = header.getEncryptionMethod().getName();
171        } else if (algMode == ECDH.AlgorithmMode.KW) {
172            // algID = alg
173            algID = header.getAlgorithm().getName();
174        } else {
175            throw new JOSEException("Unsupported JWE ECDH algorithm mode: " + algMode);
176        }
177
178        return concatKDF.deriveKey(
179                Z,
180                sharedKeyLength,
181                ConcatKDF.encodeDataWithLength(algID.getBytes(StandardCharsets.US_ASCII)),
182                ConcatKDF.encodeDataWithLength(header.getAgreementPartyUInfo()),
183                ConcatKDF.encodeDataWithLength(header.getAgreementPartyVInfo()),
184                ConcatKDF.encodeIntData(sharedKeyLength),
185                ConcatKDF.encodeNoData()
186        );
187    }
188
189    /**
190     * Derives a shared key (via concat KDF).
191     *
192     * The method should only be called in {@link ECDH.AlgorithmMode#KW}.
193     *
194     * In Key Agreement with {@link ECDH.AlgorithmMode#KW} mode, the JWE
195     * Authentication Tag is included in the input to the KDF. This ensures
196     * that the content of the JWE was produced by the original sender and not
197     * by another recipient.
198     *
199     *
200     * @param header    The JWE header. Its algorithm and encryption method
201     *                  must be supported. Must not be {@code null}.
202     * @param Z         The derived shared secret ('Z'). Must not be
203     *                  {@code null}.
204     * @param tag       In Direct Key Agreement mode this is set to an empty
205     *                  octet string. In Key Agreement with Key Wrapping mode,
206     *                  this is set to a value of the form Data, where Data is
207     *                  the raw octets of the JWE Authentication Tag.
208     * @param concatKDF The concat KDF. Must be initialised and not
209     *                  {@code null}.
210     *
211     * @return The derived shared key.
212     *
213     * @throws JOSEException If derivation of the shared key failed.
214     */
215    public static SecretKey deriveSharedKey(final JWEHeader header,
216                        final SecretKey Z,
217                        final Base64URL tag,
218                        final ConcatKDF concatKDF)
219        throws JOSEException {
220
221        Objects.requireNonNull(header, "The parameter \"header\" must not be null");
222        Objects.requireNonNull(Z, "The parameter \"Z\" must not be null");
223        Objects.requireNonNull(tag, "The parameter \"tag\" must not be null");
224        Objects.requireNonNull(concatKDF, "The parameter \"concatKDF\" must not be null");
225
226        final int sharedKeyLength = sharedKeyLength(header.getAlgorithm(), header.getEncryptionMethod());
227
228        // Set the alg ID for the concat KDF
229        ECDH.AlgorithmMode algMode = resolveAlgorithmMode(header.getAlgorithm());
230
231        final String algID;
232
233        if (algMode == ECDH.AlgorithmMode.DIRECT) {
234            // algID = enc
235            algID = header.getEncryptionMethod().getName();
236        } else if (algMode == ECDH.AlgorithmMode.KW) {
237            // algID = alg
238            algID = header.getAlgorithm().getName();
239        } else {
240            throw new JOSEException("Unsupported JWE ECDH algorithm mode: " + algMode);
241        }
242
243        return concatKDF.deriveKey(
244            Z,
245            sharedKeyLength,
246            ConcatKDF.encodeDataWithLength(algID.getBytes(StandardCharsets.US_ASCII)),
247            ConcatKDF.encodeDataWithLength(header.getAgreementPartyUInfo()),
248            ConcatKDF.encodeDataWithLength(header.getAgreementPartyVInfo()),
249            ConcatKDF.encodeIntData(sharedKeyLength),
250            ConcatKDF.encodeNoData(),
251            ConcatKDF.encodeDataWithLength(tag)
252        );
253    }
254
255    /**
256     * Derives a shared secret (also called 'Z') where Z is the concatenation
257     * of Ze and Zs.
258     *
259     * @param Ze The shared secret derived from applying the ECDH primitive to
260     *           the sender's ephemeral private key and the recipient's static
261     *           public key (when sending) or the recipient's static private
262     *           key and the sender's ephemeral public key (when receiving).
263     *           Must not be {@code null}.
264     * @param Zs The shared secret derived from applying the ECDH primitive to
265     *           the sender's static private key and the recipient's static
266     *           public key (when sending) or the recipient's static private
267     *           key and the sender's static public key (when receiving). Must
268     *           not be {@code null}.
269     *
270     * @return The derived shared key.
271     */
272    public static SecretKey deriveZ(final SecretKey Ze, final SecretKey Zs) {
273        Objects.requireNonNull(Ze, "The parameter \"Ze\" must not be null");
274        Objects.requireNonNull(Zs, "The parameter \"Zs\" must not be null");
275
276        byte[] encodedKey = ByteUtils.concat(Ze.getEncoded(), Zs.getEncoded());
277        return new SecretKeySpec(encodedKey, 0, encodedKey.length, "AES");
278    }
279
280
281    /**
282     * Derives a shared secret (also called 'Z') for sender where Z is the
283     * concatenation of Ze and Zs. Where Ze is shared secret from applying
284     * the ECDH primitive to the sender's ephemeral private key and the recipient's
285     * static public key, Zs is the shared secret derived from
286     * applying the ECDH primitive to the sender's static private key and
287     * the recipient's static public key.
288     *
289     * @param privateKey The sender EC private key.
290     * @param publicKey  The recipient EC public key.
291     * @param epk        The sender EC ephemeral private key.
292     * @param provider   The specific JCA provider for the ECDH key
293     *                   agreement, {@code null} to use the default one.
294     *
295     * @return The derived shared secret ('Z'), with algorithm "AES".
296     *
297     * @throws JOSEException If derivation of the shared secret failed.
298     */
299    public static SecretKey deriveSenderZ(
300            final ECPrivateKey privateKey,
301            final ECPublicKey publicKey,
302            final ECPrivateKey epk,
303            final Provider provider) throws JOSEException {
304
305        validateSameCurve(privateKey, publicKey);
306        validateSameCurve(epk, publicKey);
307
308        SecretKey Ze = ECDH.deriveSharedSecret(
309                publicKey,
310                epk,
311                provider
312        );
313
314        SecretKey Zs = ECDH.deriveSharedSecret(
315                publicKey,
316                privateKey,
317                provider
318        );
319
320        return deriveZ(Ze, Zs);
321    }
322
323    /**
324     * Derives a shared secret (also called 'Z') for sender where Z is the
325     * concatenation of Ze and Zs. Where Ze is shared secret from applying
326     * the ECDH primitive to the sender's ephemeral public key and the recipient's
327     * static private key, Zs is the shared secret derived from
328     * applying the ECDH primitive to the sender's static public key and
329     * the recipient's static private key.
330     *
331     * @param privateKey The sender OctetKey private key.
332     * @param publicKey  The recipient OctetKey public key.
333     * @param epk        The sender OctetKey ephemeral private key.
334     *
335     * @return The derived shared secret ('Z'), with algorithm "AES".
336     *
337     * @throws JOSEException If derivation of the shared secret failed.
338     */
339    public static SecretKey deriveSenderZ(
340            final OctetKeyPair privateKey,
341            final OctetKeyPair publicKey,
342            final OctetKeyPair epk) throws JOSEException {
343
344        validateSameCurve(privateKey, publicKey);
345        validateSameCurve(epk, publicKey);
346
347        SecretKey Ze = ECDH.deriveSharedSecret(publicKey, epk);
348        SecretKey Zs = ECDH.deriveSharedSecret(publicKey, privateKey);
349
350        return deriveZ(Ze, Zs);
351    }
352
353    /**
354     * Derives a shared secret (also called 'Z') for sender where Z is the
355     * concatenation of Ze and Zs. Where Ze is shared secret from applying
356     * the ECDH primitive to the sender's ephemeral public key and the recipient's
357     * static private key, Zs is the shared secret derived from
358     * applying the ECDH primitive to the sender's static public key and
359     * the recipient's static private key.
360     *
361     * @param privateKey The sender EC private key.
362     * @param publicKey  The recipient EC public key.
363     * @param epk        The sender EC ephemeral public key.
364     * @param provider   The specific JCA provider for the ECDH key
365     *                   agreement, {@code null} to use the default one.
366     *
367     * @return The derived shared secret ('Z'), with algorithm "AES".
368     *
369     * @throws JOSEException If derivation of the shared secret failed.
370     */
371    public static SecretKey deriveRecipientZ(
372            final ECPrivateKey privateKey,
373            final ECPublicKey publicKey,
374            final ECPublicKey epk,
375            final Provider provider) throws JOSEException {
376
377        validateSameCurve(privateKey, publicKey);
378        validateSameCurve(privateKey, epk);
379
380        SecretKey Ze = ECDH.deriveSharedSecret(
381                epk,
382                privateKey,
383                provider
384        );
385
386        SecretKey Zs = ECDH.deriveSharedSecret(
387                publicKey,
388                privateKey,
389                provider
390        );
391
392        return deriveZ(Ze, Zs);
393    }
394
395    /**
396     * Derives a shared secret (also called 'Z') for recipient where Z is the
397     * concatenation of Ze and Zs.
398     *
399     * @param privateKey The sender OctetKey private key.
400     * @param publicKey  The recipient OctetKey public key.
401     * @param epk        The sender OctetKey ephemeral private key.
402     *
403     * @return The derived shared secret ('Z'), with algorithm "AES".
404     *
405     * @throws JOSEException If derivation of the shared secret failed.
406     */
407    public static SecretKey deriveRecipientZ(
408            final OctetKeyPair privateKey,
409            final OctetKeyPair publicKey,
410            final OctetKeyPair epk) throws JOSEException {
411
412        validateSameCurve(privateKey, publicKey);
413        validateSameCurve(privateKey, epk);
414
415        SecretKey Ze = ECDH.deriveSharedSecret(
416                epk,
417                privateKey
418        );
419
420        SecretKey Zs = ECDH.deriveSharedSecret(
421                publicKey,
422                privateKey
423        );
424
425        return deriveZ(Ze, Zs);
426    }
427
428    /**
429     * Ensures the private key and public key are from the same curve.
430     *
431     * @param privateKey EC private key. Must not be {@code null}.
432     * @param publicKey  EC public key. Must not be {@code null}.
433     *
434     * @throws JOSEException If the key curves don't match.
435     */
436    public static void validateSameCurve(final ECPrivateKey privateKey, final ECPublicKey publicKey)
437            throws JOSEException{
438
439        Objects.requireNonNull(privateKey, "The parameter \"privateKey\" must not be null");
440        Objects.requireNonNull(publicKey, "The parameter \"publicKey\" must not be null");
441
442        if (!privateKey.getParams().getCurve().equals(publicKey.getParams().getCurve())) {
443            throw new JOSEException("Curve of public key does not match curve of private key");
444        }
445
446        if (!ECChecks.isPointOnCurve(publicKey, privateKey)) {
447            throw new JOSEException("Invalid public EC key: Point(s) not on the expected curve");
448        }
449    }
450
451    /**
452     * Ensures the private key and public key are from the same curve.
453     *
454     * @param privateKey OKP private key. Must not be {@code null}.
455     * @param publicKey  OKP public key. Must not be {@code null}.
456     *
457     * @throws JOSEException If the curves don't match.
458     */
459    public static void validateSameCurve(final OctetKeyPair privateKey, final OctetKeyPair publicKey)
460            throws JOSEException {
461
462        Objects.requireNonNull(privateKey, "The parameter \"privateKey\" must not be null");
463        Objects.requireNonNull(publicKey, "The parameter \"publicKey\" must not be null");
464
465        if (!privateKey.isPrivate()) {
466            throw new JOSEException("OKP private key should be a private key");
467        }
468
469        if (publicKey.isPrivate()) {
470            throw new JOSEException("OKP public key should not be a private key");
471        }
472
473        if (!publicKey.getCurve().equals(Curve.X25519)) {
474            throw new JOSEException("Only supports OctetKeyPairs with crv=X25519");
475        }
476
477        if (!privateKey.getCurve().equals(publicKey.getCurve())) {
478            throw new JOSEException("Curve of public key does not match curve of private key");
479        }
480    }
481
482    /**
483     * Prevents public instantiation.
484     */
485    private ECDH1PU() {
486
487    }
488}