8210583: Base64.Encoder incorrectly throws NegativeArraySizeException
authornishjain
Thu, 24 Jan 2019 12:45:19 +0530
changeset 53462 091ed8f2e7d7
parent 53461 08d6edeb3145
child 53463 b2d1c3b0bd31
8210583: Base64.Encoder incorrectly throws NegativeArraySizeException Reviewed-by: rriggs, naoto, darcy, alanb
src/java.base/share/classes/java/util/Base64.java
test/jdk/java/util/Base64/TestEncodingDecodingLength.java
--- a/src/java.base/share/classes/java/util/Base64.java	Wed Jan 23 21:17:51 2019 -0500
+++ b/src/java.base/share/classes/java/util/Base64.java	Thu Jan 24 12:45:19 2019 +0530
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2012, 2018, Oracle and/or its affiliates. All rights reserved.
+ * Copyright (c) 2012, 2019, Oracle and/or its affiliates. All rights reserved.
  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
  *
  * This code is free software; you can redistribute it and/or modify it
@@ -186,6 +186,10 @@
      * a method of this class will cause a
      * {@link java.lang.NullPointerException NullPointerException} to
      * be thrown.
+     * <p> If the encoded byte output of the needed size can not
+     *     be allocated, the encode methods of this class will
+     *     cause an {@link java.lang.OutOfMemoryError OutOfMemoryError}
+     *     to be thrown.
      *
      * @see     Decoder
      * @since   1.8
@@ -237,16 +241,37 @@
         static final Encoder RFC4648_URLSAFE = new Encoder(true, null, -1, true);
         static final Encoder RFC2045 = new Encoder(false, CRLF, MIMELINEMAX, true);
 
