8157495: SHA-3 Hash algorithm performance improvements (~12x speedup)
authorvaleriep
Fri, 10 Jun 2016 22:39:40 +0000
changeset 38877 2c4dd22f0629
parent 38876 ae60b602b87d
child 38878 7213c4599103
8157495: SHA-3 Hash algorithm performance improvements (~12x speedup) Summary: Various improvements on performance and memory footprint Reviewed-by: ascarpino
jdk/src/java.base/share/classes/sun/security/provider/SHA3.java
--- a/jdk/src/java.base/share/classes/sun/security/provider/SHA3.java	Fri Jun 10 10:54:34 2016 -0700
+++ b/jdk/src/java.base/share/classes/sun/security/provider/SHA3.java	Fri Jun 10 22:39:40 2016 +0000
@@ -61,14 +61,14 @@
         0x8000000000008080L, 0x80000001L, 0x8000000080008008L,
     };
 
-    private byte[] state;
+    private byte[] state = new byte[WIDTH];
+    private final long[] lanes = new long[DM*DM];
 
     /**
      * Creates a new SHA-3 object.
      */
     SHA3(String name, int digestLength) {
         super(name, digestLength, (WIDTH - (2 * digestLength)));
-        implReset();
     }
 
     /**
@@ -79,7 +79,7 @@
         for (int i = 0; i < buffer.length; i++) {
             state[i] ^= b[ofs++];
         }
-        state = keccak(state);
+        keccak();
     }
 
     /**
@@ -95,7 +95,7 @@
         for (int i = 0; i < buffer.length; i++) {
             state[i] ^= buffer[i];
         }
-        state = keccak(state);
+        keccak();
         System.arraycopy(state, 0, out, ofs, engineGetDigestLength());
     }
 
@@ -103,15 +103,8 @@
      * Resets the internal state to start a new hash.
      */
     void implReset() {
-        state = new byte[WIDTH];
-    }
-
-    /**
-     * Utility function for circular shift the specified long
-     * value to the left for n bits.
-     */
-    private static long circularShiftLeft(long lane, int n) {
-        return ((lane << n) | (lane >>> (64 - n)));
+        Arrays.fill(state, (byte)0);
+        Arrays.fill(lanes, 0L);
     }
 
     /**
@@ -132,115 +125,119 @@
     }
 
     /**
-     * Utility function for transforming the specified state from
-     * the byte array format into array of lanes as defined in
-     * section 3.1.2.
+     * Utility function for transforming the specified byte array 's'
+     * into array of lanes 'm' as defined in section 3.1.2.
      */
