8203228: Branch-free output conversion for X25519 and X448
authorapetcher
Tue, 26 Jun 2018 11:14:27 -0400
changeset 50792 59306e5a6cc7
parent 50791 b1e90a8a876c
child 50793 ca4eea543d23
8203228: Branch-free output conversion for X25519 and X448 Summary: Make some field arithmetic operations for X25519/X448 more resilient against side-channel attacks Reviewed-by: ascarpino
src/java.base/share/classes/sun/security/util/math/intpoly/IntegerPolynomial.java
src/java.base/share/classes/sun/security/util/math/intpoly/IntegerPolynomial1305.java
src/java.base/share/classes/sun/security/util/math/intpoly/IntegerPolynomial25519.java
src/java.base/share/classes/sun/security/util/math/intpoly/IntegerPolynomial448.java
--- a/src/java.base/share/classes/sun/security/util/math/intpoly/IntegerPolynomial.java	Mon Jun 25 23:04:21 2018 +0200
+++ b/src/java.base/share/classes/sun/security/util/math/intpoly/IntegerPolynomial.java	Tue Jun 26 11:14:27 2018 -0400
@@ -54,6 +54,8 @@
  * setDifference
  * setProduct
  * setSquare
+ * addModPowerTwo
+ * asByteArray
  *
  * All other operations may branch in some subclasses.
  *
@@ -66,6 +68,7 @@
     protected final int numLimbs;
     private final BigInteger modulus;
     protected final int bitsPerLimb;
+    private final long[] posModLimbs;
 
     // must work when a==r
     protected abstract void multByInt(long[] a, long b, long[] r);
@@ -84,6 +87,14 @@
         this.numLimbs = numLimbs;
         this.modulus = modulus;
         this.bitsPerLimb = bitsPerLimb;
+
+        posModLimbs = setPosModLimbs();
+    }
+
+    private long[] setPosModLimbs() {
+        long[] result = new long[numLimbs];
+        setLimbsValuePositive(modulus, result);
+        return result;
     }
 
     protected int getNumLimbs() {
@@ -250,6 +261,58 @@
         }
     }
 
+    protected abstract void finalCarryReduceLast(long[] limbs);
+
+    // Convert reduced limbs into a number between 0 and MODULUS-1
+    protected void finalReduce(long[] limbs) {
+
+        // This method works by doing several full carry/reduce operations.
+        // Some representations have extra high bits, so the carry/reduce out
+        // of the high position is implementation-specific. The "unsigned"
+        // carry operation always carries some (negative) value out of a
+        // position occupied by a negative value. So after a number of
+        // passes, all negative values are removed.
+
+        // The first pass may leave a negative value in the high position, but
+        // this only happens if something was carried out of the previous
+        // position. So the previous position must have a "small" value. The
+        // next full carry is guaranteed not to carry out of that position.
+
+        for (int pass = 0; pass < 2; pass++) {
+            // unsigned carry out of last position and reduce in to
+            // first position
+            finalCarryReduceLast(limbs);
+
+            // unsigned carry on all positions
+            long carry = 0;
+            for (int i = 0; i < numLimbs - 1; i++) {
+                limbs[i] += carry;
+                carry = limbs[i] >> bitsPerLimb;
+                limbs[i] -= carry << bitsPerLimb;
+            }
+            limbs[numLimbs - 1] += carry;
+        }
+
+        // Limbs are positive and all less than 2^bitsPerLimb, and the
+        // high-order limb may be even smaller due to the representation-
+        // specific carry/reduce out of the high position.
+        // The value may still be greater than the modulus.
+        // Subtract the max limb values only if all limbs end up non-negative
+        // This only works if there is at most one position where posModLimbs
+        // is less than 2^bitsPerLimb - 1 (not counting the high-order limb,
+        // if it has extra bits that are cleared by finalCarryReduceLast).
+        int smallerNonNegative = 1;
+        long[] smaller = new long[numLimbs];
+        for (int i = numLimbs - 1; i >= 0; i--) {
+            smaller[i] = limbs[i] - posModLimbs[i];
+            // expression on right is 1 if smaller[i] is nonnegative,
+            // 0 otherwise
+            smallerNonNegative *= (int) (smaller[i] >> 63) + 1;
+        }
+        conditionalSwap(smallerNonNegative, limbs, smaller);
+
+    }
+
     // v must be final reduced. I.e. all limbs in [0, bitsPerLimb)
     // and value in [0, modulus)
     protected void decode(long[] v, byte[] dst, int offset, int length) {
@@ -262,7 +325,10 @@
             int dstIndex = i + offset;
             if (bitPos + 8 >= bitsPerLimb) {
                 dst[dstIndex] = (byte) curLimbValue;
-                curLimbValue = v[nextLimbIndex++];
+                curLimbValue = 0;
+                if (nextLimbIndex < v.length) {
+                    curLimbValue = v[nextLimbIndex++];
+                }
                 int bitsAdded = bitsPerLimb - bitPos;
                 int bitsLeft = 8 - bitsAdded;
 
@@ -293,33 +359,33 @@
         }
     }
 
