src/java.base/share/classes/sun/security/rsa/RSAKeyFactory.java
branchJDK-8145252-TLS13-branch
changeset 56542 56aaa6cb3693
parent 47216 71c04702a3d5
child 56592 b1902b22005e
--- a/src/java.base/share/classes/sun/security/rsa/RSAKeyFactory.java	Fri May 11 14:55:56 2018 -0700
+++ b/src/java.base/share/classes/sun/security/rsa/RSAKeyFactory.java	Fri May 11 15:53:12 2018 -0700
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2003, 2011, Oracle and/or its affiliates. All rights reserved.
+ * Copyright (c) 2003, 2018, Oracle and/or its affiliates. All rights reserved.
  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
  *
  * This code is free software; you can redistribute it and/or modify it
@@ -32,10 +32,15 @@
 import java.security.spec.*;
 
 import sun.security.action.GetPropertyAction;
+import sun.security.x509.AlgorithmId;
+import static sun.security.rsa.RSAUtil.KeyType;
 
 /**
- * KeyFactory for RSA keys. Keys must be instances of PublicKey or PrivateKey
- * and getAlgorithm() must return "RSA". For such keys, it supports conversion
+ * KeyFactory for RSA keys, e.g. "RSA", "RSASSA-PSS".
+ * Keys must be instances of PublicKey or PrivateKey
+ * and getAlgorithm() must return a value which matches the type which are
+ * specified during construction time of the KeyFactory object.
+ * For such keys, it supports conversion
  * between the following:
  *
  * For public keys:
@@ -58,21 +63,21 @@
  * @since   1.5
  * @author  Andreas Sterbenz
  */
