8145239: JShell: throws AssertionError when replace classes with some methods which depends on these classes
Reviewed-by: rfield
Contributed-by: bitterfoxc@gmail.com
--- a/langtools/src/jdk.jshell/share/classes/jdk/jshell/Eval.java Thu Dec 24 10:34:05 2015 -0800
+++ b/langtools/src/jdk.jshell/share/classes/jdk/jshell/Eval.java Tue Dec 29 21:27:25 2015 -0800
@@ -190,7 +190,7 @@
private List<SnippetEvent> processVariables(String userSource, List<? extends Tree> units, String compileSource, ParseTask pt) {
List<SnippetEvent> allEvents = new ArrayList<>();
- TreeDissector dis = new TreeDissector(pt);
+ TreeDissector dis = TreeDissector.createByFirstClass(pt);
for (Tree unitTree : units) {
VariableTree vt = (VariableTree) unitTree;
String name = vt.getName().toString();
@@ -295,7 +295,7 @@
TreeDependencyScanner tds = new TreeDependencyScanner();
tds.scan(unitTree);
- TreeDissector dis = new TreeDissector(pt);
+ TreeDissector dis = TreeDissector.createByFirstClass(pt);
ClassTree klassTree = (ClassTree) unitTree;
String name = klassTree.getSimpleName().toString();
@@ -354,7 +354,7 @@
tds.scan(unitTree);
MethodTree mt = (MethodTree) unitTree;
- TreeDissector dis = new TreeDissector(pt);
+ TreeDissector dis = TreeDissector.createByFirstClass(pt);
DiagList modDiag = modifierDiagnostics(mt.getModifiers(), dis, true);
if (modDiag.hasErrors()) {
return compileFailResult(modDiag, userSource);
@@ -418,8 +418,8 @@
private ExpressionInfo typeOfExpression(String expression) {
Wrap guts = Wrap.methodReturnWrap(expression);
TaskFactory.AnalyzeTask at = trialCompile(guts);
- if (!at.hasErrors() && at.cuTree() != null) {
- return new TreeDissector(at)
+ if (!at.hasErrors() && at.firstCuTree() != null) {
+ return TreeDissector.createByFirstClass(at)
.typeOfReturnStatement(at.messages(), state.maps::fullClassNameAndPackageToClass);
}
return null;
@@ -513,13 +513,17 @@
ins.stream().forEach(u -> u.initialize(ins));
AnalyzeTask at = state.taskFactory.new AnalyzeTask(ins);
ins.stream().forEach(u -> u.setDiagnostics(at));
+
// corral any Snippets that need it
- if (ins.stream().filter(u -> u.corralIfNeeded(ins)).count() > 0) {
+ AnalyzeTask cat;
+ if (ins.stream().anyMatch(u -> u.corralIfNeeded(ins))) {
// if any were corralled, re-analyze everything
- AnalyzeTask cat = state.taskFactory.new AnalyzeTask(ins);
+ cat = state.taskFactory.new AnalyzeTask(ins);
ins.stream().forEach(u -> u.setCorralledDiagnostics(cat));
+ } else {
+ cat = at;
}
- ins.stream().forEach(u -> u.setStatus());
+ ins.stream().forEach(u -> u.setStatus(cat));
// compile and load the legit snippets
boolean success;
while (true) {
--- a/langtools/src/jdk.jshell/share/classes/jdk/jshell/SourceCodeAnalysisImpl.java Thu Dec 24 10:34:05 2015 -0800
+++ b/langtools/src/jdk.jshell/share/classes/jdk/jshell/SourceCodeAnalysisImpl.java Tue Dec 29 21:27:25 2015 -0800
@@ -239,7 +239,7 @@
private List<Suggestion> computeSuggestions(OuterWrap code, int cursor, int[] anchor) {
AnalyzeTask at = proc.taskFactory.new AnalyzeTask(code);
SourcePositions sp = at.trees().getSourcePositions();
- CompilationUnitTree topLevel = at.cuTree();
+ CompilationUnitTree topLevel = at.firstCuTree();
List<Suggestion> result = new ArrayList<>();
TreePath tp = pathFor(topLevel, sp, code.snippetIndexToWrapIndex(cursor));
if (tp != null) {
@@ -976,7 +976,7 @@
OuterWrap codeWrap = wrapInClass(Wrap.methodWrap(code));
AnalyzeTask at = proc.taskFactory.new AnalyzeTask(codeWrap);
SourcePositions sp = at.trees().getSourcePositions();
- CompilationUnitTree topLevel = at.cuTree();
+ CompilationUnitTree topLevel = at.firstCuTree();
TreePath tp = pathFor(topLevel, sp, codeWrap.snippetIndexToWrapIndex(cursor));
if (tp == null)
--- a/langtools/src/jdk.jshell/share/classes/jdk/jshell/TaskFactory.java Thu Dec 24 10:34:05 2015 -0800
+++ b/langtools/src/jdk.jshell/share/classes/jdk/jshell/TaskFactory.java Tue Dec 29 21:27:25 2015 -0800
@@ -56,6 +56,7 @@
import java.util.LinkedHashMap;
import java.util.Map;
import java.util.stream.Collectors;
+import static java.util.stream.Collectors.toList;
import java.util.stream.Stream;
import javax.lang.model.util.Elements;
import javax.tools.FileObject;
@@ -196,7 +197,7 @@
*/
class ParseTask extends BaseTask {
- private final CompilationUnitTree cut;
+ private final Iterable<? extends CompilationUnitTree> cuts;
private final List<? extends Tree> units;
ParseTask(final String source) {
@@ -204,16 +205,13 @@
new StringSourceHandler(),
"-XDallowStringFolding=false", "-proc:none");
ReplParserFactory.instance(getContext());
- Iterable<? extends CompilationUnitTree> asts = parse();
- Iterator<? extends CompilationUnitTree> it = asts.iterator();
- if (it.hasNext()) {
- this.cut = it.next();
- List<? extends ImportTree> imps = cut.getImports();
- this.units = !imps.isEmpty() ? imps : cut.getTypeDecls();
- } else {
- this.cut = null;
- this.units = Collections.emptyList();
- }
+ cuts = parse();
+ units = Util.stream(cuts)
+ .flatMap(cut -> {
+ List<? extends ImportTree> imps = cut.getImports();
+ return (!imps.isEmpty() ? imps : cut.getTypeDecls()).stream();
+ })
+ .collect(toList());
}
private Iterable<? extends CompilationUnitTree> parse() {
@@ -229,8 +227,8 @@
}
@Override
- CompilationUnitTree cuTree() {
- return cut;
+ Iterable<? extends CompilationUnitTree> cuTrees() {
+ return cuts;
}
}
@@ -239,7 +237,7 @@
*/
class AnalyzeTask extends BaseTask {
- private final CompilationUnitTree cut;
+ private final Iterable<? extends CompilationUnitTree> cuts;
AnalyzeTask(final OuterWrap wrap) {
this(Stream.of(wrap),
@@ -255,14 +253,7 @@
<T>AnalyzeTask(final Stream<T> stream, SourceHandler<T> sourceHandler,
String... extraOptions) {
super(stream, sourceHandler, extraOptions);
- Iterator<? extends CompilationUnitTree> cuts = analyze().iterator();
- if (cuts.hasNext()) {
- this.cut = cuts.next();
- //proc.debug("AnalyzeTask element=%s cutp=%s cut=%s\n", e, cutp, cut);
- } else {
- this.cut = null;
- //proc.debug("AnalyzeTask -- no elements -- %s\n", getDiagnostics());
- }
+ cuts = analyze();
}
private Iterable<? extends CompilationUnitTree> analyze() {
@@ -276,8 +267,8 @@
}
@Override
- CompilationUnitTree cuTree() {
- return cut;
+ Iterable<? extends CompilationUnitTree> cuTrees() {
+ return cuts;
}
Elements getElements() {
@@ -332,7 +323,7 @@
}
@Override
- CompilationUnitTree cuTree() {
+ Iterable<? extends CompilationUnitTree> cuTrees() {
throw new UnsupportedOperationException("Not supported.");
}
}
@@ -362,7 +353,11 @@
compilationUnits, context);
}
- abstract CompilationUnitTree cuTree();
+ abstract Iterable<? extends CompilationUnitTree> cuTrees();
+
+ CompilationUnitTree firstCuTree() {
+ return cuTrees().iterator().next();
+ }
Diag diag(Diagnostic<? extends JavaFileObject> diag) {
return sourceHandler.diag(diag);
--- a/langtools/src/jdk.jshell/share/classes/jdk/jshell/TreeDissector.java Thu Dec 24 10:34:05 2015 -0800
+++ b/langtools/src/jdk.jshell/share/classes/jdk/jshell/TreeDissector.java Tue Dec 29 21:27:25 2015 -0800
@@ -48,7 +48,10 @@
import java.util.List;
import java.util.Locale;
import java.util.function.BinaryOperator;
+import java.util.function.Predicate;
+import java.util.stream.Stream;
import javax.lang.model.type.TypeMirror;
+import jdk.jshell.Util.Pair;
/**
* Utilities for analyzing compiler API parse trees.
@@ -68,23 +71,48 @@
}
private final TaskFactory.BaseTask bt;
- private ClassTree firstClass;
+ private final ClassTree targetClass;
+ private final CompilationUnitTree targetCompilationUnit;
private SourcePositions theSourcePositions = null;
- TreeDissector(TaskFactory.BaseTask bt) {
+ private TreeDissector(TaskFactory.BaseTask bt, CompilationUnitTree targetCompilationUnit, ClassTree targetClass) {
this.bt = bt;
+ this.targetCompilationUnit = targetCompilationUnit;
+ this.targetClass = targetClass;
+ }
+
+ static TreeDissector createByFirstClass(TaskFactory.BaseTask bt) {
+ Pair<CompilationUnitTree, ClassTree> pair = classes(bt.firstCuTree())
+ .findFirst().orElseGet(() -> new Pair<>(bt.firstCuTree(), null));
+
+ return new TreeDissector(bt, pair.first, pair.second);
}
+ private static final Predicate<? super Tree> isClassOrInterface =
+ t -> t.getKind() == Tree.Kind.CLASS || t.getKind() == Tree.Kind.INTERFACE;
- ClassTree firstClass() {
- if (firstClass == null) {
- firstClass = computeFirstClass();
- }
- return firstClass;
+ private static Stream<Pair<CompilationUnitTree, ClassTree>> classes(CompilationUnitTree cut) {
+ return cut == null
+ ? Stream.empty()
+ : cut.getTypeDecls().stream()
+ .filter(isClassOrInterface)
+ .map(decl -> new Pair<>(cut, (ClassTree)decl));
}
- CompilationUnitTree cuTree() {
- return bt.cuTree();
+ private static Stream<Pair<CompilationUnitTree, ClassTree>> classes(Iterable<? extends CompilationUnitTree> cuts) {
+ return Util.stream(cuts)
+ .flatMap(TreeDissector::classes);
+ }
+
+ static TreeDissector createBySnippet(TaskFactory.BaseTask bt, Snippet si) {
+ String name = si.className();
+
+ Pair<CompilationUnitTree, ClassTree> pair = classes(bt.cuTrees())
+ .filter(p -> p.second.getSimpleName().contentEquals(name))
+ .findFirst().orElseThrow(() ->
+ new IllegalArgumentException("Class " + name + " is not found."));
+
+ return new TreeDissector(bt, pair.first, pair.second);
}
Types types() {
@@ -103,11 +131,11 @@
}
int getStartPosition(Tree tree) {
- return (int) getSourcePositions().getStartPosition(cuTree(), tree);
+ return (int) getSourcePositions().getStartPosition(targetCompilationUnit, tree);
}
int getEndPosition(Tree tree) {
- return (int) getSourcePositions().getEndPosition(cuTree(), tree);
+ return (int) getSourcePositions().getEndPosition(targetCompilationUnit, tree);
}
Range treeToRange(Tree tree) {
@@ -134,9 +162,9 @@
}
Tree firstClassMember() {
- if (firstClass() != null) {
+ if (targetClass != null) {
//TODO: missing classes
- for (Tree mem : firstClass().getMembers()) {
+ for (Tree mem : targetClass.getMembers()) {
if (mem.getKind() == Tree.Kind.VARIABLE) {
return mem;
}
@@ -152,8 +180,8 @@
}
StatementTree firstStatement() {
- if (firstClass() != null) {
- for (Tree mem : firstClass().getMembers()) {
+ if (targetClass != null) {
+ for (Tree mem : targetClass.getMembers()) {
if (mem.getKind() == Tree.Kind.METHOD) {
MethodTree mt = (MethodTree) mem;
if (isDoIt(mt.getName())) {
@@ -169,8 +197,8 @@
}
VariableTree firstVariable() {
- if (firstClass() != null) {
- for (Tree mem : firstClass().getMembers()) {
+ if (targetClass != null) {
+ for (Tree mem : targetClass.getMembers()) {
if (mem.getKind() == Tree.Kind.VARIABLE) {
VariableTree vt = (VariableTree) mem;
return vt;
@@ -180,17 +208,6 @@
return null;
}
- private ClassTree computeFirstClass() {
- if (cuTree() == null) {
- return null;
- }
- for (Tree decl : cuTree().getTypeDecls()) {
- if (decl.getKind() == Tree.Kind.CLASS || decl.getKind() == Tree.Kind.INTERFACE) {
- return (ClassTree) decl;
- }
- }
- return null;
- }
ExpressionInfo typeOfReturnStatement(JavacMessages messages, BinaryOperator<String> fullClassNameAndPackageToClass) {
ExpressionInfo ei = new ExpressionInfo();
@@ -198,7 +215,7 @@
if (unitTree instanceof ReturnTree) {
ei.tree = ((ReturnTree) unitTree).getExpression();
if (ei.tree != null) {
- TreePath viPath = trees().getPath(cuTree(), ei.tree);
+ TreePath viPath = trees().getPath(targetCompilationUnit, ei.tree);
if (viPath != null) {
TypeMirror tm = trees().getTypeMirror(viPath);
if (tm != null) {
--- a/langtools/src/jdk.jshell/share/classes/jdk/jshell/Unit.java Thu Dec 24 10:34:05 2015 -0800
+++ b/langtools/src/jdk.jshell/share/classes/jdk/jshell/Unit.java Tue Dec 29 21:27:25 2015 -0800
@@ -225,7 +225,7 @@
return false;
}
- void setStatus() {
+ void setStatus(AnalyzeTask at) {
if (!compilationDiagnostics.hasErrors()) {
status = VALID;
} else if (isRecoverable()) {
@@ -237,7 +237,7 @@
} else {
status = REJECTED;
}
- checkForOverwrite();
+ checkForOverwrite(at);
state.debug(DBG_GEN, "setStatus() %s - status: %s\n",
si, status);
@@ -361,17 +361,18 @@
si, status, unresolved);
}
- private void checkForOverwrite() {
+ private void checkForOverwrite(AnalyzeTask at) {
secondaryEvents = new ArrayList<>();
if (replaceOldEvent != null) secondaryEvents.add(replaceOldEvent);
// Defined methods can overwrite methods of other (equivalent) snippets
if (si.kind() == Kind.METHOD && status.isDefined) {
- String oqpt = ((MethodSnippet) si).qualifiedParameterTypes();
- String nqpt = computeQualifiedParameterTypes(si);
+ MethodSnippet msi = (MethodSnippet)si;
+ String oqpt = msi.qualifiedParameterTypes();
+ String nqpt = computeQualifiedParameterTypes(at, msi);
if (!nqpt.equals(oqpt)) {
- ((MethodSnippet) si).setQualifiedParamaterTypes(nqpt);
- Status overwrittenStatus = overwriteMatchingMethod(si);
+ msi.setQualifiedParamaterTypes(nqpt);
+ Status overwrittenStatus = overwriteMatchingMethod(msi);
if (overwrittenStatus != null) {
prevStatus = overwrittenStatus;
signatureChanged = true;
@@ -383,19 +384,19 @@
// Check if there is a method whose user-declared parameter types are
// different (and thus has a different snippet) but whose compiled parameter
// types are the same. if so, consider it an overwrite replacement.
- private Status overwriteMatchingMethod(Snippet si) {
- String qpt = ((MethodSnippet) si).qualifiedParameterTypes();
+ private Status overwriteMatchingMethod(MethodSnippet msi) {
+ String qpt = msi.qualifiedParameterTypes();
// Look through all methods for a method of the same name, with the
// same computed qualified parameter types
Status overwrittenStatus = null;
for (MethodSnippet sn : state.methods()) {
- if (sn != null && sn != si && sn.status().isActive && sn.name().equals(si.name())) {
+ if (sn != null && sn != msi && sn.status().isActive && sn.name().equals(msi.name())) {
if (qpt.equals(sn.qualifiedParameterTypes())) {
overwrittenStatus = sn.status();
SnippetEvent se = new SnippetEvent(
sn, overwrittenStatus, OVERWRITTEN,
- false, si, null, null);
+ false, msi, null, null);
sn.setOverwritten();
secondaryEvents.add(se);
state.debug(DBG_EVNT,
@@ -408,20 +409,16 @@
return overwrittenStatus;
}
- private String computeQualifiedParameterTypes(Snippet si) {
- MethodSnippet msi = (MethodSnippet) si;
- String qpt;
- AnalyzeTask at = state.taskFactory.new AnalyzeTask(msi.outerWrap());
- String rawSig = new TreeDissector(at).typeOfMethod();
+ private String computeQualifiedParameterTypes(AnalyzeTask at, MethodSnippet msi) {
+ String rawSig = TreeDissector.createBySnippet(at, msi).typeOfMethod();
String signature = expunge(rawSig);
int paren = signature.lastIndexOf(')');
- if (paren < 0) {
- // Uncompilable snippet, punt with user parameter types
- qpt = msi.parameterTypes();
- } else {
- qpt = signature.substring(0, paren + 1);
- }
- return qpt;
+
+ // Extract the parameter type string from the method signature,
+ // if method did not compile use the user-supplied parameter types
+ return paren >= 0
+ ? signature.substring(0, paren + 1)
+ : msi.parameterTypes();
}
SnippetEvent event(String value, Exception exception) {
--- a/langtools/src/jdk.jshell/share/classes/jdk/jshell/Util.java Thu Dec 24 10:34:05 2015 -0800
+++ b/langtools/src/jdk.jshell/share/classes/jdk/jshell/Util.java Tue Dec 29 21:27:25 2015 -0800
@@ -91,4 +91,14 @@
static <T> Stream<T> stream(Iterable<T> iterable) {
return StreamSupport.stream(iterable.spliterator(), false);
}
+
+ static class Pair<T, U> {
+ final T first;
+ final U second;
+
+ Pair(T first, U second) {
+ this.first = first;
+ this.second = second;
+ }
+ }
}
--- a/langtools/test/jdk/jshell/ClassesTest.java Thu Dec 24 10:34:05 2015 -0800
+++ b/langtools/test/jdk/jshell/ClassesTest.java Tue Dec 29 21:27:25 2015 -0800
@@ -23,6 +23,7 @@
/*
* @test
+ * @bug 8145239
* @summary Tests for EvaluationState.classes
* @build KullaTesting TestingInputStream ExpectedDiagnostic
* @run testng ClassesTest
@@ -174,6 +175,27 @@
assertActiveKeys();
}
+ public void classesRedeclaration3() {
+ Snippet a = classKey(assertEval("class A { }"));
+ assertClasses(clazz(KullaTesting.ClassType.CLASS, "A"));
+ assertActiveKeys();
+
+ Snippet test1 = methodKey(assertEval("A test() { return null; }"));
+ Snippet test2 = methodKey(assertEval("void test(A a) { }"));
+ Snippet test3 = methodKey(assertEval("void test(int n) {A a;}"));
+ assertActiveKeys();
+
+ assertEval("interface A { }",
+ ste(MAIN_SNIPPET, VALID, VALID, true, null),
+ ste(test1, VALID, VALID, true, MAIN_SNIPPET),
+ ste(test2, VALID, VALID, true, MAIN_SNIPPET),
+ ste(test3, VALID, VALID, false, MAIN_SNIPPET),
+ ste(a, VALID, OVERWRITTEN, false, MAIN_SNIPPET));
+ assertClasses(clazz(KullaTesting.ClassType.INTERFACE, "A"));
+ assertMethods(method("()A", "test"), method("(A)void", "test"), method("(int)void", "test"));
+ assertActiveKeys();
+ }
+
public void classesCyclic1() {
Snippet b = classKey(assertEval("class B extends A { }",
added(RECOVERABLE_NOT_DEFINED)));