8020016: Numerous splitereator impls do not throw NPE for null Consumers
authorpsandoz
Thu, 01 Aug 2013 15:28:57 +0100
changeset 19188 bbf287c5cd92
parent 19187 5aa85bc92303
child 19189 a4b8478a2bc5
8020016: Numerous splitereator impls do not throw NPE for null Consumers Reviewed-by: mduigou, alanb, henryjen
jdk/src/share/classes/java/util/stream/SpinedBuffer.java
jdk/src/share/classes/java/util/stream/StreamSpliterators.java
jdk/src/share/classes/java/util/stream/Streams.java
jdk/test/java/util/Spliterator/SpliteratorTraversingAndSplittingTest.java
jdk/test/java/util/stream/bootlib/java/util/stream/SpliteratorTestHelper.java
--- a/jdk/src/share/classes/java/util/stream/SpinedBuffer.java	Thu Aug 01 16:53:40 2013 +0100
+++ b/jdk/src/share/classes/java/util/stream/SpinedBuffer.java	Thu Aug 01 15:28:57 2013 +0100
@@ -28,6 +28,7 @@
 import java.util.Arrays;
 import java.util.Iterator;
 import java.util.List;
+import java.util.Objects;
 import java.util.PrimitiveIterator;
 import java.util.Spliterator;
 import java.util.Spliterators;
