8031346: Enhance RSA key handling
authorxuelei
Sat, 29 Mar 2014 23:24:26 +0000
changeset 25541 df83fb1a542e
parent 25540 021f6cd857f5
child 25542 05badfb785b2
8031346: Enhance RSA key handling Reviewed-by: ahgross, ascarpino, asmotrak, robm, weijun, wetmore
jdk/src/share/classes/sun/security/rsa/RSACore.java
--- a/jdk/src/share/classes/sun/security/rsa/RSACore.java	Thu Apr 24 21:04:16 2014 +0400
+++ b/jdk/src/share/classes/sun/security/rsa/RSACore.java	Sat Mar 29 23:24:26 2014 +0000
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2003, 2011, Oracle and/or its affiliates. All rights reserved.
+ * Copyright (c) 2003, 2014, Oracle and/or its affiliates. All rights reserved.
  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
  *
  * This code is free software; you can redistribute it and/or modify it
@@ -50,6 +50,15 @@
  */
 public final class RSACore {
 
+    // globally enable/disable use of blinding
+    private final static boolean ENABLE_BLINDING = true;
+
+    // cache for blinding parameters. Map<BigInteger, BlindingParameters>
+    // use a weak hashmap so that cached values are automatically cleared
+    // when the modulus is GC'ed
+    private final static Map<BigInteger, BlindingParameters>
+                blindingCache = new WeakHashMap<>();
+
     private RSACore() {
         // empty
     }
@@ -100,12 +109,12 @@
         if (key instanceof RSAPrivateCrtKey) {
             return crtCrypt(msg, (RSAPrivateCrtKey)key);
         } else {
-            return crypt(msg, key.getModulus(), key.getPrivateExponent());
+            return priCrypt(msg, key.getModulus(), key.getPrivateExponent());
         }
     }
 
     /**
-     * RSA public key ops and non-CRT private key ops. Simple modPow().
+     * RSA public key ops. Simple modPow().
      */
     private static byte[] crypt(byte[] msg, BigInteger n, BigInteger exp)
             throws BadPaddingException {
@@ -115,22 +124,29 @@
     }
 
     /**
+     * RSA non-CRT private key operations.
+     */
+    private static byte[] priCrypt(byte[] msg, BigInteger n, BigInteger exp)
+            throws BadPaddingException {
+
+        BigInteger c = parseMsg(msg, n);
+        BlindingRandomPair brp = null;
+        BigInteger m;
+        if (ENABLE_BLINDING) {
+            brp = getBlindingRandomPair(null, exp, n);
+            c = c.multiply(brp.u).mod(n);
+            m = c.modPow(exp, n);
+            m = m.multiply(brp.v).mod(n);
+        } else {
+            m = c.modPow(exp, n);
+        }
+
+        return toByteArray(m, getByteLength(n));
+    }
+
+    /**
      * RSA private key operations with CRT. Algorithm and variable naming
      * are taken from PKCS#1 v2.1, section 5.1.2.
-     *
-     * The only difference is the addition of blinding to twart timing attacks.
-     * This is described in the RSA Bulletin#2 (Jan 96) among other places.
-     * This means instead of implementing RSA as
-     *   m = c ^ d mod n (or RSA in CRT variant)
-     * we do
-     *   r  = random(0, n-1)
-     *   c' = c  * r^e  mod n
-     *   m' = c' ^ d    mod n (or RSA in CRT variant)
-     *   m  = m' * r^-1 mod n (where r^-1 is the modular inverse of r mod n)
-     * This works because r^(e*d) * r^-1 = r * r^-1 = 1 (all mod n)
-     *
-     * We do not generate new blinding parameters for each operation but reuse
-     * them BLINDING_MAX_REUSE times (see definition below).
      */
     private static byte[] crtCrypt(byte[] msg, RSAPrivateCrtKey key)
             throws BadPaddingException {
@@ -141,13 +157,13 @@
         BigInteger dP = key.getPrimeExponentP();
         BigInteger dQ = key.getPrimeExponentQ();
         BigInteger qInv = key.getCrtCoefficient();
+        BigInteger e = key.getPublicExponent();
+        BigInteger d = key.getPrivateExponent();
 
-        BlindingParameters params;
+        BlindingRandomPair brp;
         if (ENABLE_BLINDING) {
-            params = getBlindingParameters(key);
-            c = c.multiply(params.re).mod(n);
-        } else {
-            params = null;
+            brp = getBlindingRandomPair(e, d, n);
+            c = c.multiply(brp.u).mod(n);
         }
 
         // m1 = c ^ dP mod p
@@ -165,8 +181,8 @@
         // m = m2 + q * h
         BigInteger m = h.multiply(q).add(m2);
 
-        if (params != null) {
-            m = m.multiply(params.rInv).mod(n);
+        if (ENABLE_BLINDING) {
+            m = m.multiply(brp.v).mod(n);
         }
 
         return toByteArray(m, getByteLength(n));
@@ -208,82 +224,217 @@
         return t;
     }
 
-    // globally enable/disable use of blinding
-    private final static boolean ENABLE_BLINDING = true;
+    /**
+     * Parameters (u,v) for RSA Blinding.  This is described in the RSA
+     * Bulletin#2 (Jan 96) and other places:
+     *
+     *     ftp://ftp.rsa.com/pub/pdfs/bull-2.pdf
+     *
+     * The standard RSA Blinding decryption requires the public key exponent
+     * (e) and modulus (n), and converts ciphertext (c) to plaintext (p).
+     *
+     * Before the modular exponentiation operation, the input message should
+     * be multiplied by (u (mod n)), and afterward the result is corrected
+     * by multiplying with (v (mod n)).  The system should reject messages
+     * equal to (0 (mod n)).  That is:
+     *
+     *     1.  Generate r between 0 and n-1, relatively prime to n.
+     *     2.  Compute x = (c*u) mod n
+     *     3.  Compute y = (x^d) mod n
+     *     4.  Compute p = (y*v) mod n
+     *
+     * The Java APIs allows for either standard RSAPrivateKey or
+     * RSAPrivateCrtKey RSA keys.
+     *
+     * If the public exponent is available to us (e.g. RSAPrivateCrtKey),
+     * choose a random r, then let (u, v):
+     *
+     *     u = r ^ e mod n
+     *     v = r ^ (-1) mod n
+     *
+     * The proof follows:
+     *
+     *     p = (((c * u) ^ d mod n) * v) mod n
+     *       = ((c ^ d) * (u ^ d) * v) mod n
+     *       = ((c ^ d) * (r ^ e) ^ d) * (r ^ (-1))) mod n
+     *       = ((c ^ d) * (r ^ (e * d)) * (r ^ (-1))) mod n
+     *       = ((c ^ d) * (r ^ 1) * (r ^ (-1))) mod n  (see below)
+     *       = (c ^ d) mod n
+     *
+     * because in RSA cryptosystem, d is the multiplicative inverse of e:
+     *
+     *    (r^(e * d)) mod n
+     *       = (r ^ 1) mod n
+     *       = r mod n
+     *
+     * However, if the public exponent is not available (e.g. RSAPrivateKey),
+     * we mitigate the timing issue by using a similar random number blinding
+     * approach using the private key:
+     *
+     *     u = r
+     *     v = ((r ^ (-1)) ^ d) mod n
+     *
+     * This returns the same plaintext because:
+     *
+     *     p = (((c * u) ^ d mod n) * v) mod n
+     *       = ((c ^ d) * (u ^ d) * v) mod n
+     *       = ((c ^ d) * (u ^ d) * ((u ^ (-1)) ^d)) mod n
+     *       = (c ^ d) mod n
+     *
+     * Computing inverses mod n and random number generation is slow, so
+     * it is often not practical to generate a new random (u, v) pair for
+     * each new exponentiation.  The calculation of parameters might even be
+     * subject to timing attacks.  However, (u, v) pairs should not be
+     * reused since they themselves might be compromised by timing attacks,
+     * leaving the private exponent vulnerable.  An efficient solution to
+     * this problem is update u and v before each modular exponentiation
+     * step by computing:
+     *
+     *     u = u ^ 2
+     *     v = v ^ 2
+     *
+     * The total performance cost is small.
+     */
+    private final static class BlindingRandomPair {
+        final BigInteger u;
+        final BigInteger v;
 
-    // maximum number of times that we will use a set of blinding parameters
-    // value suggested by Paul Kocher (quoted by NSS)
-    private final static int BLINDING_MAX_REUSE = 50;
-
-    // cache for blinding parameters. Map<BigInteger, BlindingParameters>
-    // use a weak hashmap so that cached values are automatically cleared
-    // when the modulus is GC'ed
-    private final static Map<BigInteger, BlindingParameters> blindingCache =
-                new WeakHashMap<>();
+        BlindingRandomPair(BigInteger u, BigInteger v) {
+            this.u = u;
+            this.v = v;
+        }
+    }
 
     /**
      * Set of blinding parameters for a given RSA key.
      *
      * The RSA modulus is usually unique, so we index by modulus in
-     * blindingCache. However, to protect against the unlikely case of two
-     * keys sharing the same modulus, we also store the public exponent.
-     * This means we cannot cache blinding parameters for multiple keys that
-     * share the same modulus, but since sharing moduli is fundamentally broken
-     * an insecure, this does not matter.
+     * {@code blindingCache}.  However, to protect against the unlikely
+     * case of two keys sharing the same modulus, we also store the public
+     * or the private exponent.  This means we cannot cache blinding
+     * parameters for multiple keys that share the same modulus, but
+     * since sharing moduli is fundamentally broken and insecure, this
+     * does not matter.
      */
-    private static final class BlindingParameters {
-        // e (RSA public exponent)
-        final BigInteger e;
-        // r ^ e mod n
-        final BigInteger re;
-        // inverse of r mod n
-        final BigInteger rInv;
-        // how many more times this parameter object can be used
-        private volatile int remainingUses;
-        BlindingParameters(BigInteger e, BigInteger re, BigInteger rInv) {
+    private final static class BlindingParameters {
+        private final static BigInteger BIG_TWO = BigInteger.valueOf(2L);
+
+        // RSA public exponent
+        private final BigInteger e;
+
+        // hash code of RSA private exponent
+        private final BigInteger d;
+
+        // r ^ e mod n (CRT), or r mod n (Non-CRT)
+        private BigInteger u;
+
+        // r ^ (-1) mod n (CRT) , or ((r ^ (-1)) ^ d) mod n (Non-CRT)
+        private BigInteger v;
+
+        // e: the public exponent
+        // d: the private exponent
+        // n: the modulus
+        BlindingParameters(BigInteger e, BigInteger d, BigInteger n) {
+            this.u = null;
+            this.v = null;
             this.e = e;
-            this.re = re;
-            this.rInv = rInv;
-            // initialize remaining uses, subtract current use now
-            remainingUses = BLINDING_MAX_REUSE - 1;
+            this.d = d;
+
+            int len = n.bitLength();
+            SecureRandom random = JCAUtil.getSecureRandom();
+            u = new BigInteger(len, random).mod(n);
+            // Although the possibility is very much limited that u is zero
+            // or is not relatively prime to n, we still want to be careful
+            // about the special value.
+            //
+            // Secure random generation is expensive, try to use BigInteger.ONE
+            // this time if this new generated random number is zero or is not
+            // relatively prime to n.  Next time, new generated secure random
+            // number will be used instead.
+            if (u.equals(BigInteger.ZERO)) {
+                u = BigInteger.ONE;     // use 1 this time
+            }
+
+            try {
+                // The call to BigInteger.modInverse() checks that u is
+                // relatively prime to n.  Otherwise, ArithmeticException is
+                // thrown.
+                v = u.modInverse(n);
+            } catch (ArithmeticException ae) {
+                // if u is not relatively prime to n, use 1 this time
+                u = BigInteger.ONE;
+                v = BigInteger.ONE;
+            }
+
+            if (e != null) {
+                u = u.modPow(e, n);   // e: the public exponent
+                                      // u: random ^ e
+                                      // v: random ^ (-1)
+            } else {
+                v = v.modPow(d, n);   // d: the private exponent
+                                      // u: random
+                                      // v: random ^ (-d)
+            }
         }
-        boolean valid(BigInteger e) {
-            int k = remainingUses--;
-            return (k > 0) && this.e.equals(e);
+
+        // return null if need to reset the parameters
+        BlindingRandomPair getBlindingRandomPair(
+                BigInteger e, BigInteger d, BigInteger n) {
+
+            if ((this.e != null && this.e.equals(e)) ||
+                (this.d != null && this.d.equals(d))) {
+
+                BlindingRandomPair brp = null;
+                synchronized (this) {
+                    if (!u.equals(BigInteger.ZERO) &&
+                        !v.equals(BigInteger.ZERO)) {
+
+                        brp = new BlindingRandomPair(u, v);
+                        if (u.compareTo(BigInteger.ONE) <= 0 ||
+                            v.compareTo(BigInteger.ONE) <= 0) {
+
+                            // need to reset the random pair next time
+                            u = BigInteger.ZERO;
+                            v = BigInteger.ZERO;
+                        } else {
+                            u = u.modPow(BIG_TWO, n);
+                            v = v.modPow(BIG_TWO, n);
+                        }
+                    } // Otherwise, need to reset the random pair.
+                }
+                return brp;
+            }
+
+            return null;
         }
     }
 
-    /**
-     * Return valid RSA blinding parameters for the given private key.
-     * Use cached parameters if available. If not, generate new parameters
-     * and cache.
-     */
-    private static BlindingParameters getBlindingParameters
-            (RSAPrivateCrtKey key) {
-        BigInteger modulus = key.getModulus();
-        BigInteger e = key.getPublicExponent();
-        BlindingParameters params;
-        // we release the lock between get() and put()
-        // that means threads might concurrently generate new blinding
-        // parameters for the same modulus. this is only a slight waste
-        // of cycles and seems preferable in terms of scalability
-        // to locking out all threads while generating new parameters
+    private static BlindingRandomPair getBlindingRandomPair(
+            BigInteger e, BigInteger d, BigInteger n) {
+
+        BlindingParameters bps = null;
         synchronized (blindingCache) {
-            params = blindingCache.get(modulus);
+            bps = blindingCache.get(n);
         }
-        if ((params != null) && params.valid(e)) {
-            return params;
+
+        if (bps == null) {
+            bps = new BlindingParameters(e, d, n);
+            synchronized (blindingCache) {
+                blindingCache.putIfAbsent(n, bps);
+            }
         }
-        int len = modulus.bitLength();
-        SecureRandom random = JCAUtil.getSecureRandom();
-        BigInteger r = new BigInteger(len, random).mod(modulus);
-        BigInteger re = r.modPow(e, modulus);
-        BigInteger rInv = r.modInverse(modulus);
-        params = new BlindingParameters(e, re, rInv);
-        synchronized (blindingCache) {
-            blindingCache.put(modulus, params);
+
+        BlindingRandomPair brp = bps.getBlindingRandomPair(e, d, n);
+        if (brp == null) {
+            // need to reset the blinding parameters
+            bps = new BlindingParameters(e, d, n);
+            synchronized (blindingCache) {
+                blindingCache.replace(n, bps);
+            }
+            brp = bps.getBlindingRandomPair(e, d, n);
         }
-        return params;
+
+        return brp;
     }
 
 }