8181386: CipherSpi ByteBuffer to byte array conversion fails for certain data overlap conditions
authorvaleriep
Wed, 10 Jul 2019 18:43:45 +0000
changeset 55661 b32b6ffb221b
parent 55660 fe5dcb38a26a
child 55662 1d5ce4787723
8181386: CipherSpi ByteBuffer to byte array conversion fails for certain data overlap conditions Summary: Detect potential buffer overlap and use extra buffer if necessary Reviewed-by: xuelei
src/java.base/share/classes/javax/crypto/CipherSpi.java
test/jdk/javax/crypto/CipherSpi/CipherByteBufferOverwriteTest.java
--- a/src/java.base/share/classes/javax/crypto/CipherSpi.java	Wed Jul 10 18:48:05 2019 +0200
+++ b/src/java.base/share/classes/javax/crypto/CipherSpi.java	Wed Jul 10 18:43:45 2019 +0000
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 1997, 2018, Oracle and/or its affiliates. All rights reserved.
+ * Copyright (c) 1997, 2019, Oracle and/or its affiliates. All rights reserved.
  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
  *
  * This code is free software; you can redistribute it and/or modify it
@@ -761,78 +761,87 @@
                 + " bytes of space in output buffer");
         }
 
+        // detecting input and output buffer overlap may be tricky
+        // we can only write directly into output buffer when we
+        // are 100% sure it's safe to do so
+
         boolean a1 = input.hasArray();
         boolean a2 = output.hasArray();
         int total = 0;
