8130150: Implement BigInteger.montgomeryMultiply intrinsic
authoraph
Tue, 16 Jun 2015 14:58:30 +0100
changeset 31667 15a14e8fcfb0
parent 31253 5efb78c8a07d
child 31668 042a51bddfa5
8130150: Implement BigInteger.montgomeryMultiply intrinsic Summary: Add montgomeryMultiply intrinsics Reviewed-by: kvn
jdk/src/java.base/share/classes/java/math/BigInteger.java
jdk/src/java.base/share/classes/java/math/MutableBigInteger.java
--- a/jdk/src/java.base/share/classes/java/math/BigInteger.java	Thu Jun 11 14:20:01 2015 +0300
+++ b/jdk/src/java.base/share/classes/java/math/BigInteger.java	Tue Jun 16 14:58:30 2015 +0100
@@ -262,6 +262,15 @@
      */
     private static final int MULTIPLY_SQUARE_THRESHOLD = 20;
 
+    /**
+     * The threshold for using an intrinsic version of
+     * implMontgomeryXXX to perform Montgomery multiplication.  If the
+     * number of ints in the number is more than this value we do not
+     * use the intrinsic.
+     */
+    private static final int MONTGOMERY_INTRINSIC_THRESHOLD = 512;
+
+
     // Constructors
 
     /**
@@ -1639,7 +1648,7 @@
      * Multiplies int arrays x and y to the specified lengths and places
      * the result into z. There will be no leading zeros in the resultant array.
      */
-    private int[] multiplyToLen(int[] x, int xlen, int[] y, int ylen, int[] z) {
+    private static int[] multiplyToLen(int[] x, int xlen, int[] y, int ylen, int[] z) {
         int xstart = xlen - 1;
         int ystart = ylen - 1;
 
@@ -2601,6 +2610,75 @@
         return (invertResult ? result.modInverse(m) : result);
     }
 
+    // Montgomery multiplication.  These are wrappers for
+    // implMontgomeryXX routines which are expected to be replaced by
+    // virtual machine intrinsics.  We don't use the intrinsics for
+    // very large operands: MONTGOMERY_INTRINSIC_THRESHOLD should be
+    // larger than any reasonable crypto key.
+    private static int[] montgomeryMultiply(int[] a, int[] b, int[] n, int len, long inv,
+                                            int[] product) {
+        implMontgomeryMultiplyChecks(a, b, n, len, product);
+        if (len > MONTGOMERY_INTRINSIC_THRESHOLD) {
+            // Very long argument: do not use an intrinsic
+            product = multiplyToLen(a, len, b, len, product);
+            return montReduce(product, n, len, (int)inv);
+        } else {
+            return implMontgomeryMultiply(a, b, n, len, inv, materialize(product, len));
+        }
+    }
+    private static int[] montgomerySquare(int[] a, int[] n, int len, long inv,
+                                          int[] product) {
+        implMontgomeryMultiplyChecks(a, a, n, len, product);
+        if (len > MONTGOMERY_INTRINSIC_THRESHOLD) {
+            // Very long argument: do not use an intrinsic
+            product = squareToLen(a, len, product);
+            return montReduce(product, n, len, (int)inv);
+        } else {
+            return implMontgomerySquare(a, n, len, inv, materialize(product, len));
+        }
+    }
+
+    // Range-check everything.
+    private static void implMontgomeryMultiplyChecks
+        (int[] a, int[] b, int[] n, int len, int[] product) throws RuntimeException {
+        if (len % 2 != 0) {
+            throw new IllegalArgumentException("input array length must be even: " + len);
+        }
+
+        if (len < 1) {
+            throw new IllegalArgumentException("invalid input length: " + len);
+        }
+
+        if (len > a.length ||
+            len > b.length ||
+            len > n.length ||
+            (product != null && len > product.length)) {
+            throw new IllegalArgumentException("input array length out of bound: " + len);
+        }
+    }
+
+    // Make sure that the int array z (which is expected to contain
+    // the result of a Montgomery multiplication) is present and
+    // sufficiently large.
+    private static int[] materialize(int[] z, int len) {
+         if (z == null || z.length < len)
+             z = new int[len];
+         return z;
+    }
+
+    // These methods are intended to be be replaced by virtual machine
+    // intrinsics.
+    private static int[] implMontgomeryMultiply(int[] a, int[] b, int[] n, int len,
+                                         long inv, int[] product) {
+        product = multiplyToLen(a, len, b, len, product);
+        return montReduce(product, n, len, (int)inv);
+    }
+    private static int[] implMontgomerySquare(int[] a, int[] n, int len,
+                                       long inv, int[] product) {
+        product = squareToLen(a, len, product);
+        return montReduce(product, n, len, (int)inv);
+    }
+
     static int[] bnExpModThreshTable = {7, 25, 81, 241, 673, 1793,
                                                 Integer.MAX_VALUE}; // Sentinel
 
@@ -2679,6 +2757,17 @@
         int[] mod = z.mag;
         int modLen = mod.length;
 
+        // Make modLen even. It is conventional to use a cryptographic
+        // modulus that is 512, 768, 1024, or 2048 bits, so this code
+        // will not normally be executed. However, it is necessary for
+        // the correct functioning of the HotSpot intrinsics.
+        if ((modLen & 1) != 0) {
+            int[] x = new int[modLen + 1];
+            System.arraycopy(mod, 0, x, 1, modLen);
+            mod = x;
+            modLen++;
+        }
+
         // Select an appropriate window size
         int wbits = 0;
         int ebits = bitLength(exp, exp.length);
@@ -2697,8 +2786,10 @@
         for (int i=0; i < tblmask; i++)
             table[i] = new int[modLen];
 
-        // Compute the modular inverse
-        int inv = -MutableBigInteger.inverseMod32(mod[modLen-1]);
+        // Compute the modular inverse of the least significant 64-bit
+        // digit of the modulus
+        long n0 = (mod[modLen-1] & LONG_MASK) + ((mod[modLen-2] & LONG_MASK) << 32);
+        long inv = -MutableBigInteger.inverseMod64(n0);
 
         // Convert base to Montgomery form
         int[] a = leftShift(base, base.length, modLen << 5);
@@ -2706,6 +2797,8 @@
         MutableBigInteger q = new MutableBigInteger(),
                           a2 = new MutableBigInteger(a),
                           b2 = new MutableBigInteger(mod);
+        b2.normalize(); // MutableBigInteger.divide() assumes that its
+                        // divisor is in normal form.
 
         MutableBigInteger r= a2.divide(b2, q);
         table[0] = r.toIntArray();
@@ -2714,22 +2807,19 @@
         if (table[0].length < modLen) {
            int offset = modLen - table[0].length;
            int[] t2 = new int[modLen];
-           for (int i=0; i < table[0].length; i++)
-               t2[i+offset] = table[0][i];
+           System.arraycopy(table[0], 0, t2, offset, table[0].length);
            table[0] = t2;
         }
 
         // Set b to the square of the base
-        int[] b = squareToLen(table[0], modLen, null);
-        b = montReduce(b, mod, modLen, inv);
+        int[] b = montgomerySquare(table[0], mod, modLen, inv, null);
 
         // Set t to high half of b
         int[] t = Arrays.copyOf(b, modLen);
 
         // Fill in the table with odd powers of the base
         for (int i=1; i < tblmask; i++) {
-            int[] prod = multiplyToLen(t, modLen, table[i-1], modLen, null);
-            table[i] = montReduce(prod, mod, modLen, inv);
+            table[i] = montgomeryMultiply(t, table[i-1], mod, modLen, inv, null);
         }
 
         // Pre load the window that slides over the exponent
@@ -2800,8 +2890,7 @@
                     isone = false;
                 } else {
                     t = b;
-                    a = multiplyToLen(t, modLen, mult, modLen, a);
-                    a = montReduce(a, mod, modLen, inv);
+                    a = montgomeryMultiply(t, mult, mod, modLen, inv, a);
                     t = a; a = b; b = t;
                 }
             }
