001package com.nimbusds.jose.proc;
002
003
004import java.security.Key;
005import java.security.PublicKey;
006import java.util.Collections;
007import java.util.LinkedList;
008import java.util.List;
009import javax.crypto.SecretKey;
010
011import com.nimbusds.jose.JWSAlgorithm;
012import com.nimbusds.jose.JWSHeader;
013import com.nimbusds.jose.KeySourceException;
014import com.nimbusds.jose.jwk.*;
015import com.nimbusds.jose.jwk.source.JWKSource;
016import net.jcip.annotations.ThreadSafe;
017
018
019/**
020 * Key selector for verifying JWS objects, where the key candidates are
021 * retrieved from a {@link JWKSource JSON Web Key (JWK) source}.
022 *
023 * @author Vladimir Dzhuvinov
024 * @version 2016-06-21
025 */
026@ThreadSafe
027public class JWSVerificationKeySelector<C extends SecurityContext> extends AbstractJWKSelectorWithSource<C> implements JWSKeySelector<C> {
028
029
030        /**
031         * The expected JWS algorithm.
032         */
033        private final JWSAlgorithm jwsAlg;
034
035
036        /**
037         * Creates a new JWS verification key selector.
038         *
039         * @param jwsAlg    The expected JWS algorithm for the objects to be
040         *                  verified. Must not be {@code null}.
041         * @param jwkSource The JWK source. Must not be {@code null}.
042         */
043        public JWSVerificationKeySelector(final JWSAlgorithm jwsAlg, final JWKSource<C> jwkSource) {
044                super(jwkSource);
045                if (jwsAlg == null) {
046                        throw new IllegalArgumentException("The JWS algorithm must not be null");
047                }
048                this.jwsAlg = jwsAlg;
049        }
050
051
052        /**
053         * Returns the expected JWS algorithm.
054         *
055         * @return The expected JWS algorithm.
056         */
057        public JWSAlgorithm getExpectedJWSAlgorithm() {
058
059                return jwsAlg;
060        }
061
062
063        /**
064         * Creates a JWK matcher for the expected JWS algorithm and the
065         * specified JWS header.
066         *
067         * @param jwsHeader The JWS header. Must not be {@code null}.
068         *
069         * @return The JWK matcher, {@code null} if none could be created.
070         */
071        protected JWKMatcher createJWKMatcher(final JWSHeader jwsHeader) {
072
073                if (! getExpectedJWSAlgorithm().equals(jwsHeader.getAlgorithm())) {
074                        // Unexpected JWS alg
075                        return null;
076                } else if (JWSAlgorithm.Family.RSA.contains(getExpectedJWSAlgorithm()) || JWSAlgorithm.Family.EC.contains(getExpectedJWSAlgorithm())) {
077                        // RSA or EC key matcher
078                        return new JWKMatcher.Builder()
079                                        .keyType(KeyType.forAlgorithm(getExpectedJWSAlgorithm()))
080                                        .keyID(jwsHeader.getKeyID())
081                                        .keyUses(KeyUse.SIGNATURE, null)
082                                        .algorithms(getExpectedJWSAlgorithm(), null)
083                                        .build();
084                } else if (JWSAlgorithm.Family.HMAC_SHA.contains(getExpectedJWSAlgorithm())) {
085                        // HMAC secret matcher
086                        return new JWKMatcher.Builder()
087                                        .keyType(KeyType.forAlgorithm(getExpectedJWSAlgorithm()))
088                                        .keyID(jwsHeader.getKeyID())
089                                        .privateOnly(true)
090                                        .algorithms(getExpectedJWSAlgorithm(), null)
091                                        .build();
092                } else {
093                        return null; // Unsupported algorithm
094                }
095        }
096
097
098        @Override
099        public List<Key> selectJWSKeys(final JWSHeader jwsHeader, final C context)
100                throws KeySourceException {
101
102                if (! jwsAlg.equals(jwsHeader.getAlgorithm())) {
103                        // Unexpected JWS alg
104                        return Collections.emptyList();
105                }
106
107                JWKMatcher jwkMatcher = createJWKMatcher(jwsHeader);
108                if (jwkMatcher == null) {
109                        return Collections.emptyList();
110                }
111
112                List<JWK> jwkMatches = getJWKSource().get(new JWKSelector(jwkMatcher), context);
113
114                List<Key> sanitizedKeyList = new LinkedList<>();
115
116                for (Key key: KeyConverter.toJavaKeys(jwkMatches)) {
117                        if (key instanceof PublicKey || key instanceof SecretKey) {
118                                sanitizedKeyList.add(key);
119                        } // skip asymmetric private keys
120                }
121
122                return sanitizedKeyList;
123        }
124}