src/java.base/share/classes/com/sun/crypto/provider/PBKDF2KeyImpl.java
changeset 51504 c9a3e3cac9c7
parent 48560 46e99460e8c9
child 54182 2e586b74722e
--- a/src/java.base/share/classes/com/sun/crypto/provider/PBKDF2KeyImpl.java	Thu Aug 23 10:52:27 2018 +0200
+++ b/src/java.base/share/classes/com/sun/crypto/provider/PBKDF2KeyImpl.java	Thu Aug 23 11:37:14 2018 +0100
@@ -93,46 +93,50 @@
         }
         // Convert the password from char[] to byte[]
         byte[] passwdBytes = getPasswordBytes(this.passwd);
+        // remove local copy
+        if (passwd != null) Arrays.fill(passwd, '\0');
 
-        this.salt = keySpec.getSalt();
-        if (salt == null) {
-            throw new InvalidKeySpecException("Salt not found");
-        }
-        this.iterCount = keySpec.getIterationCount();
-        if (iterCount == 0) {
-            throw new InvalidKeySpecException("Iteration count not found");
-        } else if (iterCount < 0) {
-            throw new InvalidKeySpecException("Iteration count is negative");
-        }
-        int keyLength = keySpec.getKeyLength();
-        if (keyLength == 0) {
-            throw new InvalidKeySpecException("Key length not found");
-        } else if (keyLength < 0) {
-            throw new InvalidKeySpecException("Key length is negative");
-        }
         try {
+            this.salt = keySpec.getSalt();
+            if (salt == null) {
+                throw new InvalidKeySpecException("Salt not found");
+            }
+            this.iterCount = keySpec.getIterationCount();
+            if (iterCount == 0) {
+                throw new InvalidKeySpecException("Iteration count not found");
+            } else if (iterCount < 0) {
+                throw new InvalidKeySpecException("Iteration count is negative");
+            }
+            int keyLength = keySpec.getKeyLength();
+            if (keyLength == 0) {
+                throw new InvalidKeySpecException("Key length not found");
+            } else if (keyLength < 0) {
+                throw new InvalidKeySpecException("Key length is negative");
+            }
             this.prf = Mac.getInstance(prfAlgo);
             // SunPKCS11 requires a non-empty PBE password
             if (passwdBytes.length == 0 &&
-                this.prf.getProvider().getName().startsWith("SunPKCS11")) {
+                    this.prf.getProvider().getName().startsWith("SunPKCS11")) {
                 this.prf = Mac.getInstance(prfAlgo, SunJCE.getInstance());
             }
+            this.key = deriveKey(prf, passwdBytes, salt, iterCount, keyLength);
         } catch (NoSuchAlgorithmException nsae) {
             // not gonna happen; re-throw just in case
             InvalidKeySpecException ike = new InvalidKeySpecException();
             ike.initCause(nsae);
             throw ike;
-        }
-        this.key = deriveKey(prf, passwdBytes, salt, iterCount, keyLength);
+        } finally {
+            Arrays.fill(passwdBytes, (byte) 0x00);
 
-        // Use the cleaner to zero the key when no longer referenced
-        final byte[] k = this.key;
-        final char[] p = this.passwd;
-        CleanerFactory.cleaner().register(this,
-                () -> {
-                    java.util.Arrays.fill(k, (byte)0x00);
-                    java.util.Arrays.fill(p, '0');
-                });
+            // Use the cleaner to zero the key when no longer referenced
+            final byte[] k = this.key;
+            final char[] p = this.passwd;
+            CleanerFactory.cleaner().register(this,
+                    () -> {
+                        Arrays.fill(k, (byte) 0x00);
+                        Arrays.fill(p, '\0');
+                    });
+        }
     }
 
     private static byte[] deriveKey(final Mac prf, final byte[] password,
@@ -266,8 +270,8 @@
         if (!(that.getFormat().equalsIgnoreCase("RAW")))
             return false;
         byte[] thatEncoded = that.getEncoded();
-        boolean ret = MessageDigest.isEqual(key, that.getEncoded());
-        java.util.Arrays.fill(thatEncoded, (byte)0x00);
+        boolean ret = MessageDigest.isEqual(key, thatEncoded);
+        Arrays.fill(thatEncoded, (byte)0x00);
         return ret;
     }