8178374: Problematic ByteBuffer handling in CipherSpi.bufferCrypt method
authorvaleriep
Mon, 11 Jun 2018 21:56:58 +0000
changeset 50510 e93ba293e962
parent 50509 2b940ad6816f
child 50511 075e9982b409
8178374: Problematic ByteBuffer handling in CipherSpi.bufferCrypt method Summary: Updated the impl and added reg test to cover all 4 combinations of ByteBuffers Reviewed-by: ascarpino
src/java.base/share/classes/javax/crypto/CipherSpi.java
test/jdk/javax/crypto/CipherSpi/TestGCMWithByteBuffer.java
--- a/src/java.base/share/classes/javax/crypto/CipherSpi.java	Mon Jun 11 14:29:38 2018 -0700
+++ b/src/java.base/share/classes/javax/crypto/CipherSpi.java	Mon Jun 11 21:56:58 2018 +0000
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 1997, 2017, Oracle and/or its affiliates. All rights reserved.
+ * Copyright (c) 1997, 2018, Oracle and/or its affiliates. All rights reserved.
  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
  *
  * This code is free software; you can redistribute it and/or modify it
@@ -755,6 +755,7 @@
             return 0;
         }
         int outLenNeeded = engineGetOutputSize(inLen);
+
         if (output.remaining() < outLenNeeded) {
             throw new ShortBufferException("Need at least " + outLenNeeded
                 + " bytes of space in output buffer");
@@ -762,98 +763,77 @@
 
         boolean a1 = input.hasArray();
         boolean a2 = output.hasArray();
-
-        if (a1 && a2) {
-            byte[] inArray = input.array();
-            int inOfs = input.arrayOffset() + inPos;
-            byte[] outArray = output.array();
+        int total = 0;
+        byte[] inArray, outArray;
+        if (a2) { // output has an accessible byte[]
+            outArray = output.array();
             int outPos = output.position();
             int outOfs = output.arrayOffset() + outPos;
-            int n;
-            if (isUpdate) {
-                n = engineUpdate(inArray, inOfs, inLen, outArray, outOfs);
-            } else {
-                n = engineDoFinal(inArray, inOfs, inLen, outArray, outOfs);
-            }
-            input.position(inLimit);
-            output.position(outPos + n);
-            return n;
-        } else if (!a1 && a2) {
-            int outPos = output.position();
-            byte[] outArray = output.array();
-            int outOfs = output.arrayOffset() + outPos;
-            byte[] inArray = new byte[getTempArraySize(inLen)];
-            int total = 0;
-            do {
-                int chunk = Math.min(inLen, inArray.length);
-                if (chunk > 0) {
-                    input.get(inArray, 0, chunk);
+
+            if (a1) { // input also has an accessible byte[]
+                inArray = input.array();
+                int inOfs = input.arrayOffset() + inPos;
+                if (isUpdate) {
+                    total = engineUpdate(inArray, inOfs, inLen, outArray, outOfs);
+                } else {
+                    total = engineDoFinal(inArray, inOfs, inLen, outArray, outOfs);
                 }
-                int n;
-                if (isUpdate || (inLen != chunk)) {
-                    n = engineUpdate(inArray, 0, chunk, outArray, outOfs);
-                } else {
-                    n = engineDoFinal(inArray, 0, chunk, outArray, outOfs);
-                }
-                total += n;
-                outOfs += n;
-                inLen -= chunk;
-            } while (inLen > 0);
-            output.position(outPos + total);
-            return total;
-        } else { // output is not backed by an accessible byte[]
-            byte[] inArray;
-            int inOfs;
-            if (a1) {
-                inArray = input.array();
-                inOfs = input.arrayOffset() + inPos;
-            } else {
+                input.position(inLimit);
+            } else { // input does not have accessible byte[]
                 inArray = new byte[getTempArraySize(inLen)];
-                inOfs = 0;
+                do {
+                    int chunk = Math.min(inLen, inArray.length);
+                    if (chunk > 0) {
+                        input.get(inArray, 0, chunk);
+                    }
+                    int n;
+                    if (isUpdate || (inLen > chunk)) {
+                        n = engineUpdate(inArray, 0, chunk, outArray, outOfs);
+                    } else {
+                        n = engineDoFinal(inArray, 0, chunk, outArray, outOfs);
+                    }
+                    total += n;
+                    outOfs += n;
+                    inLen -= chunk;
+                } while (inLen > 0);
             }
-            byte[] outArray = new byte[getTempArraySize(outLenNeeded)];
-            int outSize = outArray.length;
-            int total = 0;
-            boolean resized = false;
-            do {
-                int chunk =
-                    Math.min(inLen, (outSize == 0? inArray.length : outSize));
-                if (!a1 && !resized && chunk > 0) {
-                    input.get(inArray, 0, chunk);
-                    inOfs = 0;
+            output.position(outPos + total);
+        } else { // output does not have an accessible byte[]
+            if (a1) { // but input has an accessible byte[]
+                inArray = input.array();
+                int inOfs = input.arrayOffset() + inPos;
+                if (isUpdate) {
+                    outArray = engineUpdate(inArray, inOfs, inLen);
+                } else {
+                    outArray = engineDoFinal(inArray, inOfs, inLen);
+                }
+                input.position(inLimit);
+                if (outArray != null && outArray.length != 0) {
+                    output.put(outArray);
+                    total = outArray.length;
                 }
-                try {
+            } else { // input also does not have an accessible byte[]
+                inArray = new byte[getTempArraySize(inLen)];
+                do {
+                    int chunk = Math.min(inLen, inArray.length);
+                    if (chunk > 0) {
+                        input.get(inArray, 0, chunk);
+                    }
                     int n;
-                    if (isUpdate || (inLen != chunk)) {
-                        n = engineUpdate(inArray, inOfs, chunk, outArray, 0);
+                    if (isUpdate || (inLen > chunk)) {
+                        outArray = engineUpdate(inArray, 0, chunk);
                     } else {
-                        n = engineDoFinal(inArray, inOfs, chunk, outArray, 0);
+                        outArray = engineDoFinal(inArray, 0, chunk);
                     }
-                    resized = false;
-                    inOfs += chunk;
+                    if (outArray != null && outArray.length != 0) {
+                        output.put(outArray);
+                        total += outArray.length;
+                    }
                     inLen -= chunk;
-                    if (n > 0) {
-                        output.put(outArray, 0, n);
-                        total += n;
-                    }
-                } catch (ShortBufferException e) {
-                    if (resized) {
-                        // we just resized the output buffer, but it still
-                        // did not work. Bug in the provider, abort
-                        throw (ProviderException)new ProviderException
-                            ("Could not determine buffer size").initCause(e);
-                    }
-                    // output buffer is too small, realloc and try again
-                    resized = true;
-                    outSize = engineGetOutputSize(chunk);
-                    outArray = new byte[outSize];
-                }
-            } while (inLen > 0);
-            if (a1) {
-                input.position(inLimit);
+                } while (inLen > 0);
             }
-            return total;
         }
+        return total;
     }
 
     /**
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/test/jdk/javax/crypto/CipherSpi/TestGCMWithByteBuffer.java	Mon Jun 11 21:56:58 2018 +0000
@@ -0,0 +1,165 @@
+/*
+ * Copyright (c) 2018, Oracle and/or its affiliates. All rights reserved.
+ * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
+ *
+ * This code is free software; you can redistribute it and/or modify it
+ * under the terms of the GNU General Public License version 2 only, as
+ * published by the Free Software Foundation.
+ *
+ * This code is distributed in the hope that it will be useful, but WITHOUT
+ * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
+ * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
+ * version 2 for more details (a copy is included in the LICENSE file that
+ * accompanied this code).
+ *
+ * You should have received a copy of the GNU General Public License version
+ * 2 along with this work; if not, write to the Free Software Foundation,
+ * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
+ *
+ * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
+ * or visit www.oracle.com if you need additional information or have any
+ * questions.
+ */
+
+/**
+ * @test
+ * @bug 8178374
+ * @summary Test GCM decryption with various types of input/output
+ *     ByteBuffer objects
+ * @key randomness
+ */
+
+import java.nio.ByteBuffer;
+import java.security.*;
+import java.util.Random;
+
+import javax.crypto.Cipher;
+import javax.crypto.SecretKey;
+import javax.crypto.AEADBadTagException;
+import javax.crypto.spec.*;
+
+public class TestGCMWithByteBuffer {
+
+    private static Random random = new SecureRandom();
+    private static int dataSize = 4096; // see javax.crypto.CipherSpi
+    private static int multiples = 3;
+
+    public static void main(String args[]) throws Exception {
+        Provider[] provs = Security.getProviders();
+        for (Provider p : provs) {
+            try {
+                Cipher cipher = Cipher.getInstance("AES/GCM/NoPadding", p);
+                test(cipher);
+            } catch (NoSuchAlgorithmException nsae) {
+                // skip testing due to no support
+                continue;
+            }
+        }
+    }
+
+    private static void test(Cipher cipher) throws Exception {
+        System.out.println("Testing " + cipher.getProvider());
+
+        boolean failedOnce = false;
+        Exception failedReason = null;
+
+        int tagLen = 96; // in bits
+        byte[] keyBytes = new byte[16];
+        random.nextBytes(keyBytes);
+        byte[] dataChunk = new byte[dataSize];
+        random.nextBytes(dataChunk);
+
+        SecretKey key = new SecretKeySpec(keyBytes, "AES");
+        // re-use key bytes as IV as the real test is buffer calculation
+        GCMParameterSpec s = new GCMParameterSpec(tagLen, keyBytes);
+
+        /*
+         * Iterate through various sizes to make sure that the code works with
+         * internal temp buffer size 4096.
+         */
+        for (int t = 1; t <= multiples; t++) {
+            int size = t * dataSize;
+
+            System.out.println("\nTesting data size: " + size);
+
+            try {
+                decrypt(cipher, key, s, dataChunk, t,
+                        ByteBuffer.allocate(dataSize),
+                        ByteBuffer.allocate(size),
+                        ByteBuffer.allocateDirect(dataSize),
+                        ByteBuffer.allocateDirect(size));
+            } catch (Exception e) {
+                System.out.println("\tFailed with data size " + size);
+                failedOnce = true;
+                failedReason = e;
+            }
+        }
+        if (failedOnce) {
+            throw failedReason;
+        }
+        System.out.println("\n=> Passed...");
+    }
+
+    private enum TestVariant {
+        HEAP_HEAP, HEAP_DIRECT, DIRECT_HEAP, DIRECT_DIRECT
+    };
+
+    private static void decrypt(Cipher cipher, SecretKey key,
+            GCMParameterSpec s, byte[] dataChunk, int multiples,
+            ByteBuffer heapIn, ByteBuffer heapOut, ByteBuffer directIn,
+            ByteBuffer directOut) throws Exception {
+
+        ByteBuffer inBB = null;
+        ByteBuffer outBB = null;
+
+        // try various combinations of input/output
+        for (TestVariant tv : TestVariant.values()) {
+            System.out.println(" " + tv);
+
+            switch (tv) {
+            case HEAP_HEAP:
+                inBB = heapIn;
+                outBB = heapOut;
+                break;
+            case HEAP_DIRECT:
+                inBB = heapIn;
+                outBB = directOut;
+                break;
+            case DIRECT_HEAP:
+                inBB = directIn;
+                outBB = heapOut;
+                break;
+            case DIRECT_DIRECT:
+                inBB = directIn;
+                outBB = directOut;
+                break;
+            }
+
+            // prepare input and output buffers
+            inBB.clear();
+            inBB.put(dataChunk);
+
+            outBB.clear();
+
+            try {
+                // Always re-init the Cipher object so cipher is in
+                // a good state for future testing
+                cipher.init(Cipher.DECRYPT_MODE, key, s);
+
+                for (int i = 0; i < multiples; i++) {
+                    inBB.flip();
+                    cipher.update(inBB, outBB);
+                    if (inBB.hasRemaining()) {
+                        throw new Exception("buffer not empty");
+                    }
+                }
+                // finish decryption and process all data buffered
+                cipher.doFinal(inBB, outBB);
+                throw new RuntimeException("Error: doFinal completed without exception");
+            } catch (AEADBadTagException ex) {
+                System.out.println("Expected AEADBadTagException thrown");
+                continue;
+            }
+        }
+    }
+}