8143925: Enhancing CounterMode.crypt() for AES
Summary: Add intrinsic for CounterMode.crypt() to leverage the parallel nature of AES in Counter(CTR) Mode.
Reviewed-by: kvn, ascarpino
Contributed-by: kishor.kharbas@intel.com
--- a/jdk/src/java.base/share/classes/com/sun/crypto/provider/CounterMode.java Tue Dec 22 13:41:12 2015 -0800
+++ b/jdk/src/java.base/share/classes/com/sun/crypto/provider/CounterMode.java Mon Dec 28 22:28:49 2015 -0800
@@ -26,7 +26,9 @@
package com.sun.crypto.provider;
import java.security.InvalidKeyException;
+import java.util.Objects;
+import jdk.internal.HotSpotIntrinsicCandidate;
/**
* This class represents ciphers in counter (CTR) mode.
@@ -138,7 +140,7 @@
* <code>cipherOffset</code>.
*
* @param in the buffer with the input data to be encrypted
- * @param inOffset the offset in <code>plain</code>
+ * @param inOff the offset in <code>plain</code>
* @param len the length of the input data
* @param out the buffer for the result
* @param outOff the offset in <code>cipher</code>
@@ -170,6 +172,15 @@
* are encrypted on demand.
*/
private int crypt(byte[] in, int inOff, int len, byte[] out, int outOff) {
+
+ cryptBlockCheck(in, inOff, len);
+ cryptBlockCheck(out, outOff, len);
+ return implCrypt(in, inOff, len, out, outOff);
+ }
+
+ // Implementation of crpyt() method. Possibly replaced with a compiler intrinsic.
+ @HotSpotIntrinsicCandidate
+ private int implCrypt(byte[] in, int inOff, int len, byte[] out, int outOff) {
int result = len;
while (len-- > 0) {
if (used >= blockSize) {
@@ -181,4 +192,23 @@
}
return result;
}
+
+ // Used to perform all checks required by the Java semantics
+ // (i.e., null checks and bounds checks) on the input parameters to crypt().
+ // Normally, the Java Runtime performs these checks, however, as crypt() is
+ // possibly replaced with compiler intrinsic, the JDK performs the
+ // required checks instead.
+ // Does not check accesses to class-internal (private) arrays.
+ private static void cryptBlockCheck(byte[] array, int offset, int len) {
+ Objects.requireNonNull(array);
+
+ if (offset < 0 || len < 0 || offset >= array.length) {
+ throw new ArrayIndexOutOfBoundsException(offset);
+ }
+
+ int largestIndex = offset + len - 1;
+ if (largestIndex < 0 || largestIndex >= array.length) {
+ throw new ArrayIndexOutOfBoundsException(largestIndex);
+ }
+ }
}