8042355: stream with sorted() causes downstream ops not to be lazy
authorpsandoz
Tue, 06 May 2014 10:29:59 +0200
changeset 24258 0e9ab834f44a
parent 24257 f524e23d7f7b
child 24259 c64d805ad840
8042355: stream with sorted() causes downstream ops not to be lazy Reviewed-by: mduigou
jdk/src/share/classes/java/util/stream/SortedOps.java
jdk/test/java/util/stream/test/org/openjdk/tests/java/util/stream/SortedOpTest.java
--- a/jdk/src/share/classes/java/util/stream/SortedOps.java	Tue May 06 10:28:48 2014 +0400
+++ b/jdk/src/share/classes/java/util/stream/SortedOps.java	Tue May 06 10:29:59 2014 +0200
@@ -279,16 +279,60 @@
     }
 
     /**
+     * Abstract {@link Sink} for implementing sort on reference streams.
+     *
+     * <p>
+     * Note: documentation below applies to reference and all primitive sinks.
+     * <p>
+     * Sorting sinks first accept all elements, buffering then into an array
+     * or a re-sizable data structure, if the size of the pipeline is known or
+     * unknown respectively.  At the end of the sink protocol those elements are
+     * sorted and then pushed downstream.
+     * This class records if {@link #cancellationRequested} is called.  If so it
+     * can be inferred that the source pushing source elements into the pipeline
+     * knows that the pipeline is short-circuiting.  In such cases sub-classes
+     * pushing elements downstream will preserve the short-circuiting protocol
+     * by calling {@code downstream.cancellationRequested()} and checking the
+     * result is {@code false} before an element is pushed.
+     * <p>
+     * Note that the above behaviour is an optimization for sorting with
+     * sequential streams.  It is not an error that more elements, than strictly
+     * required to produce a result, may flow through the pipeline.  This can
+     * occur, in general (not restricted to just sorting), for short-circuiting
+     * parallel pipelines.
+     */
+    private static abstract class AbstractRefSortingSink<T> extends Sink.ChainedReference<T, T> {
+        protected final Comparator<? super T> comparator;
+        // @@@ could be a lazy final value, if/when support is added
+        protected boolean cancellationWasRequested;
+
+        AbstractRefSortingSink(Sink<? super T> downstream, Comparator<? super T> comparator) {
+            super(downstream);
+            this.comparator = comparator;
+        }
+
+        /**
+         * Records is cancellation is requested so short-circuiting behaviour
+         * can be preserved when the sorted elements are pushed downstream.
+         *
+         * @return false, as this sink never short-circuits.
+         */
+        @Override
+        public final boolean cancellationRequested() {
+            cancellationWasRequested = true;
+            return false;
+        }
+    }
+
+    /**
      * {@link Sink} for implementing sort on SIZED reference streams.
      */