-    private static long[][] bytes2Lanes(byte[] s) {
-        if (s.length != WIDTH) {
-            throw new ProviderException("Error: incorrect input size " +
-                s.length);
+    private static void bytes2Lanes(byte[] s, long[] m) {
+        int sOfs = 0;
+        // Conversion traverses along x-axis before y-axis
+        for (int y = 0; y < DM; y++, sOfs += 40) {
+            b2lLittle(s, sOfs, m, DM*y, 40);
         }
-        // The conversion traverses along x-axis before y-axis. So, y is the
-        // first dimension and x is the second dimension.
-        long[][] s2 = new long[DM][DM];
-        int sOfs = 0;
-        for (int y = 0; y < DM; y++, sOfs += 40) {
-            b2lLittle(s, sOfs, s2[y], 0, 40);
-        }
-        return s2;
     }
 
     /**
-     * Utility function for transforming the specified arrays of
-     * lanes into a byte array as defined in section 3.1.3.
+     * Utility function for transforming the specified array of
+     * lanes 'm' into a byte array 's' as defined in section 3.1.3.
      */
-    private static byte[] lanes2Bytes(long[][] m) {
-        byte[] s = new byte[WIDTH];
+    private static void lanes2Bytes(long[] m, byte[] s) {
         int sOfs = 0;
-        // The conversion traverses along x-axis before y-axis. So, y is the
-        // first dimension and x is the second dimension.
+        // Conversion traverses along x-axis before y-axis
         for (int y = 0; y < DM; y++, sOfs += 40) {
-            l2bLittle(m[y], 0, s, sOfs, 40);
+            l2bLittle(m, DM*y, s, sOfs, 40);
         }
-        return s;
     }
 
     /**
      * Step mapping Theta as defined in section 3.2.1 .
      */
-    private static long[][] smTheta(long[][] a) {
-        long[] c = new long[DM];
-        for (int i = 0; i < DM; i++) {
-            c[i] = a[0][i]^a[1][i]^a[2][i]^a[3][i]^a[4][i];
-        }
-        long[] d = new long[DM];
-        for (int i = 0; i < DM; i++) {
-            long c1 = c[(i + 4) % DM];
-            // left shift and wrap the leftmost bit into the rightmost bit
-            long c2 = circularShiftLeft(c[(i + 1) % DM], 1);
-            d[i] = c1^c2;
-        }
-        for (int y = 0; y < DM; y++) {
-            for (int x = 0; x < DM; x++) {
-                a[y][x] ^= d[x];
-            }
+    private static long[] smTheta(long[] a) {
+        long c0 = a[0]^a[5]^a[10]^a[15]^a[20];
+        long c1 = a[1]^a[6]^a[11]^a[16]^a[21];
+        long c2 = a[2]^a[7]^a[12]^a[17]^a[22];
+        long c3 = a[3]^a[8]^a[13]^a[18]^a[23];
+        long c4 = a[4]^a[9]^a[14]^a[19]^a[24];
+        long d0 = c4 ^ Long.rotateLeft(c1, 1);
+        long d1 = c0 ^ Long.rotateLeft(c2, 1);
+        long d2 = c1 ^ Long.rotateLeft(c3, 1);
+        long d3 = c2 ^ Long.rotateLeft(c4, 1);
+        long d4 = c3 ^ Long.rotateLeft(c0, 1);
+        for (int y = 0; y < a.length; y += DM) {
+            a[y] ^= d0;
+            a[y+1] ^= d1;
+            a[y+2] ^= d2;
+            a[y+3] ^= d3;
+            a[y+4] ^= d4;
         }
         return a;
     }
 
     /**
-     * Step mapping Rho as defined in section 3.2.2.
+     * Merged Step mapping Rho (section 3.2.2) and Pi (section 3.2.3).
+     * for performance. Optimization is achieved by precalculating
+     * shift constants for the following loop
+     *   int xNext, yNext;
+     *   for (int t = 0, x = 1, y = 0; t <= 23; t++, x = xNext, y = yNext) {
+     *        int numberOfShift = ((t + 1)*(t + 2)/2) % 64;
+     *        a[y][x] = Long.rotateLeft(a[y][x], numberOfShift);
+     *        xNext = y;
+     *        yNext = (2 * x + 3 * y) % DM;
+     *   }
+     * and with inplace permutation.
      */
-    private static long[][] smRho(long[][] a) {
-        long[][] a2 = new long[DM][DM];
-        a2[0][0] = a[0][0];
-        int xNext, yNext;
-        for (int t = 0, x = 1, y = 0; t <= 23; t++, x = xNext, y = yNext) {
-            int numberOfShift = ((t + 1)*(t + 2)/2) % 64;
-            a2[y][x] = circularShiftLeft(a[y][x], numberOfShift);
-            xNext = y;
-            yNext = (2 * x + 3 * y) % DM;
-        }
-        return a2;
-    }
-
-    /**
-     * Step mapping Pi as defined in section 3.2.3.
-     */
-    private static long[][] smPi(long[][] a) {
-        long[][] a2 = new long[DM][DM];
-        for (int y = 0; y < DM; y++) {
-            for (int x = 0; x < DM; x++) {
-                a2[y][x] = a[x][(x + 3 * y) % DM];
-            }
-        }
-        return a2;
+    private static long[] smPiRho(long[] a) {
+        long tmp = Long.rotateLeft(a[10], 3);
+        a[10] = Long.rotateLeft(a[1], 1);
+        a[1] = Long.rotateLeft(a[6], 44);
+        a[6] = Long.rotateLeft(a[9], 20);
+        a[9] = Long.rotateLeft(a[22], 61);
+        a[22] = Long.rotateLeft(a[14], 39);
+        a[14] = Long.rotateLeft(a[20], 18);
+        a[20] = Long.rotateLeft(a[2], 62);
+        a[2] = Long.rotateLeft(a[12], 43);
+        a[12] = Long.rotateLeft(a[13], 25);
+        a[13] = Long.rotateLeft(a[19], 8);
+        a[19] = Long.rotateLeft(a[23], 56);
+        a[23] = Long.rotateLeft(a[15], 41);
+        a[15] = Long.rotateLeft(a[4], 27);
+        a[4] = Long.rotateLeft(a[24], 14);
+        a[24] = Long.rotateLeft(a[21], 2);
+        a[21] = Long.rotateLeft(a[8], 55);
+        a[8] = Long.rotateLeft(a[16], 45);
+        a[16] = Long.rotateLeft(a[5], 36);
+        a[5] = Long.rotateLeft(a[3], 28);
+        a[3] = Long.rotateLeft(a[18], 21);
+        a[18] = Long.rotateLeft(a[17], 15);
+        a[17] = Long.rotateLeft(a[11], 10);
+        a[11] = Long.rotateLeft(a[7], 6);
+        a[7] = tmp;
+        return a;
     }
 
     /**
      * Step mapping Chi as defined in section 3.2.4.
      */
-    private static long[][] smChi(long[][] a) {
-        long[][] a2 = new long[DM][DM];
-        for (int y = 0; y < DM; y++) {
-            for (int x = 0; x < DM; x++) {
-                a2[y][x] = a[y][x] ^
-                    ((a[y][(x + 1) % DM] ^ 0xFFFFFFFFFFFFFFFFL) &
-                     a[y][(x + 2) % DM]);
-            }
+    private static long[] smChi(long[] a) {
+        for (int y = 0; y < a.length; y+=DM) {
+            long ay0 = a[y];
+            long ay1 = a[y+1];
+            long ay2 = a[y+2];
+            long ay3 = a[y+3];
+            long ay4 = a[y+4];
+            a[y] = ay0 ^ ((~ay1) & ay2);
+            a[y+1] = ay1 ^ ((~ay2) & ay3);
+            a[y+2] = ay2 ^ ((~ay3) & ay4);
+            a[y+3] = ay3 ^ ((~ay4) & ay0);
+            a[y+4] = ay4 ^ ((~ay0) & ay1);
         }
-        return a2;
+        return a;
     }
 
     /**
      * Step mapping Iota as defined in section 3.2.5.
-     *
-     * @return the processed state array
-     * @param state the state array to be processed
      */
-    private static long[][] smIota(long[][] a, int rndIndex) {
-        a[0][0] ^= RC_CONSTANTS[rndIndex];
+    private static long[] smIota(long[] a, int rndIndex) {
+        a[0] ^= RC_CONSTANTS[rndIndex];
         return a;
     }
 
@@ -248,12 +245,15 @@
      * The function Keccak as defined in section 5.2 with
      * rate r = 1600 and capacity c = (digest length x 2).
      */
-    private static byte[] keccak(byte[] state) {
-        long[][] lanes = bytes2Lanes(state);
+    private void keccak() {
+        // convert the 200-byte state into 25 lanes
+        bytes2Lanes(state, lanes);
+        // process the lanes through step mappings
         for (int ir = 0; ir < NR; ir++) {
-            lanes = smIota(smChi(smPi(smRho(smTheta(lanes)))), ir);
+            smIota(smChi(smPiRho(smTheta(lanes))), ir);
         }
-        return lanes2Bytes(lanes);
+        // convert the resulting 25 lanes back into 200-byte state
+        lanes2Bytes(lanes, state);
     }
 
     public Object clone() throws CloneNotSupportedException {