8075307: Pipeline calculating inconsistent flag state for parallel stateful ops
authorpsandoz
Wed, 25 Mar 2015 10:50:08 +0000
changeset 29617 4922c98744c7
parent 29616 5a1a6f9fb891
child 29618 eb6ff6f9d1ca
child 29712 833acdf3b1d1
8075307: Pipeline calculating inconsistent flag state for parallel stateful ops Reviewed-by: smarks
jdk/src/java.base/share/classes/java/util/stream/AbstractPipeline.java
jdk/src/java.base/share/classes/java/util/stream/ReduceOps.java
jdk/test/java/util/stream/boottest/java/util/stream/UnorderedTest.java
jdk/test/java/util/stream/test/org/openjdk/tests/java/util/stream/CountTest.java
jdk/test/java/util/stream/test/org/openjdk/tests/java/util/stream/DistinctOpTest.java
--- a/jdk/src/java.base/share/classes/java/util/stream/AbstractPipeline.java	Wed Mar 25 17:59:59 2015 +0900
+++ b/jdk/src/java.base/share/classes/java/util/stream/AbstractPipeline.java	Wed Mar 25 10:50:08 2015 +0000
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2012, 2014, Oracle and/or its affiliates. All rights reserved.
+ * Copyright (c) 2012, 2015, Oracle and/or its affiliates. All rights reserved.
  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
  *
  * This code is free software; you can redistribute it and/or modify it
