8205720: KeyFactory#getKeySpec and translateKey thorws NullPointerException with Invalid key
authorvaleriep
Sat, 30 Jun 2018 00:33:05 +0000
changeset 50918 ebff24bd9302
parent 50917 55a43beaa529
child 50919 803cfa425026
8205720: KeyFactory#getKeySpec and translateKey thorws NullPointerException with Invalid key Summary: Updated SunRsaSign provider to check and throw InvalidKeyException for null key algo/format/encoding Reviewed-by: xuelei
src/java.base/share/classes/sun/security/rsa/RSAKeyFactory.java
src/java.base/share/classes/sun/security/rsa/RSAPrivateCrtKeyImpl.java
src/java.base/share/classes/sun/security/rsa/RSAPublicKeyImpl.java
src/java.base/share/classes/sun/security/rsa/RSAUtil.java
--- a/src/java.base/share/classes/sun/security/rsa/RSAKeyFactory.java	Fri Jun 29 13:58:16 2018 -0700
+++ b/src/java.base/share/classes/sun/security/rsa/RSAKeyFactory.java	Sat Jun 30 00:33:05 2018 +0000
@@ -100,7 +100,7 @@
     private static void checkKeyAlgo(Key key, String expectedAlg)
             throws InvalidKeyException {
         String keyAlg = key.getAlgorithm();
-        if (!(keyAlg.equalsIgnoreCase(expectedAlg))) {
+        if (keyAlg == null || !(keyAlg.equalsIgnoreCase(expectedAlg))) {
             throw new InvalidKeyException("Expected a " + expectedAlg
                     + " key, but got " + keyAlg);
         }
@@ -123,8 +123,7 @@
             return (RSAKey)key;
         } else {
             try {
-                String keyAlgo = key.getAlgorithm();
-                KeyType type = KeyType.lookup(keyAlgo);
+                KeyType type = KeyType.lookup(key.getAlgorithm());
                 RSAKeyFactory kf = RSAKeyFactory.getInstance(type);
                 return (RSAKey) kf.engineTranslateKey(key);
             } catch (ProviderException e) {
@@ -268,8 +267,7 @@
                 throw new InvalidKeyException("Invalid key", e);
             }
         } else if ("X.509".equals(key.getFormat())) {
-            byte[] encoded = key.getEncoded();
-            RSAPublicKey translated = new RSAPublicKeyImpl(encoded);
+            RSAPublicKey translated = new RSAPublicKeyImpl(key.getEncoded());
             // ensure the key algorithm matches the current KeyFactory instance
             checkKeyAlgo(translated, type.keyAlgo());
             return translated;
@@ -313,8 +311,8 @@
                 throw new InvalidKeyException("Invalid key", e);
             }
         } else if ("PKCS#8".equals(key.getFormat())) {
-            byte[] encoded = key.getEncoded();
-            RSAPrivateKey translated = RSAPrivateCrtKeyImpl.newKey(encoded);
+            RSAPrivateKey translated =
+                RSAPrivateCrtKeyImpl.newKey(key.getEncoded());
             // ensure the key algorithm matches the current KeyFactory instance
             checkKeyAlgo(translated, type.keyAlgo());
             return translated;
--- a/src/java.base/share/classes/sun/security/rsa/RSAPrivateCrtKeyImpl.java	Fri Jun 29 13:58:16 2018 -0700
+++ b/src/java.base/share/classes/sun/security/rsa/RSAPrivateCrtKeyImpl.java	Sat Jun 30 00:33:05 2018 +0000
@@ -123,6 +123,10 @@
      * Construct a key from its encoding. Called from newKey above.
      */
     RSAPrivateCrtKeyImpl(byte[] encoded) throws InvalidKeyException {
+        if (encoded == null || encoded.length == 0) {
+            throw new InvalidKeyException("Missing key encoding");
+        }
+
         decode(encoded);
         RSAKeyFactory.checkRSAProviderKeyLengths(n.bitLength(), e);
         try {
--- a/src/java.base/share/classes/sun/security/rsa/RSAPublicKeyImpl.java	Fri Jun 29 13:58:16 2018 -0700
+++ b/src/java.base/share/classes/sun/security/rsa/RSAPublicKeyImpl.java	Sat Jun 30 00:33:05 2018 +0000
@@ -116,6 +116,9 @@
      * Construct a key from its encoding. Used by RSAKeyFactory.
      */
     RSAPublicKeyImpl(byte[] encoded) throws InvalidKeyException {
+        if (encoded == null || encoded.length == 0) {
+            throw new InvalidKeyException("Missing key encoding");
+        }
         decode(encoded); // this sets n and e value
         RSAKeyFactory.checkRSAProviderKeyLengths(n.bitLength(), e);
         checkExponentRange(n, e);
--- a/src/java.base/share/classes/sun/security/rsa/RSAUtil.java	Fri Jun 29 13:58:16 2018 -0700
+++ b/src/java.base/share/classes/sun/security/rsa/RSAUtil.java	Sat Jun 30 00:33:05 2018 +0000
@@ -52,7 +52,11 @@
         public String keyAlgo() {
             return algo;
         }
-        public static KeyType lookup(String name) {
+        public static KeyType lookup(String name)
+                throws InvalidKeyException, ProviderException {
+            if (name == null) {
+                throw new InvalidKeyException("Null key algorithm");
+            }
             for (KeyType kt : KeyType.values()) {
                 if (kt.keyAlgo().equalsIgnoreCase(name)) {
                     return kt;
@@ -133,21 +137,24 @@
             throws ProviderException {
         if (params == null) return null;
 
-        String algName = params.getAlgorithm();
-        KeyType type = KeyType.lookup(algName);
-        Class<? extends AlgorithmParameterSpec> specCls;
-        switch (type) {
-            case RSA:
-                throw new ProviderException("No params accepted for " +
-                    type.keyAlgo());
-            case PSS:
-                specCls = PSSParameterSpec.class;
-                break;
-            default:
-                throw new ProviderException("Unsupported RSA algorithm: " + algName);
-        }
         try {
+            String algName = params.getAlgorithm();
+            KeyType type = KeyType.lookup(algName);
+            Class<? extends AlgorithmParameterSpec> specCls;
+            switch (type) {
+                case RSA:
+                    throw new ProviderException("No params accepted for " +
+                        type.keyAlgo());
+                case PSS:
+                    specCls = PSSParameterSpec.class;
+                    break;
+                default:
+                    throw new ProviderException("Unsupported RSA algorithm: " + algName);
+            }
             return params.getParameterSpec(specCls);
+        } catch (ProviderException pe) {
+            // pass it up
+            throw pe;
         } catch (Exception e) {
             throw new ProviderException(e);
         }