@@ -317,6 +318,8 @@
 
             @Override
             public boolean tryAdvance(Consumer<? super E> consumer) {
+                Objects.requireNonNull(consumer);
+
                 if (splSpineIndex < lastSpineIndex
                     || (splSpineIndex == lastSpineIndex && splElementIndex < lastSpineElementFence)) {
                     consumer.accept(splChunk[splElementIndex++]);
@@ -334,6 +337,8 @@
 
             @Override
             public void forEachRemaining(Consumer<? super E> consumer) {
+                Objects.requireNonNull(consumer);
+
                 if (splSpineIndex < lastSpineIndex
                     || (splSpineIndex == lastSpineIndex && splElementIndex < lastSpineElementFence)) {
                     int i = splElementIndex;
@@ -634,6 +639,8 @@
 
             @Override
             public boolean tryAdvance(T_CONS consumer) {
+                Objects.requireNonNull(consumer);
+
                 if (splSpineIndex < lastSpineIndex
                     || (splSpineIndex == lastSpineIndex && splElementIndex < lastSpineElementFence)) {
                     arrayForOne(splChunk, splElementIndex++, consumer);
@@ -651,6 +658,8 @@
 
             @Override
             public void forEachRemaining(T_CONS consumer) {
+                Objects.requireNonNull(consumer);
+
                 if (splSpineIndex < lastSpineIndex
                     || (splSpineIndex == lastSpineIndex && splElementIndex < lastSpineElementFence)) {
                     int i = splElementIndex;
--- a/jdk/src/share/classes/java/util/stream/StreamSpliterators.java	Thu Aug 01 16:53:40 2013 +0100
+++ b/jdk/src/share/classes/java/util/stream/StreamSpliterators.java	Thu Aug 01 15:28:57 2013 +0100
@@ -25,6 +25,7 @@
 package java.util.stream;
 
 import java.util.Comparator;
+import java.util.Objects;
 import java.util.Spliterator;
 import java.util.concurrent.atomic.AtomicLong;
 import java.util.function.BooleanSupplier;
@@ -294,6 +295,7 @@
 
         @Override
         public boolean tryAdvance(Consumer<? super P_OUT> consumer) {
+            Objects.requireNonNull(consumer);
             boolean hasNext = doAdvance();
             if (hasNext)
                 consumer.accept(buffer.get(nextToConsume));
@@ -303,6 +305,7 @@
         @Override
         public void forEachRemaining(Consumer<? super P_OUT> consumer) {
             if (buffer == null && !finished) {
+                Objects.requireNonNull(consumer);
                 init();
 
                 ph.wrapAndCopyInto((Sink<P_OUT>) consumer::accept, spliterator);
@@ -350,6 +353,7 @@
 
         @Override
         public boolean tryAdvance(IntConsumer consumer) {
+            Objects.requireNonNull(consumer);
             boolean hasNext = doAdvance();
             if (hasNext)
                 consumer.accept(buffer.get(nextToConsume));
@@ -359,6 +363,7 @@
         @Override
         public void forEachRemaining(IntConsumer consumer) {
             if (buffer == null && !finished) {
+                Objects.requireNonNull(consumer);
                 init();
 
                 ph.wrapAndCopyInto((Sink.OfInt) consumer::accept, spliterator);
@@ -406,6 +411,7 @@
 
         @Override
         public boolean tryAdvance(LongConsumer consumer) {
+            Objects.requireNonNull(consumer);
             boolean hasNext = doAdvance();
             if (hasNext)
                 consumer.accept(buffer.get(nextToConsume));
@@ -415,6 +421,7 @@
         @Override
         public void forEachRemaining(LongConsumer consumer) {
             if (buffer == null && !finished) {
+                Objects.requireNonNull(consumer);
                 init();
 
                 ph.wrapAndCopyInto((Sink.OfLong) consumer::accept, spliterator);
@@ -462,6 +469,7 @@
 
         @Override
         public boolean tryAdvance(DoubleConsumer consumer) {
+            Objects.requireNonNull(consumer);
             boolean hasNext = doAdvance();
             if (hasNext)
                 consumer.accept(buffer.get(nextToConsume));
@@ -471,6 +479,7 @@
         @Override
         public void forEachRemaining(DoubleConsumer consumer) {
             if (buffer == null && !finished) {
+                Objects.requireNonNull(consumer);
                 init();
 
                 ph.wrapAndCopyInto((Sink.OfDouble) consumer::accept, spliterator);
@@ -696,6 +705,8 @@
 
             @Override
             public boolean tryAdvance(Consumer<? super T> action) {
+                Objects.requireNonNull(action);
+
                 if (sliceOrigin >= fence)
                     return false;
 
@@ -713,6 +724,8 @@
 
             @Override
             public void forEachRemaining(Consumer<? super T> action) {
+                Objects.requireNonNull(action);
+
                 if (sliceOrigin >= fence)
                     return;
 
@@ -754,6 +767,8 @@
 
             @Override
             public boolean tryAdvance(T_CONS action) {
+                Objects.requireNonNull(action);
+
                 if (sliceOrigin >= fence)
                     return false;
 
@@ -771,6 +786,8 @@
 
             @Override
             public void forEachRemaining(T_CONS action) {
+                Objects.requireNonNull(action);
+
                 if (sliceOrigin >= fence)
                     return;
 
@@ -985,6 +1002,8 @@
 
             @Override
             public boolean tryAdvance(Consumer<? super T> action) {
+                Objects.requireNonNull(action);
+
                 while (permitStatus() != PermitStatus.NO_MORE) {
                     if (!s.tryAdvance(this))
                         return false;
@@ -999,6 +1018,8 @@
 
             @Override
             public void forEachRemaining(Consumer<? super T> action) {
+                Objects.requireNonNull(action);
+
                 ArrayBuffer.OfRef<T> sb = null;
                 PermitStatus permitStatus;
                 while ((permitStatus = permitStatus()) != PermitStatus.NO_MORE) {
@@ -1051,6 +1072,8 @@
 
             @Override
             public boolean tryAdvance(T_CONS action) {
+                Objects.requireNonNull(action);
+
                 while (permitStatus() != PermitStatus.NO_MORE) {
                     if (!s.tryAdvance((T_CONS) this))
                         return false;
@@ -1066,6 +1089,8 @@
 
             @Override
             public void forEachRemaining(T_CONS action) {
+                Objects.requireNonNull(action);
+
                 T_BUFF sb = null;
                 PermitStatus permitStatus;
                 while ((permitStatus = permitStatus()) != PermitStatus.NO_MORE) {
@@ -1237,6 +1262,8 @@
 
             @Override
             public boolean tryAdvance(Consumer<? super T> action) {
+                Objects.requireNonNull(action);
+
                 action.accept(s.get());
                 return true;
             }
@@ -1260,6 +1287,8 @@
 
             @Override
             public boolean tryAdvance(IntConsumer action) {
+                Objects.requireNonNull(action);
+
                 action.accept(s.getAsInt());
                 return true;
             }
@@ -1283,6 +1312,8 @@
 
             @Override
             public boolean tryAdvance(LongConsumer action) {
+                Objects.requireNonNull(action);
+
                 action.accept(s.getAsLong());
                 return true;
             }
@@ -1306,6 +1337,8 @@
 
             @Override
             public boolean tryAdvance(DoubleConsumer action) {
+                Objects.requireNonNull(action);
+
                 action.accept(s.getAsDouble());
                 return true;
             }
--- a/jdk/src/share/classes/java/util/stream/Streams.java	Thu Aug 01 16:53:40 2013 +0100
+++ b/jdk/src/share/classes/java/util/stream/Streams.java	Thu Aug 01 15:28:57 2013 +0100
@@ -25,6 +25,7 @@
 package java.util.stream;
 
 import java.util.Comparator;
+import java.util.Objects;
 import java.util.Spliterator;
 import java.util.function.Consumer;
 import java.util.function.DoubleConsumer;
@@ -80,6 +81,8 @@
 
         @Override
         public boolean tryAdvance(IntConsumer consumer) {
+            Objects.requireNonNull(consumer);
+
             final int i = from;
             if (i < upTo) {
                 from++;
@@ -96,6 +99,8 @@
 
         @Override
         public void forEachRemaining(IntConsumer consumer) {
+            Objects.requireNonNull(consumer);
+
             int i = from;
             final int hUpTo = upTo;
             int hLast = last;
@@ -199,6 +204,8 @@
 
         @Override
         public boolean tryAdvance(LongConsumer consumer) {
+            Objects.requireNonNull(consumer);
+
             final long i = from;
             if (i < upTo) {
                 from++;
@@ -215,6 +222,8 @@
 
         @Override
         public void forEachRemaining(LongConsumer consumer) {
+            Objects.requireNonNull(consumer);
+
             long i = from;
             final long hUpTo = upTo;
             int hLast = last;
@@ -388,6 +397,8 @@
 
         @Override
         public boolean tryAdvance(Consumer<? super T> action) {
+            Objects.requireNonNull(action);
+
             if (count == -2) {
                 action.accept(first);
                 count = -1;
@@ -400,6 +411,8 @@
 
         @Override
         public void forEachRemaining(Consumer<? super T> action) {
+            Objects.requireNonNull(action);
+
             if (count == -2) {
                 action.accept(first);
                 count = -1;
@@ -475,6 +488,8 @@
 
         @Override
         public boolean tryAdvance(IntConsumer action) {
+            Objects.requireNonNull(action);
+
             if (count == -2) {
                 action.accept(first);
                 count = -1;
@@ -487,6 +502,8 @@
 
         @Override
         public void forEachRemaining(IntConsumer action) {
+            Objects.requireNonNull(action);
+
             if (count == -2) {
                 action.accept(first);
                 count = -1;
@@ -562,6 +579,8 @@
 
         @Override
         public boolean tryAdvance(LongConsumer action) {
+            Objects.requireNonNull(action);
+
             if (count == -2) {
                 action.accept(first);
                 count = -1;
@@ -574,6 +593,8 @@
 
         @Override
         public void forEachRemaining(LongConsumer action) {
+            Objects.requireNonNull(action);
+
             if (count == -2) {
                 action.accept(first);
                 count = -1;
@@ -649,6 +670,8 @@
 
         @Override
         public boolean tryAdvance(DoubleConsumer action) {
+            Objects.requireNonNull(action);
+
             if (count == -2) {
                 action.accept(first);
                 count = -1;
@@ -661,6 +684,8 @@
 
         @Override
         public void forEachRemaining(DoubleConsumer action) {
+            Objects.requireNonNull(action);
+
             if (count == -2) {
                 action.accept(first);
                 count = -1;
--- a/jdk/test/java/util/Spliterator/SpliteratorTraversingAndSplittingTest.java	Thu Aug 01 16:53:40 2013 +0100
+++ b/jdk/test/java/util/Spliterator/SpliteratorTraversingAndSplittingTest.java	Thu Aug 01 15:28:57 2013 +0100
@@ -81,6 +81,10 @@
 import static org.testng.Assert.*;
 import static org.testng.Assert.assertEquals;
 
+/**
+ * @test
+ * @bug 8020016
+ */
 @Test
 public class SpliteratorTraversingAndSplittingTest {
 
@@ -386,11 +390,23 @@
 
             db.addCollection(CopyOnWriteArraySet::new);
 
-            if (size == 1) {
+            if (size == 0) {
+                db.addCollection(c -> Collections.<Integer>emptySet());
+                db.addList(c -> Collections.<Integer>emptyList());
+            }
+            else if (size == 1) {
                 db.addCollection(c -> Collections.singleton(exp.get(0)));
                 db.addCollection(c -> Collections.singletonList(exp.get(0)));
             }
 
+            {
+                Integer[] ai = new Integer[size];
+                Arrays.fill(ai, 1);
+                db.add(String.format("Collections.nCopies(%d, 1)", exp.size()),
+                       Arrays.asList(ai),
+                       () -> Collections.nCopies(exp.size(), 1).spliterator());
+            }
+
             // Collections.synchronized/unmodifiable/checked wrappers
             db.addCollection(Collections::unmodifiableCollection);
             db.addCollection(c -> Collections.unmodifiableSet(new HashSet<>(c)));
@@ -454,6 +470,13 @@
             db.addMap(ConcurrentHashMap::new);
 
             db.addMap(ConcurrentSkipListMap::new);
+
+            if (size == 0) {
+                db.addMap(m -> Collections.<Integer, Integer>emptyMap());
+            }
+            else if (size == 1) {
+                db.addMap(m -> Collections.singletonMap(exp.get(0), exp.get(0)));
+            }
         }
 
         return spliteratorDataProvider = data.toArray(new Object[0][]);
--- a/jdk/test/java/util/stream/bootlib/java/util/stream/SpliteratorTestHelper.java	Thu Aug 01 16:53:40 2013 +0100
+++ b/jdk/test/java/util/stream/bootlib/java/util/stream/SpliteratorTestHelper.java	Thu Aug 01 15:28:57 2013 +0100
@@ -22,6 +22,8 @@
  */
 package java.util.stream;
 
+import org.testng.annotations.Test;
+
 import java.util.ArrayDeque;
 import java.util.ArrayList;
 import java.util.Collection;
@@ -154,6 +156,7 @@
 
         Collection<T> exp = Collections.unmodifiableList(fromForEach);
 
+        testNullPointerException(supplier);
         testForEach(exp, supplier, boxingAdapter, asserter);
         testTryAdvance(exp, supplier, boxingAdapter, asserter);
         testMixedTryAdvanceForEach(exp, supplier, boxingAdapter, asserter);
@@ -166,6 +169,31 @@
 
     //
 
+    private static <T, S extends Spliterator<T>> void testNullPointerException(Supplier<S> s) {
+        S sp = s.get();
+        // Have to check instances and use casts to avoid tripwire messages and
+        // directly test the primitive methods
+        if (sp instanceof Spliterator.OfInt) {
+            Spliterator.OfInt psp = (Spliterator.OfInt) sp;
+            executeAndCatch(NullPointerException.class, () -> psp.forEachRemaining((IntConsumer) null));
+            executeAndCatch(NullPointerException.class, () -> psp.tryAdvance((IntConsumer) null));
+        }
+        else if (sp instanceof Spliterator.OfLong) {
+            Spliterator.OfLong psp = (Spliterator.OfLong) sp;
+            executeAndCatch(NullPointerException.class, () -> psp.forEachRemaining((LongConsumer) null));
+            executeAndCatch(NullPointerException.class, () -> psp.tryAdvance((LongConsumer) null));
+        }
+        else if (sp instanceof Spliterator.OfDouble) {
+            Spliterator.OfDouble psp = (Spliterator.OfDouble) sp;
+            executeAndCatch(NullPointerException.class, () -> psp.forEachRemaining((DoubleConsumer) null));
+            executeAndCatch(NullPointerException.class, () -> psp.tryAdvance((DoubleConsumer) null));
+        }
+        else {
+            executeAndCatch(NullPointerException.class, () -> sp.forEachRemaining(null));
+            executeAndCatch(NullPointerException.class, () -> sp.tryAdvance(null));
+        }
+    }
+
     private static <T, S extends Spliterator<T>> void testForEach(
             Collection<T> exp,
             Supplier<S> supplier,
@@ -573,6 +601,23 @@
         }
     }
 
+    private static void executeAndCatch(Class<? extends Exception> expected, Runnable r) {
+        Exception caught = null;
+        try {
+            r.run();
+        }
+        catch (Exception e) {
+            caught = e;
+        }
+
+        assertNotNull(caught,
+                      String.format("No Exception was thrown, expected an Exception of %s to be thrown",
+                                    expected.getName()));
+        assertTrue(expected.isInstance(caught),
+                   String.format("Exception thrown %s not an instance of %s",
+                                 caught.getClass().getName(), expected.getName()));
+    }
+
     static<U> void mixedTraverseAndSplit(Consumer<U> b, Spliterator<U> splTop) {
         Spliterator<U> spl1, spl2, spl3;
         splTop.tryAdvance(b);