8144675: Add a filtering collector
authorshinyafox
Sun, 13 Dec 2015 15:20:35 +0100
changeset 34685 ababd79c3b2b
parent 34684 b721350c05c0
child 34686 29ea8310a27a
8144675: Add a filtering collector Reviewed-by: psandoz, smarks
jdk/src/java.base/share/classes/java/util/stream/Collectors.java
jdk/test/java/util/stream/test/org/openjdk/tests/java/util/stream/CollectorsTest.java
--- a/jdk/src/java.base/share/classes/java/util/stream/Collectors.java	Sun Dec 13 15:10:13 2015 +0100
+++ b/jdk/src/java.base/share/classes/java/util/stream/Collectors.java	Sun Dec 13 15:20:35 2015 +0100
@@ -434,7 +434,7 @@
      * stream returned by mapper
      * @return a collector which applies the mapping function to the input
      * elements and provides the flat mapped results to the downstream collector
-     * @since 1.9
+     * @since 9
      */
     public static <T, U, A, R>
     Collector<T, ?, R> flatMapping(Function<? super T, ? extends Stream<? extends U>> mapper,
@@ -452,6 +452,53 @@
     }
 
     /**
+     * Adapts a {@code Collector} to one accepting elements of the same type
+     * {@code T} by applying the predicate to each input element and only
+     * accumulating if the predicate returns {@code true}.
+     *
+     * @apiNote
+     * The {@code filtering()} collectors are most useful when used in a
+     * multi-level reduction, such as downstream of a {@code groupingBy} or
+     * {@code partitioningBy}.  For example, given a stream of
+     * {@code Employee}, to accumulate the employees in each department that have a
+     * salary above a certain threshold:
+     * <pre>{@code
+     *     Map<Department, Set<Employee>> wellPaidEmployeesByDepartment
+     *         = employees.stream().collect(groupingBy(Employee::getDepartment,
+     *                                              filtering(e -> e.getSalary() > 2000, toSet())));
+     * }</pre>
+     * A filtering collector differs from a stream's {@code filter()} operation.
+     * In this example, suppose there are no employees whose salary is above the
+     * threshold in some department.  Using a filtering collector as shown above
+     * would result in a mapping from that department to an empty {@code Set}.
+     * If a stream {@code filter()} operation were done instead, there would be
+     * no mapping for that department at all.
+     *
+     * @param <T> the type of the input elements
+     * @param <A> intermediate accumulation type of the downstream collector
+     * @param <R> result type of collector
+     * @param predicate a predicate to be applied to the input elements
+     * @param downstream a collector which will accept values that match the
+     * predicate
+     * @return a collector which applies the predicate to the input elements
+     * and provides matching elements to the downstream collector
+     * @since 9
+     */
+    public static <T, A, R>
+    Collector<T, ?, R> filtering(Predicate<? super T> predicate,
+                               Collector<? super T, A, R> downstream) {
+        BiConsumer<A, ? super T> downstreamAccumulator = downstream.accumulator();
+        return new CollectorImpl<>(downstream.supplier(),
+                                   (r, t) -> {
+                                       if (predicate.test(t)) {
+                                           downstreamAccumulator.accept(r, t);
+                                       }
+                                   },
+                                   downstream.combiner(), downstream.finisher(),
+                                   downstream.characteristics());
+    }
+
+    /**
      * Adapts a {@code Collector} to perform an additional finishing
      * transformation.  For example, one could adapt the {@link #toList()}
      * collector to always produce an immutable list with:
--- a/jdk/test/java/util/stream/test/org/openjdk/tests/java/util/stream/CollectorsTest.java	Sun Dec 13 15:10:13 2015 +0100
+++ b/jdk/test/java/util/stream/test/org/openjdk/tests/java/util/stream/CollectorsTest.java	Sun Dec 13 15:20:35 2015 +0100
@@ -56,6 +56,7 @@
 
 import static java.util.stream.Collectors.collectingAndThen;
 import static java.util.stream.Collectors.flatMapping;
+import static java.util.stream.Collectors.filtering;
 import static java.util.stream.Collectors.groupingBy;
 import static java.util.stream.Collectors.groupingByConcurrent;
 import static java.util.stream.Collectors.mapping;
@@ -72,7 +73,7 @@
 
 /*
  * @test
- * @bug 8071600
+ * @bug 8071600 8144675
  * @summary Test for collectors.
  */
 public class CollectorsTest extends OpTestCase {
@@ -118,6 +119,23 @@
         }
     }
 
+    static class FilteringAssertion<T, R> extends CollectorAssertion<T, R> {
+        private final Predicate<T> filter;
+        private final CollectorAssertion<T, R> downstream;
+
+        public FilteringAssertion(Predicate<T> filter, CollectorAssertion<T, R> downstream) {
+            this.filter = filter;
+            this.downstream = downstream;
+        }
+
+        @Override
+        void assertValue(R value, Supplier<Stream<T>> source, boolean ordered) throws ReflectiveOperationException {
+            downstream.assertValue(value,
+                                   () -> source.get().filter(filter),
+                                   ordered);
+        }
+    }
+
     static class GroupingByAssertion<T, K, V, M extends Map<K, ? extends V>> extends CollectorAssertion<T, M> {
         private final Class<? extends Map> clazz;
         private final Function<T, K> classifier;
@@ -551,6 +569,36 @@
     }
 
     @Test(dataProvider = "StreamTestData<Integer>", dataProviderClass = StreamTestDataProvider.class)
+    public void testGroupingByWithFiltering(String name, TestData.OfRef<Integer> data) throws ReflectiveOperationException {
+        Function<Integer, Integer> classifier = i -> i % 3;
+        Predicate<Integer> filteringByMod2 = i -> i % 2 == 0;
+        Predicate<Integer> filteringByUnder100 = i -> i % 2 < 100;
+        Predicate<Integer> filteringByTrue = i -> true;
+        Predicate<Integer> filteringByFalse = i -> false;
+
+        exerciseMapCollection(data,
+                              groupingBy(classifier, filtering(filteringByMod2, toList())),
+                              new GroupingByAssertion<>(classifier, HashMap.class,
+                                                        new FilteringAssertion<>(filteringByMod2,
+                                                                                   new ToListAssertion<>())));
+        exerciseMapCollection(data,
+                              groupingBy(classifier, filtering(filteringByUnder100, toList())),
+                              new GroupingByAssertion<>(classifier, HashMap.class,
+                                                        new FilteringAssertion<>(filteringByUnder100,
+                                                                                   new ToListAssertion<>())));
+        exerciseMapCollection(data,
+                              groupingBy(classifier, filtering(filteringByTrue, toList())),
+                              new GroupingByAssertion<>(classifier, HashMap.class,
+                                                        new FilteringAssertion<>(filteringByTrue,
+                                                                                   new ToListAssertion<>())));
+        exerciseMapCollection(data,
+                              groupingBy(classifier, filtering(filteringByFalse, toList())),
+                              new GroupingByAssertion<>(classifier, HashMap.class,
+                                                        new FilteringAssertion<>(filteringByFalse,
+                                                                                   new ToListAssertion<>())));
+    }
+
+    @Test(dataProvider = "StreamTestData<Integer>", dataProviderClass = StreamTestDataProvider.class)
     public void testTwoLevelGroupingBy(String name, TestData.OfRef<Integer> data) throws ReflectiveOperationException {
         Function<Integer, Integer> classifier = i -> i % 6;
         Function<Integer, Integer> classifier2 = i -> i % 23;