src/jdk.jshell/share/classes/jdk/jshell/TaskFactory.java
changeset 47350 d65c3b21081c
parent 47268 48ec75306997
child 48610 a587f95313f1
--- a/src/jdk.jshell/share/classes/jdk/jshell/TaskFactory.java	Mon Oct 16 18:15:41 2017 +0530
+++ b/src/jdk.jshell/share/classes/jdk/jshell/TaskFactory.java	Fri Sep 01 14:04:20 2017 +0200
@@ -29,10 +29,8 @@
 import com.sun.source.tree.Tree;
 import com.sun.source.util.Trees;
 import com.sun.tools.javac.api.JavacTaskImpl;
-import com.sun.tools.javac.api.JavacTool;
 import com.sun.tools.javac.util.Context;
 import java.util.ArrayList;
-import java.util.Arrays;
 import java.util.List;
 import javax.tools.Diagnostic;
 import javax.tools.DiagnosticCollector;
@@ -62,18 +60,28 @@
 import jdk.jshell.MemoryFileManager.SourceMemoryJavaFileObject;
 import java.lang.Runtime.Version;
 import java.nio.CharBuffer;
+import java.util.function.BiFunction;
 import com.sun.source.tree.Tree.Kind;
+import com.sun.source.util.TaskEvent;
+import com.sun.source.util.TaskListener;
+import com.sun.tools.javac.api.JavacTaskPool;
+import com.sun.tools.javac.code.ClassFinder;
 import com.sun.tools.javac.code.Kinds;
 import com.sun.tools.javac.code.Symbol.ClassSymbol;
+import com.sun.tools.javac.code.Symbol.PackageSymbol;
 import com.sun.tools.javac.code.Symbol.VarSymbol;
+import com.sun.tools.javac.code.Symtab;
+import com.sun.tools.javac.comp.Attr;
 import com.sun.tools.javac.parser.Parser;
+import com.sun.tools.javac.parser.ParserFactory;
 import com.sun.tools.javac.tree.JCTree.JCClassDecl;
 import com.sun.tools.javac.tree.JCTree.JCCompilationUnit;
 import com.sun.tools.javac.tree.JCTree.JCExpression;
 import com.sun.tools.javac.tree.JCTree.JCTypeCast;
 import com.sun.tools.javac.tree.JCTree.Tag;
-import com.sun.tools.javac.util.Context.Factory;
+import com.sun.tools.javac.util.Log;
 import com.sun.tools.javac.util.Log.DiscardDiagnosticHandler;
