8025535: Unsafe typecast in java.util.stream.SortedOps
authorpsandoz
Tue, 01 Oct 2013 18:20:03 +0200
changeset 20503 074dd13d9cdf
parent 20502 33bb53f4ec14
child 20504 3fdfa9294734
8025535: Unsafe typecast in java.util.stream.SortedOps Reviewed-by: mduigou, chegar
jdk/src/share/classes/java/util/stream/SortedOps.java
jdk/test/java/util/stream/test/org/openjdk/tests/java/util/stream/SortedOpTest.java
--- a/jdk/src/share/classes/java/util/stream/SortedOps.java	Wed Oct 02 17:57:04 2013 +0900
+++ b/jdk/src/share/classes/java/util/stream/SortedOps.java	Tue Oct 01 18:20:03 2013 +0200
@@ -277,8 +277,10 @@
         }
     }
 
+    private static final String BAD_SIZE = "Stream size exceeds max array size";
+
     /**
-     * {@link ForkJoinTask} for implementing sort on SIZED reference streams.
+     * {@link Sink} for implementing sort on SIZED reference streams.
      */
     private static final class SizedRefSortingSink<T> extends Sink.ChainedReference<T, T> {
         private final Comparator<? super T> comparator;
@@ -293,16 +295,12 @@
         @Override
         public void begin(long size) {
             if (size >= Nodes.MAX_ARRAY_SIZE)
-                throw new IllegalArgumentException("Stream size exceeds max array size");
+                throw new IllegalArgumentException(BAD_SIZE);
             array = (T[]) new Object[(int) size];
         }
 
         @Override
         public void end() {
-            // Need to use offset rather than array.length since the downstream
-            // many be short-circuiting
-            // @@@ A better approach is to know if the downstream short-circuits
-            //     and check sink.cancellationRequested
             Arrays.sort(array, 0, offset, comparator);
             downstream.begin(offset);
             for (int i = 0; i < offset; i++)
@@ -331,6 +329,8 @@
 
         @Override
         public void begin(long size) {
+            if (size >= Nodes.MAX_ARRAY_SIZE)
+                throw new IllegalArgumentException(BAD_SIZE);
             list = (size >= 0) ? new ArrayList<T>((int) size) : new ArrayList<T>();
         }
 
@@ -363,7 +363,7 @@
         @Override
         public void begin(long size) {
             if (size >= Nodes.MAX_ARRAY_SIZE)
-                throw new IllegalArgumentException("Stream size exceeds max array size");
+                throw new IllegalArgumentException(BAD_SIZE);
             array = new int[(int) size];
         }
 
@@ -395,6 +395,8 @@
 
         @Override
         public void begin(long size) {
+            if (size >= Nodes.MAX_ARRAY_SIZE)
+                throw new IllegalArgumentException(BAD_SIZE);
             b = (size > 0) ? new SpinedBuffer.OfInt((int) size) : new SpinedBuffer.OfInt();
         }
 
@@ -428,7 +430,7 @@
         @Override
         public void begin(long size) {
             if (size >= Nodes.MAX_ARRAY_SIZE)
-                throw new IllegalArgumentException("Stream size exceeds max array size");
+                throw new IllegalArgumentException(BAD_SIZE);
             array = new long[(int) size];
         }
 
@@ -460,6 +462,8 @@
 
         @Override
         public void begin(long size) {
+            if (size >= Nodes.MAX_ARRAY_SIZE)
+                throw new IllegalArgumentException(BAD_SIZE);
             b = (size > 0) ? new SpinedBuffer.OfLong((int) size) : new SpinedBuffer.OfLong();
         }
 
@@ -493,7 +497,7 @@
         @Override
         public void begin(long size) {
             if (size >= Nodes.MAX_ARRAY_SIZE)
-                throw new IllegalArgumentException("Stream size exceeds max array size");
+                throw new IllegalArgumentException(BAD_SIZE);
             array = new double[(int) size];
         }
 
@@ -525,6 +529,8 @@
 
         @Override
         public void begin(long size) {
+            if (size >= Nodes.MAX_ARRAY_SIZE)
+                throw new IllegalArgumentException(BAD_SIZE);
             b = (size > 0) ? new SpinedBuffer.OfDouble((int) size) : new SpinedBuffer.OfDouble();
         }
 
--- a/jdk/test/java/util/stream/test/org/openjdk/tests/java/util/stream/SortedOpTest.java	Wed Oct 02 17:57:04 2013 +0900
+++ b/jdk/test/java/util/stream/test/org/openjdk/tests/java/util/stream/SortedOpTest.java	Tue Oct 01 18:20:03 2013 +0200
@@ -26,6 +26,8 @@
 
 import java.util.*;
 import java.util.Spliterators;
+import java.util.function.Function;
+import java.util.function.Supplier;
 import java.util.stream.*;
 
 import static java.util.stream.LambdaTestHelpers.*;
@@ -37,6 +39,69 @@
  */
 @Test
 public class SortedOpTest extends OpTestCase {
+
+    public void testRefStreamTooLarge() {
+        Function<LongStream, Stream<Long>> f = s ->
+                // Clear the SORTED flag
+                s.mapToObj(i -> i)
+                .sorted();
+
+        testStreamTooLarge(f, Stream::findFirst);
+    }
+
+    public void testIntStreamTooLarge() {
+        Function<LongStream, IntStream> f = s ->
+                // Clear the SORTED flag
+                s.mapToInt(i -> (int) i)
+                .sorted();
+
+        testStreamTooLarge(f, IntStream::findFirst);
+    }
+
+    public void testLongStreamTooLarge() {
+        Function<LongStream, LongStream> f = s ->
+                // Clear the SORTED flag
+                s.map(i -> i)
+                .sorted();
+
+        testStreamTooLarge(f, LongStream::findFirst);
+    }
+
+    public void testDoubleStreamTooLarge() {
+        Function<LongStream, DoubleStream> f = s ->
+                // Clear the SORTED flag
+                s.mapToDouble(i -> (double) i)
+                .sorted();
+
+        testStreamTooLarge(f, DoubleStream::findFirst);
+    }
+
+    <T, S extends BaseStream<T, S>> void testStreamTooLarge(Function<LongStream, S> s,
+                                                            Function<S, ?> terminal) {
+        // Set up conditions for a large input > maximum array size
+        Supplier<LongStream> input = () -> LongStream.range(0, 1L + Integer.MAX_VALUE);
+
+        // Transformation functions
+        List<Function<LongStream, LongStream>> transforms = Arrays.asList(
+                ls -> ls,
+                ls -> ls.parallel(),
+                // Clear the SIZED flag
+                ls -> ls.limit(Long.MAX_VALUE),
+                ls -> ls.limit(Long.MAX_VALUE).parallel());
+
+        for (Function<LongStream, LongStream> transform : transforms) {
+            RuntimeException caught = null;
+            try {
+                terminal.apply(s.apply(transform.apply(input.get())));
+            } catch (RuntimeException e) {
+                caught = e;
+            }
+            assertNotNull(caught, "Expected an instance of exception IllegalArgumentException but no exception thrown");
+            assertTrue(caught instanceof IllegalArgumentException,
+                       String.format("Expected an instance of exception IllegalArgumentException but got %s", caught));
+        }
+    }
+
     public void testSorted() {
         assertCountSum(countTo(0).stream().sorted(), 0, 0);
         assertCountSum(countTo(10).stream().sorted(), 10, 55);