-    private void bigIntToByteArray(BigInteger bi, byte[] result) {
-        byte[] biBytes = bi.toByteArray();
-        // biBytes is backwards and possibly too big
-        // Copy the low-order bytes into result in reverse
-        int sourceIndex = biBytes.length - 1;
-        for (int i = 0; i < result.length; i++) {
-            if (sourceIndex >= 0) {
-                result[i] = biBytes[sourceIndex--];
-            }
-            else {
-                result[i] = 0;
-            }
-        }
-    }
-
     protected void limbsToByteArray(long[] limbs, byte[] result) {
 
-        bigIntToByteArray(evaluate(limbs), result);
+        long[] reducedLimbs = limbs.clone();
+        finalReduce(reducedLimbs);
+
+        decode(reducedLimbs, result, 0, result.length);
     }
 
     protected void addLimbsModPowerTwo(long[] limbs, long[] other,
                                        byte[] result) {
 
-        BigInteger bi1 = evaluate(limbs);
-        BigInteger bi2 = evaluate(other);
-        BigInteger biResult = bi1.add(bi2);
-        bigIntToByteArray(biResult, result);
+        long[] reducedOther = other.clone();
+        long[] reducedLimbs = limbs.clone();
+        finalReduce(reducedOther);
+        finalReduce(reducedLimbs);
+
+        addLimbs(reducedLimbs, reducedOther, reducedLimbs);
+
+        // may carry out a value which can be ignored
+        long carry = 0;
+        for (int i = 0; i < numLimbs; i++) {
+            reducedLimbs[i] += carry;
+            carry  = reducedLimbs[i] >> bitsPerLimb;
+            reducedLimbs[i] -= carry << bitsPerLimb;
+        }
+
+        decode(reducedLimbs, result, 0, result.length);
     }
 
     private abstract class Element implements IntegerModuloP {
@@ -418,11 +484,11 @@
         }
 
         public void addModPowerTwo(IntegerModuloP arg, byte[] result) {
-            if (!summand) {
+
+            Element other = (Element) arg;
+            if (!(summand && other.summand)) {
                 throw new ArithmeticException("Not a valid summand");
             }
-
-            Element other = (Element) arg;
             addLimbsModPowerTwo(limbs, other.limbs, result);
         }
 
--- a/src/java.base/share/classes/sun/security/util/math/intpoly/IntegerPolynomial1305.java	Mon Jun 25 23:04:21 2018 +0200
+++ b/src/java.base/share/classes/sun/security/util/math/intpoly/IntegerPolynomial1305.java	Tue Jun 26 11:14:27 2018 -0400
@@ -33,13 +33,6 @@
 /**
  * An IntegerFieldModuloP designed for use with the Poly1305 authenticator.
  * The representation uses 5 signed long values.
- *
- * In addition to the branch-free operations specified in the parent class,
- * the following operations are branch-free:
- *
- * addModPowerTwo
- * asByteArray
- *
  */
 
 public class IntegerPolynomial1305 extends IntegerPolynomial {
@@ -51,17 +44,8 @@
     private static final BigInteger MODULUS
         = TWO.pow(POWER).subtract(BigInteger.valueOf(SUBTRAHEND));
 
-    private final long[] posModLimbs;
-
-    private long[] setPosModLimbs() {
-        long[] result = new long[NUM_LIMBS];
-        setLimbsValuePositive(MODULUS, result);
-        return result;
-    }
-
     public IntegerPolynomial1305() {
         super(BITS_PER_LIMB, NUM_LIMBS, MODULUS);
-        posModLimbs = setPosModLimbs();
     }
 
     protected void mult(long[] a, long[] b, long[] r) {
@@ -181,12 +165,19 @@
         }
     }
 
-    protected void modReduceIn(long[] limbs, int index, long x) {
+    private void modReduceIn(long[] limbs, int index, long x) {
         // this only works when BITS_PER_LIMB * NUM_LIMBS = POWER exactly
         long reducedValue = (x * SUBTRAHEND);
         limbs[index - NUM_LIMBS] += reducedValue;
     }
 
+    @Override
+    protected void finalCarryReduceLast(long[] limbs) {
+        long carry = limbs[numLimbs - 1] >> bitsPerLimb;
+        limbs[numLimbs - 1] -= carry << bitsPerLimb;
+        modReduceIn(limbs, numLimbs, carry);
+    }
+
     protected final void modReduce(long[] limbs, int start, int end) {
 
         for (int i = start; i < end; i++) {
@@ -220,82 +211,5 @@
         carry(limbs);
     }
 
-    // Convert reduced limbs into a number between 0 and MODULUS-1
-    private void finalReduce(long[] limbs) {
-
-        addLimbs(limbs, posModLimbs, limbs);
-        // now all values are positive, so remaining operations will be unsigned
-
-        // unsigned carry out of last position and reduce in to first position
-        long carry = limbs[NUM_LIMBS - 1] >> BITS_PER_LIMB;
-        limbs[NUM_LIMBS - 1] -= carry << BITS_PER_LIMB;
-        modReduceIn(limbs, NUM_LIMBS, carry);
-
-        // unsigned carry on all positions
-        carry = 0;
-        for (int i = 0; i < NUM_LIMBS; i++) {
-            limbs[i] += carry;
-            carry = limbs[i] >> BITS_PER_LIMB;
-            limbs[i] -= carry << BITS_PER_LIMB;
-        }
-        // reduce final carry value back in
-        modReduceIn(limbs, NUM_LIMBS, carry);
-        // we only reduce back in a nonzero value if some value was carried out
-        // of the previous loop. So at least one remaining value is small.
-
-        // One more carry is all that is necessary. Nothing will be carried out
-        // at the end
-        carry = 0;
-        for (int i = 0; i < NUM_LIMBS; i++) {
-            limbs[i] += carry;
-            carry = limbs[i] >> BITS_PER_LIMB;
-            limbs[i] -= carry << BITS_PER_LIMB;
-        }
-
-        // limbs are positive and all less than 2^BITS_PER_LIMB
-        // but the value may be greater than the MODULUS.
-        // Subtract the max limb values only if all limbs end up non-negative
-        int smallerNonNegative = 1;
-        long[] smaller = new long[NUM_LIMBS];
-        for (int i = NUM_LIMBS - 1; i >= 0; i--) {
-            smaller[i] = limbs[i] - posModLimbs[i];
-            // expression on right is 1 if smaller[i] is nonnegative,
-            // 0 otherwise
-            smallerNonNegative *= (int) (smaller[i] >> 63) + 1;
-        }
-        conditionalSwap(smallerNonNegative, limbs, smaller);
-
-    }
-
-    @Override
-    protected void limbsToByteArray(long[] limbs, byte[] result) {
-
-        long[] reducedLimbs = limbs.clone();
-        finalReduce(reducedLimbs);
-
-        decode(reducedLimbs, result, 0, result.length);
-    }
-
-    @Override
-    protected void addLimbsModPowerTwo(long[] limbs, long[] other,
-                                       byte[] result) {
-
-        long[] reducedOther = other.clone();
-        long[] reducedLimbs = limbs.clone();
-        finalReduce(reducedLimbs);
-
-        addLimbs(reducedLimbs, reducedOther, reducedLimbs);
-
-        // may carry out a value which can be ignored
-        long carry = 0;
-        for (int i = 0; i < NUM_LIMBS; i++) {
-            reducedLimbs[i] += carry;
-            carry  = reducedLimbs[i] >> BITS_PER_LIMB;
-            reducedLimbs[i] -= carry << BITS_PER_LIMB;
-        }
-
-        decode(reducedLimbs, result, 0, result.length);
-    }
-
 }
 
--- a/src/java.base/share/classes/sun/security/util/math/intpoly/IntegerPolynomial25519.java	Mon Jun 25 23:04:21 2018 +0200
+++ b/src/java.base/share/classes/sun/security/util/math/intpoly/IntegerPolynomial25519.java	Tue Jun 26 11:14:27 2018 -0400
@@ -27,6 +27,11 @@
 
 import java.math.BigInteger;
 
+/**
+ * An IntegerFieldModuloP designed for use with the Curve25519.
+ * The representation uses 10 signed long values.
+ */
+
 public class IntegerPolynomial25519 extends IntegerPolynomial {
 
     private static final int POWER = 255;
@@ -47,6 +52,14 @@
     }
 
     @Override
+    protected void finalCarryReduceLast(long[] limbs) {
+
+        long reducedValue = limbs[numLimbs - 1] >> RIGHT_BIT_OFFSET;
+        limbs[numLimbs - 1] -= reducedValue << RIGHT_BIT_OFFSET;
+        limbs[0] += reducedValue * SUBTRAHEND;
+    }
+
+    @Override
     protected void mult(long[] a, long[] b, long[] r) {
 
         // Use grade-school multiplication into primitives to avoid the
--- a/src/java.base/share/classes/sun/security/util/math/intpoly/IntegerPolynomial448.java	Mon Jun 25 23:04:21 2018 +0200
+++ b/src/java.base/share/classes/sun/security/util/math/intpoly/IntegerPolynomial448.java	Tue Jun 26 11:14:27 2018 -0400
@@ -27,6 +27,11 @@
 
 import java.math.BigInteger;
 
+/**
+ * An IntegerFieldModuloP designed for use with the Curve448.
+ * The representation uses 16 signed long values.
+ */
+
 public class IntegerPolynomial448 extends IntegerPolynomial {
 
     private static final int POWER = 448;
@@ -40,6 +45,18 @@
         super(BITS_PER_LIMB, NUM_LIMBS, MODULUS);
     }
 
+    private void modReduceIn(long[] limbs, int index, long x) {
+        limbs[index - NUM_LIMBS] += x;
+        limbs[index - NUM_LIMBS / 2] += x;
+    }
+
+    @Override
+    protected void finalCarryReduceLast(long[] limbs) {
+        long carry = limbs[numLimbs - 1] >> bitsPerLimb;
+        limbs[numLimbs - 1] -= carry << bitsPerLimb;
+        modReduceIn(limbs, numLimbs, carry);
+    }
+
     @Override
     protected void mult(long[] a, long[] b, long[] r) {