001package com.nimbusds.jose.proc; 002 003 004import java.security.Key; 005import java.security.PublicKey; 006import java.util.Collections; 007import java.util.LinkedList; 008import java.util.List; 009import javax.crypto.SecretKey; 010 011import com.nimbusds.jose.JWSAlgorithm; 012import com.nimbusds.jose.JWSHeader; 013import com.nimbusds.jose.jwk.*; 014import com.nimbusds.jose.jwk.source.JWKSource; 015import net.jcip.annotations.ThreadSafe; 016 017 018/** 019 * Key selector for verifying JWS objects, where the key candidates are 020 * retrieved from a {@link JWKSource JSON Web Key (JWK) source}. 021 * 022 * @author Vladimir Dzhuvinov 023 * @version 2016-04-10 024 */ 025@ThreadSafe 026public class JWSVerificationKeySelector<C extends SecurityContext> extends AbstractJWKSelectorWithSource<C> implements JWSKeySelector<C> { 027 028 029 /** 030 * The expected JWS algorithm. 031 */ 032 private final JWSAlgorithm jwsAlg; 033 034 035 /** 036 * Creates a new JWS verification key selector. 037 * 038 * @param jwsAlg The expected JWS algorithm for the objects to be 039 * verified. Must not be {@code null}. 040 * @param jwkSource The JWK source. Must not be {@code null}. 041 */ 042 public JWSVerificationKeySelector(final JWSAlgorithm jwsAlg, final JWKSource<C> jwkSource) { 043 super(jwkSource); 044 if (jwsAlg == null) { 045 throw new IllegalArgumentException("The JWS algorithm must not be null"); 046 } 047 this.jwsAlg = jwsAlg; 048 } 049 050 051 /** 052 * Returns the expected JWS algorithm. 053 * 054 * @return The expected JWS algorithm. 055 */ 056 public JWSAlgorithm getExpectedJWSAlgorithm() { 057 058 return jwsAlg; 059 } 060 061 062 /** 063 * Creates a JWK matcher for the expected JWS algorithm and the 064 * specified JWS header. 065 * 066 * @param jwsHeader The JWS header. Must not be {@code null}. 067 * 068 * @return The JWK matcher, {@code null} if none could be created. 069 */ 070 protected JWKMatcher createJWKMatcher(final JWSHeader jwsHeader) { 071 072 if (! getExpectedJWSAlgorithm().equals(jwsHeader.getAlgorithm())) { 073 // Unexpected JWS alg 074 return null; 075 } else if (JWSAlgorithm.Family.RSA.contains(getExpectedJWSAlgorithm()) || JWSAlgorithm.Family.EC.contains(getExpectedJWSAlgorithm())) { 076 // RSA or EC key matcher 077 return new JWKMatcher.Builder() 078 .keyType(KeyType.forAlgorithm(getExpectedJWSAlgorithm())) 079 .keyID(jwsHeader.getKeyID()) 080 .keyUses(KeyUse.SIGNATURE, null) 081 .algorithms(getExpectedJWSAlgorithm(), null) 082 .build(); 083 } else if (JWSAlgorithm.Family.HMAC_SHA.contains(getExpectedJWSAlgorithm())) { 084 // HMAC secret matcher 085 return new JWKMatcher.Builder() 086 .keyType(KeyType.forAlgorithm(getExpectedJWSAlgorithm())) 087 .keyID(jwsHeader.getKeyID()) 088 .privateOnly(true) 089 .algorithms(getExpectedJWSAlgorithm(), null) 090 .build(); 091 } else { 092 return null; // Unsupported algorithm 093 } 094 } 095 096 097 @Override 098 public List<Key> selectJWSKeys(final JWSHeader jwsHeader, final C context) { 099 100 if (! jwsAlg.equals(jwsHeader.getAlgorithm())) { 101 // Unexpected JWS alg 102 return Collections.emptyList(); 103 } 104 105 JWKMatcher jwkMatcher = createJWKMatcher(jwsHeader); 106 if (jwkMatcher == null) { 107 return Collections.emptyList(); 108 } 109 110 List<JWK> jwkMatches = getJWKSource().get(new JWKSelector(jwkMatcher), context); 111 112 List<Key> sanitizedKeyList = new LinkedList<>(); 113 114 for (Key key: KeyConverter.toJavaKeys(jwkMatches)) { 115 if (key instanceof PublicKey || key instanceof SecretKey) { 116 sanitizedKeyList.add(key); 117 } // skip asymmetric private keys 118 } 119 120 return sanitizedKeyList; 121 } 122}