001package com.nimbusds.oauth2.sdk.jose.jwk; 002 003 004import java.security.Key; 005import java.security.PublicKey; 006import java.util.Collections; 007import java.util.LinkedList; 008import java.util.List; 009import javax.crypto.SecretKey; 010 011import com.nimbusds.jose.JWSAlgorithm; 012import com.nimbusds.jose.JWSHeader; 013import com.nimbusds.jose.jwk.*; 014import com.nimbusds.jose.proc.JWSKeySelector; 015import com.nimbusds.jose.proc.SecurityContext; 016import com.nimbusds.oauth2.sdk.id.Identifier; 017import net.jcip.annotations.ThreadSafe; 018 019 020/** 021 * Key selector for verifying JWS objects used in OpenID Connect. 022 * 023 * <p>Can be used to select RSA and EC key candidates for the verification of: 024 * 025 * <ul> 026 * <li>Signed ID tokens 027 * <li>Signed JWT-encoded UserInfo responses 028 * <li>Signed OpenID request objects 029 * </ul> 030 * 031 * <p>Client secret candidates for the verification of: 032 * 033 * <ul> 034 * <li>HMAC ID tokens 035 * <li>HMAC JWT-encoded UserInfo responses 036 * <li>HMAC OpenID request objects 037 * </ul> 038 */ 039@ThreadSafe 040public class JWSVerificationKeySelector extends AbstractJWKSelectorWithSource implements JWSKeySelector { 041 042 043 /** 044 * The expected JWS algorithm. 045 */ 046 private final JWSAlgorithm jwsAlg; 047 048 049 /** 050 * Creates a new JWS verification key selector. 051 * 052 * @param id Identifier for the JWS originator, typically an 053 * OAuth 2.0 server issuer ID, or client ID. Must not 054 * be {@code null}. 055 * @param jwsAlg The expected JWS algorithm for the objects to be 056 * verified. Must not be {@code null}. 057 * @param jwkSource The JWK source. Must not be {@code null}. 058 */ 059 public JWSVerificationKeySelector(final Identifier id, final JWSAlgorithm jwsAlg, final JWKSource jwkSource) { 060 super(id, jwkSource); 061 if (jwsAlg == null) { 062 throw new IllegalArgumentException("The JWS algorithm must not be null"); 063 } 064 this.jwsAlg = jwsAlg; 065 } 066 067 068 /** 069 * Returns the expected JWS algorithm. 070 * 071 * @return The expected JWS algorithm. 072 */ 073 public JWSAlgorithm getExpectedJWSAlgorithm() { 074 075 return jwsAlg; 076 } 077 078 079 /** 080 * Creates a JWK matcher for the expected JWS algorithm and the 081 * specified JWS header. 082 * 083 * @param jwsHeader The JWS header. Must not be {@code null}. 084 * 085 * @return The JWK matcher, {@code null} if none could be created. 086 */ 087 protected JWKMatcher createJWKMatcher(final JWSHeader jwsHeader) { 088 089 if (! getExpectedJWSAlgorithm().equals(jwsHeader.getAlgorithm())) { 090 // Unexpected JWS alg 091 return null; 092 } else if (JWSAlgorithm.Family.RSA.contains(getExpectedJWSAlgorithm()) || JWSAlgorithm.Family.EC.contains(getExpectedJWSAlgorithm())) { 093 // RSA or EC key matcher 094 return new JWKMatcher.Builder() 095 .keyType(KeyType.forAlgorithm(getExpectedJWSAlgorithm())) 096 .keyID(jwsHeader.getKeyID()) 097 .keyUses(KeyUse.SIGNATURE, null) 098 .algorithms(getExpectedJWSAlgorithm(), null) 099 .build(); 100 } else if (JWSAlgorithm.Family.HMAC_SHA.contains(getExpectedJWSAlgorithm())) { 101 // Client secret matcher 102 return new JWKMatcher.Builder() 103 .keyType(KeyType.forAlgorithm(getExpectedJWSAlgorithm())) 104 .keyID(jwsHeader.getKeyID()) 105 .privateOnly(true) 106 .algorithms(getExpectedJWSAlgorithm(), null) 107 .build(); 108 } else { 109 return null; // Unsupported algorithm 110 } 111 } 112 113 114 @Override 115 public List<Key> selectJWSKeys(final JWSHeader jwsHeader, final SecurityContext context) { 116 117 if (! jwsAlg.equals(jwsHeader.getAlgorithm())) { 118 // Unexpected JWS alg 119 return Collections.emptyList(); 120 } 121 122 JWKMatcher jwkMatcher = createJWKMatcher(jwsHeader); 123 if (jwkMatcher == null) { 124 return Collections.emptyList(); 125 } 126 127 List<JWK> jwkMatches = getJWKSource().get(getIdentifier(), new JWKSelector(jwkMatcher)); 128 129 List<Key> sanitizedKeyList = new LinkedList<>(); 130 131 for (Key key: KeyConverter.toJavaKeys(jwkMatches)) { 132 if (key instanceof PublicKey || key instanceof SecretKey) { 133 sanitizedKeyList.add(key); 134 } // skip asymmetric private keys 135 } 136 137 return sanitizedKeyList; 138 } 139}