langtools/src/share/classes/com/sun/tools/javac/comp/Infer.java
changeset 19914 d86271bd430a
parent 19127 8a0cbd5cb055
child 20249 93f8eae31092
--- a/langtools/src/share/classes/com/sun/tools/javac/comp/Infer.java	Fri Aug 30 17:36:47 2013 -0700
+++ b/langtools/src/share/classes/com/sun/tools/javac/comp/Infer.java	Mon Sep 02 22:38:36 2013 +0100
@@ -40,17 +40,17 @@
 import com.sun.tools.javac.comp.Infer.GraphSolver.InferenceGraph.Node;
 import com.sun.tools.javac.comp.Resolve.InapplicableMethodException;
 import com.sun.tools.javac.comp.Resolve.VerboseResolutionMode;
-
-import java.util.Comparator;
-import java.util.HashMap;
-import java.util.Map;
-import java.util.Set;
-import java.util.TreeSet;
+import com.sun.tools.javac.util.GraphUtils.TarjanNode;
 
 import java.util.ArrayList;
 import java.util.Collections;
+import java.util.EnumMap;
 import java.util.EnumSet;
+import java.util.HashMap;
 import java.util.HashSet;
+import java.util.LinkedHashSet;
+import java.util.Map;
+import java.util.Set;
 
 import static com.sun.tools.javac.code.TypeTag.*;
 
@@ -114,6 +114,12 @@
         }
 
         @Override
