@Override public RSAPublicKey getPublicKeyById(String keyId) { try { final PublicKey publicKey = jwkProvider.get(keyId).getPublicKey(); if (!(publicKey instanceof RSAPublicKey)) { throw new IllegalArgumentException(String.format("Key with ID '%s' was found in JWKS but is not a RSA-key.", keyId)); } return (RSAPublicKey) publicKey; } catch (JwkException e) { throw new IllegalArgumentException(String.format("Key with ID '%s' couldn't be fetched from JWKS.", keyId), e); } }
private List<Jwk> getAll() throws SigningKeyNotFoundException { List<Jwk> jwks = Lists.newArrayList(); @SuppressWarnings("unchecked") final List<Map<String, Object>> keys = (List<Map<String, Object>>) getJwks().get("keys"); if (keys == null || keys.isEmpty()) { throw new SigningKeyNotFoundException("No keys found in " + url.toString(), null); } try { for (Map<String, Object> values : keys) { jwks.add(Jwk.fromValues(values)); } } catch (IllegalArgumentException e) { throw new SigningKeyNotFoundException("Failed to parse jwk from json", e); } return jwks; }
@Override public Jwk get(String keyId) throws JwkException { final List<Jwk> jwks = getAll(); if (keyId == null && jwks.size() == 1) { return jwks.get(0); } if (keyId != null) { for (Jwk jwk : jwks) { if (keyId.equals(jwk.getId())) { return jwk; } } } throw new SigningKeyNotFoundException("No key found in " + url.toString() + " with kid " + keyId, null); } }
@SuppressWarnings("unchecked") private OpenIdMetadataKey findKey(String keyId) { try { Jwk jwk = this.cacheKeys.get(keyId); OpenIdMetadataKey key = new OpenIdMetadataKey(); key.key = (RSAPublicKey) jwk.getPublicKey(); key.endorsements = (List<String>) jwk.getAdditionalAttributes().get("endorsements"); return key; } catch (JwkException e) { String errorDescription = String.format("Failed to load keys: %s", e.getMessage()); LOGGER.log(Level.WARNING, errorDescription); } return null; } }
@Test public void shouldReturnPublicKeyForEmptyKeyOpsParam() throws Exception { final String kid = randomKeyId(); Map<String, Object> values = publicKeyValues(kid, Lists.newArrayList()); Jwk jwk = Jwk.fromValues(values); assertThat(jwk.getPublicKey(), notNullValue()); assertThat(jwk.getOperationsAsList(), notNullValue()); assertThat(jwk.getOperationsAsList().size(), equalTo(0)); assertThat(jwk.getOperations(), nullValue()); }
@Test public void shouldBuildWithMap() throws Exception { final String kid = randomKeyId(); Map<String, Object> values = publicKeyValues(kid, KEY_OPS_LIST); Jwk jwk = Jwk.fromValues(values); assertThat(jwk.getId(), equalTo(kid)); assertThat(jwk.getAlgorithm(), equalTo(RS_256)); assertThat(jwk.getType(), equalTo(RSA)); assertThat(jwk.getUsage(), equalTo(SIG)); assertThat(jwk.getOperationsAsList(), equalTo(KEY_OPS_LIST)); assertThat(jwk.getOperations(), is(KEY_OPS_STRING)); assertThat(jwk.getCertificateThumbprint(), equalTo(THUMBPRINT)); assertThat(jwk.getCertificateChain(), contains(CERT_CHAIN)); }
@Test public void shouldReturnKeyWithMissingAlgParam() throws Exception { final String kid = randomKeyId(); Map<String, Object> values = publicKeyValues(kid, KEY_OPS_LIST); values.remove("alg"); Jwk jwk = Jwk.fromValues(values); assertThat(jwk.getPublicKey(), notNullValue()); }
@SuppressWarnings("unchecked") static Jwk fromValues(Map<String, Object> map) { Map<String, Object> values = Maps.newHashMap(map); String kid = (String) values.remove("kid"); String kty = (String) values.remove("kty"); String alg = (String) values.remove("alg"); String use = (String) values.remove("use"); Object keyOps = values.remove("key_ops"); String x5u = (String) values.remove("x5u"); List<String> x5c = (List<String>) values.remove("x5c"); String x5t = (String) values.remove("x5t"); if (kty == null) { throw new IllegalArgumentException("Attributes " + map + " are not from a valid jwk"); } if (keyOps instanceof String) { return new Jwk(kid, kty, alg, use, (String) keyOps, x5u, x5c, x5t, values); } else { return new Jwk(kid, kty, alg, use, (List<String>) keyOps, x5u, x5c, x5t, values); } }
@Test public void shouldReturnPublicKeyForStringKeyOpsParam() throws Exception { final String kid = randomKeyId(); Map<String, Object> values = publicKeyValues(kid, KEY_OPS_STRING); Jwk jwk = Jwk.fromValues(values); assertThat(jwk.getPublicKey(), notNullValue()); assertThat(jwk.getOperationsAsList(), is(KEY_OPS_LIST)); assertThat(jwk.getOperations(), is(KEY_OPS_STRING)); }
@Test public void shouldThrowForNonRSAKey() throws Exception { final String kid = randomKeyId(); Map<String, Object> values = nonRSAValues(kid); Jwk jwk = Jwk.fromValues(values); expectedException.expect(InvalidPublicKeyException.class); expectedException.expectMessage("The key is not of type RSA"); jwk.getPublicKey(); }
@SuppressWarnings("unchecked") private OpenIdMetadataKey findKey(String keyId) { try { Jwk jwk = cacheKeys.get(keyId); OpenIdMetadataKey key = new OpenIdMetadataKey(); key.key = (RSAPublicKey) jwk.getPublicKey(); key.endorsements = (List<String>) jwk.getAdditionalAttributes().get("endorsements"); return key; } catch (JwkException e) { String errorDescription = String.format("Failed to load keys: %s", e.getMessage()); LOGGER.log(Level.WARNING, errorDescription); } return null; } }
@Test public void shouldReturnPublicKeyForNullKeyOpsParam() throws Exception { final String kid = randomKeyId(); Map<String, Object> values = publicKeyValues(kid, null); Jwk jwk = Jwk.fromValues(values); assertThat(jwk.getPublicKey(), notNullValue()); assertThat(jwk.getOperationsAsList(), nullValue()); assertThat(jwk.getOperations(), nullValue()); }
private void verifyJwt(JWT decoded) { try { Jwk jwk = jwkProvider.get(decoded.getKeyId()); // TODO check for Algorithm JWTVerifier verifier = JWT.require(Algorithm.RSA256((RSAKey) jwk.getPublicKey())).build(); verifier.verify(decoded.getToken()); } catch (Exception e) { e.printStackTrace(); throw new IllegalStateException("Bad token!"); } }
@Test public void shouldNotThrowInvalidArgumentExceptionOnMissingKidParam() throws Exception { //kid is optional - https://tools.ietf.org/html/rfc7517#section-4.5 final String kid = randomKeyId(); Map<String, Object> values = publicKeyValues(kid, KEY_OPS_LIST); values.remove("kid"); Jwk.fromValues(values); }
@Test public void shouldReturnPublicKey() throws Exception { final String kid = randomKeyId(); Map<String, Object> values = publicKeyValues(kid, KEY_OPS_LIST); Jwk jwk = Jwk.fromValues(values); assertThat(jwk.getPublicKey(), notNullValue()); assertThat(jwk.getOperationsAsList(), is(KEY_OPS_LIST)); assertThat(jwk.getOperations(), is(KEY_OPS_STRING)); }
byte[] publicKeyBytes = jwk.getPublicKey().getEncoded(); X509EncodedKeySpec keySpec = new X509EncodedKeySpec(publicKeyBytes); KeyFactory keyFactory = KeyFactory.getInstance("RSA");
@Test public void shouldThrowInvalidArgumentExceptionOnMissingKtyParam() throws Exception { final String kid = randomKeyId(); Map<String, Object> values = publicKeyValues(kid, KEY_OPS_LIST); values.remove("kty"); expectedException.expect(IllegalArgumentException.class); Jwk.fromValues(values); }
private synchronized PublicKey getJwtPublicKey(JwsHeader<?> header) { String kid = header.getKeyId(); if (header.getKeyId() == null) { LOG.warn( "'kid' is missing in the JWT token header. This is not possible to validate the token with OIDC provider keys"); throw new JwtException("'kid' is missing in the JWT token header."); } try { return jwkProvider.get(kid).getPublicKey(); } catch (JwkException e) { throw new JwtException( "Error during the retrieval of the public key during JWT token validation", e); } } }