-        byte[] inArray, outArray;
-        if (a2) { // output has an accessible byte[]
-            outArray = output.array();
-            int outPos = output.position();
-            int outOfs = output.arrayOffset() + outPos;
+
+        if (a1) { // input has an accessible byte[]
+            byte[] inArray = input.array();
+            int inOfs = input.arrayOffset() + inPos;
+
+            if (a2) { // output has an accessible byte[]
+                byte[] outArray = output.array();
+                int outPos = output.position();
+                int outOfs = output.arrayOffset() + outPos;
 
-            if (a1) { // input also has an accessible byte[]
-                inArray = input.array();
-                int inOfs = input.arrayOffset() + inPos;
+                // check array address and offsets and use temp output buffer
+                // if output offset is larger than input offset and
+                // falls within the range of input data
+                boolean useTempOut = false;
+                if (inArray == outArray &&
+                    ((inOfs < outOfs) && (outOfs < inOfs + inLen))) {
+                    useTempOut = true;
+                    outArray = new byte[outLenNeeded];
+                    outOfs = 0;
+                }
                 if (isUpdate) {
                     total = engineUpdate(inArray, inOfs, inLen, outArray, outOfs);
                 } else {
                     total = engineDoFinal(inArray, inOfs, inLen, outArray, outOfs);
                 }
+                if (useTempOut) {
+                    output.put(outArray, outOfs, total);
+                } else {
+                    // adjust output position manually
+                    output.position(outPos + total);
+                }
+                // adjust input position manually
                 input.position(inLimit);
-            } else { // input does not have accessible byte[]
-                inArray = new byte[getTempArraySize(inLen)];
-                do {
-                    int chunk = Math.min(inLen, inArray.length);
-                    if (chunk > 0) {
-                        input.get(inArray, 0, chunk);
-                    }
-                    int n;
-                    if (isUpdate || (inLen > chunk)) {
-                        n = engineUpdate(inArray, 0, chunk, outArray, outOfs);
-                    } else {
-                        n = engineDoFinal(inArray, 0, chunk, outArray, outOfs);
-                    }
-                    total += n;
-                    outOfs += n;
-                    inLen -= chunk;
-                } while (inLen > 0);
-            }
-            output.position(outPos + total);
-        } else { // output does not have an accessible byte[]
-            if (a1) { // but input has an accessible byte[]
-                inArray = input.array();
-                int inOfs = input.arrayOffset() + inPos;
+            } else { // output does not have an accessible byte[]
+                byte[] outArray = null;
                 if (isUpdate) {
                     outArray = engineUpdate(inArray, inOfs, inLen);
                 } else {
                     outArray = engineDoFinal(inArray, inOfs, inLen);
                 }
-                input.position(inLimit);
                 if (outArray != null && outArray.length != 0) {
                     output.put(outArray);
                     total = outArray.length;
                 }
-            } else { // input also does not have an accessible byte[]
-                inArray = new byte[getTempArraySize(inLen)];
-                do {
-                    int chunk = Math.min(inLen, inArray.length);
-                    if (chunk > 0) {
-                        input.get(inArray, 0, chunk);
-                    }
-                    int n;
-                    if (isUpdate || (inLen > chunk)) {
-                        outArray = engineUpdate(inArray, 0, chunk);
-                    } else {
-                        outArray = engineDoFinal(inArray, 0, chunk);
-                    }
-                    if (outArray != null && outArray.length != 0) {
-                        output.put(outArray);
-                        total += outArray.length;
-                    }
-                    inLen -= chunk;
-                } while (inLen > 0);
+                // adjust input position manually
+                input.position(inLimit);
+            }
+        } else { // input does not have an accessible byte[]
+            // have to assume the worst, since we have no way of determine
+            // if input and output overlaps or not
+            byte[] tempOut = new byte[outLenNeeded];
+            int outOfs = 0;
+
+            byte[] tempIn = new byte[getTempArraySize(inLen)];
+            do {
+                int chunk = Math.min(inLen, tempIn.length);
+                if (chunk > 0) {
+                    input.get(tempIn, 0, chunk);
+                }
+                int n;
+                if (isUpdate || (inLen > chunk)) {
+                    n = engineUpdate(tempIn, 0, chunk, tempOut, outOfs);
+                } else {
+                    n = engineDoFinal(tempIn, 0, chunk, tempOut, outOfs);
+                }
+                outOfs += n;
+                total += n;
+                inLen -= chunk;
+            } while (inLen > 0);
+            if (total > 0) {
+                output.put(tempOut, 0, total);
             }
         }
+
         return total;
     }
 
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/test/jdk/javax/crypto/CipherSpi/CipherByteBufferOverwriteTest.java	Wed Jul 10 18:43:45 2019 +0000
@@ -0,0 +1,191 @@
+/*
+ * Copyright (c) 2019, Oracle and/or its affiliates. All rights reserved.
+ * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
+ *
+ * This code is free software; you can redistribute it and/or modify it
+ * under the terms of the GNU General Public License version 2 only, as
+ * published by the Free Software Foundation.
+ *
+ * This code is distributed in the hope that it will be useful, but WITHOUT
+ * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
+ * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
+ * version 2 for more details (a copy is included in the LICENSE file that
+ * accompanied this code).
+ *
+ * You should have received a copy of the GNU General Public License version
+ * 2 along with this work; if not, write to the Free Software Foundation,
+ * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
+ *
+ * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
+ * or visit www.oracle.com if you need additional information or have any
+ * questions.
+ */
+
+/**
+ * @test
+ * @bug 8181386
+ * @summary CipherSpi ByteBuffer to byte array conversion fails for
+ *          certain data overlap conditions
+ * @run main CipherByteBufferOverwriteTest 0 false
+ * @run main CipherByteBufferOverwriteTest 0 true
+ * @run main CipherByteBufferOverwriteTest 4 false
+ * @run main CipherByteBufferOverwriteTest 4 true
+ */
+
+import java.security.spec.AlgorithmParameterSpec;
+import javax.crypto.Cipher;
+import javax.crypto.SecretKey;
+import javax.crypto.spec.IvParameterSpec;
+import javax.crypto.spec.SecretKeySpec;
+import java.nio.ByteBuffer;
+import java.util.Arrays;
+
+public class CipherByteBufferOverwriteTest {
+
+    private static final boolean DEBUG = false;
+
+    private static final String TRANSFORMATION = "AES/CBC/PKCS5Padding";
+
+    // must be larger than the temp array size, i.e. 4096, hardcoded in
+    // javax.crypto.CipherSpi class
+    private static final int PLAINTEXT_SIZE = 8192;
+    // leave room for padding
+    private static final int CIPHERTEXT_BUFFER_SIZE = PLAINTEXT_SIZE + 32;
+
+    private static final SecretKey KEY = new SecretKeySpec(new byte[16], "AES");
+    private static final AlgorithmParameterSpec PARAMS =
+            new IvParameterSpec(new byte[16]);
+
+    private static ByteBuffer inBuf;
+    private static ByteBuffer outBuf;
+
+    private enum BufferType {
+        ALLOCATE, DIRECT, WRAP;
+    }
+
+    public static void main(String[] args) throws Exception {
+
+        int offset = Integer.parseInt(args[0]);
+        boolean useRO = Boolean.parseBoolean(args[1]);
+
+        // an all-zeros plaintext is the easiest way to demonstrate the issue,
+        // but it fails with any plaintext, of course
+        byte[] expectedPT = new byte[PLAINTEXT_SIZE];
+        byte[] buf = new byte[offset + CIPHERTEXT_BUFFER_SIZE];
+        System.arraycopy(expectedPT, 0, buf, 0, PLAINTEXT_SIZE);
+
+        // generate expected cipher text using byte[] methods
+        Cipher c = Cipher.getInstance(TRANSFORMATION);
+        c.init(Cipher.ENCRYPT_MODE, KEY, PARAMS);
+        byte[] expectedCT = c.doFinal(expectedPT);
+
+        // Test#1: against ByteBuffer generated with allocate(int) call
+        prepareBuffers(BufferType.ALLOCATE, useRO, buf.length,
+                buf, 0, PLAINTEXT_SIZE, offset);
+
+        runTest(offset, expectedPT, expectedCT);
+        System.out.println("\tALLOCATE: passed");
+
+        // Test#2: against direct ByteBuffer
+        prepareBuffers(BufferType.DIRECT, useRO, buf.length,
+                buf, 0, PLAINTEXT_SIZE, offset);
+        System.out.println("\tDIRECT: passed");
+
+        runTest(offset, expectedPT, expectedCT);
+
+        // Test#3: against ByteBuffer wrapping existing array
+        prepareBuffers(BufferType.WRAP, useRO, buf.length,
+                buf, 0, PLAINTEXT_SIZE, offset);
+
+        runTest(offset, expectedPT, expectedCT);
+        System.out.println("\tWRAP: passed");
+
+        System.out.println("All Tests Passed");
+    }
+
+    private static void prepareBuffers(BufferType type,
+            boolean useRO, int bufSz, byte[] in, int inOfs, int inLen,
+            int outOfs) {
+        switch (type) {
+            case ALLOCATE:
+                outBuf = ByteBuffer.allocate(bufSz);
+                inBuf = outBuf.slice();
+                inBuf.put(in, inOfs, inLen);
+                inBuf.rewind();
+                inBuf.limit(inLen);
+                outBuf.position(outOfs);
+                break;
+            case DIRECT:
+                outBuf = ByteBuffer.allocateDirect(bufSz);
+                inBuf = outBuf.slice();
+                inBuf.put(in, inOfs, inLen);
+                inBuf.rewind();
+                inBuf.limit(inLen);
+                outBuf.position(outOfs);
+                break;
+            case WRAP:
+                if (in.length < bufSz) {
+                    throw new RuntimeException("ERROR: Input buffer too small");
+                }
+                outBuf = ByteBuffer.wrap(in);
+                inBuf = ByteBuffer.wrap(in, inOfs, inLen);
+                outBuf.position(outOfs);
+                break;
+        }
+        if (useRO) {
+            inBuf = inBuf.asReadOnlyBuffer();
+        }
+        if (DEBUG) {
+            System.out.println("inBuf, pos = " + inBuf.position() +
+                ", capacity = " + inBuf.capacity() +
+                ", limit = " + inBuf.limit() +
+                ", remaining = " + inBuf.remaining());
+            System.out.println("outBuf, pos = " + outBuf.position() +
+                ", capacity = " + outBuf.capacity() +
+                ", limit = " + outBuf.limit() +
+                ", remaining = " + outBuf.remaining());
+        }
+    }
+
+    private static void runTest(int ofs, byte[] expectedPT, byte[] expectedCT)
+            throws Exception {
+
+        Cipher c = Cipher.getInstance(TRANSFORMATION);
+        c.init(Cipher.ENCRYPT_MODE, KEY, PARAMS);
+        int ciphertextSize = c.doFinal(inBuf, outBuf);
+
+        // read out the encrypted result
+        outBuf.position(ofs);
+        byte[] finalCT = new byte[ciphertextSize];
+        if (DEBUG) {
+            System.out.println("runTest, ciphertextSize = " + ciphertextSize);
+            System.out.println("runTest, ofs = " + ofs +
+                ", remaining = " + finalCT.length +
+                ", limit = " + outBuf.limit());
+        }
+        outBuf.get(finalCT);
+
+        if (!Arrays.equals(finalCT, expectedCT)) {
+            throw new Exception("ERROR: Ciphertext does not match");
+        }
+
+        // now do decryption
+        outBuf.position(ofs);
+        outBuf.limit(ofs + ciphertextSize);
+
+        c.init(Cipher.DECRYPT_MODE, KEY, PARAMS);
+        ByteBuffer finalPTBuf = ByteBuffer.allocate(
+                c.getOutputSize(outBuf.remaining()));
+        c.doFinal(outBuf, finalPTBuf);
+
+        // read out the decrypted result
+        finalPTBuf.flip();
+        byte[] finalPT = new byte[finalPTBuf.remaining()];
+        finalPTBuf.get(finalPT);
+
+        if (!Arrays.equals(finalPT, expectedPT)) {
+            throw new Exception("ERROR: Plaintext does not match");
+        }
+    }
+}
+