8201194: Handle local variable declarations in lambda deduplication
authorcushon
Thu, 05 Apr 2018 14:39:04 -0700
changeset 49541 4f6887eade94
parent 49540 9704789737c1
child 49542 da62fa14a3fe
8201194: Handle local variable declarations in lambda deduplication Reviewed-by: vromero
src/jdk.compiler/share/classes/com/sun/tools/javac/comp/LambdaToMethod.java
src/jdk.compiler/share/classes/com/sun/tools/javac/comp/TreeDiffer.java
src/jdk.compiler/share/classes/com/sun/tools/javac/comp/TreeHasher.java
test/langtools/tools/javac/lambda/deduplication/Deduplication.java
test/langtools/tools/javac/lambda/deduplication/DeduplicationTest.java
--- a/src/jdk.compiler/share/classes/com/sun/tools/javac/comp/LambdaToMethod.java	Fri Apr 06 02:52:24 2018 +0200
+++ b/src/jdk.compiler/share/classes/com/sun/tools/javac/comp/LambdaToMethod.java	Thu Apr 05 14:39:04 2018 -0700
@@ -183,15 +183,7 @@
         public int hashCode() {
             int hashCode = this.hashCode;
             if (hashCode == 0) {
-                this.hashCode = hashCode = TreeHasher.hash(tree, sym -> {
-                    if (sym.owner == symbol) {
-                        int idx = symbol.params().indexOf(sym);
-                        if (idx != -1) {
-                            return idx;
-                        }
-                    }
-                    return null;
-                });
+                this.hashCode = hashCode = TreeHasher.hash(tree, symbol.params());
             }
             return hashCode;
         }
@@ -203,17 +195,7 @@
             }
             DedupedLambda that = (DedupedLambda) o;
             return types.isSameType(symbol.asType(), that.symbol.asType())
-                    && new TreeDiffer((lhs, rhs) -> {
-                if (lhs.owner == symbol) {
-                    int idx = symbol.params().indexOf(lhs);
-                    if (idx != -1) {
-                        if (Objects.equals(idx, that.symbol.params().indexOf(rhs))) {
-                            return true;
-                        }
-                    }
-                }
-                return null;
-            }).scan(tree, that.tree);
+                    && new TreeDiffer(symbol.params(), that.symbol.params()).scan(tree, that.tree);
         }
     }
 
--- a/src/jdk.compiler/share/classes/com/sun/tools/javac/comp/TreeDiffer.java	Fri Apr 06 02:52:24 2018 +0200
+++ b/src/jdk.compiler/share/classes/com/sun/tools/javac/comp/TreeDiffer.java	Thu Apr 05 14:39:04 2018 -0700
@@ -89,24 +89,34 @@
 import com.sun.tools.javac.tree.TreeInfo;
 import com.sun.tools.javac.tree.TreeScanner;
 import com.sun.tools.javac.util.List;
-
-import javax.lang.model.element.ElementKind;
+import java.util.Collection;
+import java.util.HashMap;
 import java.util.Iterator;
+import java.util.Map;
 import java.util.Objects;