@@ -380,60 +380,6 @@
     }
 
     /**
-     * Prepare the pipeline for a parallel execution.  As the pipeline is built,
-     * the flags and depth indicators are set up for a sequential execution.
-     * If the execution is parallel, and there are any stateful operations, then
-     * some of these need to be adjusted, as well as adjusting for flags from
-     * the terminal operation (such as back-propagating UNORDERED).
-     * Need not be called for a sequential execution.
-     *
-     * @param terminalFlags Operation flags for the terminal operation
-     */
-    private void parallelPrepare(int terminalFlags) {
-        @SuppressWarnings("rawtypes")
-        AbstractPipeline backPropagationHead = sourceStage;
-        if (sourceStage.sourceAnyStateful) {
-            int depth = 1;
-            for (  @SuppressWarnings("rawtypes") AbstractPipeline u = sourceStage, p = sourceStage.nextStage;
-                 p != null;
-                 u = p, p = p.nextStage) {
-                int thisOpFlags = p.sourceOrOpFlags;
-                if (p.opIsStateful()) {
-                    // If the stateful operation is a short-circuit operation
-                    // then move the back propagation head forwards
-                    // NOTE: there are no size-injecting ops
-                    if (StreamOpFlag.SHORT_CIRCUIT.isKnown(thisOpFlags)) {
-                        backPropagationHead = p;
-                        // Clear the short circuit flag for next pipeline stage
-                        // This stage encapsulates short-circuiting, the next
-                        // stage may not have any short-circuit operations, and
-                        // if so spliterator.forEachRemaining should be used
-                        // for traversal
-                        thisOpFlags = thisOpFlags & ~StreamOpFlag.IS_SHORT_CIRCUIT;
-                    }
-
-                    depth = 0;
-                    // The following injects size, it is equivalent to:
-                    // StreamOpFlag.combineOpFlags(StreamOpFlag.IS_SIZED, p.combinedFlags);
-                    thisOpFlags = (thisOpFlags & ~StreamOpFlag.NOT_SIZED) | StreamOpFlag.IS_SIZED;
-                }
-                p.depth = depth++;
-                p.combinedFlags = StreamOpFlag.combineOpFlags(thisOpFlags, u.combinedFlags);
-            }
-        }
-
-        // Apply the upstream terminal flags
-        if (terminalFlags != 0) {
-            int upstreamTerminalFlags = terminalFlags & StreamOpFlag.UPSTREAM_TERMINAL_OP_MASK;
-            for ( @SuppressWarnings("rawtypes") AbstractPipeline p = backPropagationHead; p.nextStage != null; p = p.nextStage) {
-                p.combinedFlags = StreamOpFlag.combineOpFlags(upstreamTerminalFlags, p.combinedFlags);
-            }
-
-            combinedFlags = StreamOpFlag.combineOpFlags(terminalFlags, combinedFlags);
-        }
-    }
-
-    /**
      * Get the source spliterator for this pipeline stage.  For a sequential or
      * stateless parallel pipeline, this is the source spliterator.  For a
      * stateful parallel pipeline, this is a spliterator describing the results
@@ -456,24 +402,70 @@
             throw new IllegalStateException(MSG_CONSUMED);
         }
 
-        if (isParallel()) {
-            // @@@ Merge parallelPrepare with the loop below and use the
-            //     spliterator characteristics to determine if SIZED
-            //     should be injected
-            parallelPrepare(terminalFlags);
+        boolean hasTerminalFlags = terminalFlags != 0;
+        if (isParallel() && sourceStage.sourceAnyStateful) {
+            // Adjust pipeline stages if there are stateful ops,
+            // and find the last short circuiting op, if any, that
+            // defines the head stage for back-propagation of terminal flags
+            @SuppressWarnings("rawtypes")
+            AbstractPipeline backPropagationHead = sourceStage;
+            int depth = 1;
+            for (@SuppressWarnings("rawtypes") AbstractPipeline p = sourceStage.nextStage;
+                 p != null;
+                 p = p.nextStage) {
+                if (p.opIsStateful()) {
+                    if (StreamOpFlag.SHORT_CIRCUIT.isKnown(p.sourceOrOpFlags)) {
+                        // If the stateful operation is a short-circuit operation
+                        // then move the back propagation head forwards
+                        // NOTE: there are no size-injecting ops
+                        backPropagationHead = p;
+                    }
+
+                    depth = 0;
+                }
+                p.depth = depth++;
+            }
 
             // Adapt the source spliterator, evaluating each stateful op
             // in the pipeline up to and including this pipeline stage
-            for ( @SuppressWarnings("rawtypes") AbstractPipeline u = sourceStage, p = sourceStage.nextStage, e = this;
+            // Flags for each pipeline stage are adjusted accordingly
+            boolean backPropagate = false;
+            int upstreamTerminalFlags = terminalFlags & StreamOpFlag.UPSTREAM_TERMINAL_OP_MASK;
+            for (@SuppressWarnings("rawtypes") AbstractPipeline u = sourceStage, p = sourceStage.nextStage, e = this;
                  u != e;
                  u = p, p = p.nextStage) {
 
+                if (hasTerminalFlags &&
+                    (backPropagate || (backPropagate = (u == backPropagationHead)))) {
+                    // Back-propagate flags from the terminal operation
+                    u.combinedFlags = StreamOpFlag.combineOpFlags(upstreamTerminalFlags, u.combinedFlags);
+                }
+
+                int thisOpFlags = p.sourceOrOpFlags;
                 if (p.opIsStateful()) {
+                    if (StreamOpFlag.SHORT_CIRCUIT.isKnown(thisOpFlags)) {
+                        // Clear the short circuit flag for next pipeline stage
+                        // This stage encapsulates short-circuiting, the next
+                        // stage may not have any short-circuit operations, and
+                        // if so spliterator.forEachRemaining should be used
+                        // for traversal
+                        thisOpFlags = thisOpFlags & ~StreamOpFlag.IS_SHORT_CIRCUIT;
+                    }
+
                     spliterator = p.opEvaluateParallelLazy(u, spliterator);
+
+                    // Inject or clear SIZED on the source pipeline stage
+                    // based on the stage's spliterator
+                    thisOpFlags = spliterator.hasCharacteristics(Spliterator.SIZED)
+                            ? (thisOpFlags & ~StreamOpFlag.NOT_SIZED) | StreamOpFlag.IS_SIZED
+                            : (thisOpFlags & ~StreamOpFlag.IS_SIZED) | StreamOpFlag.NOT_SIZED;
                 }
+                p.combinedFlags = StreamOpFlag.combineOpFlags(thisOpFlags, u.combinedFlags);
             }
         }
-        else if (terminalFlags != 0)  {
+
+        if (hasTerminalFlags)  {
+            // Apply flags from the terminal operation to last pipeline stage
             combinedFlags = StreamOpFlag.combineOpFlags(terminalFlags, combinedFlags);
         }
 
--- a/jdk/src/java.base/share/classes/java/util/stream/ReduceOps.java	Wed Mar 25 17:59:59 2015 +0900
+++ b/jdk/src/java.base/share/classes/java/util/stream/ReduceOps.java	Wed Mar 25 10:50:08 2015 +0000
@@ -264,6 +264,11 @@
                     return spliterator.getExactSizeIfKnown();
                 return super.evaluateParallel(helper, spliterator);
             }
+
+            @Override
+            public int getOpFlags() {
+                return StreamOpFlag.NOT_ORDERED;
+            }
         };
     }
 
@@ -433,6 +438,11 @@
                     return spliterator.getExactSizeIfKnown();
                 return super.evaluateParallel(helper, spliterator);
             }
+
+            @Override
+            public int getOpFlags() {
+                return StreamOpFlag.NOT_ORDERED;
+            }
         };
     }
 
@@ -602,6 +612,11 @@
                     return spliterator.getExactSizeIfKnown();
                 return super.evaluateParallel(helper, spliterator);
             }
+
+            @Override
+            public int getOpFlags() {
+                return StreamOpFlag.NOT_ORDERED;
+            }
         };
     }
 
@@ -771,6 +786,11 @@
                     return spliterator.getExactSizeIfKnown();
                 return super.evaluateParallel(helper, spliterator);
             }
+
+            @Override
+            public int getOpFlags() {
+                return StreamOpFlag.NOT_ORDERED;
+            }
         };
     }
 
--- a/jdk/test/java/util/stream/boottest/java/util/stream/UnorderedTest.java	Wed Mar 25 17:59:59 2015 +0900
+++ b/jdk/test/java/util/stream/boottest/java/util/stream/UnorderedTest.java	Wed Mar 25 10:50:08 2015 +0000
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2013, Oracle and/or its affiliates. All rights reserved.
+ * Copyright (c) 2013, 2015, Oracle and/or its affiliates. All rights reserved.
  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
  *
  * This code is free software; you can redistribute it and/or modify it
@@ -251,7 +251,7 @@
 
         final int lastLimitIndex = l;
         return s -> {
-            if (lastLimitIndex == -1)
+            if (lastLimitIndex == -1 && fs.size() > 0)
                 s = fi.apply(s);
             for (int i = 0; i < fs.size(); i++) {
                 s = fs.get(i).apply(s);
--- a/jdk/test/java/util/stream/test/org/openjdk/tests/java/util/stream/CountTest.java	Wed Mar 25 17:59:59 2015 +0900
+++ b/jdk/test/java/util/stream/test/org/openjdk/tests/java/util/stream/CountTest.java	Wed Mar 25 10:50:08 2015 +0000
@@ -24,11 +24,12 @@
 /**
  * @test
  * @summary Tests counting of streams
- * @bug 8031187 8067969
+ * @bug 8031187 8067969 8075307
  */
 
 package org.openjdk.tests.java.util.stream;
 