@@ -2813,8 +2902,7 @@
             // Square the input
             if (!isone) {
                 t = b;
-                a = squareToLen(t, modLen, a);
-                a = montReduce(a, mod, modLen, inv);
+                a = montgomerySquare(t, mod, modLen, inv, a);
                 t = a; a = b; b = t;
             }
         }
@@ -2823,7 +2911,7 @@
         int[] t2 = new int[2*modLen];
         System.arraycopy(b, 0, t2, modLen, modLen);
 
-        b = montReduce(t2, mod, modLen, inv);
+        b = montReduce(t2, mod, modLen, (int)inv);
 
         t2 = Arrays.copyOf(b, modLen);
 
--- a/jdk/src/java.base/share/classes/java/math/MutableBigInteger.java	Thu Jun 11 14:20:01 2015 +0300
+++ b/jdk/src/java.base/share/classes/java/math/MutableBigInteger.java	Tue Jun 16 14:58:30 2015 +0100
@@ -2065,6 +2065,21 @@
     }
 
     /**
+     * Returns the multiplicative inverse of val mod 2^64.  Assumes val is odd.
+     */
+    static long inverseMod64(long val) {
+        // Newton's iteration!
+        long t = val;
+        t *= 2 - val*t;
+        t *= 2 - val*t;
+        t *= 2 - val*t;
+        t *= 2 - val*t;
+        t *= 2 - val*t;
+        assert(t * val == 1);
+        return t;
+    }
+
+    /**
      * Calculate the multiplicative inverse of 2^k mod mod, where mod is odd.
      */
     static MutableBigInteger modInverseBP2(MutableBigInteger mod, int k) {