langtools/src/jdk.jshell/share/classes/jdk/jshell/TaskFactory.java
changeset 34857 14d1224cfed3
parent 34092 fd97b0f6abcc
child 35000 952a7b4652f0
--- a/langtools/src/jdk.jshell/share/classes/jdk/jshell/TaskFactory.java	Thu Dec 24 10:34:05 2015 -0800
+++ b/langtools/src/jdk.jshell/share/classes/jdk/jshell/TaskFactory.java	Tue Dec 29 21:27:25 2015 -0800
@@ -56,6 +56,7 @@
 import java.util.LinkedHashMap;
 import java.util.Map;
 import java.util.stream.Collectors;
+import static java.util.stream.Collectors.toList;
 import java.util.stream.Stream;
 import javax.lang.model.util.Elements;
 import javax.tools.FileObject;
@@ -196,7 +197,7 @@
      */
     class ParseTask extends BaseTask {
 
-        private final CompilationUnitTree cut;
+        private final Iterable<? extends CompilationUnitTree> cuts;
         private final List<? extends Tree> units;
 
         ParseTask(final String source) {
@@ -204,16 +205,13 @@
                     new StringSourceHandler(),
                     "-XDallowStringFolding=false", "-proc:none");
             ReplParserFactory.instance(getContext());
-            Iterable<? extends CompilationUnitTree> asts = parse();
-            Iterator<? extends CompilationUnitTree> it = asts.iterator();
-            if (it.hasNext()) {
-                this.cut = it.next();
-                List<? extends ImportTree> imps = cut.getImports();
-                this.units = !imps.isEmpty() ? imps : cut.getTypeDecls();
-            } else {
-                this.cut = null;
-                this.units = Collections.emptyList();
-            }
+            cuts = parse();
+            units = Util.stream(cuts)
+                    .flatMap(cut -> {
+                        List<? extends ImportTree> imps = cut.getImports();
+                        return (!imps.isEmpty() ? imps : cut.getTypeDecls()).stream();
+                    })
+                    .collect(toList());
         }
 
         private Iterable<? extends CompilationUnitTree> parse() {
@@ -229,8 +227,8 @@
         }
 
         @Override
-        CompilationUnitTree cuTree() {
-            return cut;
+        Iterable<? extends CompilationUnitTree> cuTrees() {
+            return cuts;
         }
     }
 
@@ -239,7 +237,7 @@
      */
     class AnalyzeTask extends BaseTask {
 
-        private final CompilationUnitTree cut;
+        private final Iterable<? extends CompilationUnitTree> cuts;
 
         AnalyzeTask(final OuterWrap wrap) {
             this(Stream.of(wrap),
@@ -255,14 +253,7 @@
         <T>AnalyzeTask(final Stream<T> stream, SourceHandler<T> sourceHandler,
                 String... extraOptions) {
             super(stream, sourceHandler, extraOptions);
-            Iterator<? extends CompilationUnitTree> cuts = analyze().iterator();
-            if (cuts.hasNext()) {
-                this.cut = cuts.next();
-                //proc.debug("AnalyzeTask element=%s  cutp=%s  cut=%s\n", e, cutp, cut);
-            } else {
-                this.cut = null;
-                //proc.debug("AnalyzeTask -- no elements -- %s\n", getDiagnostics());
-            }
+            cuts = analyze();
         }
 
         private Iterable<? extends CompilationUnitTree> analyze() {
@@ -276,8 +267,8 @@
         }
 
         @Override
-        CompilationUnitTree cuTree() {
-            return cut;
+        Iterable<? extends CompilationUnitTree> cuTrees() {
+            return cuts;
         }
 
         Elements getElements() {
@@ -332,7 +323,7 @@
         }
 
         @Override
-        CompilationUnitTree cuTree() {
+        Iterable<? extends CompilationUnitTree> cuTrees() {
             throw new UnsupportedOperationException("Not supported.");
         }
     }
@@ -362,7 +353,11 @@
                     compilationUnits, context);
         }
 
-        abstract CompilationUnitTree cuTree();
+        abstract Iterable<? extends CompilationUnitTree> cuTrees();
+
+        CompilationUnitTree firstCuTree() {
+            return cuTrees().iterator().next();
+        }
 
         Diag diag(Diagnostic<? extends JavaFileObject> diag) {
             return sourceHandler.diag(diag);