8025534: Unsafe typecast in java.util.stream.Streams.Nodes
authorpsandoz
Wed, 02 Oct 2013 16:34:12 +0200
changeset 20507 8498104f92c3
parent 20506 d826dd5f8e10
child 20508 0c41e68de505
8025534: Unsafe typecast in java.util.stream.Streams.Nodes 8025538: Unsafe typecast in java.util.stream.SpinedBuffer 8025533: Unsafe typecast in java.util.stream.Streams.RangeIntSpliterator.splitPoint() 8025525: Unsafe typecast in java.util.stream.Node.OfPrimitive.asArray() Reviewed-by: chegar
jdk/src/share/classes/java/util/stream/Node.java
jdk/src/share/classes/java/util/stream/Nodes.java
jdk/src/share/classes/java/util/stream/SortedOps.java
jdk/src/share/classes/java/util/stream/SpinedBuffer.java
jdk/src/share/classes/java/util/stream/Streams.java
--- a/jdk/src/share/classes/java/util/stream/Node.java	Wed Oct 02 19:13:42 2013 -0400
+++ b/jdk/src/share/classes/java/util/stream/Node.java	Wed Oct 02 16:34:12 2013 +0200
@@ -149,7 +149,9 @@
     /**
      * Copies the content of this {@code Node} into an array, starting at a
      * given offset into the array.  It is the caller's responsibility to ensure
-     * there is sufficient room in the array.
+     * there is sufficient room in the array, otherwise unspecified behaviour
+     * will occur if the array length is less than the number of elements
+     * contained in this node.
      *
      * @param array the array into which to copy the contents of this
      *       {@code Node}
@@ -258,6 +260,12 @@
          */
         @Override
         default T[] asArray(IntFunction<T[]> generator) {
+            if (java.util.stream.Tripwire.ENABLED)
+                java.util.stream.Tripwire.trip(getClass(), "{0} calling Node.OfPrimitive.asArray");
+
+            long size = count();
+            if (size >= Nodes.MAX_ARRAY_SIZE)
+                throw new IllegalArgumentException(Nodes.BAD_SIZE);
             T[] boxed = generator.apply((int) count());
             copyInto(boxed, 0);
             return boxed;
--- a/jdk/src/share/classes/java/util/stream/Nodes.java	Wed Oct 02 19:13:42 2013 -0400
+++ b/jdk/src/share/classes/java/util/stream/Nodes.java	Wed Oct 02 16:34:12 2013 +0200
@@ -60,6 +60,9 @@
      */
     static final long MAX_ARRAY_SIZE = Integer.MAX_VALUE - 8;
 
+    // IllegalArgumentException messages
+    static final String BAD_SIZE = "Stream size exceeds max array size";
+
     @SuppressWarnings("raw")
     private static final Node EMPTY_NODE = new EmptyNode.OfRef();
     private static final Node.OfInt EMPTY_INT_NODE = new EmptyNode.OfInt();
@@ -317,7 +320,7 @@
         long size = helper.exactOutputSizeIfKnown(spliterator);
         if (size >= 0 && spliterator.hasCharacteristics(Spliterator.SUBSIZED)) {
             if (size >= MAX_ARRAY_SIZE)
-                throw new IllegalArgumentException("Stream size exceeds max array size");
+                throw new IllegalArgumentException(BAD_SIZE);
             P_OUT[] array = generator.apply((int) size);
             new SizedCollectorTask.OfRef<>(spliterator, helper, array).invoke();
             return node(array);
@@ -354,7 +357,7 @@
         long size = helper.exactOutputSizeIfKnown(spliterator);
         if (size >= 0 && spliterator.hasCharacteristics(Spliterator.SUBSIZED)) {
             if (size >= MAX_ARRAY_SIZE)
-                throw new IllegalArgumentException("Stream size exceeds max array size");
+                throw new IllegalArgumentException(BAD_SIZE);
             int[] array = new int[(int) size];
             new SizedCollectorTask.OfInt<>(spliterator, helper, array).invoke();
             return node(array);
@@ -392,7 +395,7 @@
         long size = helper.exactOutputSizeIfKnown(spliterator);
         if (size >= 0 && spliterator.hasCharacteristics(Spliterator.SUBSIZED)) {
             if (size >= MAX_ARRAY_SIZE)
-                throw new IllegalArgumentException("Stream size exceeds max array size");
+                throw new IllegalArgumentException(BAD_SIZE);
             long[] array = new long[(int) size];
             new SizedCollectorTask.OfLong<>(spliterator, helper, array).invoke();
             return node(array);
@@ -430,7 +433,7 @@
         long size = helper.exactOutputSizeIfKnown(spliterator);
         if (size >= 0 && spliterator.hasCharacteristics(Spliterator.SUBSIZED)) {
             if (size >= MAX_ARRAY_SIZE)
-                throw new IllegalArgumentException("Stream size exceeds max array size");
+                throw new IllegalArgumentException(BAD_SIZE);
             double[] array = new double[(int) size];
             new SizedCollectorTask.OfDouble<>(spliterator, helper, array).invoke();
             return node(array);
@@ -460,7 +463,10 @@
      */
     public static <T> Node<T> flatten(Node<T> node, IntFunction<T[]> generator) {
         if (node.getChildCount() > 0) {
-            T[] array = generator.apply((int) node.count());
+            long size = node.count();
+            if (size >= MAX_ARRAY_SIZE)
+                throw new IllegalArgumentException(BAD_SIZE);
+            T[] array = generator.apply((int) size);
             new ToArrayTask.OfRef<>(node, array, 0).invoke();
             return node(array);
         } else {
@@ -483,7 +489,10 @@
      */
     public static Node.OfInt flattenInt(Node.OfInt node) {
         if (node.getChildCount() > 0) {
-            int[] array = new int[(int) node.count()];
+            long size = node.count();
+            if (size >= MAX_ARRAY_SIZE)
+                throw new IllegalArgumentException(BAD_SIZE);
+            int[] array = new int[(int) size];
             new ToArrayTask.OfInt(node, array, 0).invoke();
             return node(array);
         } else {
@@ -506,7 +515,10 @@
      */
     public static Node.OfLong flattenLong(Node.OfLong node) {
         if (node.getChildCount() > 0) {
-            long[] array = new long[(int) node.count()];
+            long size = node.count();
+            if (size >= MAX_ARRAY_SIZE)
+                throw new IllegalArgumentException(BAD_SIZE);
+            long[] array = new long[(int) size];
             new ToArrayTask.OfLong(node, array, 0).invoke();
             return node(array);
         } else {
@@ -529,7 +541,10 @@
      */
     public static Node.OfDouble flattenDouble(Node.OfDouble node) {
         if (node.getChildCount() > 0) {
-            double[] array = new double[(int) node.count()];
+            long size = node.count();
+            if (size >= MAX_ARRAY_SIZE)
+                throw new IllegalArgumentException(BAD_SIZE);
+            double[] array = new double[(int) size];
             new ToArrayTask.OfDouble(node, array, 0).invoke();
             return node(array);
         } else {
@@ -627,7 +642,7 @@
         @SuppressWarnings("unchecked")
         ArrayNode(long size, IntFunction<T[]> generator) {
             if (size >= MAX_ARRAY_SIZE)
-                throw new IllegalArgumentException("Stream size exceeds max array size");
+                throw new IllegalArgumentException(BAD_SIZE);
             this.array = generator.apply((int) size);
             this.curSize = 0;
         }
@@ -777,12 +792,17 @@
         public void copyInto(T[] array, int offset) {
             Objects.requireNonNull(array);
             left.copyInto(array, offset);
+            // Cast to int is safe since it is the callers responsibility to
+            // ensure that there is sufficient room in the array
             right.copyInto(array, offset + (int) left.count());
         }
 
         @Override
         public T[] asArray(IntFunction<T[]> generator) {
-            T[] array = generator.apply((int) count());
+            long size = count();
+            if (size >= MAX_ARRAY_SIZE)
+                throw new IllegalArgumentException(BAD_SIZE);
+            T[] array = generator.apply((int) size);
             copyInto(array, 0);
             return array;
         }
@@ -836,12 +856,17 @@
             @Override
             public void copyInto(T_ARR array, int offset) {
                 left.copyInto(array, offset);
+                // Cast to int is safe since it is the callers responsibility to
+                // ensure that there is sufficient room in the array
                 right.copyInto(array, offset + (int) left.count());
             }
 
             @Override
             public T_ARR asPrimitiveArray() {
-                T_ARR array = newArray((int) count());
+                long size = count();
+                if (size >= MAX_ARRAY_SIZE)
+                    throw new IllegalArgumentException(BAD_SIZE);
+                T_ARR array = newArray((int) size);
                 copyInto(array, 0);
                 return array;
             }
@@ -1287,7 +1312,7 @@
 
         IntArrayNode(long size) {
             if (size >= MAX_ARRAY_SIZE)
-                throw new IllegalArgumentException("Stream size exceeds max array size");
+                throw new IllegalArgumentException(BAD_SIZE);
             this.array = new int[(int) size];
             this.curSize = 0;
         }
@@ -1343,7 +1368,7 @@
 
         LongArrayNode(long size) {
             if (size >= MAX_ARRAY_SIZE)
-                throw new IllegalArgumentException("Stream size exceeds max array size");
+                throw new IllegalArgumentException(BAD_SIZE);
             this.array = new long[(int) size];
             this.curSize = 0;
         }
@@ -1397,7 +1422,7 @@
 
         DoubleArrayNode(long size) {
             if (size >= MAX_ARRAY_SIZE)
-                throw new IllegalArgumentException("Stream size exceeds max array size");
+                throw new IllegalArgumentException(BAD_SIZE);
             this.array = new double[(int) size];
             this.curSize = 0;
         }
@@ -1843,8 +1868,8 @@
                 task = task.makeChild(rightSplit, task.offset + leftSplitSize,
                                       task.length - leftSplitSize);
             }
-            if (task.offset + task.length >= MAX_ARRAY_SIZE)
-                throw new IllegalArgumentException("Stream size exceeds max array size");
+
+            assert task.offset + task.length < MAX_ARRAY_SIZE;
             T_SINK sink = (T_SINK) task;
             task.helper.wrapAndCopyInto(sink, rightSplit);
             task.propagateCompletion();
@@ -1854,10 +1879,13 @@
 
         @Override
         public void begin(long size) {
-            if(size > length)
+            if (size > length)
                 throw new IllegalStateException("size passed to Sink.begin exceeds array length");
+            // Casts to int are safe since absolute size is verified to be within
+            // bounds when the root concrete SizedCollectorTask is constructed
+            // with the shared array
             index = (int) offset;
-            fence = (int) offset + (int) length;
+            fence = index + (int) length;
         }
 
         @SuppressWarnings("serial")
--- a/jdk/src/share/classes/java/util/stream/SortedOps.java	Wed Oct 02 19:13:42 2013 -0400
+++ b/jdk/src/share/classes/java/util/stream/SortedOps.java	Wed Oct 02 16:34:12 2013 +0200
@@ -277,8 +277,6 @@
         }
     }
 
-    private static final String BAD_SIZE = "Stream size exceeds max array size";
-
     /**
      * {@link Sink} for implementing sort on SIZED reference streams.
      */
@@ -295,7 +293,7 @@
         @Override
         public void begin(long size) {
             if (size >= Nodes.MAX_ARRAY_SIZE)
-                throw new IllegalArgumentException(BAD_SIZE);
+                throw new IllegalArgumentException(Nodes.BAD_SIZE);
             array = (T[]) new Object[(int) size];
         }
 
@@ -330,7 +328,7 @@
         @Override
         public void begin(long size) {
             if (size >= Nodes.MAX_ARRAY_SIZE)
-                throw new IllegalArgumentException(BAD_SIZE);
+                throw new IllegalArgumentException(Nodes.BAD_SIZE);
             list = (size >= 0) ? new ArrayList<T>((int) size) : new ArrayList<T>();
         }
 
@@ -363,7 +361,7 @@
         @Override
         public void begin(long size) {
             if (size >= Nodes.MAX_ARRAY_SIZE)
-                throw new IllegalArgumentException(BAD_SIZE);
+                throw new IllegalArgumentException(Nodes.BAD_SIZE);
             array = new int[(int) size];
         }
 
@@ -396,7 +394,7 @@
         @Override
         public void begin(long size) {
             if (size >= Nodes.MAX_ARRAY_SIZE)
-                throw new IllegalArgumentException(BAD_SIZE);
+                throw new IllegalArgumentException(Nodes.BAD_SIZE);
             b = (size > 0) ? new SpinedBuffer.OfInt((int) size) : new SpinedBuffer.OfInt();
         }
 
@@ -430,7 +428,7 @@
         @Override
         public void begin(long size) {
             if (size >= Nodes.MAX_ARRAY_SIZE)
-                throw new IllegalArgumentException(BAD_SIZE);
+                throw new IllegalArgumentException(Nodes.BAD_SIZE);
             array = new long[(int) size];
         }
 
@@ -463,7 +461,7 @@
         @Override
         public void begin(long size) {
             if (size >= Nodes.MAX_ARRAY_SIZE)
-                throw new IllegalArgumentException(BAD_SIZE);
+                throw new IllegalArgumentException(Nodes.BAD_SIZE);
             b = (size > 0) ? new SpinedBuffer.OfLong((int) size) : new SpinedBuffer.OfLong();
         }
 
@@ -497,7 +495,7 @@
         @Override
         public void begin(long size) {
             if (size >= Nodes.MAX_ARRAY_SIZE)
-                throw new IllegalArgumentException(BAD_SIZE);
+                throw new IllegalArgumentException(Nodes.BAD_SIZE);
             array = new double[(int) size];
         }
 
@@ -530,7 +528,7 @@
         @Override
         public void begin(long size) {
             if (size >= Nodes.MAX_ARRAY_SIZE)
-                throw new IllegalArgumentException(BAD_SIZE);
+                throw new IllegalArgumentException(Nodes.BAD_SIZE);
             b = (size > 0) ? new SpinedBuffer.OfDouble((int) size) : new SpinedBuffer.OfDouble();
         }
 
--- a/jdk/src/share/classes/java/util/stream/SpinedBuffer.java	Wed Oct 02 19:13:42 2013 -0400
+++ b/jdk/src/share/classes/java/util/stream/SpinedBuffer.java	Wed Oct 02 16:34:12 2013 +0200
@@ -156,6 +156,9 @@
     public E get(long index) {
         // @@@ can further optimize by caching last seen spineIndex,
         // which is going to be right most of the time
+
+        // Casts to int are safe since the spine array index is the index minus
+        // the prior element count from the current spine
         if (spineIndex == 0) {
             if (index < elementIndex)
                 return curChunk[((int) index)];
@@ -201,11 +204,11 @@
      * elements into it.
      */
     public E[] asArray(IntFunction<E[]> arrayFactory) {
-        // @@@ will fail for size == MAX_VALUE
-        E[] result = arrayFactory.apply((int) count());
-
+        long size = count();
+        if (size >= Nodes.MAX_ARRAY_SIZE)
+            throw new IllegalArgumentException(Nodes.BAD_SIZE);
+        E[] result = arrayFactory.apply((int) size);
         copyInto(result, 0);
-
         return result;
     }
 
@@ -547,8 +550,10 @@
         }
 
         public T_ARR asPrimitiveArray() {
-            // @@@ will fail for size == MAX_VALUE
-            T_ARR result = newArray((int) count());
+            long size = count();
+            if (size >= Nodes.MAX_ARRAY_SIZE)
+                throw new IllegalArgumentException(Nodes.BAD_SIZE);
+            T_ARR result = newArray((int) size);
             copyInto(result, 0);
             return result;
         }
@@ -760,11 +765,13 @@
         }
 
         public int get(long index) {
+            // Casts to int are safe since the spine array index is the index minus
+            // the prior element count from the current spine
             int ch = chunkFor(index);
             if (spineIndex == 0 && ch == 0)
                 return curChunk[(int) index];
             else
-                return spine[ch][(int) (index-priorElementCount[ch])];
+                return spine[ch][(int) (index - priorElementCount[ch])];
         }
 
         @Override
@@ -871,11 +878,13 @@
         }
 
         public long get(long index) {
+            // Casts to int are safe since the spine array index is the index minus
+            // the prior element count from the current spine
             int ch = chunkFor(index);
             if (spineIndex == 0 && ch == 0)
                 return curChunk[(int) index];
             else
-                return spine[ch][(int) (index-priorElementCount[ch])];
+                return spine[ch][(int) (index - priorElementCount[ch])];
         }
 
         @Override
@@ -984,11 +993,13 @@
         }
 
         public double get(long index) {
+            // Casts to int are safe since the spine array index is the index minus
+            // the prior element count from the current spine
             int ch = chunkFor(index);
             if (spineIndex == 0 && ch == 0)
                 return curChunk[(int) index];
             else
-                return spine[ch][(int) (index-priorElementCount[ch])];
+                return spine[ch][(int) (index - priorElementCount[ch])];
         }
 
         @Override
--- a/jdk/src/share/classes/java/util/stream/Streams.java	Wed Oct 02 19:13:42 2013 -0400
+++ b/jdk/src/share/classes/java/util/stream/Streams.java	Wed Oct 02 16:34:12 2013 +0200
@@ -169,7 +169,9 @@
 
         private int splitPoint(long size) {
             int d = (size < BALANCED_SPLIT_THRESHOLD) ? 2 : RIGHT_BALANCED_SPLIT_RATIO;
-            // 2 <= size <= 2^32
+            // Cast to int is safe since:
+            //   2 <= size < 2^32
+            //   2 <= d <= 8
             return (int) (size / d);
         }
     }