-        private final int outLength(int srclen) {
+        /**
+         * Calculates the length of the encoded output bytes.
+         *
+         * @param srclen length of the bytes to encode
+         * @param throwOOME if true, throws OutOfMemoryError if the length of
+         *                  the encoded bytes overflows; else returns the
+         *                  length
+         * @return length of the encoded bytes, or -1 if the length overflows
+         *
+         */
+        private final int outLength(int srclen, boolean throwOOME) {
             int len = 0;
-            if (doPadding) {
-                len = 4 * ((srclen + 2) / 3);
-            } else {
-                int n = srclen % 3;
-                len = 4 * (srclen / 3) + (n == 0 ? 0 : n + 1);
+            try {
+                if (doPadding) {
+                    len = Math.multiplyExact(4, (Math.addExact(srclen, 2) / 3));
+                } else {
+                    int n = srclen % 3;
+                    len = Math.addExact(Math.multiplyExact(4, (srclen / 3)), (n == 0 ? 0 : n + 1));
+                }
+                if (linemax > 0) {                             // line separators
+                    len = Math.addExact(len, (len - 1) / linemax * newline.length);
+                }
+            } catch (ArithmeticException ex) {
+                if (throwOOME) {
+                    throw new OutOfMemoryError("Encoded size is too large");
+                } else {
+                    // let the caller know that encoded bytes length
+                    // is too large
+                    len = -1;
+                }
             }
-            if (linemax > 0)                                  // line separators
-                len += (len - 1) / linemax * newline.length;
             return len;
         }
 
@@ -261,7 +286,7 @@
          *          encoded bytes.
          */
         public byte[] encode(byte[] src) {
-            int len = outLength(src.length);          // dst array size
+            int len = outLength(src.length, true);          // dst array size
             byte[] dst = new byte[len];
             int ret = encode0(src, 0, src.length, dst);
             if (ret != dst.length)
@@ -289,8 +314,8 @@
          *          space for encoding all input bytes.
          */
         public int encode(byte[] src, byte[] dst) {
-            int len = outLength(src.length);         // dst array size
-            if (dst.length < len)
+            int len = outLength(src.length, false);         // dst array size
+            if (dst.length < len || len == -1)
                 throw new IllegalArgumentException(
                     "Output byte array is too small for encoding all input bytes");
             return encode0(src, 0, src.length, dst);
@@ -334,7 +359,7 @@
          * @return  A newly-allocated byte buffer containing the encoded bytes.
          */
         public ByteBuffer encode(ByteBuffer buffer) {
-            int len = outLength(buffer.remaining());
+            int len = outLength(buffer.remaining(), true);
             byte[] dst = new byte[len];
             int ret = 0;
             if (buffer.hasArray()) {
@@ -469,6 +494,10 @@
      * a method of this class will cause a
      * {@link java.lang.NullPointerException NullPointerException} to
      * be thrown.
+     * <p> If the decoded byte output of the needed size can not
+     *     be allocated, the decode methods of this class will
+     *     cause an {@link java.lang.OutOfMemoryError OutOfMemoryError}
+     *     to be thrown.
      *
      * @see     Encoder
      * @since   1.8
@@ -531,7 +560,7 @@
          *          if {@code src} is not in valid Base64 scheme
          */
         public byte[] decode(byte[] src) {
-            byte[] dst = new byte[outLength(src, 0, src.length)];
+            byte[] dst = new byte[outLength(src, 0, src.length, true)];
             int ret = decode0(src, 0, src.length, dst);
             if (ret != dst.length) {
                 dst = Arrays.copyOf(dst, ret);
@@ -584,8 +613,8 @@
          *          does not have enough space for decoding all input bytes.
          */
         public int decode(byte[] src, byte[] dst) {
-            int len = outLength(src, 0, src.length);
-            if (dst.length < len)
+            int len = outLength(src, 0, src.length, false);
+            if (dst.length < len || len == -1)
                 throw new IllegalArgumentException(
                     "Output byte array is too small for decoding all input bytes");
             return decode0(src, 0, src.length, dst);
@@ -610,7 +639,7 @@
          * @return  A newly-allocated byte buffer containing the decoded bytes
          *
          * @throws  IllegalArgumentException
-         *          if {@code src} is not in valid Base64 scheme.
+         *          if {@code buffer} is not in valid Base64 scheme
          */
         public ByteBuffer decode(ByteBuffer buffer) {
             int pos0 = buffer.position();
@@ -628,7 +657,7 @@
                     sp = 0;
                     sl = src.length;
                 }
-                byte[] dst = new byte[outLength(src, sp, sl)];
+                byte[] dst = new byte[outLength(src, sp, sl, true)];
                 return ByteBuffer.wrap(dst, 0, decode0(src, sp, sl, dst));
             } catch (IllegalArgumentException iae) {
                 buffer.position(pos0);
@@ -656,7 +685,19 @@
             return new DecInputStream(is, isURL ? fromBase64URL : fromBase64, isMIME);
         }
 
-        private int outLength(byte[] src, int sp, int sl) {
+        /**
+         * Calculates the length of the decoded output bytes.
+         *
+         * @param src the byte array to decode
+         * @param sp the source  position
+         * @param sl the source limit
+         * @param throwOOME if true, throws OutOfMemoryError if the length of
+         *                  the decoded bytes overflows; else returns the
+         *                  length
+         * @return length of the decoded bytes, or -1 if the length overflows
+         *
+         */
+        private int outLength(byte[] src, int sp, int sl, boolean throwOOME) {
             int[] base64 = isURL ? fromBase64URL : fromBase64;
             int paddings = 0;
             int len = sl - sp;
@@ -691,7 +732,19 @@
             }
             if (paddings == 0 && (len & 0x3) !=  0)
                 paddings = 4 - (len & 0x3);
-            return 3 * ((len + 3) / 4) - paddings;
+
+            try {
+                len = Math.multiplyExact(3, (Math.addExact(len, 3) / 4)) - paddings;
+            } catch (ArithmeticException ex) {
+                if (throwOOME) {
+                    throw new OutOfMemoryError("Decoded size is too large");
+                } else {
+                    // let the caller know that the decoded bytes length
+                    // is too large
+                    len = -1;
+                }
+            }
+            return len;
         }
 
         private int decode0(byte[] src, int sp, int sl, byte[] dst) {
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/test/jdk/java/util/Base64/TestEncodingDecodingLength.java	Thu Jan 24 12:45:19 2019 +0530
@@ -0,0 +1,72 @@
+/*
+ * Copyright (c) 2019, Oracle and/or its affiliates. All rights reserved.
+ * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
+ *
+ * This code is free software; you can redistribute it and/or modify it
+ * under the terms of the GNU General Public License version 2 only, as
+ * published by the Free Software Foundation.
+ *
+ * This code is distributed in the hope that it will be useful, but WITHOUT
+ * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
+ * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
+ * version 2 for more details (a copy is included in the LICENSE file that
+ * accompanied this code).
+ *
+ * You should have received a copy of the GNU General Public License version
+ * 2 along with this work; if not, write to the Free Software Foundation,
+ * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
+ *
+ * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
+ * or visit www.oracle.com if you need additional information or have any
+ * questions.
+ */
+
+import java.nio.ByteBuffer;
+import java.util.Base64;
+
+/**
+ * @test
+ * @bug 8210583
+ * @summary Tests Base64.Encoder.encode/Decoder.decode for the large size
+ *          of resulting bytes which can not be allocated
+ * @requires os.maxMemory >= 6g
+ * @run main/othervm -Xms4g -Xmx6g TestEncodingDecodingLength
+ *
+ */
+
+public class TestEncodingDecodingLength {
+
+    public static void main(String[] args) {
+        int size = Integer.MAX_VALUE - 2;
+        byte[] inputBytes = new byte[size];
+        byte[] outputBytes = new byte[size];
+
+        // Check encoder with large array length
+        Base64.Encoder encoder = Base64.getEncoder();
+        checkOOM("encode(byte[])", () -> encoder.encode(inputBytes));
+        checkIAE("encode(byte[] byte[])", () -> encoder.encode(inputBytes, outputBytes));
+        checkOOM("encodeToString(byte[])", () -> encoder.encodeToString(inputBytes));
+        checkOOM("encode(ByteBuffer)", () -> encoder.encode(ByteBuffer.allocate(size)));
+
+        // Check decoder with large array length
+        Base64.Decoder decoder = Base64.getDecoder();
+        checkOOM("decode(byte[])", () -> decoder.decode(inputBytes));
+        checkIAE("decode(byte[], byte[])", () -> decoder.decode(inputBytes, outputBytes));
+        checkOOM("decode(ByteBuffer)", () -> decoder.decode(ByteBuffer.allocate(size)));
+    }
+
+    private static final void checkOOM(String methodName, Runnable r) {
+        try {
+            r.run();
+            throw new RuntimeException("OutOfMemoryError should have been thrown by: " + methodName);
+        } catch (OutOfMemoryError er) {}
+    }
+
+    private static final void checkIAE(String methodName, Runnable r) {
+        try {
+            r.run();
+            throw new RuntimeException("IllegalArgumentException should have been thrown by: " + methodName);
+        } catch (IllegalArgumentException iae) {}
+    }
+}
+