< prev index next >

src/java.base/share/classes/sun/security/rsa/RSAKeyFactory.java

Print this page

        

@@ -1,7 +1,7 @@
 /*
- * Copyright (c) 2003, 2011, Oracle and/or its affiliates. All rights reserved.
+ * Copyright (c) 2003, 2018, Oracle and/or its affiliates. All rights reserved.
  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
  *
  * This code is free software; you can redistribute it and/or modify it
  * under the terms of the GNU General Public License version 2 only, as
  * published by the Free Software Foundation.  Oracle designates this

@@ -30,14 +30,19 @@
 import java.security.*;
 import java.security.interfaces.*;
 import java.security.spec.*;
 
 import sun.security.action.GetPropertyAction;
+import sun.security.x509.AlgorithmId;
+import static sun.security.rsa.RSAUtil.KeyType;
 
 /**
- * KeyFactory for RSA keys. Keys must be instances of PublicKey or PrivateKey
- * and getAlgorithm() must return "RSA". For such keys, it supports conversion
+ * KeyFactory for RSA keys, e.g. "RSA", "RSASSA-PSS".
+ * Keys must be instances of PublicKey or PrivateKey
+ * and getAlgorithm() must return a value which matches the type which are
+ * specified during construction time of the KeyFactory object.
+ * For such keys, it supports conversion
  * between the following:
  *
  * For public keys:
  *  . PublicKey with an X.509 encoding
  *  . RSAPublicKey

@@ -56,25 +61,25 @@
  * Note: as always, RSA keys should be at least 512 bits long
  *
  * @since   1.5
  * @author  Andreas Sterbenz
  */
