src/java.base/share/classes/sun/security/util/math/intpoly/IntegerPolynomial.java
changeset 52942 746602d9682f
parent 51569 46ec360a7014
--- a/src/java.base/share/classes/sun/security/util/math/intpoly/IntegerPolynomial.java	Tue Dec 11 15:21:50 2018 +0100
+++ b/src/java.base/share/classes/sun/security/util/math/intpoly/IntegerPolynomial.java	Tue Dec 11 09:36:49 2018 -0500
@@ -69,14 +69,25 @@
     private final BigInteger modulus;
     protected final int bitsPerLimb;
     private final long[] posModLimbs;
+    private final int maxAdds;
+
+    /**
+     * Reduce an IntegerPolynomial representation (a) and store the result
+     * in a. Requires that a.length == numLimbs.
+     */
+    protected abstract void reduce(long[] a);
 
     /**
      * Multiply an IntegerPolynomial representation (a) with a long (b) and
-     * store the result in an IntegerPolynomial representation (r). Requires
-     * that a.length == r.length == numLimbs. It is allowed for a and r to be
-     * the same array.
+     * store the result in an IntegerPolynomial representation in a. Requires
+     * that a.length == numLimbs.
      */
-    protected abstract void multByInt(long[] a, long b, long[] r);
+    protected void multByInt(long[] a, long b) {
+        for (int i = 0; i < a.length; i++) {
+            a[i] *= b;
+        }
+        reduce(a);
+    }
 
     /**
      * Multiply two IntegerPolynomial representations (a and b) and store the
@@ -96,12 +107,14 @@
 
     IntegerPolynomial(int bitsPerLimb,
                       int numLimbs,
+                      int maxAdds,
                       BigInteger modulus) {
 
 
         this.numLimbs = numLimbs;
         this.modulus = modulus;
         this.bitsPerLimb = bitsPerLimb;
+        this.maxAdds = maxAdds;
 
         posModLimbs = setPosModLimbs();
     }
@@ -116,6 +129,10 @@
         return numLimbs;
     }
 
+    public int getMaxAdds() {
+        return maxAdds;
+    }
+
     @Override
     public BigInteger getSize() {
         return modulus;
@@ -155,12 +172,22 @@
      */
     protected void encode(ByteBuffer buf, int length, byte highByte,
                           long[] result) {
+
         int numHighBits = 32 - Integer.numberOfLeadingZeros(highByte);
         int numBits = 8 * length + numHighBits;
-        int maxBits = bitsPerLimb * result.length;
-        if (numBits > maxBits) {
-            throw new ArithmeticException("Value is too large.");
+        int requiredLimbs = (numBits + bitsPerLimb - 1) / bitsPerLimb;
+        if (requiredLimbs > numLimbs) {
+            long[] temp = new long[requiredLimbs];
+            encodeSmall(buf, length, highByte, temp);
+            // encode does a full carry/reduce
+            System.arraycopy(temp, 0, result, 0, result.length);
+        } else {
+            encodeSmall(buf, length, highByte, result);
         }
+    }
+
+    protected void encodeSmall(ByteBuffer buf, int length, byte highByte,
+                               long[] result) {
 
         int limbIndex = 0;
         long curLimbValue = 0;
@@ -195,10 +222,10 @@
             }
         }
 
-        if (limbIndex < numLimbs) {
+        if (limbIndex < result.length) {
             result[limbIndex++] = curLimbValue;
         }
-        Arrays.fill(result, limbIndex, numLimbs, 0);
+        Arrays.fill(result, limbIndex, result.length, 0);
 
         postEncodeCarry(result);
     }
@@ -211,8 +238,10 @@
         encode(buf, length, highByte, result);
     }
 
