src/java.base/share/classes/javax/crypto/CipherSpi.java
changeset 55661 b32b6ffb221b
parent 50510 e93ba293e962
--- a/src/java.base/share/classes/javax/crypto/CipherSpi.java	Wed Jul 10 18:48:05 2019 +0200
+++ b/src/java.base/share/classes/javax/crypto/CipherSpi.java	Wed Jul 10 18:43:45 2019 +0000
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 1997, 2018, Oracle and/or its affiliates. All rights reserved.
+ * Copyright (c) 1997, 2019, Oracle and/or its affiliates. All rights reserved.
  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
  *
  * This code is free software; you can redistribute it and/or modify it
@@ -761,78 +761,87 @@
                 + " bytes of space in output buffer");
         }
 
+        // detecting input and output buffer overlap may be tricky
+        // we can only write directly into output buffer when we
+        // are 100% sure it's safe to do so
+
         boolean a1 = input.hasArray();
         boolean a2 = output.hasArray();
         int total = 0;
-        byte[] inArray, outArray;
-        if (a2) { // output has an accessible byte[]
-            outArray = output.array();
-            int outPos = output.position();
-            int outOfs = output.arrayOffset() + outPos;
+
+        if (a1) { // input has an accessible byte[]
+            byte[] inArray = input.array();
+            int inOfs = input.arrayOffset() + inPos;
+
+            if (a2) { // output has an accessible byte[]
+                byte[] outArray = output.array();
+                int outPos = output.position();
+                int outOfs = output.arrayOffset() + outPos;
 
-            if (a1) { // input also has an accessible byte[]
-                inArray = input.array();
-                int inOfs = input.arrayOffset() + inPos;
+                // check array address and offsets and use temp output buffer
+                // if output offset is larger than input offset and
+                // falls within the range of input data
+                boolean useTempOut = false;
+                if (inArray == outArray &&
+                    ((inOfs < outOfs) && (outOfs < inOfs + inLen))) {
+                    useTempOut = true;
+                    outArray = new byte[outLenNeeded];
+                    outOfs = 0;
+                }
                 if (isUpdate) {
                     total = engineUpdate(inArray, inOfs, inLen, outArray, outOfs);
                 } else {
                     total = engineDoFinal(inArray, inOfs, inLen, outArray, outOfs);
                 }
+                if (useTempOut) {
+                    output.put(outArray, outOfs, total);
+                } else {
+                    // adjust output position manually
+                    output.position(outPos + total);
+                }
+                // adjust input position manually
                 input.position(inLimit);
-            } else { // input does not have accessible byte[]
-                inArray = new byte[getTempArraySize(inLen)];
-                do {
-                    int chunk = Math.min(inLen, inArray.length);
-                    if (chunk > 0) {
-                        input.get(inArray, 0, chunk);
-                    }
-                    int n;
-                    if (isUpdate || (inLen > chunk)) {
-                        n = engineUpdate(inArray, 0, chunk, outArray, outOfs);
-                    } else {
-                        n = engineDoFinal(inArray, 0, chunk, outArray, outOfs);
-                    }
-                    total += n;
-                    outOfs += n;
-                    inLen -= chunk;
-                } while (inLen > 0);
-            }
-            output.position(outPos + total);
-        } else { // output does not have an accessible byte[]
-            if (a1) { // but input has an accessible byte[]
-                inArray = input.array();
-                int inOfs = input.arrayOffset() + inPos;
+            } else { // output does not have an accessible byte[]
+                byte[] outArray = null;
                 if (isUpdate) {
                     outArray = engineUpdate(inArray, inOfs, inLen);
                 } else {
                     outArray = engineDoFinal(inArray, inOfs, inLen);
                 }
-                input.position(inLimit);
                 if (outArray != null && outArray.length != 0) {
                     output.put(outArray);
                     total = outArray.length;
                 }
-            } else { // input also does not have an accessible byte[]
-                inArray = new byte[getTempArraySize(inLen)];
-                do {
-                    int chunk = Math.min(inLen, inArray.length);
-                    if (chunk > 0) {
-                        input.get(inArray, 0, chunk);
-                    }
-                    int n;
-                    if (isUpdate || (inLen > chunk)) {
-                        outArray = engineUpdate(inArray, 0, chunk);
-                    } else {
-                        outArray = engineDoFinal(inArray, 0, chunk);
-                    }
-                    if (outArray != null && outArray.length != 0) {
-                        output.put(outArray);
-                        total += outArray.length;
-                    }
-                    inLen -= chunk;
-                } while (inLen > 0);
+                // adjust input position manually
+                input.position(inLimit);
+            }
+        } else { // input does not have an accessible byte[]
+            // have to assume the worst, since we have no way of determine
+            // if input and output overlaps or not
+            byte[] tempOut = new byte[outLenNeeded];
+            int outOfs = 0;
+
+            byte[] tempIn = new byte[getTempArraySize(inLen)];
+            do {
+                int chunk = Math.min(inLen, tempIn.length);
+                if (chunk > 0) {
+                    input.get(tempIn, 0, chunk);
+                }
+                int n;
+                if (isUpdate || (inLen > chunk)) {
+                    n = engineUpdate(tempIn, 0, chunk, tempOut, outOfs);
+                } else {
+                    n = engineDoFinal(tempIn, 0, chunk, tempOut, outOfs);
+                }
+                outOfs += n;
+                total += n;
+                inLen -= chunk;
+            } while (inLen > 0);
+            if (total > 0) {
+                output.put(tempOut, 0, total);
             }
         }
+
         return total;
     }