8072784: Better spliterator implementation for BitSet.stream()
Reviewed-by: martin
--- a/jdk/src/java.base/share/classes/java/util/BitSet.java Wed Nov 16 14:26:12 2016 -0800
+++ b/jdk/src/java.base/share/classes/java/util/BitSet.java Wed Nov 16 14:26:14 2016 -0800
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 1995, 2014, Oracle and/or its affiliates. All rights reserved.
+ * Copyright (c) 1995, 2016, Oracle and/or its affiliates. All rights reserved.
* DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
*
* This code is free software; you can redistribute it and/or modify it
@@ -29,6 +29,7 @@
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.nio.LongBuffer;
+import java.util.function.IntConsumer;
import java.util.stream.IntStream;
import java.util.stream.StreamSupport;
@@ -1217,32 +1218,156 @@
* @since 1.8
*/
public IntStream stream() {
- class BitSetIterator implements PrimitiveIterator.OfInt {
- int next = nextSetBit(0);
+ class BitSetSpliterator implements Spliterator.OfInt {
+ private int index; // current bit index for a set bit
+ private int fence; // -1 until used; then one past last bit index
+ private int est; // size estimate
+ private boolean root; // true if root and not split
+ // root == true then size estimate is accurate
+ // index == -1 or index >= fence if fully traversed
+ // Special case when the max bit set is Integer.MAX_VALUE
+
+ BitSetSpliterator(int origin, int fence, int est, boolean root) {
+ this.index = origin;
+ this.fence = fence;
+ this.est = est;
+ this.root = root;
+ }
+
+ private int getFence() {
+ int hi;
+ if ((hi = fence) < 0) {
+ // Round up fence to maximum cardinality for allocated words
+ // This is sufficient and cheap for sequential access
+ // When splitting this value is lowered
+ hi = fence = (wordsInUse >= wordIndex(Integer.MAX_VALUE))
+ ? Integer.MAX_VALUE
+ : wordsInUse << ADDRESS_BITS_PER_WORD;
+ est = cardinality();
+ index = nextSetBit(0);
+ }
+ return hi;
+ }
@Override
- public boolean hasNext() {
- return next != -1;
+ public boolean tryAdvance(IntConsumer action) {
+ Objects.requireNonNull(action);
+
+ int hi = getFence();
+ int i = index;
+ if (i < 0 || i >= hi) {
+ // Check if there is a final bit set for Integer.MAX_VALUE
+ if (i == Integer.MAX_VALUE && hi == Integer.MAX_VALUE) {
+ index = -1;
+ action.accept(Integer.MAX_VALUE);
+ return true;
+ }
+ return false;
+ }
+
+ index = nextSetBit(i + 1, wordIndex(hi - 1));
+ action.accept(i);
+ return true;
+ }
+
+ @Override
+ public void forEachRemaining(IntConsumer action) {
+ Objects.requireNonNull(action);
+
+ int hi = getFence();
+ int i = index;
+ int v = wordIndex(hi - 1);
+ index = -1;
+ while (i >= 0 && i < hi) {
+ action.accept(i);
+ i = nextSetBit(i + 1, v);
+ }
+ // Check if there is a final bit set for Integer.MAX_VALUE
+ if (i == Integer.MAX_VALUE && hi == Integer.MAX_VALUE) {
+ action.accept(Integer.MAX_VALUE);
+ }
}
@Override
- public int nextInt() {
- if (next != -1) {
- int ret = next;
- next = (next == Integer.MAX_VALUE) ? -1 : nextSetBit(next+1);
- return ret;
- } else {
- throw new NoSuchElementException();
+ public OfInt trySplit() {
+ int hi = getFence();
+ int lo = index;
+ if (lo < 0) {
+ return null;
+ }
+
+ // Lower the fence to be the upper bound of last bit set
+ // The index is the first bit set, thus this spliterator
+ // covers one bit and cannot be split, or two or more
+ // bits
+ hi = fence = (hi < Integer.MAX_VALUE || !get(Integer.MAX_VALUE))
+ ? previousSetBit(hi - 1) + 1
+ : Integer.MAX_VALUE;
+
+ // Find the mid point
+ int mid = (lo + hi) >>> 1;
+ if (lo >= mid) {
+ return null;
}
+
+ // Raise the index of this spliterator to be the next set bit
+ // from the mid point
+ index = nextSetBit(mid, wordIndex(hi - 1));
+ root = false;
+
+ // Don't lower the fence (mid point) of the returned spliterator,
+ // traversal or further splitting will do that work
+ return new BitSetSpliterator(lo, mid, est >>>= 1, false);
+ }
+
+ @Override
+ public long estimateSize() {
+ getFence(); // force init
+ return est;
+ }
+
+ @Override
+ public int characteristics() {
+ // Only sized when root and not split
+ return (root ? Spliterator.SIZED : 0) |
+ Spliterator.ORDERED | Spliterator.DISTINCT | Spliterator.SORTED;
+ }
+
+ @Override
+ public Comparator<? super Integer> getComparator() {
+ return null;
}
}
+ return StreamSupport.intStream(new BitSetSpliterator(0, -1, 0, true), false);
+ }
- return StreamSupport.intStream(
- () -> Spliterators.spliterator(
- new BitSetIterator(), cardinality(),
- Spliterator.ORDERED | Spliterator.DISTINCT | Spliterator.SORTED),
- Spliterator.SIZED | Spliterator.SUBSIZED |
- Spliterator.ORDERED | Spliterator.DISTINCT | Spliterator.SORTED,
- false);
+ /**
+ * Returns the index of the first bit that is set to {@code true}
+ * that occurs on or after the specified starting index and up to and
+ * including the specified word index
+ * If no such bit exists then {@code -1} is returned.
+ *
+ * @param fromIndex the index to start checking from (inclusive)
+ * @param toWordIndex the last word index to check (inclusive)
+ * @return the index of the next set bit, or {@code -1} if there
+ * is no such bit
+ */
+ private int nextSetBit(int fromIndex, int toWordIndex) {
+ int u = wordIndex(fromIndex);
+ // Check if out of bounds
+ if (u > toWordIndex)
+ return -1;
+
+ long word = words[u] & (WORD_MASK << fromIndex);
+
+ while (true) {
+ if (word != 0)
+ return (u * BITS_PER_WORD) + Long.numberOfTrailingZeros(word);
+ // Check if out of bounds
+ if (++u > toWordIndex)
+ return -1;
+ word = words[u];
+ }
}
+
}
--- a/jdk/test/java/util/BitSet/BitSetStreamTest.java Wed Nov 16 14:26:12 2016 -0800
+++ b/jdk/test/java/util/BitSet/BitSetStreamTest.java Wed Nov 16 14:26:14 2016 -0800
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2012, 2013, Oracle and/or its affiliates. All rights reserved.
+ * Copyright (c) 2012, 2016, Oracle and/or its affiliates. All rights reserved.
* DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
*
* This code is free software; you can redistribute it and/or modify it
@@ -60,25 +60,6 @@
public int getAsInt() { int s = n1; n1 = n2; n2 = s + n1; return s; }
}
- private static final Object[][] testcases = new Object[][] {
- { "none", IntStream.empty() },
- { "index 0", IntStream.of(0) },
- { "index 255", IntStream.of(255) },
- { "every bit", IntStream.range(0, 255) },
- { "step 2", IntStream.range(0, 255).map(f -> f * 2) },
- { "step 3", IntStream.range(0, 255).map(f -> f * 3) },
- { "step 5", IntStream.range(0, 255).map(f -> f * 5) },
- { "step 7", IntStream.range(0, 255).map(f -> f * 7) },
- { "1, 10, 100, 1000", IntStream.of(1, 10, 100, 1000) },
- { "max int", IntStream.of(Integer.MAX_VALUE) },
- { "25 fibs", IntStream.generate(new Fibs()).limit(25) }
- };
-
- @DataProvider(name = "cases")
- public static Object[][] produceCases() {
- return testcases;
- }
-
@Test
public void testFibs() {
Fibs f = new Fibs();
@@ -93,22 +74,46 @@
assertEquals(987, Fibs.fibs(16));
}
+
+ @DataProvider(name = "cases")
+ public static Object[][] produceCases() {
+ return new Object[][] {
+ { "none", IntStream.empty() },
+ { "index 0", IntStream.of(0) },
+ { "index 255", IntStream.of(255) },
+ { "index 0 and 255", IntStream.of(0, 255) },
+ { "index Integer.MAX_VALUE", IntStream.of(Integer.MAX_VALUE) },
+ { "index Integer.MAX_VALUE - 1", IntStream.of(Integer.MAX_VALUE - 1) },
+ { "index 0 and Integer.MAX_VALUE", IntStream.of(0, Integer.MAX_VALUE) },
+ { "every bit", IntStream.range(0, 255) },
+ { "step 2", IntStream.range(0, 255).map(f -> f * 2) },
+ { "step 3", IntStream.range(0, 255).map(f -> f * 3) },
+ { "step 5", IntStream.range(0, 255).map(f -> f * 5) },
+ { "step 7", IntStream.range(0, 255).map(f -> f * 7) },
+ { "1, 10, 100, 1000", IntStream.of(1, 10, 100, 1000) },
+ { "25 fibs", IntStream.generate(new Fibs()).limit(25) }
+ };
+ }
+
@Test(dataProvider = "cases")
public void testBitsetStream(String name, IntStream data) {
- BitSet bs = new BitSet();
- long setBits = data.distinct()
- .peek(i -> bs.set(i))
- .count();
+ BitSet bs = data.collect(BitSet::new, BitSet::set, BitSet::or);
+
+ assertEquals(bs.cardinality(), bs.stream().count());
- assertEquals(bs.cardinality(), setBits);
- assertEquals(bs.cardinality(), bs.stream().reduce(0, (s, i) -> s+1));
+ int[] indexHolder = new int[] { -1 };
+ bs.stream().forEach(i -> {
+ int ei = indexHolder[0];
+ indexHolder[0] = bs.nextSetBit(ei + 1);
+ assertEquals(i, indexHolder[0]);
+ });
PrimitiveIterator.OfInt it = bs.stream().iterator();
- for (int i = bs.nextSetBit(0); i >= 0; i = bs.nextSetBit(i+1)) {
+ for (int i = bs.nextSetBit(0); i >= 0; i = bs.nextSetBit(i + 1)) {
assertTrue(it.hasNext());
assertEquals(it.nextInt(), i);
if (i == Integer.MAX_VALUE)
- break; // or (i+1) would overflow
+ break; // or (i + 1) would overflow
}
assertFalse(it.hasNext());
}
@@ -123,16 +128,20 @@
for (int seed : seeds) {
final Random random = new Random(seed);
random.nextBytes(bytes);
- final BitSet bitSet = BitSet.valueOf(bytes);
- final int cardinality = bitSet.cardinality();
- final IntStream stream = bitSet.stream();
- final int[] array = stream.toArray();
- assertEquals(array.length, cardinality);
- int nextSetBit = -1;
- for (int i=0; i < cardinality; i++) {
- nextSetBit = bitSet.nextSetBit(nextSetBit + 1);
- assertEquals(array[i], nextSetBit);
- }
+
+ BitSet bitSet = BitSet.valueOf(bytes);
+ testBitSetContents(bitSet, bitSet.stream().toArray());
+ testBitSetContents(bitSet, bitSet.stream().parallel().toArray());
+ }
+ }
+
+ void testBitSetContents(BitSet bitSet, int[] array) {
+ int cardinality = bitSet.cardinality();
+ assertEquals(array.length, cardinality);
+ int nextSetBit = -1;
+ for (int i = 0; i < cardinality; i++) {
+ nextSetBit = bitSet.nextSetBit(nextSetBit + 1);
+ assertEquals(array[i], nextSetBit);
}
}
}
--- a/jdk/test/java/util/Spliterator/SpliteratorTraversingAndSplittingTest.java Wed Nov 16 14:26:12 2016 -0800
+++ b/jdk/test/java/util/Spliterator/SpliteratorTraversingAndSplittingTest.java Wed Nov 16 14:26:14 2016 -0800
@@ -37,6 +37,7 @@
import java.util.ArrayDeque;
import java.util.ArrayList;
import java.util.Arrays;
+import java.util.BitSet;
import java.util.Collection;
import java.util.Collections;
import java.util.Comparator;
@@ -80,7 +81,9 @@
import java.util.function.LongConsumer;
import java.util.function.Supplier;
import java.util.function.UnaryOperator;
+import java.util.stream.IntStream;
+import static java.util.stream.Collectors.toList;
import static org.testng.Assert.*;
import static org.testng.Assert.assertEquals;
@@ -883,6 +886,33 @@
cdb.add("new StringBuffer(\"%s\")", StringBuffer::new);
}
+
+ Object[][] bitStreamTestcases = new Object[][] {
+ { "none", IntStream.empty().toArray() },
+ { "index 0", IntStream.of(0).toArray() },
+ { "index 255", IntStream.of(255).toArray() },
+ { "index 0 and 255", IntStream.of(0, 255).toArray() },
+ { "index Integer.MAX_VALUE", IntStream.of(Integer.MAX_VALUE).toArray() },
+ { "index Integer.MAX_VALUE - 1", IntStream.of(Integer.MAX_VALUE - 1).toArray() },
+ { "index 0 and Integer.MAX_VALUE", IntStream.of(0, Integer.MAX_VALUE).toArray() },
+ { "every bit", IntStream.range(0, 255).toArray() },
+ { "step 2", IntStream.range(0, 255).map(f -> f * 2).toArray() },
+ { "step 3", IntStream.range(0, 255).map(f -> f * 3).toArray() },
+ { "step 5", IntStream.range(0, 255).map(f -> f * 5).toArray() },
+ { "step 7", IntStream.range(0, 255).map(f -> f * 7).toArray() },
+ { "1, 10, 100, 1000", IntStream.of(1, 10, 100, 1000).toArray() },
+ };
+ for (Object[] tc : bitStreamTestcases) {
+ String description = (String)tc[0];
+ int[] exp = (int[])tc[1];
+ SpliteratorOfIntDataBuilder db = new SpliteratorOfIntDataBuilder(
+ data, IntStream.of(exp).boxed().collect(toList()));
+
+ db.add("BitSet.stream.spliterator() {" + description + "}", () ->
+ IntStream.of(exp).collect(BitSet::new, BitSet::set, BitSet::or).
+ stream().spliterator()
+ );
+ }
return spliteratorOfIntDataProvider = data.toArray(new Object[0][]);
}