8072784: Better spliterator implementation for BitSet.stream()
authorpsandoz
Wed, 16 Nov 2016 14:26:14 -0800
changeset 42157 3e87fa9d8226
parent 42156 7ccdf3aa0f8c
child 42158 80c04775edbd
8072784: Better spliterator implementation for BitSet.stream() Reviewed-by: martin
jdk/src/java.base/share/classes/java/util/BitSet.java
jdk/test/java/util/BitSet/BitSetStreamTest.java
jdk/test/java/util/Spliterator/SpliteratorTraversingAndSplittingTest.java
--- a/jdk/src/java.base/share/classes/java/util/BitSet.java	Wed Nov 16 14:26:12 2016 -0800
+++ b/jdk/src/java.base/share/classes/java/util/BitSet.java	Wed Nov 16 14:26:14 2016 -0800
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 1995, 2014, Oracle and/or its affiliates. All rights reserved.
+ * Copyright (c) 1995, 2016, Oracle and/or its affiliates. All rights reserved.
  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
  *
  * This code is free software; you can redistribute it and/or modify it
@@ -29,6 +29,7 @@
 import java.nio.ByteBuffer;
 import java.nio.ByteOrder;
 import java.nio.LongBuffer;
+import java.util.function.IntConsumer;
 import java.util.stream.IntStream;
 import java.util.stream.StreamSupport;
 
