001/*
002 * nimbus-jose-jwt
003 *
004 * Copyright 2012-2016, Connect2id Ltd.
005 *
006 * Licensed under the Apache License, Version 2.0 (the "License"); you may not use
007 * this file except in compliance with the License. You may obtain a copy of the
008 * License at
009 *
010 *    http://www.apache.org/licenses/LICENSE-2.0
011 *
012 * Unless required by applicable law or agreed to in writing, software distributed
013 * under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
014 * CONDITIONS OF ANY KIND, either express or implied. See the License for the
015 * specific language governing permissions and limitations under the License.
016 */
017
018package com.nimbusds.jose.proc;
019
020
021import java.security.Key;
022import java.security.PublicKey;
023import java.util.Collections;
024import java.util.LinkedList;
025import java.util.List;
026import javax.crypto.SecretKey;
027
028import com.nimbusds.jose.JWSAlgorithm;
029import com.nimbusds.jose.JWSHeader;
030import com.nimbusds.jose.KeySourceException;
031import com.nimbusds.jose.jwk.*;
032import com.nimbusds.jose.jwk.source.JWKSource;
033import net.jcip.annotations.ThreadSafe;
034
035
036/**
037 * Key selector for verifying JWS objects, where the key candidates are
038 * retrieved from a {@link JWKSource JSON Web Key (JWK) source}.
039 *
040 * @author Vladimir Dzhuvinov
041 * @version 2016-06-21
042 */
043@ThreadSafe
044public class JWSVerificationKeySelector<C extends SecurityContext> extends AbstractJWKSelectorWithSource<C> implements JWSKeySelector<C> {
045
046
047        /**
048         * The expected JWS algorithm.
049         */
050        private final JWSAlgorithm jwsAlg;
051
052
053        /**
054         * Creates a new JWS verification key selector.
055         *
056         * @param jwsAlg    The expected JWS algorithm for the objects to be
057         *                  verified. Must not be {@code null}.
058         * @param jwkSource The JWK source. Must not be {@code null}.
059         */
060        public JWSVerificationKeySelector(final JWSAlgorithm jwsAlg, final JWKSource<C> jwkSource) {
061                super(jwkSource);
062                if (jwsAlg == null) {
063                        throw new IllegalArgumentException("The JWS algorithm must not be null");
064                }
065                this.jwsAlg = jwsAlg;
066        }
067
068
069        /**
070         * Returns the expected JWS algorithm.
071         *
072         * @return The expected JWS algorithm.
073         */
074        public JWSAlgorithm getExpectedJWSAlgorithm() {
075
076                return jwsAlg;
077        }
078
079
080        /**
081         * Creates a JWK matcher for the expected JWS algorithm and the
082         * specified JWS header.
083         *
084         * @param jwsHeader The JWS header. Must not be {@code null}.
085         *
086         * @return The JWK matcher, {@code null} if none could be created.
087         */
088        protected JWKMatcher createJWKMatcher(final JWSHeader jwsHeader) {
089
090                if (! getExpectedJWSAlgorithm().equals(jwsHeader.getAlgorithm())) {
091                        // Unexpected JWS alg
092                        return null;
093                } else if (JWSAlgorithm.Family.RSA.contains(getExpectedJWSAlgorithm()) || JWSAlgorithm.Family.EC.contains(getExpectedJWSAlgorithm())) {
094                        // RSA or EC key matcher
095                        return new JWKMatcher.Builder()
096                                        .keyType(KeyType.forAlgorithm(getExpectedJWSAlgorithm()))
097                                        .keyID(jwsHeader.getKeyID())
098                                        .keyUses(KeyUse.SIGNATURE, null)
099                                        .algorithms(getExpectedJWSAlgorithm(), null)
100                                        .build();
101                } else if (JWSAlgorithm.Family.HMAC_SHA.contains(getExpectedJWSAlgorithm())) {
102                        // HMAC secret matcher
103                        return new JWKMatcher.Builder()
104                                        .keyType(KeyType.forAlgorithm(getExpectedJWSAlgorithm()))
105                                        .keyID(jwsHeader.getKeyID())
106                                        .privateOnly(true)
107                                        .algorithms(getExpectedJWSAlgorithm(), null)
108                                        .build();
109                } else {
110                        return null; // Unsupported algorithm
111                }
112        }
113
114
115        @Override
116        public List<Key> selectJWSKeys(final JWSHeader jwsHeader, final C context)
117                throws KeySourceException {
118
119                if (! jwsAlg.equals(jwsHeader.getAlgorithm())) {
120                        // Unexpected JWS alg
121                        return Collections.emptyList();
122                }
123
124                JWKMatcher jwkMatcher = createJWKMatcher(jwsHeader);
125                if (jwkMatcher == null) {
126                        return Collections.emptyList();
127                }
128
129                List<JWK> jwkMatches = getJWKSource().get(new JWKSelector(jwkMatcher), context);
130
131                List<Key> sanitizedKeyList = new LinkedList<>();
132
133                for (Key key: KeyConverter.toJavaKeys(jwkMatches)) {
134                        if (key instanceof PublicKey || key instanceof SecretKey) {
135                                sanitizedKeyList.add(key);
136                        } // skip asymmetric private keys
137                }
138
139                return sanitizedKeyList;
140        }
141}