8026330: java.util.Base64 urlEncoder should omit padding
authorsherman
Mon, 11 Nov 2013 14:35:36 -0800
changeset 21647 d1825822d9a0
parent 21646 38f9ba0664de
child 21648 b485e4eb2fd9
8026330: java.util.Base64 urlEncoder should omit padding Summary: to add Encoder.withoutPadding() Reviewed-by: alanb
jdk/src/share/classes/java/util/Base64.java
jdk/test/java/util/Base64/TestBase64.java
--- a/jdk/src/share/classes/java/util/Base64.java	Fri Nov 08 18:16:12 2013 +0100
+++ b/jdk/src/share/classes/java/util/Base64.java	Mon Nov 11 14:35:36 2013 -0800
@@ -138,7 +138,7 @@
          if (lineLength <= 0) {
              return Encoder.RFC4648;
          }
-         return new Encoder(false, lineSeparator, lineLength >> 2 << 2);
+         return new Encoder(false, lineSeparator, lineLength >> 2 << 2, true);
     }
 
     /**
@@ -192,11 +192,13 @@
         private final byte[] newline;
         private final int linemax;
         private final boolean isURL;
+        private final boolean doPadding;
 
-        private Encoder(boolean isURL, byte[] newline, int linemax) {
+        private Encoder(boolean isURL, byte[] newline, int linemax, boolean doPadding) {
             this.isURL = isURL;
             this.newline = newline;
             this.linemax = linemax;
+            this.doPadding = doPadding;
         }
 
         /**
@@ -228,9 +230,22 @@
         private static final int MIMELINEMAX = 76;
         private static final byte[] CRLF = new byte[] {'\r', '\n'};
 
-        static final Encoder RFC4648 = new Encoder(false, null, -1);
-        static final Encoder RFC4648_URLSAFE = new Encoder(true, null, -1);
-        static final Encoder RFC2045 = new Encoder(false, CRLF, MIMELINEMAX);
+        static final Encoder RFC4648 = new Encoder(false, null, -1, true);
+        static final Encoder RFC4648_URLSAFE = new Encoder(true, null, -1, true);
+        static final Encoder RFC2045 = new Encoder(false, CRLF, MIMELINEMAX, true);
+
+        private final int outLength(int srclen) {
+            int len = 0;
+            if (doPadding) {
+                len = 4 * ((srclen + 2) / 3);
+            } else {
+                int n = srclen % 3;
+                len = 4 * (srclen / 3) + (n == 0 ? 0 : n + 1);
+            }
+            if (linemax > 0)                                  // line separators
+                len += (len - 1) / linemax * newline.length;
+            return len;
+        }
 
         /**
          * Encodes all bytes from the specified byte array into a newly-allocated
@@ -243,9 +258,7 @@
          *          encoded bytes.
          */
         public byte[] encode(byte[] src) {
-            int len = 4 * ((src.length + 2) / 3);    // dst array size
-            if (linemax > 0)                          // line separators
-                len += (len - 1) / linemax * newline.length;
+            int len = outLength(src.length);          // dst array size
             byte[] dst = new byte[len];
             int ret = encode0(src, 0, src.length, dst);
             if (ret != dst.length)
@@ -273,10 +286,7 @@
          *          space for encoding all input bytes.
          */
         public int encode(byte[] src, byte[] dst) {
-            int len = 4 * ((src.length + 2) / 3);    // dst array size
-            if (linemax > 0) {
-                len += (len - 1) / linemax * newline.length;
-            }
+            int len = outLength(src.length);         // dst array size
             if (dst.length < len)
                 throw new IllegalArgumentException(
                     "Output byte array is too small for encoding all input bytes");
@@ -321,9 +331,7 @@
          * @return  A newly-allocated byte buffer containing the encoded bytes.
          */
         public ByteBuffer encode(ByteBuffer buffer) {
-            int len = 4 * ((buffer.remaining() + 2) / 3);
-            if (linemax > 0)
-                len += (len - 1) / linemax * newline.length;
+            int len = outLength(buffer.remaining());
             byte[] dst = new byte[len];
             int ret = 0;
             if (buffer.hasArray()) {
@@ -415,7 +423,25 @@
         public OutputStream wrap(OutputStream os) {
             Objects.requireNonNull(os);
             return new EncOutputStream(os, isURL ? toBase64URL : toBase64,
-                                       newline, linemax);
+                                       newline, linemax, doPadding);
+        }
+
+        /**
+         * Returns an encoder instance that encodes equivalently to this one,
+         * but without adding any padding character at the end of the encoded
+         * byte data.
+         *
+         * <p> The encoding scheme of this encoder instance is unaffected by
+         * this invocation. The returned encoder instance should be used for
+         * non-padding encoding operation.
+         *
+         * @return an equivalent encoder that encodes without adding any
+         *         padding character at the end
+         */
+        public Encoder withoutPadding() {
+            if (!doPadding)
+                return this;
+            return new Encoder(isURL, newline, linemax, false);
         }
 
         private int encodeArray(ByteBuffer src, ByteBuffer dst, int bytesOut) {
@@ -476,13 +502,17 @@
                     da[dp++] = (byte)base64[b0 >> 2];
                     if (sp == sl) {
                         da[dp++] = (byte)base64[(b0 << 4) & 0x3f];
-                        da[dp++] = '=';
-                        da[dp++] = '=';
+                        if (doPadding) {
+                            da[dp++] = '=';
+                            da[dp++] = '=';
+                        }
                     } else {
                         int b1 = sa[sp++] & 0xff;
                         da[dp++] = (byte)base64[(b0 << 4) & 0x3f | (b1 >> 4)];
                         da[dp++] = (byte)base64[(b1 << 2) & 0x3f];
-                        da[dp++] = '=';
+                        if (doPadding) {
+                            da[dp++] = '=';
+                        }
                     }
                 }
                 return dp - dp00 + bytesOut;
@@ -548,13 +578,17 @@
                     dst.put(dp++, (byte)base64[b0 >> 2]);
                     if (sp == src.limit()) {
                         dst.put(dp++, (byte)base64[(b0 << 4) & 0x3f]);
-                        dst.put(dp++, (byte)'=');
-                        dst.put(dp++, (byte)'=');
+                        if (doPadding) {
+                            dst.put(dp++, (byte)'=');
+                            dst.put(dp++, (byte)'=');
+                        }
                     } else {
                         int b1 = src.get(sp++) & 0xff;
                         dst.put(dp++, (byte)base64[(b0 << 4) & 0x3f | (b1 >> 4)]);
                         dst.put(dp++, (byte)base64[(b1 << 2) & 0x3f]);
-                        dst.put(dp++, (byte)'=');
+                        if (doPadding) {
+                            dst.put(dp++, (byte)'=');
+                        }
                     }
                 }
                 return dp - dp00 + bytesOut;
@@ -597,13 +631,17 @@
                 dst[dp++] = (byte)base64[b0 >> 2];
                 if (sp == end) {
                     dst[dp++] = (byte)base64[(b0 << 4) & 0x3f];
-                    dst[dp++] = '=';
-                    dst[dp++] = '=';
+                    if (doPadding) {
+                        dst[dp++] = '=';
+                        dst[dp++] = '=';
+                    }
                 } else {
                     int b1 = src[sp++] & 0xff;
                     dst[dp++] = (byte)base64[(b0 << 4) & 0x3f | (b1 >> 4)];
                     dst[dp++] = (byte)base64[(b1 << 2) & 0x3f];
-                    dst[dp++] = '=';
+                    if (doPadding) {
+                        dst[dp++] = '=';
+                    }
                 }
             }
             return dp;
@@ -1149,14 +1187,16 @@
         private final char[] base64;    // byte->base64 mapping
         private final byte[] newline;   // line separator, if needed
         private final int linemax;
+        private final boolean doPadding;// whether or not to pad
         private int linepos = 0;
 
-        EncOutputStream(OutputStream os,
-                        char[] base64, byte[] newline, int linemax) {
+        EncOutputStream(OutputStream os, char[] base64,
+                        byte[] newline, int linemax, boolean doPadding) {
             super(os);
             this.base64 = base64;
             this.newline = newline;
             this.linemax = linemax;
+            this.doPadding = doPadding;
         }
 
         @Override
@@ -1228,14 +1268,18 @@
                     checkNewline();
                     out.write(base64[b0 >> 2]);
                     out.write(base64[(b0 << 4) & 0x3f]);
-                    out.write('=');
-                    out.write('=');
+                    if (doPadding) {
+                        out.write('=');
+                        out.write('=');
+                    }
                 } else if (leftover == 2) {
                     checkNewline();
                     out.write(base64[b0 >> 2]);
                     out.write(base64[(b0 << 4) & 0x3f | (b1 >> 4)]);
                     out.write(base64[(b1 << 2) & 0x3f]);
-                    out.write('=');
+                    if (doPadding) {
+                       out.write('=');
+                    }
                 }
                 leftover = 0;
                 out.close();
--- a/jdk/test/java/util/Base64/TestBase64.java	Fri Nov 08 18:16:12 2013 +0100
+++ b/jdk/test/java/util/Base64/TestBase64.java	Mon Nov 11 14:35:36 2013 -0800
@@ -23,7 +23,7 @@
 
 /**
  * @test 4235519 8004212 8005394 8007298 8006295 8006315 8006530 8007379 8008925
- *       8014217 8025003
+ *       8014217 8025003 8026330
  * @summary tests java.util.Base64
  */
 
@@ -47,12 +47,9 @@
             numBytes = Integer.parseInt(args[1]);
         }
 
