001package com.nimbusds.jose.proc;
002
003
004import java.security.Key;
005import java.security.PublicKey;
006import java.util.Collections;
007import java.util.LinkedList;
008import java.util.List;
009import javax.crypto.SecretKey;
010
011import com.nimbusds.jose.JWSAlgorithm;
012import com.nimbusds.jose.JWSHeader;
013import com.nimbusds.jose.jwk.*;
014import com.nimbusds.jose.jwk.source.JWKSource;
015import net.jcip.annotations.ThreadSafe;
016
017
018/**
019 * Key selector for verifying JWS objects, where the key candidates are
020 * retrieved from a {@link JWKSource JSON Web Key (JWK) source}.
021 *
022 * @author Vladimir Dzhuvinov
023 * @version 2016-04-10
024 */
025@ThreadSafe
026public class JWSVerificationKeySelector<C extends SecurityContext> extends AbstractJWKSelectorWithSource<C> implements JWSKeySelector<C> {
027
028
029        /**
030         * The expected JWS algorithm.
031         */
032        private final JWSAlgorithm jwsAlg;
033
034
035        /**
036         * Creates a new JWS verification key selector.
037         *
038         * @param jwsAlg    The expected JWS algorithm for the objects to be
039         *                  verified. Must not be {@code null}.
040         * @param jwkSource The JWK source. Must not be {@code null}.
041         */
042        public JWSVerificationKeySelector(final JWSAlgorithm jwsAlg, final JWKSource<C> jwkSource) {
043                super(jwkSource);
044                if (jwsAlg == null) {
045                        throw new IllegalArgumentException("The JWS algorithm must not be null");
046                }
047                this.jwsAlg = jwsAlg;
048        }
049
050
051        /**
052         * Returns the expected JWS algorithm.
053         *
054         * @return The expected JWS algorithm.
055         */
056        public JWSAlgorithm getExpectedJWSAlgorithm() {
057
058                return jwsAlg;
059        }
060
061
062        /**
063         * Creates a JWK matcher for the expected JWS algorithm and the
064         * specified JWS header.
065         *
066         * @param jwsHeader The JWS header. Must not be {@code null}.
067         *
068         * @return The JWK matcher, {@code null} if none could be created.
069         */
070        protected JWKMatcher createJWKMatcher(final JWSHeader jwsHeader) {
071
072                if (! getExpectedJWSAlgorithm().equals(jwsHeader.getAlgorithm())) {
073                        // Unexpected JWS alg
074                        return null;
075                } else if (JWSAlgorithm.Family.RSA.contains(getExpectedJWSAlgorithm()) || JWSAlgorithm.Family.EC.contains(getExpectedJWSAlgorithm())) {
076                        // RSA or EC key matcher
077                        return new JWKMatcher.Builder()
078                                        .keyType(KeyType.forAlgorithm(getExpectedJWSAlgorithm()))
079                                        .keyID(jwsHeader.getKeyID())
080                                        .keyUses(KeyUse.SIGNATURE, null)
081                                        .algorithms(getExpectedJWSAlgorithm(), null)
082                                        .build();
083                } else if (JWSAlgorithm.Family.HMAC_SHA.contains(getExpectedJWSAlgorithm())) {
084                        // HMAC secret matcher
085                        return new JWKMatcher.Builder()
086                                        .keyType(KeyType.forAlgorithm(getExpectedJWSAlgorithm()))
087                                        .keyID(jwsHeader.getKeyID())
088                                        .privateOnly(true)
089                                        .algorithms(getExpectedJWSAlgorithm(), null)
090                                        .build();
091                } else {
092                        return null; // Unsupported algorithm
093                }
094        }
095
096
097        @Override
098        public List<Key> selectJWSKeys(final JWSHeader jwsHeader, final C context) {
099
100                if (! jwsAlg.equals(jwsHeader.getAlgorithm())) {
101                        // Unexpected JWS alg
102                        return Collections.emptyList();
103                }
104
105                JWKMatcher jwkMatcher = createJWKMatcher(jwsHeader);
106                if (jwkMatcher == null) {
107                        return Collections.emptyList();
108                }
109
110                List<JWK> jwkMatches = getJWKSource().get(new JWKSelector(jwkMatcher), context);
111
112                List<Key> sanitizedKeyList = new LinkedList<>();
113
114                for (Key key: KeyConverter.toJavaKeys(jwkMatches)) {
115                        if (key instanceof PublicKey || key instanceof SecretKey) {
116                                sanitizedKeyList.add(key);
117                        } // skip asymmetric private keys
118                }
119
120                return sanitizedKeyList;
121        }
122}