--- old/src/java.base/share/classes/sun/security/rsa/RSAKeyFactory.java 2018-05-11 15:04:27.207470800 -0700 +++ new/src/java.base/share/classes/sun/security/rsa/RSAKeyFactory.java 2018-05-11 15:04:26.665503900 -0700 @@ -1,5 +1,5 @@ /* - * Copyright (c) 2003, 2011, Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 2003, 2018, Oracle and/or its affiliates. All rights reserved. * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. * * This code is free software; you can redistribute it and/or modify it @@ -32,10 +32,15 @@ import java.security.spec.*; import sun.security.action.GetPropertyAction; +import sun.security.x509.AlgorithmId; +import static sun.security.rsa.RSAUtil.KeyType; /** - * KeyFactory for RSA keys. Keys must be instances of PublicKey or PrivateKey - * and getAlgorithm() must return "RSA". For such keys, it supports conversion + * KeyFactory for RSA keys, e.g. "RSA", "RSASSA-PSS". + * Keys must be instances of PublicKey or PrivateKey + * and getAlgorithm() must return a value which matches the type which are + * specified during construction time of the KeyFactory object. + * For such keys, it supports conversion * between the following: * * For public keys: @@ -58,21 +63,21 @@ * @since 1.5 * @author Andreas Sterbenz */ -public final class RSAKeyFactory extends KeyFactorySpi { +public class RSAKeyFactory extends KeyFactorySpi { - private static final Class rsaPublicKeySpecClass = - RSAPublicKeySpec.class; - private static final Class rsaPrivateKeySpecClass = - RSAPrivateKeySpec.class; - private static final Class rsaPrivateCrtKeySpecClass = - RSAPrivateCrtKeySpec.class; - - private static final Class x509KeySpecClass = X509EncodedKeySpec.class; - private static final Class pkcs8KeySpecClass = PKCS8EncodedKeySpec.class; + private static final Class RSA_PUB_KEYSPEC_CLS = RSAPublicKeySpec.class; + private static final Class RSA_PRIV_KEYSPEC_CLS = + RSAPrivateKeySpec.class; + private static final Class RSA_PRIVCRT_KEYSPEC_CLS = + RSAPrivateCrtKeySpec.class; + private static final Class X509_KEYSPEC_CLS = X509EncodedKeySpec.class; + private static final Class PKCS8_KEYSPEC_CLS = PKCS8EncodedKeySpec.class; public static final int MIN_MODLEN = 512; public static final int MAX_MODLEN = 16384; + private final KeyType type; + /* * If the modulus length is above this value, restrict the size of * the exponent to something that can be reasonably computed. We @@ -87,11 +92,18 @@ "true".equalsIgnoreCase(GetPropertyAction.privilegedGetProperty( "sun.security.rsa.restrictRSAExponent", "true")); - // instance used for static translateKey(); - private static final RSAKeyFactory INSTANCE = new RSAKeyFactory(); + static RSAKeyFactory getInstance(KeyType type) { + return new RSAKeyFactory(type); + } - public RSAKeyFactory() { - // empty + // Internal utility method for checking key algorithm + private static void checkKeyAlgo(Key key, String expectedAlg) + throws InvalidKeyException { + String keyAlg = key.getAlgorithm(); + if (!(keyAlg.equalsIgnoreCase(expectedAlg))) { + throw new InvalidKeyException("Expected a " + expectedAlg + + " key, but got " + keyAlg); + } } /** @@ -107,7 +119,14 @@ (key instanceof RSAPublicKeyImpl)) { return (RSAKey)key; } else { - return (RSAKey)INSTANCE.engineTranslateKey(key); + try { + String keyAlgo = key.getAlgorithm(); + KeyType type = KeyType.lookup(keyAlgo); + RSAKeyFactory kf = RSAKeyFactory.getInstance(type); + return (RSAKey) kf.engineTranslateKey(key); + } catch (ProviderException e) { + throw new InvalidKeyException(e); + } } } @@ -171,6 +190,15 @@ } } + // disallowed as KeyType is required + private RSAKeyFactory() { + this.type = KeyType.RSA; + } + + public RSAKeyFactory(KeyType type) { + this.type = type; + } + /** * Translate an RSA key into a SunRsaSign RSA key. If conversion is * not possible, throw an InvalidKeyException. @@ -180,9 +208,14 @@ if (key == null) { throw new InvalidKeyException("Key must not be null"); } - String keyAlg = key.getAlgorithm(); - if (keyAlg.equals("RSA") == false) { - throw new InvalidKeyException("Not an RSA key: " + keyAlg); + // ensure the key algorithm matches the current KeyFactory instance + checkKeyAlgo(key, type.keyAlgo()); + + // no translation needed if the key is already our own impl + if ((key instanceof RSAPrivateKeyImpl) || + (key instanceof RSAPrivateCrtKeyImpl) || + (key instanceof RSAPublicKeyImpl)) { + return key; } if (key instanceof PublicKey) { return translatePublicKey((PublicKey)key); @@ -221,22 +254,22 @@ private PublicKey translatePublicKey(PublicKey key) throws InvalidKeyException { if (key instanceof RSAPublicKey) { - if (key instanceof RSAPublicKeyImpl) { - return key; - } RSAPublicKey rsaKey = (RSAPublicKey)key; try { return new RSAPublicKeyImpl( + RSAUtil.createAlgorithmId(type, rsaKey.getParams()), rsaKey.getModulus(), - rsaKey.getPublicExponent() - ); - } catch (RuntimeException e) { + rsaKey.getPublicExponent()); + } catch (ProviderException e) { // catch providers that incorrectly implement RSAPublicKey throw new InvalidKeyException("Invalid key", e); } } else if ("X.509".equals(key.getFormat())) { byte[] encoded = key.getEncoded(); - return new RSAPublicKeyImpl(encoded); + RSAPublicKey translated = new RSAPublicKeyImpl(encoded); + // ensure the key algorithm matches the current KeyFactory instance + checkKeyAlgo(translated, type.keyAlgo()); + return translated; } else { throw new InvalidKeyException("Public keys must be instance " + "of RSAPublicKey or have X.509 encoding"); @@ -247,12 +280,10 @@ private PrivateKey translatePrivateKey(PrivateKey key) throws InvalidKeyException { if (key instanceof RSAPrivateCrtKey) { - if (key instanceof RSAPrivateCrtKeyImpl) { - return key; - } RSAPrivateCrtKey rsaKey = (RSAPrivateCrtKey)key; try { return new RSAPrivateCrtKeyImpl( + RSAUtil.createAlgorithmId(type, rsaKey.getParams()), rsaKey.getModulus(), rsaKey.getPublicExponent(), rsaKey.getPrivateExponent(), @@ -262,27 +293,28 @@ rsaKey.getPrimeExponentQ(), rsaKey.getCrtCoefficient() ); - } catch (RuntimeException e) { + } catch (ProviderException e) { // catch providers that incorrectly implement RSAPrivateCrtKey throw new InvalidKeyException("Invalid key", e); } } else if (key instanceof RSAPrivateKey) { - if (key instanceof RSAPrivateKeyImpl) { - return key; - } RSAPrivateKey rsaKey = (RSAPrivateKey)key; try { return new RSAPrivateKeyImpl( + RSAUtil.createAlgorithmId(type, rsaKey.getParams()), rsaKey.getModulus(), rsaKey.getPrivateExponent() ); - } catch (RuntimeException e) { + } catch (ProviderException e) { // catch providers that incorrectly implement RSAPrivateKey throw new InvalidKeyException("Invalid key", e); } } else if ("PKCS#8".equals(key.getFormat())) { byte[] encoded = key.getEncoded(); - return RSAPrivateCrtKeyImpl.newKey(encoded); + RSAPrivateKey translated = RSAPrivateCrtKeyImpl.newKey(encoded); + // ensure the key algorithm matches the current KeyFactory instance + checkKeyAlgo(translated, type.keyAlgo()); + return translated; } else { throw new InvalidKeyException("Private keys must be instance " + "of RSAPrivate(Crt)Key or have PKCS#8 encoding"); @@ -294,13 +326,21 @@ throws GeneralSecurityException { if (keySpec instanceof X509EncodedKeySpec) { X509EncodedKeySpec x509Spec = (X509EncodedKeySpec)keySpec; - return new RSAPublicKeyImpl(x509Spec.getEncoded()); + RSAPublicKey generated = new RSAPublicKeyImpl(x509Spec.getEncoded()); + // ensure the key algorithm matches the current KeyFactory instance + checkKeyAlgo(generated, type.keyAlgo()); + return generated; } else if (keySpec instanceof RSAPublicKeySpec) { RSAPublicKeySpec rsaSpec = (RSAPublicKeySpec)keySpec; - return new RSAPublicKeyImpl( - rsaSpec.getModulus(), - rsaSpec.getPublicExponent() - ); + try { + return new RSAPublicKeyImpl( + RSAUtil.createAlgorithmId(type, rsaSpec.getParams()), + rsaSpec.getModulus(), + rsaSpec.getPublicExponent() + ); + } catch (ProviderException e) { + throw new InvalidKeySpecException(e); + } } else { throw new InvalidKeySpecException("Only RSAPublicKeySpec " + "and X509EncodedKeySpec supported for RSA public keys"); @@ -312,25 +352,38 @@ throws GeneralSecurityException { if (keySpec instanceof PKCS8EncodedKeySpec) { PKCS8EncodedKeySpec pkcsSpec = (PKCS8EncodedKeySpec)keySpec; - return RSAPrivateCrtKeyImpl.newKey(pkcsSpec.getEncoded()); + RSAPrivateKey generated = RSAPrivateCrtKeyImpl.newKey(pkcsSpec.getEncoded()); + // ensure the key algorithm matches the current KeyFactory instance + checkKeyAlgo(generated, type.keyAlgo()); + return generated; } else if (keySpec instanceof RSAPrivateCrtKeySpec) { RSAPrivateCrtKeySpec rsaSpec = (RSAPrivateCrtKeySpec)keySpec; - return new RSAPrivateCrtKeyImpl( - rsaSpec.getModulus(), - rsaSpec.getPublicExponent(), - rsaSpec.getPrivateExponent(), - rsaSpec.getPrimeP(), - rsaSpec.getPrimeQ(), - rsaSpec.getPrimeExponentP(), - rsaSpec.getPrimeExponentQ(), - rsaSpec.getCrtCoefficient() - ); + try { + return new RSAPrivateCrtKeyImpl( + RSAUtil.createAlgorithmId(type, rsaSpec.getParams()), + rsaSpec.getModulus(), + rsaSpec.getPublicExponent(), + rsaSpec.getPrivateExponent(), + rsaSpec.getPrimeP(), + rsaSpec.getPrimeQ(), + rsaSpec.getPrimeExponentP(), + rsaSpec.getPrimeExponentQ(), + rsaSpec.getCrtCoefficient() + ); + } catch (ProviderException e) { + throw new InvalidKeySpecException(e); + } } else if (keySpec instanceof RSAPrivateKeySpec) { RSAPrivateKeySpec rsaSpec = (RSAPrivateKeySpec)keySpec; - return new RSAPrivateKeyImpl( - rsaSpec.getModulus(), - rsaSpec.getPrivateExponent() - ); + try { + return new RSAPrivateKeyImpl( + RSAUtil.createAlgorithmId(type, rsaSpec.getParams()), + rsaSpec.getModulus(), + rsaSpec.getPrivateExponent() + ); + } catch (ProviderException e) { + throw new InvalidKeySpecException(e); + } } else { throw new InvalidKeySpecException("Only RSAPrivate(Crt)KeySpec " + "and PKCS8EncodedKeySpec supported for RSA private keys"); @@ -349,12 +402,13 @@ } if (key instanceof RSAPublicKey) { RSAPublicKey rsaKey = (RSAPublicKey)key; - if (rsaPublicKeySpecClass.isAssignableFrom(keySpec)) { + if (RSA_PUB_KEYSPEC_CLS.isAssignableFrom(keySpec)) { return keySpec.cast(new RSAPublicKeySpec( rsaKey.getModulus(), - rsaKey.getPublicExponent() + rsaKey.getPublicExponent(), + rsaKey.getParams() )); - } else if (x509KeySpecClass.isAssignableFrom(keySpec)) { + } else if (X509_KEYSPEC_CLS.isAssignableFrom(keySpec)) { return keySpec.cast(new X509EncodedKeySpec(key.getEncoded())); } else { throw new InvalidKeySpecException @@ -362,9 +416,9 @@ + "X509EncodedKeySpec for RSA public keys"); } } else if (key instanceof RSAPrivateKey) { - if (pkcs8KeySpecClass.isAssignableFrom(keySpec)) { + if (PKCS8_KEYSPEC_CLS.isAssignableFrom(keySpec)) { return keySpec.cast(new PKCS8EncodedKeySpec(key.getEncoded())); - } else if (rsaPrivateCrtKeySpecClass.isAssignableFrom(keySpec)) { + } else if (RSA_PRIVCRT_KEYSPEC_CLS.isAssignableFrom(keySpec)) { if (key instanceof RSAPrivateCrtKey) { RSAPrivateCrtKey crtKey = (RSAPrivateCrtKey)key; return keySpec.cast(new RSAPrivateCrtKeySpec( @@ -375,17 +429,19 @@ crtKey.getPrimeQ(), crtKey.getPrimeExponentP(), crtKey.getPrimeExponentQ(), - crtKey.getCrtCoefficient() + crtKey.getCrtCoefficient(), + crtKey.getParams() )); } else { throw new InvalidKeySpecException ("RSAPrivateCrtKeySpec can only be used with CRT keys"); } - } else if (rsaPrivateKeySpecClass.isAssignableFrom(keySpec)) { + } else if (RSA_PRIV_KEYSPEC_CLS.isAssignableFrom(keySpec)) { RSAPrivateKey rsaKey = (RSAPrivateKey)key; return keySpec.cast(new RSAPrivateKeySpec( rsaKey.getModulus(), - rsaKey.getPrivateExponent() + rsaKey.getPrivateExponent(), + rsaKey.getParams() )); } else { throw new InvalidKeySpecException @@ -397,4 +453,16 @@ throw new InvalidKeySpecException("Neither public nor private key"); } } + + public static final class Legacy extends RSAKeyFactory { + public Legacy() { + super(KeyType.RSA); + } + } + + public static final class PSS extends RSAKeyFactory { + public PSS() { + super(KeyType.PSS); + } + } }