-import java.util.function.BiFunction;
-import java.util.function.Consumer;
 
 /** A visitor that compares two lambda bodies for structural equality. */
 public class TreeDiffer extends TreeScanner {
 
-    private BiFunction<Symbol, Symbol, Boolean> symbolDiffer;
+    public TreeDiffer(
+            Collection<? extends Symbol> symbols, Collection<? extends Symbol> otherSymbols) {
+        this.equiv = equiv(symbols, otherSymbols);
+    }
 
-    public TreeDiffer(BiFunction<Symbol, Symbol, Boolean> symbolDiffer) {
-        this.symbolDiffer = Objects.requireNonNull(symbolDiffer);
+    private static Map<Symbol, Symbol> equiv(
+            Collection<? extends Symbol> symbols, Collection<? extends Symbol> otherSymbols) {
+        Map<Symbol, Symbol> result = new HashMap<>();
+        Iterator<? extends Symbol> it = otherSymbols.iterator();
+        for (Symbol symbol : symbols) {
+            if (!it.hasNext()) break;
+            result.put(symbol, it.next());
+        }
+        return result;
     }
 
     private JCTree parameter;
     private boolean result;
+    private Map<Symbol, Symbol> equiv = new HashMap<>();
 
     public boolean scan(JCTree tree, JCTree parameter) {
         if (tree == null || parameter == null) {
@@ -172,9 +182,8 @@
         Symbol symbol = tree.sym;
         Symbol otherSymbol = that.sym;
         if (symbol != null && otherSymbol != null) {
-            Boolean tmp = symbolDiffer.apply(symbol, otherSymbol);
-            if (tmp != null) {
-                result = tmp;
+            if (Objects.equals(equiv.get(symbol), otherSymbol)) {
+                result = true;
                 return;
             }
         }
@@ -598,6 +607,10 @@
                         && scan(tree.nameexpr, that.nameexpr)
                         && scan(tree.vartype, that.vartype)
                         && scan(tree.init, that.init);
+        if (!result) {
+            return;
+        }
+        equiv.put(tree.sym, that.sym);
     }
 
     @Override
--- a/src/jdk.compiler/share/classes/com/sun/tools/javac/comp/TreeHasher.java	Fri Apr 06 02:52:24 2018 +0200
+++ b/src/jdk.compiler/share/classes/com/sun/tools/javac/comp/TreeHasher.java	Thu Apr 05 14:39:04 2018 -0700
@@ -30,26 +30,31 @@
 import com.sun.tools.javac.tree.JCTree.JCFieldAccess;
 import com.sun.tools.javac.tree.JCTree.JCIdent;
 import com.sun.tools.javac.tree.JCTree.JCLiteral;
+import com.sun.tools.javac.tree.JCTree.JCVariableDecl;
 import com.sun.tools.javac.tree.TreeInfo;
 import com.sun.tools.javac.tree.TreeScanner;
+import java.util.Collection;
+import java.util.HashMap;
+import java.util.Map;
 import java.util.Objects;
-import java.util.function.Function;
 
 /** A tree visitor that computes a hash code. */
 public class TreeHasher extends TreeScanner {
 
-    private final Function<Symbol, Integer> symbolHasher;
+    private final Map<Symbol, Integer> symbolHashes;
     private int result = 17;
 
-    public TreeHasher(Function<Symbol, Integer> symbolHasher) {
-        this.symbolHasher = Objects.requireNonNull(symbolHasher);
+    public TreeHasher(Map<Symbol, Integer> symbolHashes) {
+        this.symbolHashes = Objects.requireNonNull(symbolHashes);
     }
 
-    public static int hash(JCTree tree, Function<Symbol, Integer> symbolHasher) {
+    public static int hash(JCTree tree, Collection<? extends Symbol> symbols) {
         if (tree == null) {
             return 0;
         }
-        TreeHasher hasher = new TreeHasher(symbolHasher);
+        Map<Symbol, Integer> symbolHashes = new HashMap<>();
+        symbols.forEach(s -> symbolHashes.put(s, symbolHashes.size()));
+        TreeHasher hasher = new TreeHasher(symbolHashes);
         tree.accept(hasher);
         return hasher.result;
     }
@@ -85,7 +90,7 @@
     public void visitIdent(JCIdent tree) {
         Symbol sym = tree.sym;
         if (sym != null) {
-            Integer hash = symbolHasher.apply(sym);
+            Integer hash = symbolHashes.get(sym);
             if (hash != null) {
                 hash(hash);
                 return;
@@ -99,4 +104,10 @@
         hash(tree.sym);
         super.visitSelect(tree);
     }
+
+    @Override
+    public void visitVarDef(JCVariableDecl tree) {
+        symbolHashes.computeIfAbsent(tree.sym, k -> symbolHashes.size());
+        super.visitVarDef(tree);
+    }
 }
--- a/test/langtools/tools/javac/lambda/deduplication/Deduplication.java	Fri Apr 06 02:52:24 2018 +0200
+++ b/test/langtools/tools/javac/lambda/deduplication/Deduplication.java	Thu Apr 05 14:39:04 2018 -0700
@@ -77,18 +77,45 @@
         group((Function<Integer, Integer>) y -> j);
 
         group(
-                (Function<Integer, Integer>) y -> {
-                        while (true) {
-                              break;
-                        }
-                        return 42;
-                },
-                (Function<Integer, Integer>) y -> {
-                        while (true) {
-                              break;
-                        }
-                        return 42;
-                });
+                (Function<Integer, Integer>)
+                        y -> {
+                            while (true) {
+                                break;
+                            }
+                            return 42;
+                        },
+                (Function<Integer, Integer>)
+                        y -> {
+                            while (true) {
+                                break;
+                            }
+                            return 42;
+                        });
+
+        group(
+                (Function<Integer, Integer>)
+                        x -> {
+                            int y = x;
+                            return y;
+                        },
+                (Function<Integer, Integer>)
+                        x -> {
+                            int y = x;
+                            return y;
+                        });
+
+        group(
+                (Function<Integer, Integer>)
+                        x -> {
+                            int y = 0, z = x;
+                            return y;
+                        });
+        group(
+                (Function<Integer, Integer>)
+                        x -> {
+                            int y = 0, z = x;
+                            return z;
+                        });
 
         class Local {
             int i;
--- a/test/langtools/tools/javac/lambda/deduplication/DeduplicationTest.java	Fri Apr 06 02:52:24 2018 +0200
+++ b/test/langtools/tools/javac/lambda/deduplication/DeduplicationTest.java	Thu Apr 05 14:39:04 2018 -0700
@@ -22,7 +22,7 @@
  */
 
 /**
- * @test 8200301
+ * @test 8200301 8201194
  * @summary deduplicate lambda methods with the same body, target type, and captured state
  * @modules jdk.jdeps/com.sun.tools.classfile jdk.compiler/com.sun.tools.javac.api
  *     jdk.compiler/com.sun.tools.javac.code jdk.compiler/com.sun.tools.javac.comp
@@ -32,6 +32,7 @@
  */
 import static java.nio.charset.StandardCharsets.UTF_8;
 import static java.util.stream.Collectors.joining;
+import static java.util.stream.Collectors.toList;
 import static java.util.stream.Collectors.toMap;
 import static java.util.stream.Collectors.toSet;
 
@@ -57,7 +58,6 @@
 import com.sun.tools.javac.tree.JCTree.JCLambda;
 import com.sun.tools.javac.tree.JCTree.JCMethodInvocation;
 import com.sun.tools.javac.tree.JCTree.JCTypeCast;
-import com.sun.tools.javac.tree.JCTree.JCVariableDecl;
 import com.sun.tools.javac.tree.JCTree.Tag;
 import com.sun.tools.javac.tree.TreeScanner;
 import com.sun.tools.javac.util.Context;
@@ -70,10 +70,8 @@
 import java.util.List;
 import java.util.Locale;
 import java.util.Map;
-import java.util.Objects;
 import java.util.Set;
 import java.util.TreeSet;
-import java.util.function.BiFunction;
 import javax.tools.Diagnostic;
 import javax.tools.DiagnosticListener;
 import javax.tools.JavaFileObject;
@@ -160,36 +158,9 @@
         }
     }
 
-    /**
-     * Returns a symbol comparator that treats symbols that correspond to the same parameter of each
-     * of the given lambdas as equal.
-     */
-    private static BiFunction<Symbol, Symbol, Boolean> paramsEqual(JCLambda lhs, JCLambda rhs) {
-        return (x, y) -> {
-            Integer idx = paramIndex(lhs, x);
-            if (idx != null && idx != -1) {
-                if (Objects.equals(idx, paramIndex(rhs, y))) {
-                    return true;
-                }
-            }
-            return null;
-        };
-    }
-
-    /**
-     * Returns the index of the given symbol as a parameter of the given lambda, or else {@code -1}
-     * if is not a parameter.
-     */
-    private static Integer paramIndex(JCLambda lambda, Symbol sym) {
-        if (sym != null) {
-            int idx = 0;
-            for (JCVariableDecl param : lambda.params) {
-                if (sym == param.sym) {
-                    return idx;
-                }
-            }
-        }
-        return null;
+    /** Returns the parameter symbols of the given lambda. */
+    private static List<Symbol> paramSymbols(JCLambda lambda) {
+        return lambda.params.stream().map(x -> x.sym).collect(toList());
     }
 
     /** A diagnostic listener that records debug messages related to lambda desugaring. */
@@ -310,13 +281,14 @@
                         dedupedLambdas.put(lhs, first);
                     }
                     for (JCLambda rhs : curr) {
-                        if (!new TreeDiffer(paramsEqual(lhs, rhs)).scan(lhs.body, rhs.body)) {
+                        if (!new TreeDiffer(paramSymbols(lhs), paramSymbols(rhs))
+                                .scan(lhs.body, rhs.body)) {
                             throw new AssertionError(
                                     String.format(
                                             "expected lambdas to be equal\n%s\n%s", lhs, rhs));
                         }
-                        if (TreeHasher.hash(lhs, sym -> paramIndex(lhs, sym))
-                                != TreeHasher.hash(rhs, sym -> paramIndex(rhs, sym))) {
+                        if (TreeHasher.hash(lhs, paramSymbols(lhs))
+                                != TreeHasher.hash(rhs, paramSymbols(rhs))) {
                             throw new AssertionError(
                                     String.format(
                                             "expected lambdas to hash to the same value\n%s\n%s",
@@ -334,14 +306,15 @@
                     }
                     for (JCLambda lhs : curr) {
                         for (JCLambda rhs : lambdaGroups.get(j)) {
-                            if (new TreeDiffer(paramsEqual(lhs, rhs)).scan(lhs.body, rhs.body)) {
+                            if (new TreeDiffer(paramSymbols(lhs), paramSymbols(rhs))
+                                    .scan(lhs.body, rhs.body)) {
                                 throw new AssertionError(
                                         String.format(
                                                 "expected lambdas to not be equal\n%s\n%s",
                                                 lhs, rhs));
                             }
-                            if (TreeHasher.hash(lhs, sym -> paramIndex(lhs, sym))
-                                    == TreeHasher.hash(rhs, sym -> paramIndex(rhs, sym))) {
+                            if (TreeHasher.hash(lhs, paramSymbols(lhs))
+                                    == TreeHasher.hash(rhs, paramSymbols(rhs))) {
                                 throw new AssertionError(
                                         String.format(
                                                 "expected lambdas to hash to different values\n%s\n%s",