8203228: Branch-free output conversion for X25519 and X448
Summary: Make some field arithmetic operations for X25519/X448 more resilient against side-channel attacks
Reviewed-by: ascarpino
--- a/src/java.base/share/classes/sun/security/util/math/intpoly/IntegerPolynomial.java Mon Jun 25 23:04:21 2018 +0200
+++ b/src/java.base/share/classes/sun/security/util/math/intpoly/IntegerPolynomial.java Tue Jun 26 11:14:27 2018 -0400
@@ -54,6 +54,8 @@
* setDifference
* setProduct
* setSquare
+ * addModPowerTwo
+ * asByteArray
*
* All other operations may branch in some subclasses.
*
@@ -66,6 +68,7 @@
protected final int numLimbs;
private final BigInteger modulus;
protected final int bitsPerLimb;
+ private final long[] posModLimbs;
// must work when a==r
protected abstract void multByInt(long[] a, long b, long[] r);
@@ -84,6 +87,14 @@
this.numLimbs = numLimbs;
this.modulus = modulus;
this.bitsPerLimb = bitsPerLimb;
+
+ posModLimbs = setPosModLimbs();
+ }
+
+ private long[] setPosModLimbs() {
+ long[] result = new long[numLimbs];
+ setLimbsValuePositive(modulus, result);
+ return result;
}
protected int getNumLimbs() {
@@ -250,6 +261,58 @@
}
}
+ protected abstract void finalCarryReduceLast(long[] limbs);
+
+ // Convert reduced limbs into a number between 0 and MODULUS-1
+ protected void finalReduce(long[] limbs) {
+
+ // This method works by doing several full carry/reduce operations.
+ // Some representations have extra high bits, so the carry/reduce out
+ // of the high position is implementation-specific. The "unsigned"
+ // carry operation always carries some (negative) value out of a
+ // position occupied by a negative value. So after a number of
+ // passes, all negative values are removed.
+
+ // The first pass may leave a negative value in the high position, but
+ // this only happens if something was carried out of the previous
+ // position. So the previous position must have a "small" value. The
+ // next full carry is guaranteed not to carry out of that position.
+
+ for (int pass = 0; pass < 2; pass++) {
+ // unsigned carry out of last position and reduce in to
+ // first position
+ finalCarryReduceLast(limbs);
+
+ // unsigned carry on all positions
+ long carry = 0;
+ for (int i = 0; i < numLimbs - 1; i++) {
+ limbs[i] += carry;
+ carry = limbs[i] >> bitsPerLimb;
+ limbs[i] -= carry << bitsPerLimb;
+ }
+ limbs[numLimbs - 1] += carry;
+ }
+
+ // Limbs are positive and all less than 2^bitsPerLimb, and the
+ // high-order limb may be even smaller due to the representation-
+ // specific carry/reduce out of the high position.
+ // The value may still be greater than the modulus.
+ // Subtract the max limb values only if all limbs end up non-negative
+ // This only works if there is at most one position where posModLimbs
+ // is less than 2^bitsPerLimb - 1 (not counting the high-order limb,
+ // if it has extra bits that are cleared by finalCarryReduceLast).
+ int smallerNonNegative = 1;
+ long[] smaller = new long[numLimbs];
+ for (int i = numLimbs - 1; i >= 0; i--) {
+ smaller[i] = limbs[i] - posModLimbs[i];
+ // expression on right is 1 if smaller[i] is nonnegative,
+ // 0 otherwise
+ smallerNonNegative *= (int) (smaller[i] >> 63) + 1;
+ }
+ conditionalSwap(smallerNonNegative, limbs, smaller);
+
+ }
+
// v must be final reduced. I.e. all limbs in [0, bitsPerLimb)
// and value in [0, modulus)
protected void decode(long[] v, byte[] dst, int offset, int length) {
@@ -262,7 +325,10 @@
int dstIndex = i + offset;
if (bitPos + 8 >= bitsPerLimb) {
dst[dstIndex] = (byte) curLimbValue;
- curLimbValue = v[nextLimbIndex++];
+ curLimbValue = 0;
+ if (nextLimbIndex < v.length) {
+ curLimbValue = v[nextLimbIndex++];
+ }
int bitsAdded = bitsPerLimb - bitPos;
int bitsLeft = 8 - bitsAdded;
@@ -293,33 +359,33 @@
}
}
- private void bigIntToByteArray(BigInteger bi, byte[] result) {
- byte[] biBytes = bi.toByteArray();
- // biBytes is backwards and possibly too big
- // Copy the low-order bytes into result in reverse
- int sourceIndex = biBytes.length - 1;
- for (int i = 0; i < result.length; i++) {
- if (sourceIndex >= 0) {
- result[i] = biBytes[sourceIndex--];
- }
- else {
- result[i] = 0;
- }
- }
- }
-
protected void limbsToByteArray(long[] limbs, byte[] result) {
- bigIntToByteArray(evaluate(limbs), result);
+ long[] reducedLimbs = limbs.clone();
+ finalReduce(reducedLimbs);
+
+ decode(reducedLimbs, result, 0, result.length);
}
protected void addLimbsModPowerTwo(long[] limbs, long[] other,
byte[] result) {
- BigInteger bi1 = evaluate(limbs);
- BigInteger bi2 = evaluate(other);
- BigInteger biResult = bi1.add(bi2);
- bigIntToByteArray(biResult, result);
+ long[] reducedOther = other.clone();
+ long[] reducedLimbs = limbs.clone();
+ finalReduce(reducedOther);
+ finalReduce(reducedLimbs);
+
+ addLimbs(reducedLimbs, reducedOther, reducedLimbs);
+
+ // may carry out a value which can be ignored
+ long carry = 0;
+ for (int i = 0; i < numLimbs; i++) {
+ reducedLimbs[i] += carry;
+ carry = reducedLimbs[i] >> bitsPerLimb;
+ reducedLimbs[i] -= carry << bitsPerLimb;
+ }
+
+ decode(reducedLimbs, result, 0, result.length);
}
private abstract class Element implements IntegerModuloP {
@@ -418,11 +484,11 @@
}
public void addModPowerTwo(IntegerModuloP arg, byte[] result) {
- if (!summand) {
+
+ Element other = (Element) arg;
+ if (!(summand && other.summand)) {
throw new ArithmeticException("Not a valid summand");
}
-
- Element other = (Element) arg;
addLimbsModPowerTwo(limbs, other.limbs, result);
}
--- a/src/java.base/share/classes/sun/security/util/math/intpoly/IntegerPolynomial1305.java Mon Jun 25 23:04:21 2018 +0200
+++ b/src/java.base/share/classes/sun/security/util/math/intpoly/IntegerPolynomial1305.java Tue Jun 26 11:14:27 2018 -0400
@@ -33,13 +33,6 @@
/**
* An IntegerFieldModuloP designed for use with the Poly1305 authenticator.
* The representation uses 5 signed long values.
- *
- * In addition to the branch-free operations specified in the parent class,
- * the following operations are branch-free:
- *
- * addModPowerTwo
- * asByteArray
- *
*/
public class IntegerPolynomial1305 extends IntegerPolynomial {
@@ -51,17 +44,8 @@
private static final BigInteger MODULUS
= TWO.pow(POWER).subtract(BigInteger.valueOf(SUBTRAHEND));
- private final long[] posModLimbs;
-
- private long[] setPosModLimbs() {
- long[] result = new long[NUM_LIMBS];
- setLimbsValuePositive(MODULUS, result);
- return result;
- }
-
public IntegerPolynomial1305() {
super(BITS_PER_LIMB, NUM_LIMBS, MODULUS);
- posModLimbs = setPosModLimbs();
}
protected void mult(long[] a, long[] b, long[] r) {
@@ -181,12 +165,19 @@
}
}
- protected void modReduceIn(long[] limbs, int index, long x) {
+ private void modReduceIn(long[] limbs, int index, long x) {
// this only works when BITS_PER_LIMB * NUM_LIMBS = POWER exactly
long reducedValue = (x * SUBTRAHEND);
limbs[index - NUM_LIMBS] += reducedValue;
}
+ @Override
+ protected void finalCarryReduceLast(long[] limbs) {
+ long carry = limbs[numLimbs - 1] >> bitsPerLimb;
+ limbs[numLimbs - 1] -= carry << bitsPerLimb;
+ modReduceIn(limbs, numLimbs, carry);
+ }
+
protected final void modReduce(long[] limbs, int start, int end) {
for (int i = start; i < end; i++) {
@@ -220,82 +211,5 @@
carry(limbs);
}
- // Convert reduced limbs into a number between 0 and MODULUS-1
- private void finalReduce(long[] limbs) {
-
- addLimbs(limbs, posModLimbs, limbs);
- // now all values are positive, so remaining operations will be unsigned
-
- // unsigned carry out of last position and reduce in to first position
- long carry = limbs[NUM_LIMBS - 1] >> BITS_PER_LIMB;
- limbs[NUM_LIMBS - 1] -= carry << BITS_PER_LIMB;
- modReduceIn(limbs, NUM_LIMBS, carry);
-
- // unsigned carry on all positions
- carry = 0;
- for (int i = 0; i < NUM_LIMBS; i++) {
- limbs[i] += carry;
- carry = limbs[i] >> BITS_PER_LIMB;
- limbs[i] -= carry << BITS_PER_LIMB;
- }
- // reduce final carry value back in
- modReduceIn(limbs, NUM_LIMBS, carry);
- // we only reduce back in a nonzero value if some value was carried out
- // of the previous loop. So at least one remaining value is small.
-
- // One more carry is all that is necessary. Nothing will be carried out
- // at the end
- carry = 0;
- for (int i = 0; i < NUM_LIMBS; i++) {
- limbs[i] += carry;
- carry = limbs[i] >> BITS_PER_LIMB;
- limbs[i] -= carry << BITS_PER_LIMB;
- }
-
- // limbs are positive and all less than 2^BITS_PER_LIMB
- // but the value may be greater than the MODULUS.
- // Subtract the max limb values only if all limbs end up non-negative
- int smallerNonNegative = 1;
- long[] smaller = new long[NUM_LIMBS];
- for (int i = NUM_LIMBS - 1; i >= 0; i--) {
- smaller[i] = limbs[i] - posModLimbs[i];
- // expression on right is 1 if smaller[i] is nonnegative,
- // 0 otherwise
- smallerNonNegative *= (int) (smaller[i] >> 63) + 1;
- }
- conditionalSwap(smallerNonNegative, limbs, smaller);
-
- }
-
- @Override
- protected void limbsToByteArray(long[] limbs, byte[] result) {
-
- long[] reducedLimbs = limbs.clone();
- finalReduce(reducedLimbs);
-
- decode(reducedLimbs, result, 0, result.length);
- }
-
- @Override
- protected void addLimbsModPowerTwo(long[] limbs, long[] other,
- byte[] result) {
-
- long[] reducedOther = other.clone();
- long[] reducedLimbs = limbs.clone();
- finalReduce(reducedLimbs);
-
- addLimbs(reducedLimbs, reducedOther, reducedLimbs);
-
- // may carry out a value which can be ignored
- long carry = 0;
- for (int i = 0; i < NUM_LIMBS; i++) {
- reducedLimbs[i] += carry;
- carry = reducedLimbs[i] >> BITS_PER_LIMB;
- reducedLimbs[i] -= carry << BITS_PER_LIMB;
- }
-
- decode(reducedLimbs, result, 0, result.length);
- }
-
}
--- a/src/java.base/share/classes/sun/security/util/math/intpoly/IntegerPolynomial25519.java Mon Jun 25 23:04:21 2018 +0200
+++ b/src/java.base/share/classes/sun/security/util/math/intpoly/IntegerPolynomial25519.java Tue Jun 26 11:14:27 2018 -0400
@@ -27,6 +27,11 @@
import java.math.BigInteger;
+/**
+ * An IntegerFieldModuloP designed for use with the Curve25519.
+ * The representation uses 10 signed long values.
+ */
+
public class IntegerPolynomial25519 extends IntegerPolynomial {
private static final int POWER = 255;
@@ -47,6 +52,14 @@
}
@Override
+ protected void finalCarryReduceLast(long[] limbs) {
+
+ long reducedValue = limbs[numLimbs - 1] >> RIGHT_BIT_OFFSET;
+ limbs[numLimbs - 1] -= reducedValue << RIGHT_BIT_OFFSET;
+ limbs[0] += reducedValue * SUBTRAHEND;
+ }
+
+ @Override
protected void mult(long[] a, long[] b, long[] r) {
// Use grade-school multiplication into primitives to avoid the
--- a/src/java.base/share/classes/sun/security/util/math/intpoly/IntegerPolynomial448.java Mon Jun 25 23:04:21 2018 +0200
+++ b/src/java.base/share/classes/sun/security/util/math/intpoly/IntegerPolynomial448.java Tue Jun 26 11:14:27 2018 -0400
@@ -27,6 +27,11 @@
import java.math.BigInteger;
+/**
+ * An IntegerFieldModuloP designed for use with the Curve448.
+ * The representation uses 16 signed long values.
+ */
+
public class IntegerPolynomial448 extends IntegerPolynomial {
private static final int POWER = 448;
@@ -40,6 +45,18 @@
super(BITS_PER_LIMB, NUM_LIMBS, MODULUS);
}
+ private void modReduceIn(long[] limbs, int index, long x) {
+ limbs[index - NUM_LIMBS] += x;
+ limbs[index - NUM_LIMBS / 2] += x;
+ }
+
+ @Override
+ protected void finalCarryReduceLast(long[] limbs) {
+ long carry = limbs[numLimbs - 1] >> bitsPerLimb;
+ limbs[numLimbs - 1] -= carry << bitsPerLimb;
+ modReduceIn(limbs, numLimbs, carry);
+ }
+
@Override
protected void mult(long[] a, long[] b, long[] r) {