+    // Encode does not produce compressed limbs. A simplified carry/reduce
+    // operation can be used to compress the limbs.
     protected void postEncodeCarry(long[] v) {
-        carry(v);
+        reduce(v);
     }
 
     public ImmutableElement getElement(byte[] v, int offset, int length,
@@ -222,7 +251,7 @@
 
         encode(v, offset, length, highByte, result);
 
-        return new ImmutableElement(result, true);
+        return new ImmutableElement(result, 0);
     }
 
     protected BigInteger evaluate(long[] limbs) {
@@ -387,6 +416,20 @@
     }
 
     /**
+     * Branch-free conditional assignment of b to a. Requires that set is 0 or
+     * 1, and that a.length == b.length. If set==0, then the values of a and b
+     * will be unchanged. If set==1, then the values of b will be assigned to a.
+     * The behavior is undefined if swap has any value other than 0 or 1.
+     */
+    protected static void conditionalAssign(int set, long[] a, long[] b) {
+        int maskValue = 0 - set;
+        for (int i = 0; i < a.length; i++) {
+            long dummyLimbs = maskValue & (a[i] ^ b[i]);
+            a[i] = dummyLimbs ^ a[i];
+        }
+    }
+
+    /**
      * Branch-free conditional swap of a and b. Requires that swap is 0 or 1,
      * and that a.length == b.length. If swap==0, then the values of a and b
      * will be unchanged. If swap==1, then the values of a and b will be
@@ -442,7 +485,7 @@
     private abstract class Element implements IntegerModuloP {
 
         protected long[] limbs;
-        protected boolean summand = false;
+        protected int numAdds;
 
         public Element(BigInteger v) {
             limbs = new long[numLimbs];
@@ -450,19 +493,19 @@
         }
 
         public Element(boolean v) {
-            limbs = new long[numLimbs];
-            limbs[0] = v ? 1l : 0l;
-            summand = true;
+            this.limbs = new long[numLimbs];
+            this.limbs[0] = v ? 1l : 0l;
+            this.numAdds = 0;
         }
 
-        private Element(long[] limbs, boolean summand) {
+        private Element(long[] limbs, int numAdds) {
             this.limbs = limbs;
-            this.summand = summand;
+            this.numAdds = numAdds;
         }
 
         private void setValue(BigInteger v) {
             setLimbsValue(v, limbs);
-            summand = true;
+            this.numAdds = 0;
         }
 
         @Override
@@ -477,14 +520,18 @@
 
         @Override
         public MutableElement mutable() {
-            return new MutableElement(limbs.clone(), summand);
+            return new MutableElement(limbs.clone(), numAdds);
+        }
+
+        protected boolean isSummand() {
+            return numAdds < maxAdds;
         }
 
         @Override
         public ImmutableElement add(IntegerModuloP genB) {
 
             Element b = (Element) genB;
-            if (!(summand && b.summand)) {
+            if (!(isSummand() && b.isSummand())) {
                 throw new ArithmeticException("Not a valid summand");
             }
 
@@ -493,7 +540,8 @@
                 newLimbs[i] = limbs[i] + b.limbs[i];
             }
 
-            return new ImmutableElement(newLimbs, false);
+            int newNumAdds = Math.max(numAdds, b.numAdds) + 1;
+            return new ImmutableElement(newLimbs, newNumAdds);
         }
 
         @Override
@@ -504,7 +552,7 @@
                 newLimbs[i] = -limbs[i];
             }
 
-            ImmutableElement result = new ImmutableElement(newLimbs, summand);
+            ImmutableElement result = new ImmutableElement(newLimbs, numAdds);
             return result;
         }
 
@@ -524,43 +572,52 @@
 
             long[] newLimbs = new long[limbs.length];
             mult(limbs, b.limbs, newLimbs);
-            return new ImmutableElement(newLimbs, true);
+            return new ImmutableElement(newLimbs, 0);
         }
 
         @Override
         public ImmutableElement square() {
             long[] newLimbs = new long[limbs.length];
             IntegerPolynomial.this.square(limbs, newLimbs);
-            return new ImmutableElement(newLimbs, true);
+            return new ImmutableElement(newLimbs, 0);
         }
 
         public void addModPowerTwo(IntegerModuloP arg, byte[] result) {
 
             Element other = (Element) arg;
-            if (!(summand && other.summand)) {
+            if (!(isSummand() && other.isSummand())) {
                 throw new ArithmeticException("Not a valid summand");
             }
             addLimbsModPowerTwo(limbs, other.limbs, result);
         }
 
         public void asByteArray(byte[] result) {
-            if (!summand) {
+            if (!isSummand()) {
                 throw new ArithmeticException("Not a valid summand");
             }
             limbsToByteArray(limbs, result);
         }
     }
 
-    private class MutableElement extends Element
+    protected class MutableElement extends Element
         implements MutableIntegerModuloP {
 
-        protected MutableElement(long[] limbs, boolean summand) {
-            super(limbs, summand);
+        protected MutableElement(long[] limbs, int numAdds) {
+            super(limbs, numAdds);
         }
 
         @Override
         public ImmutableElement fixed() {
-            return new ImmutableElement(limbs.clone(), summand);
+            return new ImmutableElement(limbs.clone(), numAdds);
+        }
+
+        @Override
+        public void conditionalSet(IntegerModuloP b, int set) {
+
+            Element other = (Element) b;
+
+            conditionalAssign(set, limbs, other.limbs);
+            numAdds = other.numAdds;
         }
 
         @Override
@@ -569,9 +626,9 @@
             MutableElement other = (MutableElement) b;
 
             conditionalSwap(swap, limbs, other.limbs);
-            boolean summandTemp = summand;
-            summand = other.summand;
-            other.summand = summandTemp;
+            int numAddsTemp = numAdds;
+            numAdds = other.numAdds;
+            other.numAdds = numAddsTemp;
         }
 
 
@@ -580,7 +637,7 @@
             Element other = (Element) v;
 
             System.arraycopy(other.limbs, 0, limbs, 0, other.limbs.length);
-            summand = other.summand;
+            numAdds = other.numAdds;
             return this;
         }
 
@@ -589,7 +646,7 @@
                                        int length, byte highByte) {
 
             encode(arr, offset, length, highByte, limbs);
-            summand = true;
+            this.numAdds = 0;
 
             return this;
         }
@@ -599,7 +656,7 @@
                                        byte highByte) {
 
             encode(buf, length, highByte, limbs);
-            summand = true;
+            numAdds = 0;
 
             return this;
         }
@@ -608,15 +665,15 @@
         public MutableElement setProduct(IntegerModuloP genB) {
             Element b = (Element) genB;
             mult(limbs, b.limbs, limbs);
-            summand = true;
+            numAdds = 0;
             return this;
         }
 
         @Override
         public MutableElement setProduct(SmallValue v) {
             int value = ((Limb) v).value;
-            multByInt(limbs, value, limbs);
-            summand = true;
+            multByInt(limbs, value);
+            numAdds = 0;
             return this;
         }
 
@@ -624,7 +681,7 @@
         public MutableElement setSum(IntegerModuloP genB) {
 
             Element b = (Element) genB;
-            if (!(summand && b.summand)) {
+            if (!(isSummand() && b.isSummand())) {
                 throw new ArithmeticException("Not a valid summand");
             }
 
@@ -632,7 +689,7 @@
                 limbs[i] = limbs[i] + b.limbs[i];
             }
 
-            summand = false;
+            numAdds = Math.max(numAdds, b.numAdds) + 1;
             return this;
         }
 
@@ -640,7 +697,7 @@
         public MutableElement setDifference(IntegerModuloP genB) {
 
             Element b = (Element) genB;
-            if (!(summand && b.summand)) {
+            if (!(isSummand() && b.isSummand())) {
                 throw new ArithmeticException("Not a valid summand");
             }
 
@@ -648,16 +705,33 @@
                 limbs[i] = limbs[i] - b.limbs[i];
             }
 
+            numAdds = Math.max(numAdds, b.numAdds) + 1;
             return this;
         }
 
         @Override
         public MutableElement setSquare() {
             IntegerPolynomial.this.square(limbs, limbs);
-            summand = true;
+            numAdds = 0;
             return this;
         }
 
+        @Override
+        public MutableElement setAdditiveInverse() {
+
+            for (int i = 0; i < limbs.length; i++) {
+                limbs[i] = -limbs[i];
+            }
+            return this;
+        }
+
+        @Override
+        public MutableElement setReduced() {
+
+            reduce(limbs);
+            numAdds = 0;
+            return this;
+        }
     }
 
     class ImmutableElement extends Element implements ImmutableIntegerModuloP {
@@ -670,8 +744,8 @@
             super(v);
         }
 
-        protected ImmutableElement(long[] limbs, boolean summand) {
-            super(limbs, summand);
+        protected ImmutableElement(long[] limbs, int numAdds) {
+            super(limbs, numAdds);
         }
 
         @Override