+import java.util.HashSet;
 import java.util.concurrent.atomic.AtomicInteger;
 import java.util.stream.DoubleStream;
 import java.util.stream.DoubleStreamTestDataProvider;
@@ -59,6 +60,19 @@
                 terminal(s -> s.filter(e -> true), Stream::count).
                 expectedResult(expectedCount).
                 exercise();
+
+        // Test with stateful distinct op that is a barrier or lazy
+        // depending if source is not already distinct and encounter order is
+        // preserved or not
+        expectedCount = data.into(new HashSet<>()).size();
+        withData(data).
+                terminal(Stream::distinct, Stream::count).
+                expectedResult(expectedCount).
+                exercise();
+        withData(data).
+                terminal(s -> s.unordered().distinct(), Stream::count).
+                expectedResult(expectedCount).
+                exercise();
     }
 
     @Test(dataProvider = "IntStreamTestData", dataProviderClass = IntStreamTestDataProvider.class)
@@ -74,6 +88,16 @@
                 terminal(s -> s.filter(e -> true), IntStream::count).
                 expectedResult(expectedCount).
                 exercise();
+
+        expectedCount = data.into(new HashSet<>()).size();
+        withData(data).
+                terminal(IntStream::distinct, IntStream::count).
+                expectedResult(expectedCount).
+                exercise();
+        withData(data).
+                terminal(s -> s.unordered().distinct(), IntStream::count).
+                expectedResult(expectedCount).
+                exercise();
     }
 
     @Test(dataProvider = "LongStreamTestData", dataProviderClass = LongStreamTestDataProvider.class)