+        InapplicableMethodException setMessage() {
+            //no message to set
+            return this;
+        }
+
+        @Override
         InapplicableMethodException setMessage(JCDiagnostic diag) {
             messages = messages.append(diag);
             return this;
@@ -1006,10 +1012,24 @@
      * and (ii) tell th engine when we are done fixing inference variables
      */
     interface GraphStrategy {
+
+        /**
+         * A NodeNotFoundException is thrown whenever an inference strategy fails
+         * to pick the next node to solve in the inference graph.
+         */
+        public static class NodeNotFoundException extends RuntimeException {
+            private static final long serialVersionUID = 0;
+
+            InferenceGraph graph;
+
+            public NodeNotFoundException(InferenceGraph graph) {
+                this.graph = graph;
+            }
+        }
         /**
          * Pick the next node (leaf) to solve in the graph
          */
-        Node pickNode(InferenceGraph g);
+        Node pickNode(InferenceGraph g) throws NodeNotFoundException;
         /**
          * Is this the last step?
          */
@@ -1022,7 +1042,10 @@
      */
     abstract class LeafSolver implements GraphStrategy {
         public Node pickNode(InferenceGraph g) {
-                        Assert.check(!g.nodes.isEmpty(), "No nodes to solve!");
+            if (g.nodes.isEmpty()) {
+                //should not happen
+                throw new NodeNotFoundException(g);
+            };
             return g.nodes.get(0);
         }
 
@@ -1069,6 +1092,7 @@
      */
     abstract class BestLeafSolver extends LeafSolver {
 
+        /** list of ivars of which at least one must be solved */
         List<Type> varsToSolve;
 
         BestLeafSolver(List<Type> varsToSolve) {
@@ -1076,54 +1100,66 @@
         }
 
         /**
-         * Computes the minimum path that goes from a given node to any of the nodes
-         * containing a variable in {@code varsToSolve}. For any given path, the cost
-         * is computed as the total number of type-variables that should be eagerly
-         * instantiated across that path.
+         * Computes a path that goes from a given node to the leafs in the graph.
+         * Typically this will start from a node containing a variable in
+         * {@code varsToSolve}. For any given path, the cost is computed as the total
+         * number of type-variables that should be eagerly instantiated across that path.
          */
-        int computeMinPath(InferenceGraph g, Node n) {
-            return computeMinPath(g, n, List.<Node>nil(), 0);
+        Pair<List<Node>, Integer> computeTreeToLeafs(Node n) {
+            Pair<List<Node>, Integer> cachedPath = treeCache.get(n);
+            if (cachedPath == null) {
+                //cache miss
+                if (n.isLeaf()) {
+                    //if leaf, stop
+                    cachedPath = new Pair<List<Node>, Integer>(List.of(n), n.data.length());
+                } else {
+                    //if non-leaf, proceed recursively
+                    Pair<List<Node>, Integer> path = new Pair<List<Node>, Integer>(List.of(n), n.data.length());
+                    for (Node n2 : n.getAllDependencies()) {
+                        if (n2 == n) continue;
+                        Pair<List<Node>, Integer> subpath = computeTreeToLeafs(n2);
+                        path = new Pair<List<Node>, Integer>(
+                                path.fst.prependList(subpath.fst),
+                                path.snd + subpath.snd);
+                    }
+                    cachedPath = path;
+                }
+                //save results in cache
+                treeCache.put(n, cachedPath);
+            }
+            return cachedPath;
         }
 
-        int computeMinPath(InferenceGraph g, Node n, List<Node> path, int cost) {
-            if (path.contains(n)) return Integer.MAX_VALUE;
-            List<Node> path2 = path.prepend(n);
-            int cost2 = cost + n.data.size();
-            if (!Collections.disjoint(n.data, varsToSolve)) {
-                return cost2;
-            } else {
-               int bestPath = Integer.MAX_VALUE;
-               for (Node n2 : g.nodes) {
-                   if (n2.deps.contains(n)) {
-                       int res = computeMinPath(g, n2, path2, cost2);
-                       if (res < bestPath) {
-                           bestPath = res;
-                       }
-                   }
-                }
-               return bestPath;
-            }
-        }
+        /** cache used to avoid redundant computation of tree costs */
+        final Map<Node, Pair<List<Node>, Integer>> treeCache =
+                new HashMap<Node, Pair<List<Node>, Integer>>();
+
+        /** constant value used to mark non-existent paths */
+        final Pair<List<Node>, Integer> noPath =
+                new Pair<List<Node>, Integer>(null, Integer.MAX_VALUE);
 
         /**
          * Pick the leaf that minimize cost
          */
         @Override
         public Node pickNode(final InferenceGraph g) {
-            final Map<Node, Integer> leavesMap = new HashMap<Node, Integer>();
+            treeCache.clear(); //graph changes at every step - cache must be cleared
+            Pair<List<Node>, Integer> bestPath = noPath;
             for (Node n : g.nodes) {
-                if (n.isLeaf(n)) {
-                    leavesMap.put(n, computeMinPath(g, n));
+                if (!Collections.disjoint(n.data, varsToSolve)) {
+                    Pair<List<Node>, Integer> path = computeTreeToLeafs(n);
+                    //discard all paths containing at least a node in the
+                    //closure computed above
+                    if (path.snd < bestPath.snd) {
+                        bestPath = path;
+                    }
                 }
             }
-            Assert.check(!leavesMap.isEmpty(), "No nodes to solve!");
-            TreeSet<Node> orderedLeaves = new TreeSet<Node>(new Comparator<Node>() {
-                public int compare(Node n1, Node n2) {
-                    return leavesMap.get(n1) - leavesMap.get(n2);
-                }
-            });
-            orderedLeaves.addAll(leavesMap.keySet());
-            return orderedLeaves.first();
+            if (bestPath == noPath) {
+                //no path leads there
+                throw new NodeNotFoundException(g);
+            }
+            return bestPath.fst.head;
         }
     }
 
@@ -1321,6 +1357,33 @@
     }
 
     /**
+     * There are two kinds of dependencies between inference variables. The basic
+     * kind of dependency (or bound dependency) arises when a variable mention
+     * another variable in one of its bounds. There's also a more subtle kind
+     * of dependency that arises when a variable 'might' lead to better constraints
+     * on another variable (this is typically the case with variables holding up
+     * stuck expressions).
+     */
+    enum DependencyKind implements GraphUtils.DependencyKind {
+
+        /** bound dependency */
+        BOUND("dotted"),
+        /** stuck dependency */
+        STUCK("dashed");
+
+        final String dotSyle;
+
+        private DependencyKind(String dotSyle) {
+            this.dotSyle = dotSyle;
+        }
+
+        @Override
+        public String getDotStyle() {
+            return dotSyle;
+        }
+    }
+
+    /**
      * This is the graph inference solver - the solver organizes all inference variables in
      * a given inference context by bound dependencies - in the general case, such dependencies
      * would lead to a cyclic directed graph (hence the name); the dependency info is used to build
@@ -1331,10 +1394,12 @@
     class GraphSolver {
 
         InferenceContext inferenceContext;
+        Map<Type, Set<Type>> stuckDeps;
         Warner warn;
 
-        GraphSolver(InferenceContext inferenceContext, Warner warn) {
+        GraphSolver(InferenceContext inferenceContext, Map<Type, Set<Type>> stuckDeps, Warner warn) {
             this.inferenceContext = inferenceContext;
+            this.stuckDeps = stuckDeps;
             this.warn = warn;
         }
 
@@ -1345,7 +1410,7 @@
          */
         void solve(GraphStrategy sstrategy) {
             checkWithinBounds(inferenceContext, warn); //initial propagation of bounds
-            InferenceGraph inferenceGraph = new InferenceGraph();
+            InferenceGraph inferenceGraph = new InferenceGraph(stuckDeps);
             while (!sstrategy.done()) {
                 InferenceGraph.Node nodeToSolve = sstrategy.pickNode(inferenceGraph);
                 List<Type> varsToSolve = List.from(nodeToSolve.data);
@@ -1390,64 +1455,172 @@
              */
             class Node extends GraphUtils.TarjanNode<ListBuffer<Type>> {
 
-                Set<Node> deps;
+                /** map listing all dependencies (grouped by kind) */
+                EnumMap<DependencyKind, Set<Node>> deps;
 
                 Node(Type ivar) {
                     super(ListBuffer.of(ivar));
-                    this.deps = new HashSet<Node>();
+                    this.deps = new EnumMap<DependencyKind, Set<Node>>(DependencyKind.class);
+                }
+
+                @Override
+                public GraphUtils.DependencyKind[] getSupportedDependencyKinds() {
+                    return DependencyKind.values();
                 }
 
                 @Override
-                public Iterable<? extends Node> getDependencies() {
-                    return deps;
+                public String getDependencyName(GraphUtils.Node<ListBuffer<Type>> to, GraphUtils.DependencyKind dk) {
+                    if (dk == DependencyKind.STUCK) return "";
+                    else {
+                        StringBuilder buf = new StringBuilder();
+                        String sep = "";
+                        for (Type from : data) {
+                            UndetVar uv = (UndetVar)inferenceContext.asFree(from);
+                            for (Type bound : uv.getBounds(InferenceBound.values())) {
+                                if (bound.containsAny(List.from(to.data))) {
+                                    buf.append(sep);
+                                    buf.append(bound);
+                                    sep = ",";
+                                }
+                            }
+                        }
+                        return buf.toString();
+                    }
+                }
+
+                @Override
+                public Iterable<? extends Node> getAllDependencies() {
+                    return getDependencies(DependencyKind.values());
                 }
 
                 @Override
-                public String printDependency(GraphUtils.Node<ListBuffer<Type>> to) {
-                    StringBuilder buf = new StringBuilder();
-                    String sep = "";
-                    for (Type from : data) {
-                        UndetVar uv = (UndetVar)inferenceContext.asFree(from);
-                        for (Type bound : uv.getBounds(InferenceBound.values())) {
-                            if (bound.containsAny(List.from(to.data))) {
-                                buf.append(sep);
-                                buf.append(bound);
-                                sep = ",";
-                            }
+                public Iterable<? extends TarjanNode<ListBuffer<Type>>> getDependenciesByKind(GraphUtils.DependencyKind dk) {
+                    return getDependencies((DependencyKind)dk);
+                }
+
+                /**
+                 * Retrieves all dependencies with given kind(s).
+                 */
+                protected Set<Node> getDependencies(DependencyKind... depKinds) {
+                    Set<Node> buf = new LinkedHashSet<Node>();
+                    for (DependencyKind dk : depKinds) {
+                        Set<Node> depsByKind = deps.get(dk);
+                        if (depsByKind != null) {
+                            buf.addAll(depsByKind);
                         }
                     }
-                    return buf.toString();
+                    return buf;
+                }
+
+                /**
+                 * Adds dependency with given kind.
+                 */
+                protected void addDependency(DependencyKind dk, Node depToAdd) {
+                    Set<Node> depsByKind = deps.get(dk);
+                    if (depsByKind == null) {
+                        depsByKind = new LinkedHashSet<Node>();
+                        deps.put(dk, depsByKind);
+                    }
+                    depsByKind.add(depToAdd);
+                }
+
+                /**
+                 * Add multiple dependencies of same given kind.
+                 */
+                protected void addDependencies(DependencyKind dk, Set<Node> depsToAdd) {
+                    for (Node n : depsToAdd) {
+                        addDependency(dk, n);
+                    }
+                }
+
+                /**
+                 * Remove a dependency, regardless of its kind.
+                 */
+                protected Set<DependencyKind> removeDependency(Node n) {
+                    Set<DependencyKind> removedKinds = new HashSet<>();
+                    for (DependencyKind dk : DependencyKind.values()) {
+                        Set<Node> depsByKind = deps.get(dk);
+                        if (depsByKind == null) continue;
+                        if (depsByKind.remove(n)) {
+                            removedKinds.add(dk);
+                        }
+                    }
+                    return removedKinds;
                 }
 
-                boolean isLeaf(Node n) {
-                    //no deps, or only one self dep
-                    return (n.deps.isEmpty() ||
-                            n.deps.size() == 1 && n.deps.contains(n));
+                /**
+                 * Compute closure of a give node, by recursively walking
+                 * through all its dependencies (of given kinds)
+                 */
+                protected Set<Node> closure(DependencyKind... depKinds) {
+                    boolean progress = true;
+                    Set<Node> closure = new HashSet<Node>();
+                    closure.add(this);
+                    while (progress) {
+                        progress = false;
+                        for (Node n1 : new HashSet<Node>(closure)) {
+                            progress = closure.addAll(n1.getDependencies(depKinds));
+                        }
+                    }
+                    return closure;
                 }
 
-                void mergeWith(List<? extends Node> nodes) {
+                /**
+                 * Is this node a leaf? This means either the node has no dependencies,
+                 * or it just has self-dependencies.
+                 */
+                protected boolean isLeaf() {
+                    //no deps, or only one self dep
+                    Set<Node> allDeps = getDependencies(DependencyKind.BOUND, DependencyKind.STUCK);
+                    if (allDeps.isEmpty()) return true;
+                    for (Node n : allDeps) {
+                        if (n != this) {
+                            return false;
+                        }
+                    }
+                    return true;
+                }
+
+                /**
+                 * Merge this node with another node, acquiring its dependencies.
+                 * This routine is used to merge all cyclic node together and
+                 * form an acyclic graph.
+                 */
+                protected void mergeWith(List<? extends Node> nodes) {
                     for (Node n : nodes) {
                         Assert.check(n.data.length() == 1, "Attempt to merge a compound node!");
                         data.appendList(n.data);
-                        deps.addAll(n.deps);
+                        for (DependencyKind dk : DependencyKind.values()) {
+                            addDependencies(dk, n.getDependencies(dk));
+                        }
                     }
                     //update deps
-                    Set<Node> deps2 = new HashSet<Node>();
-                    for (Node d : deps) {
-                        if (data.contains(d.data.first())) {
-                            deps2.add(this);
-                        } else {
-                            deps2.add(d);
+                    EnumMap<DependencyKind, Set<Node>> deps2 = new EnumMap<DependencyKind, Set<Node>>(DependencyKind.class);
+                    for (DependencyKind dk : DependencyKind.values()) {
+                        for (Node d : getDependencies(dk)) {
+                            Set<Node> depsByKind = deps2.get(dk);
+                            if (depsByKind == null) {
+                                depsByKind = new LinkedHashSet<Node>();
+                                deps2.put(dk, depsByKind);
+                            }
+                            if (data.contains(d.data.first())) {
+                                depsByKind.add(this);
+                            } else {
+                                depsByKind.add(d);
+                            }
                         }
                     }
                     deps = deps2;
                 }
 
-                void graphChanged(Node from, Node to) {
-                    if (deps.contains(from)) {
-                        deps.remove(from);
+                /**
+                 * Notify all nodes that something has changed in the graph
+                 * topology.
+                 */
+                private void graphChanged(Node from, Node to) {
+                    for (DependencyKind dk : removeDependency(from)) {
                         if (to != null) {
-                            deps.add(to);
+                            addDependency(dk, to);
                         }
                     }
                 }
@@ -1456,8 +1629,21 @@
             /** the nodes in the inference graph */
             ArrayList<Node> nodes;
 
-            InferenceGraph() {
-                initNodes();
+            InferenceGraph(Map<Type, Set<Type>> optDeps) {
+                initNodes(optDeps);
+            }
+
+            /**
+             * Basic lookup helper for retrieving a graph node given an inference
+             * variable type.
+             */
+            public Node findNode(Type t) {
+                for (Node n : nodes) {
+                    if (n.data.contains(t)) {
+                        return n;
+                    }
+                }
+                return null;
             }
 
             /**
@@ -1484,24 +1670,32 @@
              * Create the graph nodes. First a simple node is created for every inference
              * variables to be solved. Then Tarjan is used to found all connected components
              * in the graph. For each component containing more than one node, a super node is
-                 * created, effectively replacing the original cyclic nodes.
+             * created, effectively replacing the original cyclic nodes.
              */
-            void initNodes() {
+            void initNodes(Map<Type, Set<Type>> stuckDeps) {
+                //add nodes
                 nodes = new ArrayList<Node>();
                 for (Type t : inferenceContext.restvars()) {
                     nodes.add(new Node(t));
                 }
+                //add dependencies
                 for (Node n_i : nodes) {
                     Type i = n_i.data.first();
+                    Set<Type> optDepsByNode = stuckDeps.get(i);
                     for (Node n_j : nodes) {
                         Type j = n_j.data.first();
                         UndetVar uv_i = (UndetVar)inferenceContext.asFree(i);
                         if (Type.containsAny(uv_i.getBounds(InferenceBound.values()), List.of(j))) {
-                            //update i's deps
-                            n_i.deps.add(n_j);
+                            //update i's bound dependencies
+                            n_i.addDependency(DependencyKind.BOUND, n_j);
+                        }
+                        if (optDepsByNode != null && optDepsByNode.contains(j)) {
+                            //update i's stuck dependencies
+                            n_i.addDependency(DependencyKind.STUCK, n_j);
                         }
                     }
                 }
+                //merge cyclic nodes
                 ArrayList<Node> acyclicNodes = new ArrayList<Node>();
                 for (List<? extends Node> conSubGraph : GraphUtils.tarjan(nodes)) {
                     if (conSubGraph.length() > 1) {
@@ -1631,8 +1825,8 @@
             return filterVars(new Filter<UndetVar>() {
                 public boolean accepts(UndetVar uv) {
                     return uv.getBounds(InferenceBound.UPPER)
-                            .diff(uv.getDeclaredBounds())
-                            .appendList(uv.getBounds(InferenceBound.EQ, InferenceBound.LOWER)).nonEmpty();
+                             .diff(uv.getDeclaredBounds())
+                             .appendList(uv.getBounds(InferenceBound.EQ, InferenceBound.LOWER)).nonEmpty();
                 }
             });
         }
@@ -1822,11 +2016,15 @@
             }
         }
 
+        private void solve(GraphStrategy ss, Warner warn) {
+            solve(ss, new HashMap<Type, Set<Type>>(), warn);
+        }
+
         /**
          * Solve with given graph strategy.
          */
-        private void solve(GraphStrategy ss, Warner warn) {
-            GraphSolver s = new GraphSolver(this, warn);
+        private void solve(GraphStrategy ss, Map<Type, Set<Type>> stuckDeps, Warner warn) {
+            GraphSolver s = new GraphSolver(this, stuckDeps, warn);
             s.solve(ss);
         }
 
@@ -1855,18 +2053,12 @@
         /**
          * Solve at least one variable in given list.
          */
-        public void solveAny(List<Type> varsToSolve, Warner warn) {
-            checkWithinBounds(this, warn); //propagate bounds
-            List<Type> boundedVars = boundedVars().intersect(restvars()).intersect(varsToSolve);
-            if (boundedVars.isEmpty()) {
-                throw inferenceException.setMessage("cyclic.inference",
-                                freeVarsIn(varsToSolve));
-            }
-            solve(new BestLeafSolver(boundedVars) {
+        public void solveAny(List<Type> varsToSolve, Map<Type, Set<Type>> optDeps, Warner warn) {
+            solve(new BestLeafSolver(varsToSolve.intersect(restvars())) {
                 public boolean done() {
                     return instvars().intersect(varsToSolve).nonEmpty();
                 }
-            }, warn);
+            }, optDeps, warn);
         }
 
         /**