+import com.sun.tools.javac.util.Names;
 import jdk.jshell.Snippet.Status;
 
 /**
@@ -101,6 +109,7 @@
         }
         this.fileManager = new MemoryFileManager(
                 compiler.getStandardFileManager(null, null, null), state);
+        initTaskPool();
     }
 
     void addToClasspath(String path) {
@@ -108,27 +117,130 @@
         List<String> args = new ArrayList<>();
         args.add(classpath);
         fileManager().handleOption("-classpath", args.iterator());
+        initTaskPool();
     }
 
     MemoryFileManager fileManager() {
         return fileManager;
     }
 
+    public <Z> Z parse(String source,
+                       boolean forceExpression,
+                       Worker<ParseTask, Z> worker) {
+        StringSourceHandler sh = new StringSourceHandler();
+        return runTask(Stream.of(source),
+                       sh,
+                       List.of("-XDallowStringFolding=false", "-proc:none",
+                               "-XDneedsReplParserFactory=" + forceExpression),
+                       (jti, diagnostics) -> new ParseTask(sh, jti, diagnostics, forceExpression),
+                       worker);
+    }
+
+    public <Z> Z analyze(OuterWrap wrap,
+                         Worker<AnalyzeTask, Z> worker) {
+        return analyze(Collections.singletonList(wrap), worker);
+    }
+
+    public <Z> Z analyze(OuterWrap wrap,
+                         List<String> extraArgs,
+                         Worker<AnalyzeTask, Z> worker) {
+        return analyze(Collections.singletonList(wrap), extraArgs, worker);
+    }
+
+    public <Z> Z analyze(Collection<OuterWrap> wraps,
+                         Worker<AnalyzeTask, Z> worker) {
+        return analyze(wraps, Collections.emptyList(), worker);
+    }
+
+    public <Z> Z analyze(Collection<OuterWrap> wraps,
+                         List<String> extraArgs,
+                         Worker<AnalyzeTask, Z> worker) {
+        WrapSourceHandler sh = new WrapSourceHandler();
+        List<String> allOptions = new ArrayList<>();
+
+        allOptions.add("--should-stop:at=FLOW");
+        allOptions.add("-Xlint:unchecked");
+        allOptions.add("-proc:none");
+        allOptions.addAll(extraArgs);
+
+        return runTask(wraps.stream(),
+                       sh,
+                       allOptions,
+                       (jti, diagnostics) -> new AnalyzeTask(sh, jti, diagnostics),
+                       worker);
+    }
+
+    public <Z> Z compile(Collection<OuterWrap> wraps,
+                         Worker<CompileTask, Z> worker) {
+        WrapSourceHandler sh = new WrapSourceHandler();
+
+        return runTask(wraps.stream(),
+                       sh,
+                       List.of("-Xlint:unchecked", "-proc:none", "-parameters"),
+                       (jti, diagnostics) -> new CompileTask(sh, jti, diagnostics),
+                       worker);
+    }
+
+    private <S, T extends BaseTask, Z> Z runTask(Stream<S> inputs,
+                                                 SourceHandler<S> sh,
+                                                 List<String> options,
+                                                 BiFunction<JavacTaskImpl, DiagnosticCollector<JavaFileObject>, T> creator,
+                                                 Worker<T, Z> worker) {
+            List<String> allOptions = new ArrayList<>(options.size() + state.extraCompilerOptions.size());
+            allOptions.addAll(options);
+            allOptions.addAll(state.extraCompilerOptions);
+            Iterable<? extends JavaFileObject> compilationUnits = inputs
+                            .map(in -> sh.sourceToFileObject(fileManager, in))
+                            .collect(Collectors.toList());
+            DiagnosticCollector<JavaFileObject> diagnostics = new DiagnosticCollector<>();
+            return javacTaskPool.getTask(null, fileManager, diagnostics, allOptions, null,
+                                         compilationUnits, task -> {
+                 JavacTaskImpl jti = (JavacTaskImpl) task;
+                 Context context = jti.getContext();
+                 jti.addTaskListener(new TaskListenerImpl(context, state));
+                 try {
+                     return worker.withTask(creator.apply(jti, diagnostics));
+                 } finally {
+                     //additional cleanup: purge the REPL package:
+                     Symtab syms = Symtab.instance(context);
+                     Names names = Names.instance(context);
+                     PackageSymbol repl = syms.getPackage(syms.unnamedModule, names.fromString(Util.REPL_PACKAGE));
+                     if (repl != null) {
+                         for (ClassSymbol clazz : syms.getAllClasses()) {
+                             if (clazz.packge() == repl) {
+                                 syms.removeClass(syms.unnamedModule, clazz.flatName());
+                             }
+                         }
+                         repl.members_field = null;
+                         repl.completer = ClassFinder.instance(context).getCompleter();
+                     }
+                 }
+            });
+    }
+
+    interface Worker<T extends BaseTask, Z> {
+        public Z withTask(T task);
+    }
+
     // Parse a snippet and return our parse task handler
-    ParseTask parse(final String source) {
-        ParseTask pt = state.taskFactory.new ParseTask(source, false);
-        if (!pt.units().isEmpty()
-                && pt.units().get(0).getKind() == Kind.EXPRESSION_STATEMENT
-                && pt.getDiagnostics().hasOtherThanNotStatementErrors()) {
-            // It failed, it may be an expression being incorrectly
-            // parsed as having a leading type variable, example:   a < b
-            // Try forcing interpretation as an expression
-            ParseTask ept = state.taskFactory.new ParseTask(source, true);
-            if (!ept.getDiagnostics().hasOtherThanNotStatementErrors()) {
-                return ept;
+    <Z> Z parse(final String source, Worker<ParseTask, Z> worker) {
+        return parse(source, false, pt -> {
+            if (!pt.units().isEmpty()
+                    && pt.units().get(0).getKind() == Kind.EXPRESSION_STATEMENT
+                    && pt.getDiagnostics().hasOtherThanNotStatementErrors()) {
+                // It failed, it may be an expression being incorrectly
+                // parsed as having a leading type variable, example:   a < b
+                // Try forcing interpretation as an expression
+                return parse(source, true, ept -> {
+                    if (!ept.getDiagnostics().hasOtherThanNotStatementErrors()) {
+                        return worker.withTask(ept);
+                    } else {
+                        return worker.withTask(pt);
+                    }
+                });
             }
-        }
-        return pt;
+            return worker.withTask(pt);
+        });
     }
 
     private interface SourceHandler<T> {
@@ -210,11 +322,12 @@
         private final Iterable<? extends CompilationUnitTree> cuts;
         private final List<? extends Tree> units;
 
-        ParseTask(final String source, final boolean forceExpression) {
-            super(Stream.of(source),
-                    new StringSourceHandler(),
-                    "-XDallowStringFolding=false", "-proc:none");
-            ReplParserFactory.preRegister(getContext(), forceExpression);
+        private ParseTask(SourceHandler<String> sh,
+                          JavacTaskImpl task,
+                          DiagnosticCollector<JavaFileObject> diagnostics,
+                          boolean forceExpression) {
+            super(sh, task, diagnostics);
+            ReplParserFactory.preRegister(context, forceExpression);
             cuts = parse();
             units = Util.stream(cuts)
                     .flatMap(cut -> {
@@ -249,22 +362,10 @@
 
         private final Iterable<? extends CompilationUnitTree> cuts;
 
-        AnalyzeTask(final OuterWrap wrap, String... extraArgs) {
-            this(Collections.singletonList(wrap), extraArgs);
-        }
-
-        AnalyzeTask(final Collection<OuterWrap> wraps, String... extraArgs) {
-            this(wraps.stream(),
-                    new WrapSourceHandler(),
-                    Util.join(new String[] {
-                        "--should-stop:at=FLOW", "-Xlint:unchecked",
-                        "-proc:none"
-                    }, extraArgs));
-        }
-
-        private <T>AnalyzeTask(final Stream<T> stream, SourceHandler<T> sourceHandler,
-                String... extraOptions) {
-            super(stream, sourceHandler, extraOptions);
+        private AnalyzeTask(SourceHandler<OuterWrap> sh,
+                            JavacTaskImpl task,
+                            DiagnosticCollector<JavaFileObject> diagnostics) {
+            super(sh, task, diagnostics);
             cuts = analyze();
         }
 
@@ -299,9 +400,10 @@
 
         private final Map<OuterWrap, List<OutputMemoryJavaFileObject>> classObjs = new HashMap<>();
 
-        CompileTask(final Collection<OuterWrap> wraps) {
-            super(wraps.stream(), new WrapSourceHandler(),
-                    "-Xlint:unchecked", "-proc:none", "-parameters");
+        CompileTask(SourceHandler<OuterWrap>sh,
+                    JavacTaskImpl jti,
+                    DiagnosticCollector<JavaFileObject> diagnostics) {
+            super(sh, jti, diagnostics);
         }
 
         boolean compile() {
@@ -346,32 +448,30 @@
         }
     }
 
+    private JavacTaskPool javacTaskPool;
+
+    private void initTaskPool() {
+        javacTaskPool = new JavacTaskPool(5);
+    }
+
     abstract class BaseTask {
 
-        final DiagnosticCollector<JavaFileObject> diagnostics = new DiagnosticCollector<>();
+        final DiagnosticCollector<JavaFileObject> diagnostics;
         final JavacTaskImpl task;
         private DiagList diags = null;
         private final SourceHandler<?> sourceHandler;
-        final Context context = new Context();
+        final Context context;
         private Types types;
         private JavacMessages messages;
         private Trees trees;
 
-        private <T>BaseTask(Stream<T> inputs,
-                //BiFunction<MemoryFileManager, T, JavaFileObject> sfoCreator,
-                SourceHandler<T> sh,
-                String... extraOptions) {
+        private <T>BaseTask(SourceHandler<T> sh,
+                            JavacTaskImpl task,
+                            DiagnosticCollector<JavaFileObject> diagnostics) {
             this.sourceHandler = sh;
-            List<String> options = new ArrayList<>(extraOptions.length + state.extraCompilerOptions.size());
-            options.addAll(Arrays.asList(extraOptions));
-            options.addAll(state.extraCompilerOptions);
-            Iterable<? extends JavaFileObject> compilationUnits = inputs
-                            .map(in -> sh.sourceToFileObject(fileManager, in))
-                            .collect(Collectors.toList());
-            JShellJavaCompiler.preRegister(context, state);
-            this.task = (JavacTaskImpl) ((JavacTool) compiler).getTask(null,
-                    fileManager, diagnostics, options, null,
-                    compilationUnits, context);
+            this.task = task;
+            context = task.getContext();
+            this.diagnostics = diagnostics;
         }
 
         abstract Iterable<? extends CompilationUnitTree> cuTrees();
@@ -478,32 +578,36 @@
         }
     }
 
-    private static final class JShellJavaCompiler extends com.sun.tools.javac.main.JavaCompiler {
+    private static final class TaskListenerImpl implements TaskListener {
 
-        public static void preRegister(Context c, JShell state) {
-            c.put(compilerKey, (Factory<com.sun.tools.javac.main.JavaCompiler>) i -> new JShellJavaCompiler(i, state));
-        }
-
+        private final Context context;
         private final JShell state;
 
-        public JShellJavaCompiler(Context context, JShell state) {
-            super(context);
+        public TaskListenerImpl(Context context, JShell state) {
+            this.context = context;
             this.state = state;
         }
 
         @Override
-        public void processAnnotations(com.sun.tools.javac.util.List<JCCompilationUnit> roots, Collection<String> classnames) {
-            super.processAnnotations(roots, classnames);
+        public void finished(TaskEvent e) {
+            if (e.getKind() != TaskEvent.Kind.ENTER)
+                return ;
             state.maps
                  .snippetList()
                  .stream()
                  .filter(s -> s.status() == Status.VALID)
                  .filter(s -> s.kind() == Snippet.Kind.VAR)
                  .filter(s -> s.subKind() == Snippet.SubKind.VAR_DECLARATION_WITH_INITIALIZER_SUBKIND)
-                 .forEach(s -> setVariableType(roots, (VarSnippet) s));
+                 .forEach(s -> setVariableType((JCCompilationUnit) e.getCompilationUnit(), (VarSnippet) s));
         }
 
-        private void setVariableType(com.sun.tools.javac.util.List<JCCompilationUnit> roots, VarSnippet s) {
+        private void setVariableType(JCCompilationUnit root, VarSnippet s) {
+            Symtab syms = Symtab.instance(context);
+            Names names = Names.instance(context);
+            Log log  = Log.instance(context);
+            ParserFactory parserFactory = ParserFactory.instance(context);
+            Attr attr = Attr.instance(context);
+
             ClassSymbol clazz = syms.getClass(syms.unnamedModule, names.fromString(s.classFullName()));
             if (clazz == null || !clazz.isCompleted())
                 return;
@@ -520,7 +624,7 @@
                         JCTypeCast tree = (JCTypeCast) expr;
                         if (tree.clazz.hasTag(Tag.TYPEINTERSECTION)) {
                             field.type = attr.attribType(tree.clazz,
-                                                         ((JCClassDecl) roots.head.getTypeDecls().head).sym);
+                                                         ((JCClassDecl) root.getTypeDecls().head).sym);
                         }
                     }
                 } finally {