-    private static final class SizedRefSortingSink<T> extends Sink.ChainedReference<T, T> {
-        private final Comparator<? super T> comparator;
+    private static final class SizedRefSortingSink<T> extends AbstractRefSortingSink<T> {
         private T[] array;
         private int offset;
 
         SizedRefSortingSink(Sink<? super T> sink, Comparator<? super T> comparator) {
-            super(sink);
-            this.comparator = comparator;
+            super(sink, comparator);
         }
 
         @Override
@@ -303,8 +347,14 @@
         public void end() {
             Arrays.sort(array, 0, offset, comparator);
             downstream.begin(offset);
-            for (int i = 0; i < offset; i++)
-                downstream.accept(array[i]);
+            if (!cancellationWasRequested) {
+                for (int i = 0; i < offset; i++)
+                    downstream.accept(array[i]);
+            }
+            else {
+                for (int i = 0; i < offset && !downstream.cancellationRequested(); i++)
+                    downstream.accept(array[i]);
+            }
             downstream.end();
             array = null;
         }
@@ -318,13 +368,11 @@
     /**
      * {@link Sink} for implementing sort on reference streams.
      */
-    private static final class RefSortingSink<T> extends Sink.ChainedReference<T, T> {
-        private final Comparator<? super T> comparator;
+    private static final class RefSortingSink<T> extends AbstractRefSortingSink<T> {
         private ArrayList<T> list;
 
         RefSortingSink(Sink<? super T> sink, Comparator<? super T> comparator) {
-            super(sink);
-            this.comparator = comparator;
+            super(sink, comparator);
         }
 
         @Override
@@ -338,7 +386,15 @@
         public void end() {
             list.sort(comparator);
             downstream.begin(list.size());
-            list.forEach(downstream::accept);
+            if (!cancellationWasRequested) {
+                list.forEach(downstream::accept);
+            }
+            else {
+                for (T t : list) {
+                    if (downstream.cancellationRequested()) break;
+                    downstream.accept(t);
+                }
+            }
             downstream.end();
             list = null;
         }
@@ -350,9 +406,26 @@
     }
 
     /**
+     * Abstract {@link Sink} for implementing sort on int streams.
+     */
+    private static abstract class AbstractIntSortingSink extends Sink.ChainedInt<Integer> {
+        protected boolean cancellationWasRequested;
+
+        AbstractIntSortingSink(Sink<? super Integer> downstream) {
+            super(downstream);
+        }
+
+        @Override
+        public final boolean cancellationRequested() {
+            cancellationWasRequested = true;
+            return false;
+        }
+    }
+
+    /**
      * {@link Sink} for implementing sort on SIZED int streams.
      */
-    private static final class SizedIntSortingSink extends Sink.ChainedInt<Integer> {
+    private static final class SizedIntSortingSink extends AbstractIntSortingSink {
         private int[] array;
         private int offset;
 
@@ -371,8 +444,14 @@
         public void end() {
             Arrays.sort(array, 0, offset);
             downstream.begin(offset);
-            for (int i = 0; i < offset; i++)
-                downstream.accept(array[i]);
+            if (!cancellationWasRequested) {
+                for (int i = 0; i < offset; i++)
+                    downstream.accept(array[i]);
+            }
+            else {
+                for (int i = 0; i < offset && !downstream.cancellationRequested(); i++)
+                    downstream.accept(array[i]);
+            }
             downstream.end();
             array = null;
         }
@@ -386,7 +465,7 @@
     /**
      * {@link Sink} for implementing sort on int streams.
      */
-    private static final class IntSortingSink extends Sink.ChainedInt<Integer> {
+    private static final class IntSortingSink extends AbstractIntSortingSink {
         private SpinedBuffer.OfInt b;
 
         IntSortingSink(Sink<? super Integer> sink) {
@@ -405,8 +484,16 @@
             int[] ints = b.asPrimitiveArray();
             Arrays.sort(ints);
             downstream.begin(ints.length);
-            for (int anInt : ints)
-                downstream.accept(anInt);
+            if (!cancellationWasRequested) {
+                for (int anInt : ints)
+                    downstream.accept(anInt);
+            }
+            else {
+                for (int anInt : ints) {
+                    if (downstream.cancellationRequested()) break;
+                    downstream.accept(anInt);
+                }
+            }
             downstream.end();
         }
 
@@ -417,9 +504,26 @@
     }
 
     /**
+     * Abstract {@link Sink} for implementing sort on long streams.
+     */
+    private static abstract class AbstractLongSortingSink extends Sink.ChainedLong<Long> {
+        protected boolean cancellationWasRequested;
+
+        AbstractLongSortingSink(Sink<? super Long> downstream) {
+            super(downstream);
+        }
+
+        @Override
+        public final boolean cancellationRequested() {
+            cancellationWasRequested = true;
+            return false;
+        }
+    }
+
+    /**
      * {@link Sink} for implementing sort on SIZED long streams.
      */
-    private static final class SizedLongSortingSink extends Sink.ChainedLong<Long> {
+    private static final class SizedLongSortingSink extends AbstractLongSortingSink {
         private long[] array;
         private int offset;
 
@@ -438,8 +542,14 @@
         public void end() {
             Arrays.sort(array, 0, offset);
             downstream.begin(offset);
-            for (int i = 0; i < offset; i++)
-                downstream.accept(array[i]);
+            if (!cancellationWasRequested) {
+                for (int i = 0; i < offset; i++)
+                    downstream.accept(array[i]);
+            }
+            else {
+                for (int i = 0; i < offset && !downstream.cancellationRequested(); i++)
+                    downstream.accept(array[i]);
+            }
             downstream.end();
             array = null;
         }
@@ -453,7 +563,7 @@
     /**
      * {@link Sink} for implementing sort on long streams.
      */
-    private static final class LongSortingSink extends Sink.ChainedLong<Long> {
+    private static final class LongSortingSink extends AbstractLongSortingSink {
         private SpinedBuffer.OfLong b;
 
         LongSortingSink(Sink<? super Long> sink) {
@@ -472,8 +582,16 @@
             long[] longs = b.asPrimitiveArray();
             Arrays.sort(longs);
             downstream.begin(longs.length);
-            for (long aLong : longs)
-                downstream.accept(aLong);
+            if (!cancellationWasRequested) {
+                for (long aLong : longs)
+                    downstream.accept(aLong);
+            }
+            else {
+                for (long aLong : longs) {
+                    if (downstream.cancellationRequested()) break;
+                    downstream.accept(aLong);
+                }
+            }
             downstream.end();
         }
 
@@ -484,9 +602,26 @@
     }
 
     /**
+     * Abstract {@link Sink} for implementing sort on long streams.
+     */
+    private static abstract class AbstractDoubleSortingSink extends Sink.ChainedDouble<Double> {
+        protected boolean cancellationWasRequested;
+
+        AbstractDoubleSortingSink(Sink<? super Double> downstream) {
+            super(downstream);
+        }
+
+        @Override
+        public final boolean cancellationRequested() {
+            cancellationWasRequested = true;
+            return false;
+        }
+    }
+
+    /**
      * {@link Sink} for implementing sort on SIZED double streams.
      */
-    private static final class SizedDoubleSortingSink extends Sink.ChainedDouble<Double> {
+    private static final class SizedDoubleSortingSink extends AbstractDoubleSortingSink {
         private double[] array;
         private int offset;
 
@@ -505,8 +640,14 @@
         public void end() {
             Arrays.sort(array, 0, offset);
             downstream.begin(offset);
-            for (int i = 0; i < offset; i++)
-                downstream.accept(array[i]);
+            if (!cancellationWasRequested) {
+                for (int i = 0; i < offset; i++)
+                    downstream.accept(array[i]);
+            }
+            else {
+                for (int i = 0; i < offset && !downstream.cancellationRequested(); i++)
+                    downstream.accept(array[i]);
+            }
             downstream.end();
             array = null;
         }
@@ -520,7 +661,7 @@
     /**
      * {@link Sink} for implementing sort on double streams.
      */
-    private static final class DoubleSortingSink extends Sink.ChainedDouble<Double> {
+    private static final class DoubleSortingSink extends AbstractDoubleSortingSink {
         private SpinedBuffer.OfDouble b;
 
         DoubleSortingSink(Sink<? super Double> sink) {
@@ -539,8 +680,16 @@
             double[] doubles = b.asPrimitiveArray();
             Arrays.sort(doubles);
             downstream.begin(doubles.length);
-            for (double aDouble : doubles)
-                downstream.accept(aDouble);
+            if (!cancellationWasRequested) {
+                for (double aDouble : doubles)
+                    downstream.accept(aDouble);
+            }
+            else {
+                for (double aDouble : doubles) {
+                    if (downstream.cancellationRequested()) break;
+                    downstream.accept(aDouble);
+                }
+            }
             downstream.end();
         }
 
--- a/jdk/test/java/util/stream/test/org/openjdk/tests/java/util/stream/SortedOpTest.java	Tue May 06 10:28:48 2014 +0400
+++ b/jdk/test/java/util/stream/test/org/openjdk/tests/java/util/stream/SortedOpTest.java	Tue May 06 10:29:59 2014 +0200
@@ -26,6 +26,9 @@
 
 import java.util.*;
 import java.util.Spliterators;
+import java.util.concurrent.atomic.AtomicInteger;
+import java.util.function.BiFunction;
+import java.util.function.Consumer;
 import java.util.function.Function;
 import java.util.function.Supplier;
 import java.util.stream.*;
@@ -122,24 +125,33 @@
 
     @Test(groups = { "serialization-hostile" })
     public void testSequentialShortCircuitTerminal() {
-        // The sorted op for sequential evaluation will buffer all elements when accepting
-        // then at the end sort those elements and push those elements downstream
+        // The sorted op for sequential evaluation will buffer all elements when
+        // accepting then at the end sort those elements and push those elements
+        // downstream
+        // A peek operation is added in-between the sorted() and terminal
+        // operation that counts the number of calls to its consumer and
+        // asserts that the number of calls is at most the required quantity
 
         List<Integer> l = Arrays.asList(5, 4, 3, 2, 1);
 
+        Function<Integer, Stream<Integer>> knownSize = i -> assertNCallsOnly(
+                l.stream().sorted(), Stream::peek, i);
+        Function<Integer, Stream<Integer>> unknownSize = i -> assertNCallsOnly
+                (unknownSizeStream(l).sorted(), Stream::peek, i);
+
         // Find
-        assertEquals(l.stream().sorted().findFirst(), Optional.of(1));
-        assertEquals(l.stream().sorted().findAny(), Optional.of(1));
-        assertEquals(unknownSizeStream(l).sorted().findFirst(), Optional.of(1));
-        assertEquals(unknownSizeStream(l).sorted().findAny(), Optional.of(1));
+        assertEquals(knownSize.apply(1).findFirst(), Optional.of(1));
+        assertEquals(knownSize.apply(1).findAny(), Optional.of(1));
+        assertEquals(unknownSize.apply(1).findFirst(), Optional.of(1));
+        assertEquals(unknownSize.apply(1).findAny(), Optional.of(1));
 
         // Match
-        assertEquals(l.stream().sorted().anyMatch(i -> i == 2), true);
-        assertEquals(l.stream().sorted().noneMatch(i -> i == 2), false);
-        assertEquals(l.stream().sorted().allMatch(i -> i == 2), false);
-        assertEquals(unknownSizeStream(l).sorted().anyMatch(i -> i == 2), true);
-        assertEquals(unknownSizeStream(l).sorted().noneMatch(i -> i == 2), false);
-        assertEquals(unknownSizeStream(l).sorted().allMatch(i -> i == 2), false);
+        assertEquals(knownSize.apply(2).anyMatch(i -> i == 2), true);
+        assertEquals(knownSize.apply(2).noneMatch(i -> i == 2), false);
+        assertEquals(knownSize.apply(2).allMatch(i -> i == 2), false);
+        assertEquals(unknownSize.apply(2).anyMatch(i -> i == 2), true);
+        assertEquals(unknownSize.apply(2).noneMatch(i -> i == 2), false);
+        assertEquals(unknownSize.apply(2).allMatch(i -> i == 2), false);
     }
 
     private <T> Stream<T> unknownSizeStream(List<T> l) {
@@ -199,19 +211,24 @@
     public void testIntSequentialShortCircuitTerminal() {
         int[] a = new int[]{5, 4, 3, 2, 1};
 
+        Function<Integer, IntStream> knownSize = i -> assertNCallsOnly(
+                Arrays.stream(a).sorted(), (s, c) -> s.peek(c::accept), i);
+        Function<Integer, IntStream> unknownSize = i -> assertNCallsOnly
+                (unknownSizeIntStream(a).sorted(), (s, c) -> s.peek(c::accept), i);
+
         // Find
-        assertEquals(Arrays.stream(a).sorted().findFirst(), OptionalInt.of(1));
-        assertEquals(Arrays.stream(a).sorted().findAny(), OptionalInt.of(1));
-        assertEquals(unknownSizeIntStream(a).sorted().findFirst(), OptionalInt.of(1));
-        assertEquals(unknownSizeIntStream(a).sorted().findAny(), OptionalInt.of(1));
+        assertEquals(knownSize.apply(1).findFirst(), OptionalInt.of(1));
+        assertEquals(knownSize.apply(1).findAny(), OptionalInt.of(1));
+        assertEquals(unknownSize.apply(1).findFirst(), OptionalInt.of(1));
+        assertEquals(unknownSize.apply(1).findAny(), OptionalInt.of(1));
 
         // Match
-        assertEquals(Arrays.stream(a).sorted().anyMatch(i -> i == 2), true);
-        assertEquals(Arrays.stream(a).sorted().noneMatch(i -> i == 2), false);
-        assertEquals(Arrays.stream(a).sorted().allMatch(i -> i == 2), false);
-        assertEquals(unknownSizeIntStream(a).sorted().anyMatch(i -> i == 2), true);
-        assertEquals(unknownSizeIntStream(a).sorted().noneMatch(i -> i == 2), false);
-        assertEquals(unknownSizeIntStream(a).sorted().allMatch(i -> i == 2), false);
+        assertEquals(knownSize.apply(2).anyMatch(i -> i == 2), true);
+        assertEquals(knownSize.apply(2).noneMatch(i -> i == 2), false);
+        assertEquals(knownSize.apply(2).allMatch(i -> i == 2), false);
+        assertEquals(unknownSize.apply(2).anyMatch(i -> i == 2), true);
+        assertEquals(unknownSize.apply(2).noneMatch(i -> i == 2), false);
+        assertEquals(unknownSize.apply(2).allMatch(i -> i == 2), false);
     }
 
     private IntStream unknownSizeIntStream(int[] a) {
@@ -242,19 +259,24 @@
     public void testLongSequentialShortCircuitTerminal() {
         long[] a = new long[]{5, 4, 3, 2, 1};
 
+        Function<Integer, LongStream> knownSize = i -> assertNCallsOnly(
+                Arrays.stream(a).sorted(), (s, c) -> s.peek(c::accept), i);
+        Function<Integer, LongStream> unknownSize = i -> assertNCallsOnly
+                (unknownSizeLongStream(a).sorted(), (s, c) -> s.peek(c::accept), i);
+
         // Find
-        assertEquals(Arrays.stream(a).sorted().findFirst(), OptionalLong.of(1));
-        assertEquals(Arrays.stream(a).sorted().findAny(), OptionalLong.of(1));
-        assertEquals(unknownSizeLongStream(a).sorted().findFirst(), OptionalLong.of(1));
-        assertEquals(unknownSizeLongStream(a).sorted().findAny(), OptionalLong.of(1));
+        assertEquals(knownSize.apply(1).findFirst(), OptionalLong.of(1));
+        assertEquals(knownSize.apply(1).findAny(), OptionalLong.of(1));
+        assertEquals(unknownSize.apply(1).findFirst(), OptionalLong.of(1));
+        assertEquals(unknownSize.apply(1).findAny(), OptionalLong.of(1));
 
         // Match
-        assertEquals(Arrays.stream(a).sorted().anyMatch(i -> i == 2), true);
-        assertEquals(Arrays.stream(a).sorted().noneMatch(i -> i == 2), false);
-        assertEquals(Arrays.stream(a).sorted().allMatch(i -> i == 2), false);
-        assertEquals(unknownSizeLongStream(a).sorted().anyMatch(i -> i == 2), true);
-        assertEquals(unknownSizeLongStream(a).sorted().noneMatch(i -> i == 2), false);
-        assertEquals(unknownSizeLongStream(a).sorted().allMatch(i -> i == 2), false);
+        assertEquals(knownSize.apply(2).anyMatch(i -> i == 2), true);
+        assertEquals(knownSize.apply(2).noneMatch(i -> i == 2), false);
+        assertEquals(knownSize.apply(2).allMatch(i -> i == 2), false);
+        assertEquals(unknownSize.apply(2).anyMatch(i -> i == 2), true);
+        assertEquals(unknownSize.apply(2).noneMatch(i -> i == 2), false);
+        assertEquals(unknownSize.apply(2).allMatch(i -> i == 2), false);
     }
 
     private LongStream unknownSizeLongStream(long[] a) {
@@ -285,19 +307,24 @@
     public void testDoubleSequentialShortCircuitTerminal() {
         double[] a = new double[]{5.0, 4.0, 3.0, 2.0, 1.0};
 
+        Function<Integer, DoubleStream> knownSize = i -> assertNCallsOnly(
+                Arrays.stream(a).sorted(), (s, c) -> s.peek(c::accept), i);
+        Function<Integer, DoubleStream> unknownSize = i -> assertNCallsOnly
+                (unknownSizeDoubleStream(a).sorted(), (s, c) -> s.peek(c::accept), i);
+
         // Find
-        assertEquals(Arrays.stream(a).sorted().findFirst(), OptionalDouble.of(1));
-        assertEquals(Arrays.stream(a).sorted().findAny(), OptionalDouble.of(1));
-        assertEquals(unknownSizeDoubleStream(a).sorted().findFirst(), OptionalDouble.of(1));
-        assertEquals(unknownSizeDoubleStream(a).sorted().findAny(), OptionalDouble.of(1));
+        assertEquals(knownSize.apply(1).findFirst(), OptionalDouble.of(1));
+        assertEquals(knownSize.apply(1).findAny(), OptionalDouble.of(1));
+        assertEquals(unknownSize.apply(1).findFirst(), OptionalDouble.of(1));
+        assertEquals(unknownSize.apply(1).findAny(), OptionalDouble.of(1));
 
         // Match
-        assertEquals(Arrays.stream(a).sorted().anyMatch(i -> i == 2.0), true);
-        assertEquals(Arrays.stream(a).sorted().noneMatch(i -> i == 2.0), false);
-        assertEquals(Arrays.stream(a).sorted().allMatch(i -> i == 2.0), false);
-        assertEquals(unknownSizeDoubleStream(a).sorted().anyMatch(i -> i == 2.0), true);
-        assertEquals(unknownSizeDoubleStream(a).sorted().noneMatch(i -> i == 2.0), false);
-        assertEquals(unknownSizeDoubleStream(a).sorted().allMatch(i -> i == 2.0), false);
+        assertEquals(knownSize.apply(2).anyMatch(i -> i == 2.0), true);
+        assertEquals(knownSize.apply(2).noneMatch(i -> i == 2.0), false);
+        assertEquals(knownSize.apply(2).allMatch(i -> i == 2.0), false);
+        assertEquals(unknownSize.apply(2).anyMatch(i -> i == 2.0), true);
+        assertEquals(unknownSize.apply(2).noneMatch(i -> i == 2.0), false);
+        assertEquals(unknownSize.apply(2).allMatch(i -> i == 2.0), false);
     }
 
     private DoubleStream unknownSizeDoubleStream(double[] a) {
@@ -321,4 +348,14 @@
         assertSorted(result);
         assertContentsUnordered(data, result);
     }
+
+    /**
+     * Interpose a consumer that asserts it is called at most N times.
+     */
+    <T, S extends BaseStream<T, S>, R> S assertNCallsOnly(S s, BiFunction<S, Consumer<T>, S> pf, int n) {
+        AtomicInteger boxedInt = new AtomicInteger();
+        return pf.apply(s, i -> {
+            assertFalse(boxedInt.incrementAndGet() > n, "Intermediate op called more than " + n + " time(s)");
+        });
+    }
 }