# HG changeset patch
# User psandoz
# Date 1399364999 -7200
# Node ID 0e9ab834f44a71901e06cb2f416c1e1212a3bb2a
# Parent f524e23d7f7b0735a7f9119491d5a3eb8ec8e38a
8042355: stream with sorted() causes downstream ops not to be lazy
Reviewed-by: mduigou
diff -r f524e23d7f7b -r 0e9ab834f44a jdk/src/share/classes/java/util/stream/SortedOps.java
--- a/jdk/src/share/classes/java/util/stream/SortedOps.java Tue May 06 10:28:48 2014 +0400
+++ b/jdk/src/share/classes/java/util/stream/SortedOps.java Tue May 06 10:29:59 2014 +0200
@@ -279,16 +279,60 @@
}
/**
+ * Abstract {@link Sink} for implementing sort on reference streams.
+ *
+ *
+ * Note: documentation below applies to reference and all primitive sinks.
+ *
+ * Sorting sinks first accept all elements, buffering then into an array
+ * or a re-sizable data structure, if the size of the pipeline is known or
+ * unknown respectively. At the end of the sink protocol those elements are
+ * sorted and then pushed downstream.
+ * This class records if {@link #cancellationRequested} is called. If so it
+ * can be inferred that the source pushing source elements into the pipeline
+ * knows that the pipeline is short-circuiting. In such cases sub-classes
+ * pushing elements downstream will preserve the short-circuiting protocol
+ * by calling {@code downstream.cancellationRequested()} and checking the
+ * result is {@code false} before an element is pushed.
+ *
+ * Note that the above behaviour is an optimization for sorting with
+ * sequential streams. It is not an error that more elements, than strictly
+ * required to produce a result, may flow through the pipeline. This can
+ * occur, in general (not restricted to just sorting), for short-circuiting
+ * parallel pipelines.
+ */
+ private static abstract class AbstractRefSortingSink extends Sink.ChainedReference {
+ protected final Comparator super T> comparator;
+ // @@@ could be a lazy final value, if/when support is added
+ protected boolean cancellationWasRequested;
+
+ AbstractRefSortingSink(Sink super T> downstream, Comparator super T> comparator) {
+ super(downstream);
+ this.comparator = comparator;
+ }
+
+ /**
+ * Records is cancellation is requested so short-circuiting behaviour
+ * can be preserved when the sorted elements are pushed downstream.
+ *
+ * @return false, as this sink never short-circuits.
+ */
+ @Override
+ public final boolean cancellationRequested() {
+ cancellationWasRequested = true;
+ return false;
+ }
+ }
+
+ /**
* {@link Sink} for implementing sort on SIZED reference streams.
*/
- private static final class SizedRefSortingSink extends Sink.ChainedReference {
- private final Comparator super T> comparator;
+ private static final class SizedRefSortingSink extends AbstractRefSortingSink {
private T[] array;
private int offset;
SizedRefSortingSink(Sink super T> sink, Comparator super T> comparator) {
- super(sink);
- this.comparator = comparator;
+ super(sink, comparator);
}
@Override
@@ -303,8 +347,14 @@
public void end() {
Arrays.sort(array, 0, offset, comparator);
downstream.begin(offset);
- for (int i = 0; i < offset; i++)
- downstream.accept(array[i]);
+ if (!cancellationWasRequested) {
+ for (int i = 0; i < offset; i++)
+ downstream.accept(array[i]);
+ }
+ else {
+ for (int i = 0; i < offset && !downstream.cancellationRequested(); i++)
+ downstream.accept(array[i]);
+ }
downstream.end();
array = null;
}
@@ -318,13 +368,11 @@
/**
* {@link Sink} for implementing sort on reference streams.
*/
- private static final class RefSortingSink extends Sink.ChainedReference {
- private final Comparator super T> comparator;
+ private static final class RefSortingSink extends AbstractRefSortingSink {
private ArrayList list;
RefSortingSink(Sink super T> sink, Comparator super T> comparator) {
- super(sink);
- this.comparator = comparator;
+ super(sink, comparator);
}
@Override
@@ -338,7 +386,15 @@
public void end() {
list.sort(comparator);
downstream.begin(list.size());
- list.forEach(downstream::accept);
+ if (!cancellationWasRequested) {
+ list.forEach(downstream::accept);
+ }
+ else {
+ for (T t : list) {
+ if (downstream.cancellationRequested()) break;
+ downstream.accept(t);
+ }
+ }
downstream.end();
list = null;
}
@@ -350,9 +406,26 @@
}
/**
+ * Abstract {@link Sink} for implementing sort on int streams.
+ */
+ private static abstract class AbstractIntSortingSink extends Sink.ChainedInt {
+ protected boolean cancellationWasRequested;
+
+ AbstractIntSortingSink(Sink super Integer> downstream) {
+ super(downstream);
+ }
+
+ @Override
+ public final boolean cancellationRequested() {
+ cancellationWasRequested = true;
+ return false;
+ }
+ }
+
+ /**
* {@link Sink} for implementing sort on SIZED int streams.
*/
- private static final class SizedIntSortingSink extends Sink.ChainedInt {
+ private static final class SizedIntSortingSink extends AbstractIntSortingSink {
private int[] array;
private int offset;
@@ -371,8 +444,14 @@
public void end() {
Arrays.sort(array, 0, offset);
downstream.begin(offset);
- for (int i = 0; i < offset; i++)
- downstream.accept(array[i]);
+ if (!cancellationWasRequested) {
+ for (int i = 0; i < offset; i++)
+ downstream.accept(array[i]);
+ }
+ else {
+ for (int i = 0; i < offset && !downstream.cancellationRequested(); i++)
+ downstream.accept(array[i]);
+ }
downstream.end();
array = null;
}
@@ -386,7 +465,7 @@
/**
* {@link Sink} for implementing sort on int streams.
*/
- private static final class IntSortingSink extends Sink.ChainedInt {
+ private static final class IntSortingSink extends AbstractIntSortingSink {
private SpinedBuffer.OfInt b;
IntSortingSink(Sink super Integer> sink) {
@@ -405,8 +484,16 @@
int[] ints = b.asPrimitiveArray();
Arrays.sort(ints);
downstream.begin(ints.length);
- for (int anInt : ints)
- downstream.accept(anInt);
+ if (!cancellationWasRequested) {
+ for (int anInt : ints)
+ downstream.accept(anInt);
+ }
+ else {
+ for (int anInt : ints) {
+ if (downstream.cancellationRequested()) break;
+ downstream.accept(anInt);
+ }
+ }
downstream.end();
}
@@ -417,9 +504,26 @@
}
/**
+ * Abstract {@link Sink} for implementing sort on long streams.
+ */
+ private static abstract class AbstractLongSortingSink extends Sink.ChainedLong {
+ protected boolean cancellationWasRequested;
+
+ AbstractLongSortingSink(Sink super Long> downstream) {
+ super(downstream);
+ }
+
+ @Override
+ public final boolean cancellationRequested() {
+ cancellationWasRequested = true;
+ return false;
+ }
+ }
+
+ /**
* {@link Sink} for implementing sort on SIZED long streams.
*/
- private static final class SizedLongSortingSink extends Sink.ChainedLong {
+ private static final class SizedLongSortingSink extends AbstractLongSortingSink {
private long[] array;
private int offset;
@@ -438,8 +542,14 @@
public void end() {
Arrays.sort(array, 0, offset);
downstream.begin(offset);
- for (int i = 0; i < offset; i++)
- downstream.accept(array[i]);
+ if (!cancellationWasRequested) {
+ for (int i = 0; i < offset; i++)
+ downstream.accept(array[i]);
+ }
+ else {
+ for (int i = 0; i < offset && !downstream.cancellationRequested(); i++)
+ downstream.accept(array[i]);
+ }
downstream.end();
array = null;
}
@@ -453,7 +563,7 @@
/**
* {@link Sink} for implementing sort on long streams.
*/
- private static final class LongSortingSink extends Sink.ChainedLong {
+ private static final class LongSortingSink extends AbstractLongSortingSink {
private SpinedBuffer.OfLong b;
LongSortingSink(Sink super Long> sink) {
@@ -472,8 +582,16 @@
long[] longs = b.asPrimitiveArray();
Arrays.sort(longs);
downstream.begin(longs.length);
- for (long aLong : longs)
- downstream.accept(aLong);
+ if (!cancellationWasRequested) {
+ for (long aLong : longs)
+ downstream.accept(aLong);
+ }
+ else {
+ for (long aLong : longs) {
+ if (downstream.cancellationRequested()) break;
+ downstream.accept(aLong);
+ }
+ }
downstream.end();
}
@@ -484,9 +602,26 @@
}
/**
+ * Abstract {@link Sink} for implementing sort on long streams.
+ */
+ private static abstract class AbstractDoubleSortingSink extends Sink.ChainedDouble {
+ protected boolean cancellationWasRequested;
+
+ AbstractDoubleSortingSink(Sink super Double> downstream) {
+ super(downstream);
+ }
+
+ @Override
+ public final boolean cancellationRequested() {
+ cancellationWasRequested = true;
+ return false;
+ }
+ }
+
+ /**
* {@link Sink} for implementing sort on SIZED double streams.
*/
- private static final class SizedDoubleSortingSink extends Sink.ChainedDouble {
+ private static final class SizedDoubleSortingSink extends AbstractDoubleSortingSink {
private double[] array;
private int offset;
@@ -505,8 +640,14 @@
public void end() {
Arrays.sort(array, 0, offset);
downstream.begin(offset);
- for (int i = 0; i < offset; i++)
- downstream.accept(array[i]);
+ if (!cancellationWasRequested) {
+ for (int i = 0; i < offset; i++)
+ downstream.accept(array[i]);
+ }
+ else {
+ for (int i = 0; i < offset && !downstream.cancellationRequested(); i++)
+ downstream.accept(array[i]);
+ }
downstream.end();
array = null;
}
@@ -520,7 +661,7 @@
/**
* {@link Sink} for implementing sort on double streams.
*/
- private static final class DoubleSortingSink extends Sink.ChainedDouble {
+ private static final class DoubleSortingSink extends AbstractDoubleSortingSink {
private SpinedBuffer.OfDouble b;
DoubleSortingSink(Sink super Double> sink) {
@@ -539,8 +680,16 @@
double[] doubles = b.asPrimitiveArray();
Arrays.sort(doubles);
downstream.begin(doubles.length);
- for (double aDouble : doubles)
- downstream.accept(aDouble);
+ if (!cancellationWasRequested) {
+ for (double aDouble : doubles)
+ downstream.accept(aDouble);
+ }
+ else {
+ for (double aDouble : doubles) {
+ if (downstream.cancellationRequested()) break;
+ downstream.accept(aDouble);
+ }
+ }
downstream.end();
}
diff -r f524e23d7f7b -r 0e9ab834f44a jdk/test/java/util/stream/test/org/openjdk/tests/java/util/stream/SortedOpTest.java
--- a/jdk/test/java/util/stream/test/org/openjdk/tests/java/util/stream/SortedOpTest.java Tue May 06 10:28:48 2014 +0400
+++ b/jdk/test/java/util/stream/test/org/openjdk/tests/java/util/stream/SortedOpTest.java Tue May 06 10:29:59 2014 +0200
@@ -26,6 +26,9 @@
import java.util.*;
import java.util.Spliterators;
+import java.util.concurrent.atomic.AtomicInteger;
+import java.util.function.BiFunction;
+import java.util.function.Consumer;
import java.util.function.Function;
import java.util.function.Supplier;
import java.util.stream.*;
@@ -122,24 +125,33 @@
@Test(groups = { "serialization-hostile" })
public void testSequentialShortCircuitTerminal() {
- // The sorted op for sequential evaluation will buffer all elements when accepting
- // then at the end sort those elements and push those elements downstream
+ // The sorted op for sequential evaluation will buffer all elements when
+ // accepting then at the end sort those elements and push those elements
+ // downstream
+ // A peek operation is added in-between the sorted() and terminal
+ // operation that counts the number of calls to its consumer and
+ // asserts that the number of calls is at most the required quantity
List l = Arrays.asList(5, 4, 3, 2, 1);
+ Function> knownSize = i -> assertNCallsOnly(
+ l.stream().sorted(), Stream::peek, i);
+ Function> unknownSize = i -> assertNCallsOnly
+ (unknownSizeStream(l).sorted(), Stream::peek, i);
+
// Find
- assertEquals(l.stream().sorted().findFirst(), Optional.of(1));
- assertEquals(l.stream().sorted().findAny(), Optional.of(1));
- assertEquals(unknownSizeStream(l).sorted().findFirst(), Optional.of(1));
- assertEquals(unknownSizeStream(l).sorted().findAny(), Optional.of(1));
+ assertEquals(knownSize.apply(1).findFirst(), Optional.of(1));
+ assertEquals(knownSize.apply(1).findAny(), Optional.of(1));
+ assertEquals(unknownSize.apply(1).findFirst(), Optional.of(1));
+ assertEquals(unknownSize.apply(1).findAny(), Optional.of(1));
// Match
- assertEquals(l.stream().sorted().anyMatch(i -> i == 2), true);
- assertEquals(l.stream().sorted().noneMatch(i -> i == 2), false);
- assertEquals(l.stream().sorted().allMatch(i -> i == 2), false);
- assertEquals(unknownSizeStream(l).sorted().anyMatch(i -> i == 2), true);
- assertEquals(unknownSizeStream(l).sorted().noneMatch(i -> i == 2), false);
- assertEquals(unknownSizeStream(l).sorted().allMatch(i -> i == 2), false);
+ assertEquals(knownSize.apply(2).anyMatch(i -> i == 2), true);
+ assertEquals(knownSize.apply(2).noneMatch(i -> i == 2), false);
+ assertEquals(knownSize.apply(2).allMatch(i -> i == 2), false);
+ assertEquals(unknownSize.apply(2).anyMatch(i -> i == 2), true);
+ assertEquals(unknownSize.apply(2).noneMatch(i -> i == 2), false);
+ assertEquals(unknownSize.apply(2).allMatch(i -> i == 2), false);
}
private Stream unknownSizeStream(List l) {
@@ -199,19 +211,24 @@
public void testIntSequentialShortCircuitTerminal() {
int[] a = new int[]{5, 4, 3, 2, 1};
+ Function knownSize = i -> assertNCallsOnly(
+ Arrays.stream(a).sorted(), (s, c) -> s.peek(c::accept), i);
+ Function unknownSize = i -> assertNCallsOnly
+ (unknownSizeIntStream(a).sorted(), (s, c) -> s.peek(c::accept), i);
+
// Find
- assertEquals(Arrays.stream(a).sorted().findFirst(), OptionalInt.of(1));
- assertEquals(Arrays.stream(a).sorted().findAny(), OptionalInt.of(1));
- assertEquals(unknownSizeIntStream(a).sorted().findFirst(), OptionalInt.of(1));
- assertEquals(unknownSizeIntStream(a).sorted().findAny(), OptionalInt.of(1));
+ assertEquals(knownSize.apply(1).findFirst(), OptionalInt.of(1));
+ assertEquals(knownSize.apply(1).findAny(), OptionalInt.of(1));
+ assertEquals(unknownSize.apply(1).findFirst(), OptionalInt.of(1));
+ assertEquals(unknownSize.apply(1).findAny(), OptionalInt.of(1));
// Match
- assertEquals(Arrays.stream(a).sorted().anyMatch(i -> i == 2), true);
- assertEquals(Arrays.stream(a).sorted().noneMatch(i -> i == 2), false);
- assertEquals(Arrays.stream(a).sorted().allMatch(i -> i == 2), false);
- assertEquals(unknownSizeIntStream(a).sorted().anyMatch(i -> i == 2), true);
- assertEquals(unknownSizeIntStream(a).sorted().noneMatch(i -> i == 2), false);
- assertEquals(unknownSizeIntStream(a).sorted().allMatch(i -> i == 2), false);
+ assertEquals(knownSize.apply(2).anyMatch(i -> i == 2), true);
+ assertEquals(knownSize.apply(2).noneMatch(i -> i == 2), false);
+ assertEquals(knownSize.apply(2).allMatch(i -> i == 2), false);
+ assertEquals(unknownSize.apply(2).anyMatch(i -> i == 2), true);
+ assertEquals(unknownSize.apply(2).noneMatch(i -> i == 2), false);
+ assertEquals(unknownSize.apply(2).allMatch(i -> i == 2), false);
}
private IntStream unknownSizeIntStream(int[] a) {
@@ -242,19 +259,24 @@
public void testLongSequentialShortCircuitTerminal() {
long[] a = new long[]{5, 4, 3, 2, 1};
+ Function knownSize = i -> assertNCallsOnly(
+ Arrays.stream(a).sorted(), (s, c) -> s.peek(c::accept), i);
+ Function unknownSize = i -> assertNCallsOnly
+ (unknownSizeLongStream(a).sorted(), (s, c) -> s.peek(c::accept), i);
+
// Find
- assertEquals(Arrays.stream(a).sorted().findFirst(), OptionalLong.of(1));
- assertEquals(Arrays.stream(a).sorted().findAny(), OptionalLong.of(1));
- assertEquals(unknownSizeLongStream(a).sorted().findFirst(), OptionalLong.of(1));
- assertEquals(unknownSizeLongStream(a).sorted().findAny(), OptionalLong.of(1));
+ assertEquals(knownSize.apply(1).findFirst(), OptionalLong.of(1));
+ assertEquals(knownSize.apply(1).findAny(), OptionalLong.of(1));
+ assertEquals(unknownSize.apply(1).findFirst(), OptionalLong.of(1));
+ assertEquals(unknownSize.apply(1).findAny(), OptionalLong.of(1));
// Match
- assertEquals(Arrays.stream(a).sorted().anyMatch(i -> i == 2), true);
- assertEquals(Arrays.stream(a).sorted().noneMatch(i -> i == 2), false);
- assertEquals(Arrays.stream(a).sorted().allMatch(i -> i == 2), false);
- assertEquals(unknownSizeLongStream(a).sorted().anyMatch(i -> i == 2), true);
- assertEquals(unknownSizeLongStream(a).sorted().noneMatch(i -> i == 2), false);
- assertEquals(unknownSizeLongStream(a).sorted().allMatch(i -> i == 2), false);
+ assertEquals(knownSize.apply(2).anyMatch(i -> i == 2), true);
+ assertEquals(knownSize.apply(2).noneMatch(i -> i == 2), false);
+ assertEquals(knownSize.apply(2).allMatch(i -> i == 2), false);
+ assertEquals(unknownSize.apply(2).anyMatch(i -> i == 2), true);
+ assertEquals(unknownSize.apply(2).noneMatch(i -> i == 2), false);
+ assertEquals(unknownSize.apply(2).allMatch(i -> i == 2), false);
}
private LongStream unknownSizeLongStream(long[] a) {
@@ -285,19 +307,24 @@
public void testDoubleSequentialShortCircuitTerminal() {
double[] a = new double[]{5.0, 4.0, 3.0, 2.0, 1.0};
+ Function knownSize = i -> assertNCallsOnly(
+ Arrays.stream(a).sorted(), (s, c) -> s.peek(c::accept), i);
+ Function unknownSize = i -> assertNCallsOnly
+ (unknownSizeDoubleStream(a).sorted(), (s, c) -> s.peek(c::accept), i);
+
// Find
- assertEquals(Arrays.stream(a).sorted().findFirst(), OptionalDouble.of(1));
- assertEquals(Arrays.stream(a).sorted().findAny(), OptionalDouble.of(1));
- assertEquals(unknownSizeDoubleStream(a).sorted().findFirst(), OptionalDouble.of(1));
- assertEquals(unknownSizeDoubleStream(a).sorted().findAny(), OptionalDouble.of(1));
+ assertEquals(knownSize.apply(1).findFirst(), OptionalDouble.of(1));
+ assertEquals(knownSize.apply(1).findAny(), OptionalDouble.of(1));
+ assertEquals(unknownSize.apply(1).findFirst(), OptionalDouble.of(1));
+ assertEquals(unknownSize.apply(1).findAny(), OptionalDouble.of(1));
// Match
- assertEquals(Arrays.stream(a).sorted().anyMatch(i -> i == 2.0), true);
- assertEquals(Arrays.stream(a).sorted().noneMatch(i -> i == 2.0), false);
- assertEquals(Arrays.stream(a).sorted().allMatch(i -> i == 2.0), false);
- assertEquals(unknownSizeDoubleStream(a).sorted().anyMatch(i -> i == 2.0), true);
- assertEquals(unknownSizeDoubleStream(a).sorted().noneMatch(i -> i == 2.0), false);
- assertEquals(unknownSizeDoubleStream(a).sorted().allMatch(i -> i == 2.0), false);
+ assertEquals(knownSize.apply(2).anyMatch(i -> i == 2.0), true);
+ assertEquals(knownSize.apply(2).noneMatch(i -> i == 2.0), false);
+ assertEquals(knownSize.apply(2).allMatch(i -> i == 2.0), false);
+ assertEquals(unknownSize.apply(2).anyMatch(i -> i == 2.0), true);
+ assertEquals(unknownSize.apply(2).noneMatch(i -> i == 2.0), false);
+ assertEquals(unknownSize.apply(2).allMatch(i -> i == 2.0), false);
}
private DoubleStream unknownSizeDoubleStream(double[] a) {
@@ -321,4 +348,14 @@
assertSorted(result);
assertContentsUnordered(data, result);
}
+
+ /**
+ * Interpose a consumer that asserts it is called at most N times.
+ */
+ , R> S assertNCallsOnly(S s, BiFunction, S> pf, int n) {
+ AtomicInteger boxedInt = new AtomicInteger();
+ return pf.apply(s, i -> {
+ assertFalse(boxedInt.incrementAndGet() > n, "Intermediate op called more than " + n + " time(s)");
+ });
+ }
}