-        test(Base64.getEncoder(),     Base64.getDecoder(),
-             numRuns, numBytes);
-        test(Base64.getUrlEncoder(),  Base64.getUrlDecoder(),
-             numRuns, numBytes);
-        test(Base64.getMimeEncoder(), Base64.getMimeDecoder(),
-             numRuns, numBytes);
+        test(Base64.getEncoder(), Base64.getDecoder(), numRuns, numBytes);
+        test(Base64.getUrlEncoder(), Base64.getUrlDecoder(), numRuns, numBytes);
+        test(Base64.getMimeEncoder(), Base64.getMimeDecoder(), numRuns, numBytes);
 
         Random rnd = new java.util.Random();
         byte[] nl_1 = new byte[] {'\n'};
@@ -142,165 +139,175 @@
         enc.encode(new byte[0]);
         dec.decode(new byte[0]);
 
-        for (int i=0; i<numRuns; i++) {
-            for (int j=1; j<numBytes; j++) {
-                byte[] orig = new byte[j];
-                rnd.nextBytes(orig);
+        for (boolean withoutPadding : new boolean[] { false, true}) {
+            if (withoutPadding) {
+                 enc = enc.withoutPadding();
+            }
+            for (int i=0; i<numRuns; i++) {
+                for (int j=1; j<numBytes; j++) {
+                    byte[] orig = new byte[j];
+                    rnd.nextBytes(orig);
 
-                // --------testing encode/decode(byte[])--------
-                byte[] encoded = enc.encode(orig);
-                byte[] decoded = dec.decode(encoded);
-
-                checkEqual(orig, decoded,
-                           "Base64 array encoding/decoding failed!");
+                    // --------testing encode/decode(byte[])--------
+                    byte[] encoded = enc.encode(orig);
+                    byte[] decoded = dec.decode(encoded);
 
-                // compare to sun.misc.BASE64Encoder
-                byte[] encoded2 = sunmisc.encode(orig).getBytes("ASCII");
-                checkEqual(normalize(encoded),
-                           normalize(encoded2),
-                           "Base64 enc.encode() does not match sun.misc.base64!");
+                    checkEqual(orig, decoded,
+                               "Base64 array encoding/decoding failed!");
+                    if (withoutPadding) {
+                        if (encoded[encoded.length - 1] == '=')
+                            throw new RuntimeException(
+                               "Base64 enc.encode().withoutPadding() has padding!");
+                    }
+                    // compare to sun.misc.BASE64Encoder
 
-                // remove padding '=' to test non-padding decoding case
-                if (encoded[encoded.length -2] == '=')
-                    encoded2 = Arrays.copyOf(encoded,  encoded.length -2);
-                else if (encoded[encoded.length -1] == '=')
-                    encoded2 = Arrays.copyOf(encoded, encoded.length -1);
-                else
-                    encoded2 = null;
-
-                // --------testing encodetoString(byte[])/decode(String)--------
-                String str = enc.encodeToString(orig);
-                if (!Arrays.equals(str.getBytes("ASCII"), encoded)) {
-                    throw new RuntimeException(
-                        "Base64 encodingToString() failed!");
-                }
-                byte[] buf = dec.decode(new String(encoded, "ASCII"));
-                checkEqual(buf, orig, "Base64 decoding(String) failed!");
+                    byte[] encoded2 = sunmisc.encode(orig).getBytes("ASCII");
+                    if (!withoutPadding) {    // don't test for withoutPadding()
+                        checkEqual(normalize(encoded), normalize(encoded2),
+                                   "Base64 enc.encode() does not match sun.misc.base64!");
+                    }
+                    // remove padding '=' to test non-padding decoding case
+                    if (encoded[encoded.length -2] == '=')
+                        encoded2 = Arrays.copyOf(encoded,  encoded.length -2);
+                    else if (encoded[encoded.length -1] == '=')
+                        encoded2 = Arrays.copyOf(encoded, encoded.length -1);
+                    else
+                        encoded2 = null;
 
-                if (encoded2 != null) {
-                    buf = dec.decode(new String(encoded2, "ASCII"));
+                    // --------testing encodetoString(byte[])/decode(String)--------
+                    String str = enc.encodeToString(orig);
+                    if (!Arrays.equals(str.getBytes("ASCII"), encoded)) {
+                        throw new RuntimeException(
+                            "Base64 encodingToString() failed!");
+                    }
+                    byte[] buf = dec.decode(new String(encoded, "ASCII"));
                     checkEqual(buf, orig, "Base64 decoding(String) failed!");
-                }
 
-                //-------- testing encode/decode(Buffer)--------
-                testEncode(enc, ByteBuffer.wrap(orig), encoded);
-                ByteBuffer bin = ByteBuffer.allocateDirect(orig.length);
-                bin.put(orig).flip();
-                testEncode(enc, bin, encoded);
+                    if (encoded2 != null) {
+                        buf = dec.decode(new String(encoded2, "ASCII"));
+                        checkEqual(buf, orig, "Base64 decoding(String) failed!");
+                    }
 
-                testDecode(dec, ByteBuffer.wrap(encoded), orig);
-                bin = ByteBuffer.allocateDirect(encoded.length);
-                bin.put(encoded).flip();
-                testDecode(dec, bin, orig);
+                    //-------- testing encode/decode(Buffer)--------
+                    testEncode(enc, ByteBuffer.wrap(orig), encoded);
+                    ByteBuffer bin = ByteBuffer.allocateDirect(orig.length);
+                    bin.put(orig).flip();
+                    testEncode(enc, bin, encoded);
 
-                if (encoded2 != null)
-                    testDecode(dec, ByteBuffer.wrap(encoded2), orig);
+                    testDecode(dec, ByteBuffer.wrap(encoded), orig);
+                    bin = ByteBuffer.allocateDirect(encoded.length);
+                    bin.put(encoded).flip();
+                    testDecode(dec, bin, orig);
 
-                // -------- testing encode(Buffer, Buffer)--------
-                testEncode(enc, encoded,
-                           ByteBuffer.wrap(orig),
-                           ByteBuffer.allocate(encoded.length + 10));
+                    if (encoded2 != null)
+                        testDecode(dec, ByteBuffer.wrap(encoded2), orig);
+
+                    // -------- testing encode(Buffer, Buffer)--------
+                    testEncode(enc, encoded,
+                               ByteBuffer.wrap(orig),
+                               ByteBuffer.allocate(encoded.length + 10));
 
-                testEncode(enc, encoded,
-                           ByteBuffer.wrap(orig),
-                           ByteBuffer.allocateDirect(encoded.length + 10));
+                    testEncode(enc, encoded,
+                               ByteBuffer.wrap(orig),
+                               ByteBuffer.allocateDirect(encoded.length + 10));
 
-                // --------testing decode(Buffer, Buffer);--------
-                testDecode(dec, orig,
-                           ByteBuffer.wrap(encoded),
-                           ByteBuffer.allocate(orig.length + 10));
+                    // --------testing decode(Buffer, Buffer);--------
+                    testDecode(dec, orig,
+                               ByteBuffer.wrap(encoded),
+                               ByteBuffer.allocate(orig.length + 10));
 
-                testDecode(dec, orig,
-                           ByteBuffer.wrap(encoded),
-                           ByteBuffer.allocateDirect(orig.length + 10));
+                    testDecode(dec, orig,
+                               ByteBuffer.wrap(encoded),
+                               ByteBuffer.allocateDirect(orig.length + 10));
 
-                // --------testing decode.wrap(input stream)--------
-                // 1) random buf length
-                ByteArrayInputStream bais = new ByteArrayInputStream(encoded);
-                InputStream is = dec.wrap(bais);
-                buf = new byte[orig.length + 10];
-                int len = orig.length;
-                int off = 0;
-                while (true) {
-                    int n = rnd.nextInt(len);
-                    if (n == 0)
-                        n = 1;
-                    n = is.read(buf, off, n);
-                    if (n == -1) {
-                        checkEqual(off, orig.length,
-                                   "Base64 stream decoding failed");
-                        break;
+                    // --------testing decode.wrap(input stream)--------
+                    // 1) random buf length
+                    ByteArrayInputStream bais = new ByteArrayInputStream(encoded);
+                    InputStream is = dec.wrap(bais);
+                    buf = new byte[orig.length + 10];
+                    int len = orig.length;
+                    int off = 0;
+                    while (true) {
+                        int n = rnd.nextInt(len);
+                        if (n == 0)
+                            n = 1;
+                        n = is.read(buf, off, n);
+                        if (n == -1) {
+                            checkEqual(off, orig.length,
+                                       "Base64 stream decoding failed");
+                            break;
+                        }
+                        off += n;
+                        len -= n;
+                        if (len == 0)
+                            break;
                     }
-                    off += n;
-                    len -= n;
-                    if (len == 0)
-                        break;
-                }
-                buf = Arrays.copyOf(buf, off);
-                checkEqual(buf, orig, "Base64 stream decoding failed!");
+                    buf = Arrays.copyOf(buf, off);
+                    checkEqual(buf, orig, "Base64 stream decoding failed!");
 
-                // 2) read one byte each
-                bais.reset();
-                is = dec.wrap(bais);
-                buf = new byte[orig.length + 10];
-                off = 0;
-                int b;
-                while ((b = is.read()) != -1) {
-                    buf[off++] = (byte)b;
-                }
-                buf = Arrays.copyOf(buf, off);
-                checkEqual(buf, orig, "Base64 stream decoding failed!");
+                    // 2) read one byte each
+                    bais.reset();
+                    is = dec.wrap(bais);
+                    buf = new byte[orig.length + 10];
+                    off = 0;
+                    int b;
+                    while ((b = is.read()) != -1) {
+                        buf[off++] = (byte)b;
+                    }
+                    buf = Arrays.copyOf(buf, off);
+                    checkEqual(buf, orig, "Base64 stream decoding failed!");
 
-                // --------testing encode.wrap(output stream)--------
-                ByteArrayOutputStream baos = new ByteArrayOutputStream((orig.length + 2) / 3 * 4 + 10);
-                OutputStream os = enc.wrap(baos);
-                off = 0;
-                len = orig.length;
-                for (int k = 0; k < 5; k++) {
-                    if (len == 0)
-                        break;
-                    int n = rnd.nextInt(len);
-                    if (n == 0)
-                        n = 1;
-                    os.write(orig, off, n);
-                    off += n;
-                    len -= n;
-                }
-                if (len != 0)
-                    os.write(orig, off, len);
-                os.close();
-                buf = baos.toByteArray();
-                checkEqual(buf, encoded, "Base64 stream encoding failed!");
+                    // --------testing encode.wrap(output stream)--------
+                    ByteArrayOutputStream baos = new ByteArrayOutputStream((orig.length + 2) / 3 * 4 + 10);
+                    OutputStream os = enc.wrap(baos);
+                    off = 0;
+                    len = orig.length;
+                    for (int k = 0; k < 5; k++) {
+                        if (len == 0)
+                            break;
+                        int n = rnd.nextInt(len);
+                        if (n == 0)
+                            n = 1;
+                        os.write(orig, off, n);
+                        off += n;
+                        len -= n;
+                    }
+                    if (len != 0)
+                        os.write(orig, off, len);
+                    os.close();
+                    buf = baos.toByteArray();
+                    checkEqual(buf, encoded, "Base64 stream encoding failed!");
 
-                // 2) write one byte each
-                baos.reset();
-                os = enc.wrap(baos);
-                off = 0;
-                while (off < orig.length) {
-                    os.write(orig[off++]);
-                }
-                os.close();
-                buf = baos.toByteArray();
-                checkEqual(buf, encoded, "Base64 stream encoding failed!");
+                    // 2) write one byte each
+                    baos.reset();
+                    os = enc.wrap(baos);
+                    off = 0;
+                    while (off < orig.length) {
+                        os.write(orig[off++]);
+                    }
+                    os.close();
+                    buf = baos.toByteArray();
+                    checkEqual(buf, encoded, "Base64 stream encoding failed!");
 
-                // --------testing encode(in, out); -> bigger buf--------
-                buf = new byte[encoded.length + rnd.nextInt(100)];
-                int ret = enc.encode(orig, buf);
-                checkEqual(ret, encoded.length,
-                           "Base64 enc.encode(src, null) returns wrong size!");
-                buf = Arrays.copyOf(buf, ret);
-                checkEqual(buf, encoded,
-                           "Base64 enc.encode(src, dst) failed!");
+                    // --------testing encode(in, out); -> bigger buf--------
+                    buf = new byte[encoded.length + rnd.nextInt(100)];
+                    int ret = enc.encode(orig, buf);
+                    checkEqual(ret, encoded.length,
+                               "Base64 enc.encode(src, null) returns wrong size!");
+                    buf = Arrays.copyOf(buf, ret);
+                    checkEqual(buf, encoded,
+                               "Base64 enc.encode(src, dst) failed!");
 
-                // --------testing decode(in, out); -> bigger buf--------
-                buf = new byte[orig.length + rnd.nextInt(100)];
-                ret = dec.decode(encoded, buf);
-                checkEqual(ret, orig.length,
-                          "Base64 enc.encode(src, null) returns wrong size!");
-                buf = Arrays.copyOf(buf, ret);
-                checkEqual(buf, orig,
-                           "Base64 dec.decode(src, dst) failed!");
+                    // --------testing decode(in, out); -> bigger buf--------
+                    buf = new byte[orig.length + rnd.nextInt(100)];
+                    ret = dec.decode(encoded, buf);
+                    checkEqual(ret, orig.length,
+                              "Base64 enc.encode(src, null) returns wrong size!");
+                    buf = Arrays.copyOf(buf, ret);
+                    checkEqual(buf, orig,
+                               "Base64 dec.decode(src, dst) failed!");
 
+                }
             }
         }
     }