@@ -1217,32 +1218,156 @@
      * @since 1.8
      */
     public IntStream stream() {
-        class BitSetIterator implements PrimitiveIterator.OfInt {
-            int next = nextSetBit(0);
+        class BitSetSpliterator implements Spliterator.OfInt {
+            private int index; // current bit index for a set bit
+            private int fence; // -1 until used; then one past last bit index
+            private int est;   // size estimate
+            private boolean root; // true if root and not split
+            // root == true then size estimate is accurate
+            // index == -1 or index >= fence if fully traversed
+            // Special case when the max bit set is Integer.MAX_VALUE
+
+            BitSetSpliterator(int origin, int fence, int est, boolean root) {
+                this.index = origin;
+                this.fence = fence;
+                this.est = est;
+                this.root = root;
+            }
+
+            private int getFence() {
+                int hi;
+                if ((hi = fence) < 0) {
+                    // Round up fence to maximum cardinality for allocated words
+                    // This is sufficient and cheap for sequential access
+                    // When splitting this value is lowered
+                    hi = fence = (wordsInUse >= wordIndex(Integer.MAX_VALUE))
+                                 ? Integer.MAX_VALUE
+                                 : wordsInUse << ADDRESS_BITS_PER_WORD;
+                    est = cardinality();
+                    index = nextSetBit(0);
+                }
+                return hi;
+            }
 
             @Override
-            public boolean hasNext() {
-                return next != -1;
+            public boolean tryAdvance(IntConsumer action) {
+                Objects.requireNonNull(action);
+
+                int hi = getFence();
+                int i = index;
+                if (i < 0 || i >= hi) {
+                    // Check if there is a final bit set for Integer.MAX_VALUE
+                    if (i == Integer.MAX_VALUE && hi == Integer.MAX_VALUE) {
+                        index = -1;
+                        action.accept(Integer.MAX_VALUE);
+                        return true;
+                    }
+                    return false;
+                }
+
+                index = nextSetBit(i + 1, wordIndex(hi - 1));
+                action.accept(i);
+                return true;
+            }
+
+            @Override
+            public void forEachRemaining(IntConsumer action) {
+                Objects.requireNonNull(action);
+
+                int hi = getFence();
+                int i = index;
+                int v = wordIndex(hi - 1);
+                index = -1;
+                while (i >= 0 && i < hi) {
+                    action.accept(i);
+                    i = nextSetBit(i + 1, v);
+                }
+                // Check if there is a final bit set for Integer.MAX_VALUE
+                if (i == Integer.MAX_VALUE && hi == Integer.MAX_VALUE) {
+                    action.accept(Integer.MAX_VALUE);
+                }
             }
 
             @Override
-            public int nextInt() {
-                if (next != -1) {
-                    int ret = next;
-                    next = (next == Integer.MAX_VALUE) ? -1 : nextSetBit(next+1);
-                    return ret;
-                } else {
-                    throw new NoSuchElementException();
+            public OfInt trySplit() {
+                int hi = getFence();
+                int lo = index;
+                if (lo < 0) {
+                    return null;
+                }
+
+                // Lower the fence to be the upper bound of last bit set
+                // The index is the first bit set, thus this spliterator
+                // covers one bit and cannot be split, or two or more
+                // bits
+                hi = fence = (hi < Integer.MAX_VALUE || !get(Integer.MAX_VALUE))
+                        ? previousSetBit(hi - 1) + 1
+                        : Integer.MAX_VALUE;
+
+                // Find the mid point
+                int mid = (lo + hi) >>> 1;
+                if (lo >= mid) {
+                    return null;
                 }
+
+                // Raise the index of this spliterator to be the next set bit
+                // from the mid point
+                index = nextSetBit(mid, wordIndex(hi - 1));
+                root = false;
+
+                // Don't lower the fence (mid point) of the returned spliterator,
+                // traversal or further splitting will do that work
+                return new BitSetSpliterator(lo, mid, est >>>= 1, false);
+            }
+
+            @Override
+            public long estimateSize() {
+                getFence(); // force init
+                return est;
+            }
+
+            @Override
+            public int characteristics() {
+                // Only sized when root and not split
+                return (root ? Spliterator.SIZED : 0) |
+                    Spliterator.ORDERED | Spliterator.DISTINCT | Spliterator.SORTED;
+            }
+
+            @Override
+            public Comparator<? super Integer> getComparator() {
+                return null;
             }
         }
+        return StreamSupport.intStream(new BitSetSpliterator(0, -1, 0, true), false);
+    }
 
-        return StreamSupport.intStream(
-                () -> Spliterators.spliterator(
-                        new BitSetIterator(), cardinality(),
-                        Spliterator.ORDERED | Spliterator.DISTINCT | Spliterator.SORTED),
-                Spliterator.SIZED | Spliterator.SUBSIZED |
-                        Spliterator.ORDERED | Spliterator.DISTINCT | Spliterator.SORTED,
-                false);
+    /**
+     * Returns the index of the first bit that is set to {@code true}
+     * that occurs on or after the specified starting index and up to and
+     * including the specified word index
+     * If no such bit exists then {@code -1} is returned.
+     *
+     * @param  fromIndex the index to start checking from (inclusive)
+     * @param  toWordIndex the last word index to check (inclusive)
+     * @return the index of the next set bit, or {@code -1} if there
+     *         is no such bit
+     */
+    private int nextSetBit(int fromIndex, int toWordIndex) {
+        int u = wordIndex(fromIndex);
+        // Check if out of bounds
+        if (u > toWordIndex)
+            return -1;
+
+        long word = words[u] & (WORD_MASK << fromIndex);
+
+        while (true) {
+            if (word != 0)
+                return (u * BITS_PER_WORD) + Long.numberOfTrailingZeros(word);
+            // Check if out of bounds
+            if (++u > toWordIndex)
+                return -1;
+            word = words[u];
+        }
     }
+
 }
--- a/jdk/test/java/util/BitSet/BitSetStreamTest.java	Wed Nov 16 14:26:12 2016 -0800
+++ b/jdk/test/java/util/BitSet/BitSetStreamTest.java	Wed Nov 16 14:26:14 2016 -0800
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2012, 2013, Oracle and/or its affiliates. All rights reserved.
+ * Copyright (c) 2012, 2016, Oracle and/or its affiliates. All rights reserved.
  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
  *
  * This code is free software; you can redistribute it and/or modify it
@@ -60,25 +60,6 @@
         public int getAsInt() { int s = n1; n1 = n2; n2 = s + n1; return s; }
     }
 
-    private static final Object[][] testcases = new Object[][] {
-        { "none", IntStream.empty() },
-        { "index 0", IntStream.of(0) },
-        { "index 255", IntStream.of(255) },
-        { "every bit", IntStream.range(0, 255) },
-        { "step 2", IntStream.range(0, 255).map(f -> f * 2) },
-        { "step 3", IntStream.range(0, 255).map(f -> f * 3) },
-        { "step 5", IntStream.range(0, 255).map(f -> f * 5) },
-        { "step 7", IntStream.range(0, 255).map(f -> f * 7) },
-        { "1, 10, 100, 1000", IntStream.of(1, 10, 100, 1000) },
-        { "max int", IntStream.of(Integer.MAX_VALUE) },
-        { "25 fibs", IntStream.generate(new Fibs()).limit(25) }
-    };
-
-    @DataProvider(name = "cases")
-    public static Object[][] produceCases() {
-        return testcases;
-    }
-
     @Test
     public void testFibs() {
         Fibs f = new Fibs();
@@ -93,22 +74,46 @@
         assertEquals(987, Fibs.fibs(16));
     }
 
+
+    @DataProvider(name = "cases")
+    public static Object[][] produceCases() {
+        return new Object[][] {
+                { "none", IntStream.empty() },
+                { "index 0", IntStream.of(0) },
+                { "index 255", IntStream.of(255) },
+                { "index 0 and 255", IntStream.of(0, 255) },
+                { "index Integer.MAX_VALUE", IntStream.of(Integer.MAX_VALUE) },
+                { "index Integer.MAX_VALUE - 1", IntStream.of(Integer.MAX_VALUE - 1) },
+                { "index 0 and Integer.MAX_VALUE", IntStream.of(0, Integer.MAX_VALUE) },
+                { "every bit", IntStream.range(0, 255) },
+                { "step 2", IntStream.range(0, 255).map(f -> f * 2) },
+                { "step 3", IntStream.range(0, 255).map(f -> f * 3) },
+                { "step 5", IntStream.range(0, 255).map(f -> f * 5) },
+                { "step 7", IntStream.range(0, 255).map(f -> f * 7) },
+                { "1, 10, 100, 1000", IntStream.of(1, 10, 100, 1000) },
+                { "25 fibs", IntStream.generate(new Fibs()).limit(25) }
+        };
+    }
+
     @Test(dataProvider = "cases")
     public void testBitsetStream(String name, IntStream data) {
-        BitSet bs = new BitSet();
-        long setBits = data.distinct()
-                           .peek(i -> bs.set(i))
-                           .count();
+        BitSet bs = data.collect(BitSet::new, BitSet::set, BitSet::or);
+
+        assertEquals(bs.cardinality(), bs.stream().count());
 
-        assertEquals(bs.cardinality(), setBits);
-        assertEquals(bs.cardinality(), bs.stream().reduce(0, (s, i) -> s+1));
+        int[] indexHolder = new int[] { -1 };
+        bs.stream().forEach(i -> {
+            int ei = indexHolder[0];
+            indexHolder[0] = bs.nextSetBit(ei + 1);
+            assertEquals(i, indexHolder[0]);
+        });
 
         PrimitiveIterator.OfInt it = bs.stream().iterator();
-        for (int i = bs.nextSetBit(0); i >= 0; i = bs.nextSetBit(i+1)) {
+        for (int i = bs.nextSetBit(0); i >= 0; i = bs.nextSetBit(i + 1)) {
             assertTrue(it.hasNext());
             assertEquals(it.nextInt(), i);
             if (i == Integer.MAX_VALUE)
-                break; // or (i+1) would overflow
+                break; // or (i + 1) would overflow
         }
         assertFalse(it.hasNext());
     }
@@ -123,16 +128,20 @@
         for (int seed : seeds) {
             final Random random = new Random(seed);
             random.nextBytes(bytes);
-            final BitSet bitSet = BitSet.valueOf(bytes);
-            final int cardinality = bitSet.cardinality();
-            final IntStream stream = bitSet.stream();
-            final int[] array = stream.toArray();
-            assertEquals(array.length, cardinality);
-            int nextSetBit = -1;
-            for (int i=0; i < cardinality; i++) {
-                nextSetBit = bitSet.nextSetBit(nextSetBit + 1);
-                assertEquals(array[i], nextSetBit);
-            }
+
+            BitSet bitSet = BitSet.valueOf(bytes);
+            testBitSetContents(bitSet, bitSet.stream().toArray());
+            testBitSetContents(bitSet, bitSet.stream().parallel().toArray());
+        }
+    }
+
+    void testBitSetContents(BitSet bitSet, int[] array) {
+        int cardinality = bitSet.cardinality();
+        assertEquals(array.length, cardinality);
+        int nextSetBit = -1;
+        for (int i = 0; i < cardinality; i++) {
+            nextSetBit = bitSet.nextSetBit(nextSetBit + 1);
+            assertEquals(array[i], nextSetBit);
         }
     }
 }
--- a/jdk/test/java/util/Spliterator/SpliteratorTraversingAndSplittingTest.java	Wed Nov 16 14:26:12 2016 -0800
+++ b/jdk/test/java/util/Spliterator/SpliteratorTraversingAndSplittingTest.java	Wed Nov 16 14:26:14 2016 -0800
@@ -37,6 +37,7 @@
 import java.util.ArrayDeque;
 import java.util.ArrayList;
 import java.util.Arrays;
+import java.util.BitSet;
 import java.util.Collection;
 import java.util.Collections;
 import java.util.Comparator;
@@ -80,7 +81,9 @@
 import java.util.function.LongConsumer;
 import java.util.function.Supplier;
 import java.util.function.UnaryOperator;
+import java.util.stream.IntStream;
 
+import static java.util.stream.Collectors.toList;
 import static org.testng.Assert.*;
 import static org.testng.Assert.assertEquals;
 
@@ -883,6 +886,33 @@
             cdb.add("new StringBuffer(\"%s\")", StringBuffer::new);
         }
 
