8205461: Create Collector which merges results of two other collectors
authorplevart
Tue, 25 Sep 2018 14:23:37 +0200
changeset 51864 490d9001eba9
parent 51863 bc38c75eed57
child 51865 eb954a4b6083
8205461: Create Collector which merges results of two other collectors Reviewed-by: briangoetz, smarks, plevart Contributed-by: amaembo@gmail.com
src/java.base/share/classes/java/util/stream/Collectors.java
test/jdk/java/util/stream/test/org/openjdk/tests/java/util/stream/CollectorsTest.java
--- a/src/java.base/share/classes/java/util/stream/Collectors.java	Tue Sep 25 14:16:33 2018 +0200
+++ b/src/java.base/share/classes/java/util/stream/Collectors.java	Tue Sep 25 14:23:37 2018 +0200
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2012, 2017, Oracle and/or its affiliates. All rights reserved.
+ * Copyright (c) 2012, 2018, Oracle and/or its affiliates. All rights reserved.
  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
  *
  * This code is free software; you can redistribute it and/or modify it
@@ -1885,6 +1885,102 @@
     }
 
     /**
+     * Returns a {@code Collector} that is a composite of two downstream collectors.
+     * Every element passed to the resulting collector is processed by both downstream
+     * collectors, then their results are merged using the specified merge function
+     * into the final result.
+     *
+     * <p>The resulting collector functions do the following:
+     *
+     * <ul>
+     * <li>supplier: creates a result container that contains result containers
+     * obtained by calling each collector's supplier
+     * <li>accumulator: calls each collector's accumulator with its result container
+     * and the input element
+     * <li>combiner: calls each collector's combiner with two result containers
+     * <li>finisher: calls each collector's finisher with its result container,
+     * then calls the supplied merger and returns its result.
+     * </ul>
+     *
+     * <p>The resulting collector is {@link Collector.Characteristics#UNORDERED} if both downstream
+     * collectors are unordered and {@link Collector.Characteristics#CONCURRENT} if both downstream
+     * collectors are concurrent.
+     *
+     * @param <T>         the type of the input elements
+     * @param <R1>        the result type of the first collector
+     * @param <R2>        the result type of the second collector
+     * @param <R>         the final result type
+     * @param downstream1 the first downstream collector
+     * @param downstream2 the second downstream collector
+     * @param merger      the function which merges two results into the single one
+     * @return a {@code Collector} which aggregates the results of two supplied collectors.
+     * @since 12
+     */
+    public static <T, R1, R2, R>
+    Collector<T, ?, R> teeing(Collector<? super T, ?, R1> downstream1,
+                              Collector<? super T, ?, R2> downstream2,
+                              BiFunction<? super R1, ? super R2, R> merger) {
+        return teeing0(downstream1, downstream2, merger);
+    }
+
+    private static <T, A1, A2, R1, R2, R>
+    Collector<T, ?, R> teeing0(Collector<? super T, A1, R1> downstream1,
+                               Collector<? super T, A2, R2> downstream2,
+                               BiFunction<? super R1, ? super R2, R> merger) {
+        Objects.requireNonNull(downstream1, "downstream1");
+        Objects.requireNonNull(downstream2, "downstream2");
+        Objects.requireNonNull(merger, "merger");
+
+        Supplier<A1> c1Supplier = Objects.requireNonNull(downstream1.supplier(), "downstream1 supplier");
+        Supplier<A2> c2Supplier = Objects.requireNonNull(downstream2.supplier(), "downstream2 supplier");
+        BiConsumer<A1, ? super T> c1Accumulator =
+                Objects.requireNonNull(downstream1.accumulator(), "downstream1 accumulator");
+        BiConsumer<A2, ? super T> c2Accumulator =
+                Objects.requireNonNull(downstream2.accumulator(), "downstream2 accumulator");
+        BinaryOperator<A1> c1Combiner = Objects.requireNonNull(downstream1.combiner(), "downstream1 combiner");
+        BinaryOperator<A2> c2Combiner = Objects.requireNonNull(downstream2.combiner(), "downstream2 combiner");
+        Function<A1, R1> c1Finisher = Objects.requireNonNull(downstream1.finisher(), "downstream1 finisher");
+        Function<A2, R2> c2Finisher = Objects.requireNonNull(downstream2.finisher(), "downstream2 finisher");
+
+        Set<Collector.Characteristics> characteristics;
+        Set<Collector.Characteristics> c1Characteristics = downstream1.characteristics();
+        Set<Collector.Characteristics> c2Characteristics = downstream2.characteristics();
+        if (CH_ID.containsAll(c1Characteristics) || CH_ID.containsAll(c2Characteristics)) {
+            characteristics = CH_NOID;
+        } else {
+            EnumSet<Collector.Characteristics> c = EnumSet.noneOf(Collector.Characteristics.class);
+            c.addAll(c1Characteristics);
+            c.retainAll(c2Characteristics);
+            c.remove(Collector.Characteristics.IDENTITY_FINISH);
+            characteristics = Collections.unmodifiableSet(c);
+        }
+
+        class PairBox {
+            A1 left = c1Supplier.get();
+            A2 right = c2Supplier.get();
+
+            void add(T t) {
+                c1Accumulator.accept(left, t);
+                c2Accumulator.accept(right, t);
+            }
+
+            PairBox combine(PairBox other) {
+                left = c1Combiner.apply(left, other.left);
+                right = c2Combiner.apply(right, other.right);
+                return this;
+            }
+
+            R get() {
+                R1 r1 = c1Finisher.apply(left);
+                R2 r2 = c2Finisher.apply(right);
+                return merger.apply(r1, r2);
+            }
+        }
+
+        return new CollectorImpl<>(PairBox::new, PairBox::add, PairBox::combine, PairBox::get, characteristics);
+    }
+
+    /**
      * Implementation class used by partitioningBy.
      */
     private static final class Partition<T>