-public final class RSAKeyFactory extends KeyFactorySpi {
+public class RSAKeyFactory extends KeyFactorySpi {
 
-    private static final Class<?> rsaPublicKeySpecClass =
-                                                RSAPublicKeySpec.class;
-    private static final Class<?> rsaPrivateKeySpecClass =
-                                                RSAPrivateKeySpec.class;
-    private static final Class<?> rsaPrivateCrtKeySpecClass =
-                                                RSAPrivateCrtKeySpec.class;
-
-    private static final Class<?> x509KeySpecClass  = X509EncodedKeySpec.class;
-    private static final Class<?> pkcs8KeySpecClass = PKCS8EncodedKeySpec.class;
+    private static final Class<?> RSA_PUB_KEYSPEC_CLS = RSAPublicKeySpec.class;
+    private static final Class<?> RSA_PRIV_KEYSPEC_CLS =
+            RSAPrivateKeySpec.class;
+    private static final Class<?> RSA_PRIVCRT_KEYSPEC_CLS =
+            RSAPrivateCrtKeySpec.class;
+    private static final Class<?> X509_KEYSPEC_CLS = X509EncodedKeySpec.class;
+    private static final Class<?> PKCS8_KEYSPEC_CLS = PKCS8EncodedKeySpec.class;
 
     public static final int MIN_MODLEN = 512;
     public static final int MAX_MODLEN = 16384;
 
+    private final KeyType type;
+
     /*
      * If the modulus length is above this value, restrict the size of
      * the exponent to something that can be reasonably computed.  We
@@ -87,11 +92,18 @@
         "true".equalsIgnoreCase(GetPropertyAction.privilegedGetProperty(
                 "sun.security.rsa.restrictRSAExponent", "true"));
 
-    // instance used for static translateKey();
-    private static final RSAKeyFactory INSTANCE = new RSAKeyFactory();
+    static RSAKeyFactory getInstance(KeyType type) {
+        return new RSAKeyFactory(type);
+    }
 
-    public RSAKeyFactory() {
-        // empty
+    // Internal utility method for checking key algorithm
+    private static void checkKeyAlgo(Key key, String expectedAlg)
+            throws InvalidKeyException {
+        String keyAlg = key.getAlgorithm();
+        if (!(keyAlg.equalsIgnoreCase(expectedAlg))) {
+            throw new InvalidKeyException("Expected a " + expectedAlg
+                    + " key, but got " + keyAlg);
+        }
     }
 
     /**
@@ -107,7 +119,14 @@
             (key instanceof RSAPublicKeyImpl)) {
             return (RSAKey)key;
         } else {
-            return (RSAKey)INSTANCE.engineTranslateKey(key);
+            try {
+                String keyAlgo = key.getAlgorithm();
+                KeyType type = KeyType.lookup(keyAlgo);
+                RSAKeyFactory kf = RSAKeyFactory.getInstance(type);
+                return (RSAKey) kf.engineTranslateKey(key);
+            } catch (ProviderException e) {
+                throw new InvalidKeyException(e);
+            }
         }
     }
 
@@ -171,6 +190,15 @@
         }
     }
 
+    // disallowed as KeyType is required
+    private RSAKeyFactory() {
+        this.type = KeyType.RSA;
+    }
+
+    public RSAKeyFactory(KeyType type) {
+        this.type = type;
+    }
+
     /**
      * Translate an RSA key into a SunRsaSign RSA key. If conversion is
      * not possible, throw an InvalidKeyException.
@@ -180,9 +208,14 @@
         if (key == null) {
             throw new InvalidKeyException("Key must not be null");
         }
-        String keyAlg = key.getAlgorithm();
-        if (keyAlg.equals("RSA") == false) {
-            throw new InvalidKeyException("Not an RSA key: " + keyAlg);
+        // ensure the key algorithm matches the current KeyFactory instance
+        checkKeyAlgo(key, type.keyAlgo());
+
+        // no translation needed if the key is already our own impl 
+        if ((key instanceof RSAPrivateKeyImpl) ||
+            (key instanceof RSAPrivateCrtKeyImpl) ||
+            (key instanceof RSAPublicKeyImpl)) {
+            return key;
         }
         if (key instanceof PublicKey) {
             return translatePublicKey((PublicKey)key);
@@ -221,22 +254,22 @@
     private PublicKey translatePublicKey(PublicKey key)
             throws InvalidKeyException {
         if (key instanceof RSAPublicKey) {
-            if (key instanceof RSAPublicKeyImpl) {
-                return key;
-            }
             RSAPublicKey rsaKey = (RSAPublicKey)key;
             try {
                 return new RSAPublicKeyImpl(
+                    RSAUtil.createAlgorithmId(type, rsaKey.getParams()),
                     rsaKey.getModulus(),
-                    rsaKey.getPublicExponent()
-                );
-            } catch (RuntimeException e) {
+                    rsaKey.getPublicExponent());
+            } catch (ProviderException e) {
                 // catch providers that incorrectly implement RSAPublicKey
                 throw new InvalidKeyException("Invalid key", e);
             }
         } else if ("X.509".equals(key.getFormat())) {
             byte[] encoded = key.getEncoded();
-            return new RSAPublicKeyImpl(encoded);
+            RSAPublicKey translated = new RSAPublicKeyImpl(encoded);
+            // ensure the key algorithm matches the current KeyFactory instance
+            checkKeyAlgo(translated, type.keyAlgo());
+            return translated;
         } else {
             throw new InvalidKeyException("Public keys must be instance "
                 + "of RSAPublicKey or have X.509 encoding");
@@ -247,12 +280,10 @@
     private PrivateKey translatePrivateKey(PrivateKey key)
             throws InvalidKeyException {
         if (key instanceof RSAPrivateCrtKey) {
-            if (key instanceof RSAPrivateCrtKeyImpl) {
-                return key;
-            }
             RSAPrivateCrtKey rsaKey = (RSAPrivateCrtKey)key;
             try {
                 return new RSAPrivateCrtKeyImpl(
+                    RSAUtil.createAlgorithmId(type, rsaKey.getParams()),
                     rsaKey.getModulus(),
                     rsaKey.getPublicExponent(),
                     rsaKey.getPrivateExponent(),
@@ -262,27 +293,28 @@
                     rsaKey.getPrimeExponentQ(),
                     rsaKey.getCrtCoefficient()
                 );
-            } catch (RuntimeException e) {
+            } catch (ProviderException e) {
                 // catch providers that incorrectly implement RSAPrivateCrtKey
                 throw new InvalidKeyException("Invalid key", e);
             }
         } else if (key instanceof RSAPrivateKey) {
-            if (key instanceof RSAPrivateKeyImpl) {
-                return key;
-            }
             RSAPrivateKey rsaKey = (RSAPrivateKey)key;
             try {
                 return new RSAPrivateKeyImpl(
+                    RSAUtil.createAlgorithmId(type, rsaKey.getParams()),
                     rsaKey.getModulus(),
                     rsaKey.getPrivateExponent()
                 );
-            } catch (RuntimeException e) {
+            } catch (ProviderException e) {
                 // catch providers that incorrectly implement RSAPrivateKey
                 throw new InvalidKeyException("Invalid key", e);
             }
         } else if ("PKCS#8".equals(key.getFormat())) {
             byte[] encoded = key.getEncoded();
-            return RSAPrivateCrtKeyImpl.newKey(encoded);
+            RSAPrivateKey translated = RSAPrivateCrtKeyImpl.newKey(encoded);
+            // ensure the key algorithm matches the current KeyFactory instance
+            checkKeyAlgo(translated, type.keyAlgo());
+            return translated;
         } else {
             throw new InvalidKeyException("Private keys must be instance "
                 + "of RSAPrivate(Crt)Key or have PKCS#8 encoding");
@@ -294,13 +326,21 @@
             throws GeneralSecurityException {
         if (keySpec instanceof X509EncodedKeySpec) {
             X509EncodedKeySpec x509Spec = (X509EncodedKeySpec)keySpec;
-            return new RSAPublicKeyImpl(x509Spec.getEncoded());
+            RSAPublicKey generated = new RSAPublicKeyImpl(x509Spec.getEncoded());
+            // ensure the key algorithm matches the current KeyFactory instance
+            checkKeyAlgo(generated, type.keyAlgo());
+            return generated;
         } else if (keySpec instanceof RSAPublicKeySpec) {
             RSAPublicKeySpec rsaSpec = (RSAPublicKeySpec)keySpec;
-            return new RSAPublicKeyImpl(
-                rsaSpec.getModulus(),
-                rsaSpec.getPublicExponent()
-            );
+            try {
+                return new RSAPublicKeyImpl(
+                    RSAUtil.createAlgorithmId(type, rsaSpec.getParams()),
+                    rsaSpec.getModulus(),
+                    rsaSpec.getPublicExponent()
+                );
+            } catch (ProviderException e) {
+                throw new InvalidKeySpecException(e);
+            }
         } else {
             throw new InvalidKeySpecException("Only RSAPublicKeySpec "
                 + "and X509EncodedKeySpec supported for RSA public keys");
@@ -312,25 +352,38 @@
             throws GeneralSecurityException {
         if (keySpec instanceof PKCS8EncodedKeySpec) {
             PKCS8EncodedKeySpec pkcsSpec = (PKCS8EncodedKeySpec)keySpec;
-            return RSAPrivateCrtKeyImpl.newKey(pkcsSpec.getEncoded());
+            RSAPrivateKey generated = RSAPrivateCrtKeyImpl.newKey(pkcsSpec.getEncoded());
+            // ensure the key algorithm matches the current KeyFactory instance
+            checkKeyAlgo(generated, type.keyAlgo());
+            return generated;
         } else if (keySpec instanceof RSAPrivateCrtKeySpec) {
             RSAPrivateCrtKeySpec rsaSpec = (RSAPrivateCrtKeySpec)keySpec;
-            return new RSAPrivateCrtKeyImpl(
-                rsaSpec.getModulus(),
-                rsaSpec.getPublicExponent(),
-                rsaSpec.getPrivateExponent(),
-                rsaSpec.getPrimeP(),
-                rsaSpec.getPrimeQ(),
-                rsaSpec.getPrimeExponentP(),
-                rsaSpec.getPrimeExponentQ(),
-                rsaSpec.getCrtCoefficient()
-            );
+            try {
+                return new RSAPrivateCrtKeyImpl(
+                    RSAUtil.createAlgorithmId(type, rsaSpec.getParams()),
+                    rsaSpec.getModulus(),
+                    rsaSpec.getPublicExponent(),
+                    rsaSpec.getPrivateExponent(),
+                    rsaSpec.getPrimeP(),
+                    rsaSpec.getPrimeQ(),
+                    rsaSpec.getPrimeExponentP(),
+                    rsaSpec.getPrimeExponentQ(),
+                    rsaSpec.getCrtCoefficient()
+                );
+            } catch (ProviderException e) {
+                throw new InvalidKeySpecException(e);
+            }
         } else if (keySpec instanceof RSAPrivateKeySpec) {
             RSAPrivateKeySpec rsaSpec = (RSAPrivateKeySpec)keySpec;
-            return new RSAPrivateKeyImpl(
-                rsaSpec.getModulus(),
-                rsaSpec.getPrivateExponent()
-            );
+            try {
+                return new RSAPrivateKeyImpl(
+                    RSAUtil.createAlgorithmId(type, rsaSpec.getParams()),
+                    rsaSpec.getModulus(),
+                    rsaSpec.getPrivateExponent()
+                );
+            } catch (ProviderException e) {
+                throw new InvalidKeySpecException(e);
+            }
         } else {
             throw new InvalidKeySpecException("Only RSAPrivate(Crt)KeySpec "
                 + "and PKCS8EncodedKeySpec supported for RSA private keys");
@@ -349,12 +402,13 @@
         }
         if (key instanceof RSAPublicKey) {
             RSAPublicKey rsaKey = (RSAPublicKey)key;
-            if (rsaPublicKeySpecClass.isAssignableFrom(keySpec)) {
+            if (RSA_PUB_KEYSPEC_CLS.isAssignableFrom(keySpec)) {
                 return keySpec.cast(new RSAPublicKeySpec(
                     rsaKey.getModulus(),
-                    rsaKey.getPublicExponent()
+                    rsaKey.getPublicExponent(),
+                    rsaKey.getParams()
                 ));
-            } else if (x509KeySpecClass.isAssignableFrom(keySpec)) {
+            } else if (X509_KEYSPEC_CLS.isAssignableFrom(keySpec)) {
                 return keySpec.cast(new X509EncodedKeySpec(key.getEncoded()));
             } else {
                 throw new InvalidKeySpecException
@@ -362,9 +416,9 @@
                         + "X509EncodedKeySpec for RSA public keys");
             }
         } else if (key instanceof RSAPrivateKey) {
-            if (pkcs8KeySpecClass.isAssignableFrom(keySpec)) {
+            if (PKCS8_KEYSPEC_CLS.isAssignableFrom(keySpec)) {
                 return keySpec.cast(new PKCS8EncodedKeySpec(key.getEncoded()));
-            } else if (rsaPrivateCrtKeySpecClass.isAssignableFrom(keySpec)) {
+            } else if (RSA_PRIVCRT_KEYSPEC_CLS.isAssignableFrom(keySpec)) {
                 if (key instanceof RSAPrivateCrtKey) {
                     RSAPrivateCrtKey crtKey = (RSAPrivateCrtKey)key;
                     return keySpec.cast(new RSAPrivateCrtKeySpec(
@@ -375,17 +429,19 @@
                         crtKey.getPrimeQ(),
                         crtKey.getPrimeExponentP(),
                         crtKey.getPrimeExponentQ(),
-                        crtKey.getCrtCoefficient()
+                        crtKey.getCrtCoefficient(),
+                        crtKey.getParams()
                     ));
                 } else {
                     throw new InvalidKeySpecException
                     ("RSAPrivateCrtKeySpec can only be used with CRT keys");
                 }
-            } else if (rsaPrivateKeySpecClass.isAssignableFrom(keySpec)) {
+            } else if (RSA_PRIV_KEYSPEC_CLS.isAssignableFrom(keySpec)) {
                 RSAPrivateKey rsaKey = (RSAPrivateKey)key;
                 return keySpec.cast(new RSAPrivateKeySpec(
                     rsaKey.getModulus(),
-                    rsaKey.getPrivateExponent()
+                    rsaKey.getPrivateExponent(),
+                    rsaKey.getParams()
                 ));
             } else {
                 throw new InvalidKeySpecException
@@ -397,4 +453,16 @@
             throw new InvalidKeySpecException("Neither public nor private key");
         }
     }
+
+    public static final class Legacy extends RSAKeyFactory {
+        public Legacy() {
+            super(KeyType.RSA);
+        }
+    }
+
+    public static final class PSS extends RSAKeyFactory {
+        public PSS() {
+            super(KeyType.PSS);
+        }
+    }
 }