+
+        Object[][] bitStreamTestcases = new Object[][] {
+                { "none", IntStream.empty().toArray() },
+                { "index 0", IntStream.of(0).toArray() },
+                { "index 255", IntStream.of(255).toArray() },
+                { "index 0 and 255", IntStream.of(0, 255).toArray() },
+                { "index Integer.MAX_VALUE", IntStream.of(Integer.MAX_VALUE).toArray() },
+                { "index Integer.MAX_VALUE - 1", IntStream.of(Integer.MAX_VALUE - 1).toArray() },
+                { "index 0 and Integer.MAX_VALUE", IntStream.of(0, Integer.MAX_VALUE).toArray() },
+                { "every bit", IntStream.range(0, 255).toArray() },
+                { "step 2", IntStream.range(0, 255).map(f -> f * 2).toArray() },
+                { "step 3", IntStream.range(0, 255).map(f -> f * 3).toArray() },
+                { "step 5", IntStream.range(0, 255).map(f -> f * 5).toArray() },
+                { "step 7", IntStream.range(0, 255).map(f -> f * 7).toArray() },
+                { "1, 10, 100, 1000", IntStream.of(1, 10, 100, 1000).toArray() },
+        };
+        for (Object[] tc : bitStreamTestcases) {
+            String description = (String)tc[0];
+            int[] exp = (int[])tc[1];
+            SpliteratorOfIntDataBuilder db = new SpliteratorOfIntDataBuilder(
+                    data, IntStream.of(exp).boxed().collect(toList()));
+
+            db.add("BitSet.stream.spliterator() {" + description + "}", () ->
+                IntStream.of(exp).collect(BitSet::new, BitSet::set, BitSet::or).
+                        stream().spliterator()
+            );
+        }
         return spliteratorOfIntDataProvider = data.toArray(new Object[0][]);
     }