--- a/test/jdk/java/util/stream/test/org/openjdk/tests/java/util/stream/CollectorsTest.java	Tue Sep 25 14:16:33 2018 +0200
+++ b/test/jdk/java/util/stream/test/org/openjdk/tests/java/util/stream/CollectorsTest.java	Tue Sep 25 14:23:37 2018 +0200
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2012, 2015, Oracle and/or its affiliates. All rights reserved.
+ * Copyright (c) 2012, 2018, Oracle and/or its affiliates. All rights reserved.
  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
  *
  * This code is free software; you can redistribute it and/or modify it
@@ -29,6 +29,7 @@
 import java.util.Comparator;
 import java.util.HashMap;
 import java.util.HashSet;
+import java.util.IntSummaryStatistics;
 import java.util.Iterator;
 import java.util.List;
 import java.util.Map;
@@ -39,6 +40,7 @@
 import java.util.concurrent.ConcurrentHashMap;
 import java.util.concurrent.ConcurrentSkipListMap;
 import java.util.concurrent.atomic.AtomicInteger;
+import java.util.function.BiFunction;
 import java.util.function.BinaryOperator;
 import java.util.function.Function;
 import java.util.function.Predicate;
@@ -96,7 +98,7 @@
         @Override
         void assertValue(R value, Supplier<Stream<T>> source, boolean ordered) throws ReflectiveOperationException {
             downstream.assertValue(value,
-                                   () -> source.get().map(mapper::apply),
+                                   () -> source.get().map(mapper),
                                    ordered);
         }
     }
@@ -114,7 +116,7 @@
         @Override
         void assertValue(R value, Supplier<Stream<T>> source, boolean ordered) throws ReflectiveOperationException {
             downstream.assertValue(value,
-                                   () -> source.get().flatMap(mapper::apply),
+                                   () -> source.get().flatMap(mapper),
                                    ordered);
         }
     }
@@ -287,6 +289,27 @@
         }
     }
 