@@ -89,6 +113,16 @@
                 terminal(s -> s.filter(e -> true), LongStream::count).
                 expectedResult(expectedCount).
                 exercise();
+
+        expectedCount = data.into(new HashSet<>()).size();
+        withData(data).
+                terminal(LongStream::distinct, LongStream::count).
+                expectedResult(expectedCount).
+                exercise();
+        withData(data).
+                terminal(s -> s.unordered().distinct(), LongStream::count).
+                expectedResult(expectedCount).
+                exercise();
     }
 
     @Test(dataProvider = "DoubleStreamTestData", dataProviderClass = DoubleStreamTestDataProvider.class)
@@ -104,6 +138,16 @@
                 terminal(s -> s.filter(e -> true), DoubleStream::count).
                 expectedResult(expectedCount).
                 exercise();
+
+        expectedCount = data.into(new HashSet<>()).size();
+        withData(data).
+                terminal(DoubleStream::distinct, DoubleStream::count).
+                expectedResult(expectedCount).
+                exercise();
+        withData(data).
+                terminal(s -> s.unordered().distinct(), DoubleStream::count).
+                expectedResult(expectedCount).
+                exercise();
     }
 
     public void testNoEvaluationForSizedStream() {
@@ -111,24 +155,36 @@
             AtomicInteger ai = new AtomicInteger();
             Stream.of(1, 2, 3, 4).peek(e -> ai.getAndIncrement()).count();
             assertEquals(ai.get(), 0);
+
+            Stream.of(1, 2, 3, 4).peek(e -> ai.getAndIncrement()).parallel().count();
+            assertEquals(ai.get(), 0);
         }
 
         {
             AtomicInteger ai = new AtomicInteger();
             IntStream.of(1, 2, 3, 4).peek(e -> ai.getAndIncrement()).count();
             assertEquals(ai.get(), 0);
+
+            IntStream.of(1, 2, 3, 4).peek(e -> ai.getAndIncrement()).parallel().count();
+            assertEquals(ai.get(), 0);
         }
 
         {
             AtomicInteger ai = new AtomicInteger();
             LongStream.of(1, 2, 3, 4).peek(e -> ai.getAndIncrement()).count();
             assertEquals(ai.get(), 0);
+
+            LongStream.of(1, 2, 3, 4).peek(e -> ai.getAndIncrement()).parallel().count();
+            assertEquals(ai.get(), 0);
         }
 
         {
             AtomicInteger ai = new AtomicInteger();
             DoubleStream.of(1, 2, 3, 4).peek(e -> ai.getAndIncrement()).count();
             assertEquals(ai.get(), 0);
+
+            DoubleStream.of(1, 2, 3, 4).peek(e -> ai.getAndIncrement()).parallel().count();
+            assertEquals(ai.get(), 0);
         }
     }
 }
--- a/jdk/test/java/util/stream/test/org/openjdk/tests/java/util/stream/DistinctOpTest.java	Wed Mar 25 17:59:59 2015 +0900
+++ b/jdk/test/java/util/stream/test/org/openjdk/tests/java/util/stream/DistinctOpTest.java	Wed Mar 25 10:50:08 2015 +0000
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2012, 2013, Oracle and/or its affiliates. All rights reserved.
+ * Copyright (c) 2012, 2015, Oracle and/or its affiliates. All rights reserved.
  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
  *
  * This code is free software; you can redistribute it and/or modify it
@@ -128,7 +128,7 @@
     @Test(dataProvider = "StreamTestData<Integer>", dataProviderClass = StreamTestDataProvider.class)
     public void testDistinctDistinct(String name, TestData.OfRef<Integer> data) {
         Collection<Integer> result = withData(data)
-                .stream(s -> s.distinct().distinct(), new CollectorOps.TestParallelSizedOp<>())
+                .stream(s -> s.distinct().distinct())
                 .exercise();
         assertUnique(result);
     }