8145239: JShell: throws AssertionError when replace classes with some methods which depends on these classes
authorrfield
Tue, 29 Dec 2015 21:27:25 -0800
changeset 34857 14d1224cfed3
parent 34856 5ca50af1b45c
child 34858 ec69df775846
8145239: JShell: throws AssertionError when replace classes with some methods which depends on these classes Reviewed-by: rfield Contributed-by: bitterfoxc@gmail.com
langtools/src/jdk.jshell/share/classes/jdk/jshell/Eval.java
langtools/src/jdk.jshell/share/classes/jdk/jshell/SourceCodeAnalysisImpl.java
langtools/src/jdk.jshell/share/classes/jdk/jshell/TaskFactory.java
langtools/src/jdk.jshell/share/classes/jdk/jshell/TreeDissector.java
langtools/src/jdk.jshell/share/classes/jdk/jshell/Unit.java
langtools/src/jdk.jshell/share/classes/jdk/jshell/Util.java
langtools/test/jdk/jshell/ClassesTest.java
--- a/langtools/src/jdk.jshell/share/classes/jdk/jshell/Eval.java	Thu Dec 24 10:34:05 2015 -0800
+++ b/langtools/src/jdk.jshell/share/classes/jdk/jshell/Eval.java	Tue Dec 29 21:27:25 2015 -0800
@@ -190,7 +190,7 @@
 
     private List<SnippetEvent> processVariables(String userSource, List<? extends Tree> units, String compileSource, ParseTask pt) {
         List<SnippetEvent> allEvents = new ArrayList<>();
-        TreeDissector dis = new TreeDissector(pt);
+        TreeDissector dis = TreeDissector.createByFirstClass(pt);
         for (Tree unitTree : units) {
             VariableTree vt = (VariableTree) unitTree;
             String name = vt.getName().toString();
@@ -295,7 +295,7 @@
         TreeDependencyScanner tds = new TreeDependencyScanner();
         tds.scan(unitTree);
 
-        TreeDissector dis = new TreeDissector(pt);
+        TreeDissector dis = TreeDissector.createByFirstClass(pt);
 
         ClassTree klassTree = (ClassTree) unitTree;
         String name = klassTree.getSimpleName().toString();
@@ -354,7 +354,7 @@
         tds.scan(unitTree);
 
         MethodTree mt = (MethodTree) unitTree;
-        TreeDissector dis = new TreeDissector(pt);
+        TreeDissector dis = TreeDissector.createByFirstClass(pt);
         DiagList modDiag = modifierDiagnostics(mt.getModifiers(), dis, true);
         if (modDiag.hasErrors()) {
             return compileFailResult(modDiag, userSource);
@@ -418,8 +418,8 @@
     private ExpressionInfo typeOfExpression(String expression) {
         Wrap guts = Wrap.methodReturnWrap(expression);
         TaskFactory.AnalyzeTask at = trialCompile(guts);
-        if (!at.hasErrors() && at.cuTree() != null) {
-            return new TreeDissector(at)
+        if (!at.hasErrors() && at.firstCuTree() != null) {
+            return TreeDissector.createByFirstClass(at)
                     .typeOfReturnStatement(at.messages(), state.maps::fullClassNameAndPackageToClass);
         }
         return null;
@@ -513,13 +513,17 @@
             ins.stream().forEach(u -> u.initialize(ins));
             AnalyzeTask at = state.taskFactory.new AnalyzeTask(ins);
             ins.stream().forEach(u -> u.setDiagnostics(at));
+
             // corral any Snippets that need it
-            if (ins.stream().filter(u -> u.corralIfNeeded(ins)).count() > 0) {
+            AnalyzeTask cat;
+            if (ins.stream().anyMatch(u -> u.corralIfNeeded(ins))) {
                 // if any were corralled, re-analyze everything
-                AnalyzeTask cat = state.taskFactory.new AnalyzeTask(ins);
+                cat = state.taskFactory.new AnalyzeTask(ins);
                 ins.stream().forEach(u -> u.setCorralledDiagnostics(cat));
+            } else {
+                cat = at;
             }
-            ins.stream().forEach(u -> u.setStatus());
+            ins.stream().forEach(u -> u.setStatus(cat));
             // compile and load the legit snippets
             boolean success;
             while (true) {
--- a/langtools/src/jdk.jshell/share/classes/jdk/jshell/SourceCodeAnalysisImpl.java	Thu Dec 24 10:34:05 2015 -0800
+++ b/langtools/src/jdk.jshell/share/classes/jdk/jshell/SourceCodeAnalysisImpl.java	Tue Dec 29 21:27:25 2015 -0800
@@ -239,7 +239,7 @@
     private List<Suggestion> computeSuggestions(OuterWrap code, int cursor, int[] anchor) {
         AnalyzeTask at = proc.taskFactory.new AnalyzeTask(code);
         SourcePositions sp = at.trees().getSourcePositions();
-        CompilationUnitTree topLevel = at.cuTree();
+        CompilationUnitTree topLevel = at.firstCuTree();
         List<Suggestion> result = new ArrayList<>();
         TreePath tp = pathFor(topLevel, sp, code.snippetIndexToWrapIndex(cursor));
         if (tp != null) {
@@ -976,7 +976,7 @@
         OuterWrap codeWrap = wrapInClass(Wrap.methodWrap(code));
         AnalyzeTask at = proc.taskFactory.new AnalyzeTask(codeWrap);
         SourcePositions sp = at.trees().getSourcePositions();
-        CompilationUnitTree topLevel = at.cuTree();
+        CompilationUnitTree topLevel = at.firstCuTree();
         TreePath tp = pathFor(topLevel, sp, codeWrap.snippetIndexToWrapIndex(cursor));
 
         if (tp == null)
--- a/langtools/src/jdk.jshell/share/classes/jdk/jshell/TaskFactory.java	Thu Dec 24 10:34:05 2015 -0800
+++ b/langtools/src/jdk.jshell/share/classes/jdk/jshell/TaskFactory.java	Tue Dec 29 21:27:25 2015 -0800
@@ -56,6 +56,7 @@
 import java.util.LinkedHashMap;
 import java.util.Map;
 import java.util.stream.Collectors;
+import static java.util.stream.Collectors.toList;
 import java.util.stream.Stream;
 import javax.lang.model.util.Elements;
 import javax.tools.FileObject;
@@ -196,7 +197,7 @@
      */
     class ParseTask extends BaseTask {
 
-        private final CompilationUnitTree cut;
+        private final Iterable<? extends CompilationUnitTree> cuts;
         private final List<? extends Tree> units;
 
         ParseTask(final String source) {
@@ -204,16 +205,13 @@
                     new StringSourceHandler(),
                     "-XDallowStringFolding=false", "-proc:none");
             ReplParserFactory.instance(getContext());
-            Iterable<? extends CompilationUnitTree> asts = parse();
-            Iterator<? extends CompilationUnitTree> it = asts.iterator();
-            if (it.hasNext()) {
-                this.cut = it.next();
-                List<? extends ImportTree> imps = cut.getImports();
-                this.units = !imps.isEmpty() ? imps : cut.getTypeDecls();
-            } else {
-                this.cut = null;
-                this.units = Collections.emptyList();
-            }
+            cuts = parse();
+            units = Util.stream(cuts)
+                    .flatMap(cut -> {
+                        List<? extends ImportTree> imps = cut.getImports();
+                        return (!imps.isEmpty() ? imps : cut.getTypeDecls()).stream();
+                    })
+                    .collect(toList());
         }
 
         private Iterable<? extends CompilationUnitTree> parse() {
@@ -229,8 +227,8 @@
         }
 
         @Override
-        CompilationUnitTree cuTree() {
-            return cut;
+        Iterable<? extends CompilationUnitTree> cuTrees() {
+            return cuts;
         }
     }
 
@@ -239,7 +237,7 @@
      */
     class AnalyzeTask extends BaseTask {
 
-        private final CompilationUnitTree cut;
+        private final Iterable<? extends CompilationUnitTree> cuts;
 
         AnalyzeTask(final OuterWrap wrap) {
             this(Stream.of(wrap),
@@ -255,14 +253,7 @@
         <T>AnalyzeTask(final Stream<T> stream, SourceHandler<T> sourceHandler,
                 String... extraOptions) {
             super(stream, sourceHandler, extraOptions);
-            Iterator<? extends CompilationUnitTree> cuts = analyze().iterator();
-            if (cuts.hasNext()) {
-                this.cut = cuts.next();
-                //proc.debug("AnalyzeTask element=%s  cutp=%s  cut=%s\n", e, cutp, cut);
-            } else {
-                this.cut = null;
-                //proc.debug("AnalyzeTask -- no elements -- %s\n", getDiagnostics());
-            }
+            cuts = analyze();
         }
 
         private Iterable<? extends CompilationUnitTree> analyze() {
@@ -276,8 +267,8 @@
         }
 
         @Override
-        CompilationUnitTree cuTree() {
-            return cut;
+        Iterable<? extends CompilationUnitTree> cuTrees() {
+            return cuts;
         }
 
         Elements getElements() {
@@ -332,7 +323,7 @@
         }
 
         @Override
-        CompilationUnitTree cuTree() {
+        Iterable<? extends CompilationUnitTree> cuTrees() {
             throw new UnsupportedOperationException("Not supported.");
         }
     }
@@ -362,7 +353,11 @@
                     compilationUnits, context);
         }
 
-        abstract CompilationUnitTree cuTree();
+        abstract Iterable<? extends CompilationUnitTree> cuTrees();
+
+        CompilationUnitTree firstCuTree() {
+            return cuTrees().iterator().next();
+        }
 
         Diag diag(Diagnostic<? extends JavaFileObject> diag) {
             return sourceHandler.diag(diag);
--- a/langtools/src/jdk.jshell/share/classes/jdk/jshell/TreeDissector.java	Thu Dec 24 10:34:05 2015 -0800
+++ b/langtools/src/jdk.jshell/share/classes/jdk/jshell/TreeDissector.java	Tue Dec 29 21:27:25 2015 -0800
@@ -48,7 +48,10 @@
 import java.util.List;
 import java.util.Locale;
 import java.util.function.BinaryOperator;
+import java.util.function.Predicate;
+import java.util.stream.Stream;
 import javax.lang.model.type.TypeMirror;
+import jdk.jshell.Util.Pair;
 
 /**
  * Utilities for analyzing compiler API parse trees.
@@ -68,23 +71,48 @@
     }
 
     private final TaskFactory.BaseTask bt;
-    private ClassTree firstClass;
+    private final ClassTree targetClass;
+    private final CompilationUnitTree targetCompilationUnit;
     private SourcePositions theSourcePositions = null;
 
-    TreeDissector(TaskFactory.BaseTask bt) {
+    private TreeDissector(TaskFactory.BaseTask bt, CompilationUnitTree targetCompilationUnit, ClassTree targetClass) {
         this.bt = bt;
+        this.targetCompilationUnit = targetCompilationUnit;
+        this.targetClass = targetClass;
+    }
+
+    static TreeDissector createByFirstClass(TaskFactory.BaseTask bt) {
+        Pair<CompilationUnitTree, ClassTree> pair = classes(bt.firstCuTree())
+                .findFirst().orElseGet(() -> new Pair<>(bt.firstCuTree(), null));
+
+        return new TreeDissector(bt, pair.first, pair.second);
     }
 
+    private static final Predicate<? super Tree> isClassOrInterface =
+            t -> t.getKind() == Tree.Kind.CLASS || t.getKind() == Tree.Kind.INTERFACE;
 
-    ClassTree firstClass() {
-        if (firstClass == null) {
-            firstClass = computeFirstClass();
-        }
-        return firstClass;
+    private static Stream<Pair<CompilationUnitTree, ClassTree>> classes(CompilationUnitTree cut) {
+        return cut == null
+                ? Stream.empty()
+                : cut.getTypeDecls().stream()
+                        .filter(isClassOrInterface)
+                        .map(decl -> new Pair<>(cut, (ClassTree)decl));
     }
 
-    CompilationUnitTree cuTree() {
-        return bt.cuTree();
+    private static Stream<Pair<CompilationUnitTree, ClassTree>> classes(Iterable<? extends CompilationUnitTree> cuts) {
+        return Util.stream(cuts)
+                .flatMap(TreeDissector::classes);
+    }
+
+    static TreeDissector createBySnippet(TaskFactory.BaseTask bt, Snippet si) {
+        String name = si.className();
+
+        Pair<CompilationUnitTree, ClassTree> pair = classes(bt.cuTrees())
+                .filter(p -> p.second.getSimpleName().contentEquals(name))
+                .findFirst().orElseThrow(() ->
+                        new IllegalArgumentException("Class " + name + " is not found."));
+
+        return new TreeDissector(bt, pair.first, pair.second);
     }
 
     Types types() {
@@ -103,11 +131,11 @@
     }
 
     int getStartPosition(Tree tree) {
-        return (int) getSourcePositions().getStartPosition(cuTree(), tree);
+        return (int) getSourcePositions().getStartPosition(targetCompilationUnit, tree);
     }
 
     int getEndPosition(Tree tree) {
-        return (int) getSourcePositions().getEndPosition(cuTree(), tree);
+        return (int) getSourcePositions().getEndPosition(targetCompilationUnit, tree);
     }
 
     Range treeToRange(Tree tree) {
@@ -134,9 +162,9 @@
     }
 
     Tree firstClassMember() {
-        if (firstClass() != null) {
+        if (targetClass != null) {
             //TODO: missing classes
-            for (Tree mem : firstClass().getMembers()) {
+            for (Tree mem : targetClass.getMembers()) {
                 if (mem.getKind() == Tree.Kind.VARIABLE) {
                     return mem;
                 }
@@ -152,8 +180,8 @@
     }
 
     StatementTree firstStatement() {
-        if (firstClass() != null) {
-            for (Tree mem : firstClass().getMembers()) {
+        if (targetClass != null) {
+            for (Tree mem : targetClass.getMembers()) {
                 if (mem.getKind() == Tree.Kind.METHOD) {
                     MethodTree mt = (MethodTree) mem;
                     if (isDoIt(mt.getName())) {
@@ -169,8 +197,8 @@
     }
 
     VariableTree firstVariable() {
-        if (firstClass() != null) {
-            for (Tree mem : firstClass().getMembers()) {
+        if (targetClass != null) {
+            for (Tree mem : targetClass.getMembers()) {
                 if (mem.getKind() == Tree.Kind.VARIABLE) {
                     VariableTree vt = (VariableTree) mem;
                     return vt;
@@ -180,17 +208,6 @@
         return null;
     }
 
-    private ClassTree computeFirstClass() {
-        if (cuTree() == null) {
-            return null;
-        }
-        for (Tree decl : cuTree().getTypeDecls()) {
-            if (decl.getKind() == Tree.Kind.CLASS || decl.getKind() == Tree.Kind.INTERFACE) {
-                return (ClassTree) decl;
-            }
-        }
-        return null;
-    }
 
     ExpressionInfo typeOfReturnStatement(JavacMessages messages, BinaryOperator<String> fullClassNameAndPackageToClass) {
         ExpressionInfo ei = new ExpressionInfo();
@@ -198,7 +215,7 @@
         if (unitTree instanceof ReturnTree) {
             ei.tree = ((ReturnTree) unitTree).getExpression();
             if (ei.tree != null) {
-                TreePath viPath = trees().getPath(cuTree(), ei.tree);
+                TreePath viPath = trees().getPath(targetCompilationUnit, ei.tree);
                 if (viPath != null) {
                     TypeMirror tm = trees().getTypeMirror(viPath);
                     if (tm != null) {
--- a/langtools/src/jdk.jshell/share/classes/jdk/jshell/Unit.java	Thu Dec 24 10:34:05 2015 -0800
+++ b/langtools/src/jdk.jshell/share/classes/jdk/jshell/Unit.java	Tue Dec 29 21:27:25 2015 -0800
@@ -225,7 +225,7 @@
         return false;
     }
 
-    void setStatus() {
+    void setStatus(AnalyzeTask at) {
         if (!compilationDiagnostics.hasErrors()) {
             status = VALID;
         } else if (isRecoverable()) {
@@ -237,7 +237,7 @@
         } else {
             status = REJECTED;
         }
-        checkForOverwrite();
+        checkForOverwrite(at);
 
         state.debug(DBG_GEN, "setStatus() %s - status: %s\n",
                 si, status);
@@ -361,17 +361,18 @@
                 si, status, unresolved);
     }
 
-    private void checkForOverwrite() {
+    private void checkForOverwrite(AnalyzeTask at) {
         secondaryEvents = new ArrayList<>();
         if (replaceOldEvent != null) secondaryEvents.add(replaceOldEvent);
 
         // Defined methods can overwrite methods of other (equivalent) snippets
         if (si.kind() == Kind.METHOD && status.isDefined) {
-            String oqpt = ((MethodSnippet) si).qualifiedParameterTypes();
-            String nqpt = computeQualifiedParameterTypes(si);
+            MethodSnippet msi = (MethodSnippet)si;
+            String oqpt = msi.qualifiedParameterTypes();
+            String nqpt = computeQualifiedParameterTypes(at, msi);
             if (!nqpt.equals(oqpt)) {
-                ((MethodSnippet) si).setQualifiedParamaterTypes(nqpt);
-                Status overwrittenStatus = overwriteMatchingMethod(si);
+                msi.setQualifiedParamaterTypes(nqpt);
+                Status overwrittenStatus = overwriteMatchingMethod(msi);
                 if (overwrittenStatus != null) {
                     prevStatus = overwrittenStatus;
                     signatureChanged = true;
@@ -383,19 +384,19 @@
     // Check if there is a method whose user-declared parameter types are
     // different (and thus has a different snippet) but whose compiled parameter
     // types are the same. if so, consider it an overwrite replacement.
-    private Status overwriteMatchingMethod(Snippet si) {
-        String qpt = ((MethodSnippet) si).qualifiedParameterTypes();
+    private Status overwriteMatchingMethod(MethodSnippet msi) {
+        String qpt = msi.qualifiedParameterTypes();
 
         // Look through all methods for a method of the same name, with the
         // same computed qualified parameter types
         Status overwrittenStatus = null;
         for (MethodSnippet sn : state.methods()) {
-            if (sn != null && sn != si && sn.status().isActive && sn.name().equals(si.name())) {
+            if (sn != null && sn != msi && sn.status().isActive && sn.name().equals(msi.name())) {
                 if (qpt.equals(sn.qualifiedParameterTypes())) {
                     overwrittenStatus = sn.status();
                     SnippetEvent se = new SnippetEvent(
                             sn, overwrittenStatus, OVERWRITTEN,
-                            false, si, null, null);
+                            false, msi, null, null);
                     sn.setOverwritten();
                     secondaryEvents.add(se);
                     state.debug(DBG_EVNT,
@@ -408,20 +409,16 @@
         return overwrittenStatus;
     }
 
-    private String computeQualifiedParameterTypes(Snippet si) {
-        MethodSnippet msi = (MethodSnippet) si;
-        String qpt;
-        AnalyzeTask at = state.taskFactory.new AnalyzeTask(msi.outerWrap());
-        String rawSig = new TreeDissector(at).typeOfMethod();
+    private String computeQualifiedParameterTypes(AnalyzeTask at, MethodSnippet msi) {
+        String rawSig = TreeDissector.createBySnippet(at, msi).typeOfMethod();
         String signature = expunge(rawSig);
         int paren = signature.lastIndexOf(')');
-        if (paren < 0) {
-            // Uncompilable snippet, punt with user parameter types
-            qpt = msi.parameterTypes();
-        } else {
-            qpt = signature.substring(0, paren + 1);
-        }
-        return qpt;
+
+        // Extract the parameter type string from the method signature,
+        // if method did not compile use the user-supplied parameter types
+        return paren >= 0
+                ? signature.substring(0, paren + 1)
+                : msi.parameterTypes();
     }
 
     SnippetEvent event(String value, Exception exception) {
--- a/langtools/src/jdk.jshell/share/classes/jdk/jshell/Util.java	Thu Dec 24 10:34:05 2015 -0800
+++ b/langtools/src/jdk.jshell/share/classes/jdk/jshell/Util.java	Tue Dec 29 21:27:25 2015 -0800
@@ -91,4 +91,14 @@
     static <T> Stream<T> stream(Iterable<T> iterable) {
         return StreamSupport.stream(iterable.spliterator(), false);
     }
+
+    static class Pair<T, U> {
+        final T first;
+        final U second;
+
+        Pair(T first, U second) {
+            this.first = first;
+            this.second = second;
+        }
+    }
 }
--- a/langtools/test/jdk/jshell/ClassesTest.java	Thu Dec 24 10:34:05 2015 -0800
+++ b/langtools/test/jdk/jshell/ClassesTest.java	Tue Dec 29 21:27:25 2015 -0800
@@ -23,6 +23,7 @@
 
 /*
  * @test
+ * @bug 8145239
  * @summary Tests for EvaluationState.classes
  * @build KullaTesting TestingInputStream ExpectedDiagnostic
  * @run testng ClassesTest
@@ -174,6 +175,27 @@
         assertActiveKeys();
     }
 
+    public void classesRedeclaration3() {
+        Snippet a = classKey(assertEval("class A { }"));
+        assertClasses(clazz(KullaTesting.ClassType.CLASS, "A"));
+        assertActiveKeys();
+
+        Snippet test1 = methodKey(assertEval("A test() { return null; }"));
+        Snippet test2 = methodKey(assertEval("void test(A a) { }"));
+        Snippet test3 = methodKey(assertEval("void test(int n) {A a;}"));
+        assertActiveKeys();
+
+        assertEval("interface A { }",
+                ste(MAIN_SNIPPET, VALID, VALID, true, null),
+                ste(test1, VALID, VALID, true, MAIN_SNIPPET),
+                ste(test2, VALID, VALID, true, MAIN_SNIPPET),
+                ste(test3, VALID, VALID, false, MAIN_SNIPPET),
+                ste(a, VALID, OVERWRITTEN, false, MAIN_SNIPPET));
+        assertClasses(clazz(KullaTesting.ClassType.INTERFACE, "A"));
+        assertMethods(method("()A", "test"), method("(A)void", "test"), method("(int)void", "test"));
+        assertActiveKeys();
+    }
+
     public void classesCyclic1() {
         Snippet b = classKey(assertEval("class B extends A { }",
                 added(RECOVERABLE_NOT_DEFINED)));