+    static class TeeingAssertion<T, R1, R2, RR> extends CollectorAssertion<T, RR> {
+        private final Collector<T, ?, R1> c1;
+        private final Collector<T, ?, R2> c2;
+        private final BiFunction<? super R1, ? super R2, ? extends RR> finisher;
+
+        TeeingAssertion(Collector<T, ?, R1> c1, Collector<T, ?, R2> c2,
+                               BiFunction<? super R1, ? super R2, ? extends RR> finisher) {
+            this.c1 = c1;
+            this.c2 = c2;
+            this.finisher = finisher;
+        }
+
+        @Override
+        void assertValue(RR value, Supplier<Stream<T>> source, boolean ordered) {
+            R1 r1 = source.get().collect(c1);
+            R2 r2 = source.get().collect(c2);
+            RR expected = finisher.apply(r1, r2);
+            assertEquals(value, expected);
+        }
+    }
+
     private <T> ResultAsserter<T> mapTabulationAsserter(boolean ordered) {
         return (act, exp, ord, par) -> {
             if (par && (!ordered || !ord)) {
@@ -746,4 +769,42 @@
         catch (UnsupportedOperationException ignored) { }
     }
 
+    @Test(dataProvider = "StreamTestData<Integer>", dataProviderClass = StreamTestDataProvider.class)
+    public void testTeeing(String name, TestData.OfRef<Integer> data) throws ReflectiveOperationException {
+        Collector<Integer, ?, Long> summing = Collectors.summingLong(Integer::valueOf);
+        Collector<Integer, ?, Long> counting = Collectors.counting();
+        Collector<Integer, ?, Integer> min = collectingAndThen(Collectors.<Integer>minBy(Comparator.naturalOrder()),
+                opt -> opt.orElse(Integer.MAX_VALUE));
+        Collector<Integer, ?, Integer> max = collectingAndThen(Collectors.<Integer>maxBy(Comparator.naturalOrder()),
+                opt -> opt.orElse(Integer.MIN_VALUE));
+        Collector<Integer, ?, String> joining = mapping(String::valueOf, Collectors.joining(", ", "[", "]"));
+
+        Collector<Integer, ?, Map.Entry<Long, Long>> sumAndCount = Collectors.teeing(summing, counting, Map::entry);
+        Collector<Integer, ?, Map.Entry<Integer, Integer>> minAndMax = Collectors.teeing(min, max, Map::entry);
+        Collector<Integer, ?, Double> averaging = Collectors.teeing(summing, counting,
+                (sum, count) -> ((double)sum) / count);
+        Collector<Integer, ?, String> summaryStatistics = Collectors.teeing(sumAndCount, minAndMax,
+                (sumCountEntry, minMaxEntry) -> new IntSummaryStatistics(
+                        sumCountEntry.getValue(), minMaxEntry.getKey(),
+                        minMaxEntry.getValue(), sumCountEntry.getKey()).toString());
+        Collector<Integer, ?, String> countAndContent = Collectors.teeing(counting, joining,
+                (count, content) -> count+": "+content);
+
+        assertCollect(data, sumAndCount, stream -> {
+            List<Integer> list = stream.collect(toList());
+            return Map.entry(list.stream().mapToLong(Integer::intValue).sum(), (long) list.size());
+        });
+        assertCollect(data, averaging, stream -> stream.mapToInt(Integer::intValue).average().orElse(Double.NaN));
+        assertCollect(data, summaryStatistics,
+                stream -> stream.mapToInt(Integer::intValue).summaryStatistics().toString());
+        assertCollect(data, countAndContent, stream -> {
+            List<Integer> list = stream.collect(toList());
+            return list.size()+": "+list;
+        });
+
+        Function<Integer, Integer> classifier = i -> i % 3;
+        exerciseMapCollection(data, groupingBy(classifier, sumAndCount),
+                new GroupingByAssertion<>(classifier, Map.class,
+                        new TeeingAssertion<>(summing, counting, Map::entry)));
+    }
 }