-public final class RSAKeyFactory extends KeyFactorySpi {
+public class RSAKeyFactory extends KeyFactorySpi {
 
-    private static final Class<?> rsaPublicKeySpecClass =
-                                                RSAPublicKeySpec.class;
-    private static final Class<?> rsaPrivateKeySpecClass =
+    private static final Class<?> RSA_PUB_KEYSPEC_CLS = RSAPublicKeySpec.class;
+    private static final Class<?> RSA_PRIV_KEYSPEC_CLS =
                                                 RSAPrivateKeySpec.class;
-    private static final Class<?> rsaPrivateCrtKeySpecClass =
+    private static final Class<?> RSA_PRIVCRT_KEYSPEC_CLS =
                                                 RSAPrivateCrtKeySpec.class;
-
-    private static final Class<?> x509KeySpecClass  = X509EncodedKeySpec.class;
-    private static final Class<?> pkcs8KeySpecClass = PKCS8EncodedKeySpec.class;
+    private static final Class<?> X509_KEYSPEC_CLS = X509EncodedKeySpec.class;
+    private static final Class<?> PKCS8_KEYSPEC_CLS = PKCS8EncodedKeySpec.class;
 
     public static final int MIN_MODLEN = 512;
     public static final int MAX_MODLEN = 16384;
 
+    private final KeyType type;
+
     /*
      * If the modulus length is above this value, restrict the size of
      * the exponent to something that can be reasonably computed.  We
      * could simply hardcode the exp len to something like 64 bits, but
      * this approach allows flexibility in case impls would like to use

@@ -85,15 +90,22 @@
 
     private static final boolean restrictExpLen =
         "true".equalsIgnoreCase(GetPropertyAction.privilegedGetProperty(
                 "sun.security.rsa.restrictRSAExponent", "true"));
 
-    // instance used for static translateKey();
-    private static final RSAKeyFactory INSTANCE = new RSAKeyFactory();
+    static RSAKeyFactory getInstance(KeyType type) {
+        return new RSAKeyFactory(type);
+    }
 
-    public RSAKeyFactory() {
-        // empty
+    // Internal utility method for checking key algorithm
+    private static void checkKeyAlgo(Key key, String expectedAlg)
+            throws InvalidKeyException {
+        String keyAlg = key.getAlgorithm();
+        if (!(keyAlg.equalsIgnoreCase(expectedAlg))) {
+            throw new InvalidKeyException("Expected a " + expectedAlg
+                    + " key, but got " + keyAlg);
+        }
     }
 
     /**
      * Static method to convert Key into an instance of RSAPublicKeyImpl
      * or RSAPrivate(Crt)KeyImpl. If the key is not an RSA key or cannot be

@@ -105,11 +117,18 @@
         if ((key instanceof RSAPrivateKeyImpl) ||
             (key instanceof RSAPrivateCrtKeyImpl) ||
             (key instanceof RSAPublicKeyImpl)) {
             return (RSAKey)key;
         } else {
-            return (RSAKey)INSTANCE.engineTranslateKey(key);
+            try {
+                String keyAlgo = key.getAlgorithm();
+                KeyType type = KeyType.lookup(keyAlgo);
+                RSAKeyFactory kf = RSAKeyFactory.getInstance(type);
+                return (RSAKey) kf.engineTranslateKey(key);
+            } catch (ProviderException e) {
+                throw new InvalidKeyException(e);
+            }
         }
     }
 
     /*
      * Single test entry point for all of the mechanisms in the SunRsaSign

@@ -169,22 +188,36 @@
                 " if modulus is greater than " +
                 MAX_MODLEN_RESTRICT_EXP + " bits");
         }
     }
 
+    // disallowed as KeyType is required
+    private RSAKeyFactory() {
+        this.type = KeyType.RSA;
+    }
+
+    public RSAKeyFactory(KeyType type) {
+        this.type = type;
+    }
+
     /**
      * Translate an RSA key into a SunRsaSign RSA key. If conversion is
      * not possible, throw an InvalidKeyException.
      * See also JCA doc.
      */
     protected Key engineTranslateKey(Key key) throws InvalidKeyException {
         if (key == null) {
             throw new InvalidKeyException("Key must not be null");
         }
-        String keyAlg = key.getAlgorithm();
-        if (keyAlg.equals("RSA") == false) {
-            throw new InvalidKeyException("Not an RSA key: " + keyAlg);
+        // ensure the key algorithm matches the current KeyFactory instance
+        checkKeyAlgo(key, type.keyAlgo());
+
+        // no translation needed if the key is already our own impl 
+        if ((key instanceof RSAPrivateKeyImpl) ||
+            (key instanceof RSAPrivateCrtKeyImpl) ||
+            (key instanceof RSAPublicKeyImpl)) {
+            return key;
         }
         if (key instanceof PublicKey) {
             return translatePublicKey((PublicKey)key);
         } else if (key instanceof PrivateKey) {
             return translatePrivateKey((PrivateKey)key);

@@ -219,72 +252,71 @@
 
     // internal implementation of translateKey() for public keys. See JCA doc
     private PublicKey translatePublicKey(PublicKey key)
             throws InvalidKeyException {
         if (key instanceof RSAPublicKey) {
-            if (key instanceof RSAPublicKeyImpl) {
-                return key;
-            }
             RSAPublicKey rsaKey = (RSAPublicKey)key;
             try {
                 return new RSAPublicKeyImpl(
+                    RSAUtil.createAlgorithmId(type, rsaKey.getParams()),
                     rsaKey.getModulus(),
-                    rsaKey.getPublicExponent()
-                );
-            } catch (RuntimeException e) {
+                    rsaKey.getPublicExponent());
+            } catch (ProviderException e) {
                 // catch providers that incorrectly implement RSAPublicKey
                 throw new InvalidKeyException("Invalid key", e);
             }
         } else if ("X.509".equals(key.getFormat())) {
             byte[] encoded = key.getEncoded();
-            return new RSAPublicKeyImpl(encoded);
+            RSAPublicKey translated = new RSAPublicKeyImpl(encoded);
+            // ensure the key algorithm matches the current KeyFactory instance
+            checkKeyAlgo(translated, type.keyAlgo());
+            return translated;
         } else {
             throw new InvalidKeyException("Public keys must be instance "
                 + "of RSAPublicKey or have X.509 encoding");
         }
     }
 
     // internal implementation of translateKey() for private keys. See JCA doc
     private PrivateKey translatePrivateKey(PrivateKey key)
             throws InvalidKeyException {
         if (key instanceof RSAPrivateCrtKey) {
-            if (key instanceof RSAPrivateCrtKeyImpl) {
-                return key;
-            }
             RSAPrivateCrtKey rsaKey = (RSAPrivateCrtKey)key;
             try {
                 return new RSAPrivateCrtKeyImpl(
+                    RSAUtil.createAlgorithmId(type, rsaKey.getParams()),
                     rsaKey.getModulus(),
                     rsaKey.getPublicExponent(),
                     rsaKey.getPrivateExponent(),
                     rsaKey.getPrimeP(),
                     rsaKey.getPrimeQ(),
                     rsaKey.getPrimeExponentP(),
                     rsaKey.getPrimeExponentQ(),
                     rsaKey.getCrtCoefficient()
                 );
-            } catch (RuntimeException e) {
+            } catch (ProviderException e) {
                 // catch providers that incorrectly implement RSAPrivateCrtKey
                 throw new InvalidKeyException("Invalid key", e);
             }
         } else if (key instanceof RSAPrivateKey) {
-            if (key instanceof RSAPrivateKeyImpl) {
-                return key;
-            }
             RSAPrivateKey rsaKey = (RSAPrivateKey)key;
             try {
                 return new RSAPrivateKeyImpl(
+                    RSAUtil.createAlgorithmId(type, rsaKey.getParams()),
                     rsaKey.getModulus(),
                     rsaKey.getPrivateExponent()
                 );
-            } catch (RuntimeException e) {
+            } catch (ProviderException e) {
                 // catch providers that incorrectly implement RSAPrivateKey
                 throw new InvalidKeyException("Invalid key", e);
             }
         } else if ("PKCS#8".equals(key.getFormat())) {
             byte[] encoded = key.getEncoded();
-            return RSAPrivateCrtKeyImpl.newKey(encoded);
+            RSAPrivateKey translated = RSAPrivateCrtKeyImpl.newKey(encoded);
+            // ensure the key algorithm matches the current KeyFactory instance
+            checkKeyAlgo(translated, type.keyAlgo());
+            return translated;
         } else {
             throw new InvalidKeyException("Private keys must be instance "
                 + "of RSAPrivate(Crt)Key or have PKCS#8 encoding");
         }
     }

@@ -292,17 +324,25 @@
     // internal implementation of generatePublic. See JCA doc
     private PublicKey generatePublic(KeySpec keySpec)
             throws GeneralSecurityException {
         if (keySpec instanceof X509EncodedKeySpec) {
             X509EncodedKeySpec x509Spec = (X509EncodedKeySpec)keySpec;
-            return new RSAPublicKeyImpl(x509Spec.getEncoded());
+            RSAPublicKey generated = new RSAPublicKeyImpl(x509Spec.getEncoded());
+            // ensure the key algorithm matches the current KeyFactory instance
+            checkKeyAlgo(generated, type.keyAlgo());
+            return generated;
         } else if (keySpec instanceof RSAPublicKeySpec) {
             RSAPublicKeySpec rsaSpec = (RSAPublicKeySpec)keySpec;
+            try {
             return new RSAPublicKeyImpl(
+                    RSAUtil.createAlgorithmId(type, rsaSpec.getParams()),
                 rsaSpec.getModulus(),
                 rsaSpec.getPublicExponent()
             );
+            } catch (ProviderException e) {
+                throw new InvalidKeySpecException(e);
+            }
         } else {
             throw new InvalidKeySpecException("Only RSAPublicKeySpec "
                 + "and X509EncodedKeySpec supported for RSA public keys");
         }
     }

@@ -310,29 +350,42 @@
     // internal implementation of generatePrivate. See JCA doc
     private PrivateKey generatePrivate(KeySpec keySpec)
             throws GeneralSecurityException {
         if (keySpec instanceof PKCS8EncodedKeySpec) {
             PKCS8EncodedKeySpec pkcsSpec = (PKCS8EncodedKeySpec)keySpec;
-            return RSAPrivateCrtKeyImpl.newKey(pkcsSpec.getEncoded());
+            RSAPrivateKey generated = RSAPrivateCrtKeyImpl.newKey(pkcsSpec.getEncoded());
+            // ensure the key algorithm matches the current KeyFactory instance
+            checkKeyAlgo(generated, type.keyAlgo());
+            return generated;
         } else if (keySpec instanceof RSAPrivateCrtKeySpec) {
             RSAPrivateCrtKeySpec rsaSpec = (RSAPrivateCrtKeySpec)keySpec;
+            try {
             return new RSAPrivateCrtKeyImpl(
+                    RSAUtil.createAlgorithmId(type, rsaSpec.getParams()),
                 rsaSpec.getModulus(),
                 rsaSpec.getPublicExponent(),
                 rsaSpec.getPrivateExponent(),
                 rsaSpec.getPrimeP(),
                 rsaSpec.getPrimeQ(),
                 rsaSpec.getPrimeExponentP(),
                 rsaSpec.getPrimeExponentQ(),
                 rsaSpec.getCrtCoefficient()
             );
+            } catch (ProviderException e) {
+                throw new InvalidKeySpecException(e);
+            }
         } else if (keySpec instanceof RSAPrivateKeySpec) {
             RSAPrivateKeySpec rsaSpec = (RSAPrivateKeySpec)keySpec;
+            try {
             return new RSAPrivateKeyImpl(
+                    RSAUtil.createAlgorithmId(type, rsaSpec.getParams()),
                 rsaSpec.getModulus(),
                 rsaSpec.getPrivateExponent()
             );
+            } catch (ProviderException e) {
+                throw new InvalidKeySpecException(e);
+            }
         } else {
             throw new InvalidKeySpecException("Only RSAPrivate(Crt)KeySpec "
                 + "and PKCS8EncodedKeySpec supported for RSA private keys");
         }
     }

@@ -347,47 +400,50 @@
         } catch (InvalidKeyException e) {
             throw new InvalidKeySpecException(e);
         }
         if (key instanceof RSAPublicKey) {
             RSAPublicKey rsaKey = (RSAPublicKey)key;
-            if (rsaPublicKeySpecClass.isAssignableFrom(keySpec)) {
+            if (RSA_PUB_KEYSPEC_CLS.isAssignableFrom(keySpec)) {
                 return keySpec.cast(new RSAPublicKeySpec(
                     rsaKey.getModulus(),
-                    rsaKey.getPublicExponent()
+                    rsaKey.getPublicExponent(),
+                    rsaKey.getParams()
                 ));
-            } else if (x509KeySpecClass.isAssignableFrom(keySpec)) {
+            } else if (X509_KEYSPEC_CLS.isAssignableFrom(keySpec)) {
                 return keySpec.cast(new X509EncodedKeySpec(key.getEncoded()));
             } else {
                 throw new InvalidKeySpecException
                         ("KeySpec must be RSAPublicKeySpec or "
                         + "X509EncodedKeySpec for RSA public keys");
             }
         } else if (key instanceof RSAPrivateKey) {
-            if (pkcs8KeySpecClass.isAssignableFrom(keySpec)) {
+            if (PKCS8_KEYSPEC_CLS.isAssignableFrom(keySpec)) {
                 return keySpec.cast(new PKCS8EncodedKeySpec(key.getEncoded()));
-            } else if (rsaPrivateCrtKeySpecClass.isAssignableFrom(keySpec)) {
+            } else if (RSA_PRIVCRT_KEYSPEC_CLS.isAssignableFrom(keySpec)) {
                 if (key instanceof RSAPrivateCrtKey) {
                     RSAPrivateCrtKey crtKey = (RSAPrivateCrtKey)key;
                     return keySpec.cast(new RSAPrivateCrtKeySpec(
                         crtKey.getModulus(),
                         crtKey.getPublicExponent(),
                         crtKey.getPrivateExponent(),
                         crtKey.getPrimeP(),
                         crtKey.getPrimeQ(),
                         crtKey.getPrimeExponentP(),
                         crtKey.getPrimeExponentQ(),
-                        crtKey.getCrtCoefficient()
+                        crtKey.getCrtCoefficient(),
+                        crtKey.getParams()
                     ));
                 } else {
                     throw new InvalidKeySpecException
                     ("RSAPrivateCrtKeySpec can only be used with CRT keys");
                 }
-            } else if (rsaPrivateKeySpecClass.isAssignableFrom(keySpec)) {
+            } else if (RSA_PRIV_KEYSPEC_CLS.isAssignableFrom(keySpec)) {
                 RSAPrivateKey rsaKey = (RSAPrivateKey)key;
                 return keySpec.cast(new RSAPrivateKeySpec(
                     rsaKey.getModulus(),
-                    rsaKey.getPrivateExponent()
+                    rsaKey.getPrivateExponent(),
+                    rsaKey.getParams()
                 ));
             } else {
                 throw new InvalidKeySpecException
                         ("KeySpec must be RSAPrivate(Crt)KeySpec or "
                         + "PKCS8EncodedKeySpec for RSA private keys");

@@ -395,6 +451,18 @@
         } else {
             // should not occur, caught in engineTranslateKey()
             throw new InvalidKeySpecException("Neither public nor private key");
         }
     }
+
+    public static final class Legacy extends RSAKeyFactory {
+        public Legacy() {
+            super(KeyType.RSA);
+        }
+    }
+
+    public static final class PSS extends RSAKeyFactory {
+        public PSS() {
+            super(KeyType.PSS);
+        }
+    }
 }
< prev index next >