001package com.nimbusds.jose.proc; 002 003 004import java.security.Key; 005import java.security.PublicKey; 006import java.util.Collections; 007import java.util.LinkedList; 008import java.util.List; 009import javax.crypto.SecretKey; 010 011import com.nimbusds.jose.JWSAlgorithm; 012import com.nimbusds.jose.JWSHeader; 013import com.nimbusds.jose.KeySourceException; 014import com.nimbusds.jose.jwk.*; 015import com.nimbusds.jose.jwk.source.JWKSource; 016import net.jcip.annotations.ThreadSafe; 017 018 019/** 020 * Key selector for verifying JWS objects, where the key candidates are 021 * retrieved from a {@link JWKSource JSON Web Key (JWK) source}. 022 * 023 * @author Vladimir Dzhuvinov 024 * @version 2016-06-21 025 */ 026@ThreadSafe 027public class JWSVerificationKeySelector<C extends SecurityContext> extends AbstractJWKSelectorWithSource<C> implements JWSKeySelector<C> { 028 029 030 /** 031 * The expected JWS algorithm. 032 */ 033 private final JWSAlgorithm jwsAlg; 034 035 036 /** 037 * Creates a new JWS verification key selector. 038 * 039 * @param jwsAlg The expected JWS algorithm for the objects to be 040 * verified. Must not be {@code null}. 041 * @param jwkSource The JWK source. Must not be {@code null}. 042 */ 043 public JWSVerificationKeySelector(final JWSAlgorithm jwsAlg, final JWKSource<C> jwkSource) { 044 super(jwkSource); 045 if (jwsAlg == null) { 046 throw new IllegalArgumentException("The JWS algorithm must not be null"); 047 } 048 this.jwsAlg = jwsAlg; 049 } 050 051 052 /** 053 * Returns the expected JWS algorithm. 054 * 055 * @return The expected JWS algorithm. 056 */ 057 public JWSAlgorithm getExpectedJWSAlgorithm() { 058 059 return jwsAlg; 060 } 061 062 063 /** 064 * Creates a JWK matcher for the expected JWS algorithm and the 065 * specified JWS header. 066 * 067 * @param jwsHeader The JWS header. Must not be {@code null}. 068 * 069 * @return The JWK matcher, {@code null} if none could be created. 070 */ 071 protected JWKMatcher createJWKMatcher(final JWSHeader jwsHeader) { 072 073 if (! getExpectedJWSAlgorithm().equals(jwsHeader.getAlgorithm())) { 074 // Unexpected JWS alg 075 return null; 076 } else if (JWSAlgorithm.Family.RSA.contains(getExpectedJWSAlgorithm()) || JWSAlgorithm.Family.EC.contains(getExpectedJWSAlgorithm())) { 077 // RSA or EC key matcher 078 return new JWKMatcher.Builder() 079 .keyType(KeyType.forAlgorithm(getExpectedJWSAlgorithm())) 080 .keyID(jwsHeader.getKeyID()) 081 .keyUses(KeyUse.SIGNATURE, null) 082 .algorithms(getExpectedJWSAlgorithm(), null) 083 .build(); 084 } else if (JWSAlgorithm.Family.HMAC_SHA.contains(getExpectedJWSAlgorithm())) { 085 // HMAC secret matcher 086 return new JWKMatcher.Builder() 087 .keyType(KeyType.forAlgorithm(getExpectedJWSAlgorithm())) 088 .keyID(jwsHeader.getKeyID()) 089 .privateOnly(true) 090 .algorithms(getExpectedJWSAlgorithm(), null) 091 .build(); 092 } else { 093 return null; // Unsupported algorithm 094 } 095 } 096 097 098 @Override 099 public List<Key> selectJWSKeys(final JWSHeader jwsHeader, final C context) 100 throws KeySourceException { 101 102 if (! jwsAlg.equals(jwsHeader.getAlgorithm())) { 103 // Unexpected JWS alg 104 return Collections.emptyList(); 105 } 106 107 JWKMatcher jwkMatcher = createJWKMatcher(jwsHeader); 108 if (jwkMatcher == null) { 109 return Collections.emptyList(); 110 } 111 112 List<JWK> jwkMatches = getJWKSource().get(new JWKSelector(jwkMatcher), context); 113 114 List<Key> sanitizedKeyList = new LinkedList<>(); 115 116 for (Key key: KeyConverter.toJavaKeys(jwkMatches)) { 117 if (key instanceof PublicKey || key instanceof SecretKey) { 118 sanitizedKeyList.add(key); 119 } // skip asymmetric private keys 120 } 121 122 return sanitizedKeyList; 123 } 124}