# HG changeset patch # User psandoz # Date 1399364999 -7200 # Node ID 0e9ab834f44a71901e06cb2f416c1e1212a3bb2a # Parent f524e23d7f7b0735a7f9119491d5a3eb8ec8e38a 8042355: stream with sorted() causes downstream ops not to be lazy Reviewed-by: mduigou diff -r f524e23d7f7b -r 0e9ab834f44a jdk/src/share/classes/java/util/stream/SortedOps.java --- a/jdk/src/share/classes/java/util/stream/SortedOps.java Tue May 06 10:28:48 2014 +0400 +++ b/jdk/src/share/classes/java/util/stream/SortedOps.java Tue May 06 10:29:59 2014 +0200 @@ -279,16 +279,60 @@ } /** + * Abstract {@link Sink} for implementing sort on reference streams. + * + *

+ * Note: documentation below applies to reference and all primitive sinks. + *

+ * Sorting sinks first accept all elements, buffering then into an array + * or a re-sizable data structure, if the size of the pipeline is known or + * unknown respectively. At the end of the sink protocol those elements are + * sorted and then pushed downstream. + * This class records if {@link #cancellationRequested} is called. If so it + * can be inferred that the source pushing source elements into the pipeline + * knows that the pipeline is short-circuiting. In such cases sub-classes + * pushing elements downstream will preserve the short-circuiting protocol + * by calling {@code downstream.cancellationRequested()} and checking the + * result is {@code false} before an element is pushed. + *

+ * Note that the above behaviour is an optimization for sorting with + * sequential streams. It is not an error that more elements, than strictly + * required to produce a result, may flow through the pipeline. This can + * occur, in general (not restricted to just sorting), for short-circuiting + * parallel pipelines. + */ + private static abstract class AbstractRefSortingSink extends Sink.ChainedReference { + protected final Comparator comparator; + // @@@ could be a lazy final value, if/when support is added + protected boolean cancellationWasRequested; + + AbstractRefSortingSink(Sink downstream, Comparator comparator) { + super(downstream); + this.comparator = comparator; + } + + /** + * Records is cancellation is requested so short-circuiting behaviour + * can be preserved when the sorted elements are pushed downstream. + * + * @return false, as this sink never short-circuits. + */ + @Override + public final boolean cancellationRequested() { + cancellationWasRequested = true; + return false; + } + } + + /** * {@link Sink} for implementing sort on SIZED reference streams. */ - private static final class SizedRefSortingSink extends Sink.ChainedReference { - private final Comparator comparator; + private static final class SizedRefSortingSink extends AbstractRefSortingSink { private T[] array; private int offset; SizedRefSortingSink(Sink sink, Comparator comparator) { - super(sink); - this.comparator = comparator; + super(sink, comparator); } @Override @@ -303,8 +347,14 @@ public void end() { Arrays.sort(array, 0, offset, comparator); downstream.begin(offset); - for (int i = 0; i < offset; i++) - downstream.accept(array[i]); + if (!cancellationWasRequested) { + for (int i = 0; i < offset; i++) + downstream.accept(array[i]); + } + else { + for (int i = 0; i < offset && !downstream.cancellationRequested(); i++) + downstream.accept(array[i]); + } downstream.end(); array = null; } @@ -318,13 +368,11 @@ /** * {@link Sink} for implementing sort on reference streams. */ - private static final class RefSortingSink extends Sink.ChainedReference { - private final Comparator comparator; + private static final class RefSortingSink extends AbstractRefSortingSink { private ArrayList list; RefSortingSink(Sink sink, Comparator comparator) { - super(sink); - this.comparator = comparator; + super(sink, comparator); } @Override @@ -338,7 +386,15 @@ public void end() { list.sort(comparator); downstream.begin(list.size()); - list.forEach(downstream::accept); + if (!cancellationWasRequested) { + list.forEach(downstream::accept); + } + else { + for (T t : list) { + if (downstream.cancellationRequested()) break; + downstream.accept(t); + } + } downstream.end(); list = null; } @@ -350,9 +406,26 @@ } /** + * Abstract {@link Sink} for implementing sort on int streams. + */ + private static abstract class AbstractIntSortingSink extends Sink.ChainedInt { + protected boolean cancellationWasRequested; + + AbstractIntSortingSink(Sink downstream) { + super(downstream); + } + + @Override + public final boolean cancellationRequested() { + cancellationWasRequested = true; + return false; + } + } + + /** * {@link Sink} for implementing sort on SIZED int streams. */ - private static final class SizedIntSortingSink extends Sink.ChainedInt { + private static final class SizedIntSortingSink extends AbstractIntSortingSink { private int[] array; private int offset; @@ -371,8 +444,14 @@ public void end() { Arrays.sort(array, 0, offset); downstream.begin(offset); - for (int i = 0; i < offset; i++) - downstream.accept(array[i]); + if (!cancellationWasRequested) { + for (int i = 0; i < offset; i++) + downstream.accept(array[i]); + } + else { + for (int i = 0; i < offset && !downstream.cancellationRequested(); i++) + downstream.accept(array[i]); + } downstream.end(); array = null; } @@ -386,7 +465,7 @@ /** * {@link Sink} for implementing sort on int streams. */ - private static final class IntSortingSink extends Sink.ChainedInt { + private static final class IntSortingSink extends AbstractIntSortingSink { private SpinedBuffer.OfInt b; IntSortingSink(Sink sink) { @@ -405,8 +484,16 @@ int[] ints = b.asPrimitiveArray(); Arrays.sort(ints); downstream.begin(ints.length); - for (int anInt : ints) - downstream.accept(anInt); + if (!cancellationWasRequested) { + for (int anInt : ints) + downstream.accept(anInt); + } + else { + for (int anInt : ints) { + if (downstream.cancellationRequested()) break; + downstream.accept(anInt); + } + } downstream.end(); } @@ -417,9 +504,26 @@ } /** + * Abstract {@link Sink} for implementing sort on long streams. + */ + private static abstract class AbstractLongSortingSink extends Sink.ChainedLong { + protected boolean cancellationWasRequested; + + AbstractLongSortingSink(Sink downstream) { + super(downstream); + } + + @Override + public final boolean cancellationRequested() { + cancellationWasRequested = true; + return false; + } + } + + /** * {@link Sink} for implementing sort on SIZED long streams. */ - private static final class SizedLongSortingSink extends Sink.ChainedLong { + private static final class SizedLongSortingSink extends AbstractLongSortingSink { private long[] array; private int offset; @@ -438,8 +542,14 @@ public void end() { Arrays.sort(array, 0, offset); downstream.begin(offset); - for (int i = 0; i < offset; i++) - downstream.accept(array[i]); + if (!cancellationWasRequested) { + for (int i = 0; i < offset; i++) + downstream.accept(array[i]); + } + else { + for (int i = 0; i < offset && !downstream.cancellationRequested(); i++) + downstream.accept(array[i]); + } downstream.end(); array = null; } @@ -453,7 +563,7 @@ /** * {@link Sink} for implementing sort on long streams. */ - private static final class LongSortingSink extends Sink.ChainedLong { + private static final class LongSortingSink extends AbstractLongSortingSink { private SpinedBuffer.OfLong b; LongSortingSink(Sink sink) { @@ -472,8 +582,16 @@ long[] longs = b.asPrimitiveArray(); Arrays.sort(longs); downstream.begin(longs.length); - for (long aLong : longs) - downstream.accept(aLong); + if (!cancellationWasRequested) { + for (long aLong : longs) + downstream.accept(aLong); + } + else { + for (long aLong : longs) { + if (downstream.cancellationRequested()) break; + downstream.accept(aLong); + } + } downstream.end(); } @@ -484,9 +602,26 @@ } /** + * Abstract {@link Sink} for implementing sort on long streams. + */ + private static abstract class AbstractDoubleSortingSink extends Sink.ChainedDouble { + protected boolean cancellationWasRequested; + + AbstractDoubleSortingSink(Sink downstream) { + super(downstream); + } + + @Override + public final boolean cancellationRequested() { + cancellationWasRequested = true; + return false; + } + } + + /** * {@link Sink} for implementing sort on SIZED double streams. */ - private static final class SizedDoubleSortingSink extends Sink.ChainedDouble { + private static final class SizedDoubleSortingSink extends AbstractDoubleSortingSink { private double[] array; private int offset; @@ -505,8 +640,14 @@ public void end() { Arrays.sort(array, 0, offset); downstream.begin(offset); - for (int i = 0; i < offset; i++) - downstream.accept(array[i]); + if (!cancellationWasRequested) { + for (int i = 0; i < offset; i++) + downstream.accept(array[i]); + } + else { + for (int i = 0; i < offset && !downstream.cancellationRequested(); i++) + downstream.accept(array[i]); + } downstream.end(); array = null; } @@ -520,7 +661,7 @@ /** * {@link Sink} for implementing sort on double streams. */ - private static final class DoubleSortingSink extends Sink.ChainedDouble { + private static final class DoubleSortingSink extends AbstractDoubleSortingSink { private SpinedBuffer.OfDouble b; DoubleSortingSink(Sink sink) { @@ -539,8 +680,16 @@ double[] doubles = b.asPrimitiveArray(); Arrays.sort(doubles); downstream.begin(doubles.length); - for (double aDouble : doubles) - downstream.accept(aDouble); + if (!cancellationWasRequested) { + for (double aDouble : doubles) + downstream.accept(aDouble); + } + else { + for (double aDouble : doubles) { + if (downstream.cancellationRequested()) break; + downstream.accept(aDouble); + } + } downstream.end(); } diff -r f524e23d7f7b -r 0e9ab834f44a jdk/test/java/util/stream/test/org/openjdk/tests/java/util/stream/SortedOpTest.java --- a/jdk/test/java/util/stream/test/org/openjdk/tests/java/util/stream/SortedOpTest.java Tue May 06 10:28:48 2014 +0400 +++ b/jdk/test/java/util/stream/test/org/openjdk/tests/java/util/stream/SortedOpTest.java Tue May 06 10:29:59 2014 +0200 @@ -26,6 +26,9 @@ import java.util.*; import java.util.Spliterators; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.function.BiFunction; +import java.util.function.Consumer; import java.util.function.Function; import java.util.function.Supplier; import java.util.stream.*; @@ -122,24 +125,33 @@ @Test(groups = { "serialization-hostile" }) public void testSequentialShortCircuitTerminal() { - // The sorted op for sequential evaluation will buffer all elements when accepting - // then at the end sort those elements and push those elements downstream + // The sorted op for sequential evaluation will buffer all elements when + // accepting then at the end sort those elements and push those elements + // downstream + // A peek operation is added in-between the sorted() and terminal + // operation that counts the number of calls to its consumer and + // asserts that the number of calls is at most the required quantity List l = Arrays.asList(5, 4, 3, 2, 1); + Function> knownSize = i -> assertNCallsOnly( + l.stream().sorted(), Stream::peek, i); + Function> unknownSize = i -> assertNCallsOnly + (unknownSizeStream(l).sorted(), Stream::peek, i); + // Find - assertEquals(l.stream().sorted().findFirst(), Optional.of(1)); - assertEquals(l.stream().sorted().findAny(), Optional.of(1)); - assertEquals(unknownSizeStream(l).sorted().findFirst(), Optional.of(1)); - assertEquals(unknownSizeStream(l).sorted().findAny(), Optional.of(1)); + assertEquals(knownSize.apply(1).findFirst(), Optional.of(1)); + assertEquals(knownSize.apply(1).findAny(), Optional.of(1)); + assertEquals(unknownSize.apply(1).findFirst(), Optional.of(1)); + assertEquals(unknownSize.apply(1).findAny(), Optional.of(1)); // Match - assertEquals(l.stream().sorted().anyMatch(i -> i == 2), true); - assertEquals(l.stream().sorted().noneMatch(i -> i == 2), false); - assertEquals(l.stream().sorted().allMatch(i -> i == 2), false); - assertEquals(unknownSizeStream(l).sorted().anyMatch(i -> i == 2), true); - assertEquals(unknownSizeStream(l).sorted().noneMatch(i -> i == 2), false); - assertEquals(unknownSizeStream(l).sorted().allMatch(i -> i == 2), false); + assertEquals(knownSize.apply(2).anyMatch(i -> i == 2), true); + assertEquals(knownSize.apply(2).noneMatch(i -> i == 2), false); + assertEquals(knownSize.apply(2).allMatch(i -> i == 2), false); + assertEquals(unknownSize.apply(2).anyMatch(i -> i == 2), true); + assertEquals(unknownSize.apply(2).noneMatch(i -> i == 2), false); + assertEquals(unknownSize.apply(2).allMatch(i -> i == 2), false); } private Stream unknownSizeStream(List l) { @@ -199,19 +211,24 @@ public void testIntSequentialShortCircuitTerminal() { int[] a = new int[]{5, 4, 3, 2, 1}; + Function knownSize = i -> assertNCallsOnly( + Arrays.stream(a).sorted(), (s, c) -> s.peek(c::accept), i); + Function unknownSize = i -> assertNCallsOnly + (unknownSizeIntStream(a).sorted(), (s, c) -> s.peek(c::accept), i); + // Find - assertEquals(Arrays.stream(a).sorted().findFirst(), OptionalInt.of(1)); - assertEquals(Arrays.stream(a).sorted().findAny(), OptionalInt.of(1)); - assertEquals(unknownSizeIntStream(a).sorted().findFirst(), OptionalInt.of(1)); - assertEquals(unknownSizeIntStream(a).sorted().findAny(), OptionalInt.of(1)); + assertEquals(knownSize.apply(1).findFirst(), OptionalInt.of(1)); + assertEquals(knownSize.apply(1).findAny(), OptionalInt.of(1)); + assertEquals(unknownSize.apply(1).findFirst(), OptionalInt.of(1)); + assertEquals(unknownSize.apply(1).findAny(), OptionalInt.of(1)); // Match - assertEquals(Arrays.stream(a).sorted().anyMatch(i -> i == 2), true); - assertEquals(Arrays.stream(a).sorted().noneMatch(i -> i == 2), false); - assertEquals(Arrays.stream(a).sorted().allMatch(i -> i == 2), false); - assertEquals(unknownSizeIntStream(a).sorted().anyMatch(i -> i == 2), true); - assertEquals(unknownSizeIntStream(a).sorted().noneMatch(i -> i == 2), false); - assertEquals(unknownSizeIntStream(a).sorted().allMatch(i -> i == 2), false); + assertEquals(knownSize.apply(2).anyMatch(i -> i == 2), true); + assertEquals(knownSize.apply(2).noneMatch(i -> i == 2), false); + assertEquals(knownSize.apply(2).allMatch(i -> i == 2), false); + assertEquals(unknownSize.apply(2).anyMatch(i -> i == 2), true); + assertEquals(unknownSize.apply(2).noneMatch(i -> i == 2), false); + assertEquals(unknownSize.apply(2).allMatch(i -> i == 2), false); } private IntStream unknownSizeIntStream(int[] a) { @@ -242,19 +259,24 @@ public void testLongSequentialShortCircuitTerminal() { long[] a = new long[]{5, 4, 3, 2, 1}; + Function knownSize = i -> assertNCallsOnly( + Arrays.stream(a).sorted(), (s, c) -> s.peek(c::accept), i); + Function unknownSize = i -> assertNCallsOnly + (unknownSizeLongStream(a).sorted(), (s, c) -> s.peek(c::accept), i); + // Find - assertEquals(Arrays.stream(a).sorted().findFirst(), OptionalLong.of(1)); - assertEquals(Arrays.stream(a).sorted().findAny(), OptionalLong.of(1)); - assertEquals(unknownSizeLongStream(a).sorted().findFirst(), OptionalLong.of(1)); - assertEquals(unknownSizeLongStream(a).sorted().findAny(), OptionalLong.of(1)); + assertEquals(knownSize.apply(1).findFirst(), OptionalLong.of(1)); + assertEquals(knownSize.apply(1).findAny(), OptionalLong.of(1)); + assertEquals(unknownSize.apply(1).findFirst(), OptionalLong.of(1)); + assertEquals(unknownSize.apply(1).findAny(), OptionalLong.of(1)); // Match - assertEquals(Arrays.stream(a).sorted().anyMatch(i -> i == 2), true); - assertEquals(Arrays.stream(a).sorted().noneMatch(i -> i == 2), false); - assertEquals(Arrays.stream(a).sorted().allMatch(i -> i == 2), false); - assertEquals(unknownSizeLongStream(a).sorted().anyMatch(i -> i == 2), true); - assertEquals(unknownSizeLongStream(a).sorted().noneMatch(i -> i == 2), false); - assertEquals(unknownSizeLongStream(a).sorted().allMatch(i -> i == 2), false); + assertEquals(knownSize.apply(2).anyMatch(i -> i == 2), true); + assertEquals(knownSize.apply(2).noneMatch(i -> i == 2), false); + assertEquals(knownSize.apply(2).allMatch(i -> i == 2), false); + assertEquals(unknownSize.apply(2).anyMatch(i -> i == 2), true); + assertEquals(unknownSize.apply(2).noneMatch(i -> i == 2), false); + assertEquals(unknownSize.apply(2).allMatch(i -> i == 2), false); } private LongStream unknownSizeLongStream(long[] a) { @@ -285,19 +307,24 @@ public void testDoubleSequentialShortCircuitTerminal() { double[] a = new double[]{5.0, 4.0, 3.0, 2.0, 1.0}; + Function knownSize = i -> assertNCallsOnly( + Arrays.stream(a).sorted(), (s, c) -> s.peek(c::accept), i); + Function unknownSize = i -> assertNCallsOnly + (unknownSizeDoubleStream(a).sorted(), (s, c) -> s.peek(c::accept), i); + // Find - assertEquals(Arrays.stream(a).sorted().findFirst(), OptionalDouble.of(1)); - assertEquals(Arrays.stream(a).sorted().findAny(), OptionalDouble.of(1)); - assertEquals(unknownSizeDoubleStream(a).sorted().findFirst(), OptionalDouble.of(1)); - assertEquals(unknownSizeDoubleStream(a).sorted().findAny(), OptionalDouble.of(1)); + assertEquals(knownSize.apply(1).findFirst(), OptionalDouble.of(1)); + assertEquals(knownSize.apply(1).findAny(), OptionalDouble.of(1)); + assertEquals(unknownSize.apply(1).findFirst(), OptionalDouble.of(1)); + assertEquals(unknownSize.apply(1).findAny(), OptionalDouble.of(1)); // Match - assertEquals(Arrays.stream(a).sorted().anyMatch(i -> i == 2.0), true); - assertEquals(Arrays.stream(a).sorted().noneMatch(i -> i == 2.0), false); - assertEquals(Arrays.stream(a).sorted().allMatch(i -> i == 2.0), false); - assertEquals(unknownSizeDoubleStream(a).sorted().anyMatch(i -> i == 2.0), true); - assertEquals(unknownSizeDoubleStream(a).sorted().noneMatch(i -> i == 2.0), false); - assertEquals(unknownSizeDoubleStream(a).sorted().allMatch(i -> i == 2.0), false); + assertEquals(knownSize.apply(2).anyMatch(i -> i == 2.0), true); + assertEquals(knownSize.apply(2).noneMatch(i -> i == 2.0), false); + assertEquals(knownSize.apply(2).allMatch(i -> i == 2.0), false); + assertEquals(unknownSize.apply(2).anyMatch(i -> i == 2.0), true); + assertEquals(unknownSize.apply(2).noneMatch(i -> i == 2.0), false); + assertEquals(unknownSize.apply(2).allMatch(i -> i == 2.0), false); } private DoubleStream unknownSizeDoubleStream(double[] a) { @@ -321,4 +348,14 @@ assertSorted(result); assertContentsUnordered(data, result); } + + /** + * Interpose a consumer that asserts it is called at most N times. + */ + , R> S assertNCallsOnly(S s, BiFunction, S> pf, int n) { + AtomicInteger boxedInt = new AtomicInteger(); + return pf.apply(s, i -> { + assertFalse(boxedInt.incrementAndGet() > n, "Intermediate op called more than " + n + " time(s)"); + }); + } }