8069072: GHASH performance improvement
authorfweimer
Mon, 09 Feb 2015 13:32:42 -0800
changeset 28851 e4c16ed7bffa
parent 28850 4996a75e8bfb
child 28852 a581c7868768
8069072: GHASH performance improvement Summary: Eliminate allocations and vectorize Reviewed-by: mullan, ascarpino
jdk/src/java.base/share/classes/com/sun/crypto/provider/GHASH.java
jdk/test/com/sun/crypto/provider/Cipher/AES/TestGHASH.java
--- a/jdk/src/java.base/share/classes/com/sun/crypto/provider/GHASH.java	Mon Feb 09 11:37:56 2015 -0800
+++ b/jdk/src/java.base/share/classes/com/sun/crypto/provider/GHASH.java	Mon Feb 09 13:32:42 2015 -0800
@@ -1,5 +1,6 @@
 /*
  * Copyright (c) 2013, Oracle and/or its affiliates. All rights reserved.
+ * Copyright (c) 2015 Red Hat, Inc.
  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
  *
  * This code is free software; you can redistribute it and/or modify it
@@ -28,9 +29,7 @@
 
 package com.sun.crypto.provider;
 
-import java.util.Arrays;
-import java.security.*;
-import static com.sun.crypto.provider.AESConstants.AES_BLOCK_SIZE;
+import java.security.ProviderException;
 
 /**
  * This class represents the GHASH function defined in NIST 800-38D
@@ -44,62 +43,90 @@
  */
 final class GHASH {
 
-    private static final byte P128 = (byte) 0xe1; //reduction polynomial
-
-    private static boolean getBit(byte[] b, int pos) {
-        int p = pos / 8;
-        pos %= 8;
-        int i = (b[p] >>> (7 - pos)) & 1;
-        return i != 0;
+    private static long getLong(byte[] buffer, int offset) {
+        long result = 0;
+        int end = offset + 8;
+        for (int i = offset; i < end; ++i) {
+            result = (result << 8) + (buffer[i] & 0xFF);
+        }
+        return result;
     }
 
-    private static void shift(byte[] b) {
-        byte temp, temp2;
-        temp2 = 0;
-        for (int i = 0; i < b.length; i++) {
-            temp = (byte) ((b[i] & 0x01) << 7);
-            b[i] = (byte) ((b[i] & 0xff) >>> 1);
-            b[i] = (byte) (b[i] | temp2);
-            temp2 = temp;
+    private static void putLong(byte[] buffer, int offset, long value) {
+        int end = offset + 8;
+        for (int i = end - 1; i >= offset; --i) {
+            buffer[i] = (byte) value;
+            value >>= 8;
         }
     }
 
-    // Given block X and Y, returns the muliplication of X * Y
-    private static byte[] blockMult(byte[] x, byte[] y) {
-        if (x.length != AES_BLOCK_SIZE || y.length != AES_BLOCK_SIZE) {
-            throw new RuntimeException("illegal input sizes");
-        }
-        byte[] z = new byte[AES_BLOCK_SIZE];
-        byte[] v = y.clone();
-        // calculate Z1-Z127 and V1-V127
-        for (int i = 0; i < 127; i++) {
+    private static final int AES_BLOCK_SIZE = 16;
+
+    // Multiplies state0, state1 by V0, V1.
+    private void blockMult(long V0, long V1) {
+        long Z0 = 0;
+        long Z1 = 0;
+        long X;
+
+        // Separate loops for processing state0 and state1.
+        X = state0;
+        for (int i = 0; i < 64; i++) {
             // Zi+1 = Zi if bit i of x is 0
-            if (getBit(x, i)) {
-                for (int n = 0; n < z.length; n++) {
-                    z[n] ^= v[n];
-                }
-            }
-            boolean lastBitOfV = getBit(v, 127);
-            shift(v);
-            if (lastBitOfV) v[0] ^= P128;
+            long mask = X >> 63;
+            Z0 ^= V0 & mask;
+            Z1 ^= V1 & mask;
+
+            // Save mask for conditional reduction below.
+            mask = (V1 << 63) >> 63;
+
+            // V = rightshift(V)
+            long carry = V0 & 1;
+            V0 = V0 >>> 1;
+            V1 = (V1 >>> 1) | (carry << 63);
+
+            // Conditional reduction modulo P128.
+            V0 ^= 0xe100000000000000L & mask;
+            X <<= 1;
         }
+
+        X = state1;
+        for (int i = 64; i < 127; i++) {
+            // Zi+1 = Zi if bit i of x is 0
+            long mask = X >> 63;
+            Z0 ^= V0 & mask;
+            Z1 ^= V1 & mask;
+
+            // Save mask for conditional reduction below.
+            mask = (V1 << 63) >> 63;
+
+            // V = rightshift(V)
+            long carry = V0 & 1;
+            V0 = V0 >>> 1;
+            V1 = (V1 >>> 1) | (carry << 63);
+
+            // Conditional reduction.
+            V0 ^= 0xe100000000000000L & mask;
+            X <<= 1;
+        }
+
         // calculate Z128
-        if (getBit(x, 127)) {
-            for (int n = 0; n < z.length; n++) {
-                z[n] ^= v[n];
-            }
-        }
-        return z;
+        long mask = X >> 63;
+        Z0 ^= V0 & mask;
+        Z1 ^= V1 & mask;
+
+        // Save result.
+        state0 = Z0;
+        state1 = Z1;
     }
 
     // hash subkey H; should not change after the object has been constructed
-    private final byte[] subkeyH;
+    private final long subkeyH0, subkeyH1;
 
     // buffer for storing hash
-    private byte[] state;
+    private long state0, state1;
 
     // variables for save/restore calls
-    private byte[] stateSave = null;
+    private long stateSave0, stateSave1;
 
     /**
      * Initializes the cipher in the specified mode with the given key
@@ -114,8 +141,8 @@
         if ((subkeyH == null) || subkeyH.length != AES_BLOCK_SIZE) {
             throw new ProviderException("Internal error");
         }
-        this.subkeyH = subkeyH;
-        this.state = new byte[AES_BLOCK_SIZE];
+        this.subkeyH0 = getLong(subkeyH, 0);
+        this.subkeyH1 = getLong(subkeyH, 8);
     }
 
     /**
@@ -124,31 +151,33 @@
      * this object for different data w/ the same H.
      */
     void reset() {
-        Arrays.fill(state, (byte) 0);
+        state0 = 0;
+        state1 = 0;
     }
 
     /**
      * Save the current snapshot of this GHASH object.
      */
     void save() {
-        stateSave = state.clone();
+        stateSave0 = state0;
+        stateSave1 = state1;
     }
 
     /**
      * Restores this object using the saved snapshot.
      */
     void restore() {
-        state = stateSave;
+        state0 = stateSave0;
+        state1 = stateSave1;
     }
 
     private void processBlock(byte[] data, int ofs) {
         if (data.length - ofs < AES_BLOCK_SIZE) {
             throw new RuntimeException("need complete block");
         }
-        for (int n = 0; n < state.length; n++) {
-            state[n] ^= data[ofs + n];
-        }
-        state = blockMult(state, subkeyH);
+        state0 ^= getLong(data, ofs);
+        state1 ^= getLong(data, ofs + 8);
+        blockMult(subkeyH0, subkeyH1);
     }
 
     void update(byte[] in) {
@@ -169,10 +198,10 @@
     }
 
     byte[] digest() {
-        try {
-            return state.clone();
-        } finally {
-            reset();
-        }
+        byte[] result = new byte[AES_BLOCK_SIZE];
+        putLong(result, 0, state0);
+        putLong(result, 8, state1);
+        reset();
+        return result;
     }
 }
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/jdk/test/com/sun/crypto/provider/Cipher/AES/TestGHASH.java	Mon Feb 09 13:32:42 2015 -0800
@@ -0,0 +1,166 @@
+/*
+ * Copyright (c) 2015, Red Hat, Inc.
+ * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
+ *
+ * This code is free software; you can redistribute it and/or modify it
+ * under the terms of the GNU General Public License version 2 only, as
+ * published by the Free Software Foundation.
+ *
+ * This code is distributed in the hope that it will be useful, but WITHOUT
+ * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
+ * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
+ * version 2 for more details (a copy is included in the LICENSE file that
+ * accompanied this code).
+ *
+ * You should have received a copy of the GNU General Public License version
+ * 2 along with this work; if not, write to the Free Software Foundation,
+ * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
+ *
+ * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
+ * or visit www.oracle.com if you need additional information or have any
+ * questions.
+ */
+
+/*
+ * @test
+ * @bug 8069072
+ * @summary Test vectors for com.sun.crypto.provider.GHASH
+ */
+import java.lang.reflect.Constructor;
+import java.lang.reflect.Method;
+import java.nio.ByteBuffer;
+
+public class TestGHASH {
+
+    private final Constructor<?> GHASH;
+    private final Method UPDATE;
+    private final Method DIGEST;
+
+    TestGHASH(String className) throws Exception {
+        Class<?> cls = Class.forName(className);
+        GHASH = cls.getDeclaredConstructor(byte[].class);
+        GHASH.setAccessible(true);
+        UPDATE = cls.getDeclaredMethod("update", byte[].class);
+        UPDATE.setAccessible(true);
+        DIGEST = cls.getDeclaredMethod("digest");
+        DIGEST.setAccessible(true);
+    }
+
+
+    private Object newGHASH(byte[] H) throws Exception {
+        return GHASH.newInstance(H);
+    }
+
+    private void updateGHASH(Object hash, byte[] data)
+            throws Exception {
+        UPDATE.invoke(hash, data);
+    }
+
+    private byte[] digestGHASH(Object hash) throws Exception {
+        return (byte[]) DIGEST.invoke(hash);
+    }
+
+    private static final String HEX_DIGITS = "0123456789abcdef";
+
+    private static String hex(byte[] bs) {
+        StringBuilder sb = new StringBuilder(2 * bs.length);
+        for (byte b : bs) {
+            sb.append(HEX_DIGITS.charAt((b >> 4) & 0xF));
+            sb.append(HEX_DIGITS.charAt(b & 0xF));
+        }
+        return sb.toString();
+    }
+
+    private static byte[] bytes(String hex) {
+        if ((hex.length() & 1) != 0) {
+            throw new AssertionError();
+        }
+        byte[] result = new byte[hex.length() / 2];
+        for (int i = 0; i < result.length; ++i) {
+            int a = HEX_DIGITS.indexOf(hex.charAt(2 * i));
+            int b = HEX_DIGITS.indexOf(hex.charAt(2 * i + 1));
+            if ((a | b) < 0) {
+                if (a < 0) {
+                    throw new AssertionError(
+                            "bad character " + (int) hex.charAt(2 * i));
+                }
+                throw new AssertionError(
+                        "bad character " + (int) hex.charAt(2 * i + 1));
+            }
+            result[i] = (byte) ((a << 4) | b);
+        }
+        return result;
+    }
+
+    private static byte[] bytes(long L0, long L1) {
+        return ByteBuffer.allocate(16)
+                .putLong(L0)
+                .putLong(L1)
+                .array();
+    }
+
+    private void check(int testCase, String H, String A,
+            String C, String expected) throws Exception {
+        int lenA = A.length() * 4;
+        while ((A.length() % 32) != 0) {
+            A += '0';
+        }
+        int lenC = C.length() * 4;
+        while ((C.length() % 32) != 0) {
+            C += '0';
+        }
+
+        Object hash = newGHASH(bytes(H));
+        updateGHASH(hash, bytes(A));
+        updateGHASH(hash, bytes(C));
+        updateGHASH(hash, bytes(lenA, lenC));
+        byte[] digest = digestGHASH(hash);
+        String actual = hex(digest);
+        if (!expected.equals(actual)) {
+            throw new AssertionError(String.format("%d: expected %s, got %s",
+                    testCase, expected, actual));
+        }
+    }
+
+    public static void main(String[] args) throws Exception {
+        TestGHASH test;
+        if (args.length == 0) {
+            test = new TestGHASH("com.sun.crypto.provider.GHASH");
+        } else {
+            test = new TestGHASH(args[0]);
+        }
+
+        // Test vectors from David A. McGrew, John Viega,
+        // "The Galois/Counter Mode of Operation (GCM)", 2005.
+        // <http://csrc.nist.gov/groups/ST/toolkit/BCM/documents/proposedmodes/gcm/gcm-revised-spec.pdf>
+
+        test.check(1, "66e94bd4ef8a2c3b884cfa59ca342b2e", "", "",
+                "00000000000000000000000000000000");
+        test.check(2,
+                "66e94bd4ef8a2c3b884cfa59ca342b2e", "",
+                "0388dace60b6a392f328c2b971b2fe78",
+                "f38cbb1ad69223dcc3457ae5b6b0f885");
+        test.check(3,
+                "b83b533708bf535d0aa6e52980d53b78", "",
+                "42831ec2217774244b7221b784d0d49c" +
+                "e3aa212f2c02a4e035c17e2329aca12e" +
+                "21d514b25466931c7d8f6a5aac84aa05" +
+                "1ba30b396a0aac973d58e091473f5985",
+                "7f1b32b81b820d02614f8895ac1d4eac");
+        test.check(4,
+                "b83b533708bf535d0aa6e52980d53b78",
+                "feedfacedeadbeeffeedfacedeadbeef" + "abaddad2",
+                "42831ec2217774244b7221b784d0d49c" +
+                "e3aa212f2c02a4e035c17e2329aca12e" +
+                "21d514b25466931c7d8f6a5aac84aa05" +
+                "1ba30b396a0aac973d58e091",
+                "698e57f70e6ecc7fd9463b7260a9ae5f");
+        test.check(5, "b83b533708bf535d0aa6e52980d53b78",
+                "feedfacedeadbeeffeedfacedeadbeef" + "abaddad2",
+                "61353b4c2806934a777ff51fa22a4755" +
+                "699b2a714fcdc6f83766e5f97b6c7423" +
+                "73806900e49f24b22b097544d4896b42" +
+                "4989b5e1ebac0f07c23f4598",
+                "df586bb4c249b92cb6922877e444d37b");
+    }
+}