Adding basic support for method bodies. jlahoda-tree-builder
authorjlahoda
Mon, 01 Apr 2019 11:44:31 +0200
branchjlahoda-tree-builder
changeset 57297 ad0be596956b
parent 57296 464cc8d22d94
child 57298 72d5f7163f12
Adding basic support for method bodies.
src/jdk.compiler/share/classes/com/sun/source/util/TreeBuilder.java
src/jdk.compiler/share/classes/com/sun/tools/javac/api/TreeBuilderImpl.java
test/langtools/tools/javac/api/ast/ASTBuilder.java
test/langtools/tools/javac/api/ast/CodeBuilder.java
--- a/src/jdk.compiler/share/classes/com/sun/source/util/TreeBuilder.java	Fri Mar 29 10:12:18 2019 +0100
+++ b/src/jdk.compiler/share/classes/com/sun/source/util/TreeBuilder.java	Mon Apr 01 11:44:31 2019 +0200
@@ -90,9 +90,10 @@
             return parameter(type, P -> {});
         }
 
+        //TODO: parameter overload type+name?
         Method parameter(Consumer<Type> type, Consumer<Parameter> parameter);
 
-        Method body(Consumer<Statements> statements);
+        Method body(Consumer<Block> statements);
         //throws, default value
     }
 
@@ -141,6 +142,7 @@
 //    }
 //
     interface Expression {
+        void equal_to(Consumer<Expression> lhs, Consumer<Expression> rhs);
         void minusminus(Consumer<Expression> expr);
         void plus(Consumer<Expression> lhs, Consumer<Expression> rhs);
         void cond(Consumer<Expression> cond, Consumer<Expression> truePart, Consumer<Expression> falsePart);
@@ -149,10 +151,19 @@
         void literal(Object value);
     }
 
-    interface Statements {
-        Statements _if(Consumer<Expression> cond, Consumer<Statements> ifPart, Consumer<Statements> elsePart);
-        Statements expr(Consumer<Expression> expr);
-        Statements skip();
+    interface StatementBase<S> {
+        S _if(Consumer<Expression> cond, Consumer<Statement> ifPart);
+        S _if(Consumer<Expression> cond, Consumer<Statement> ifPart, Consumer<Statement> elsePart);
+        S _return();
+        S _return(Consumer<Expression> expr);
+        S expr(Consumer<Expression> expr);
+        S skip();
+    }
+
+    interface Statement extends StatementBase<Void> {
+    }
+
+    interface Block extends StatementBase<Block>{
     }
 
     static void test(TreeBuilder builder) {
@@ -163,8 +174,8 @@
                                             M -> M.parameter(T -> T._class("Foo"))
                                                   .parameter(T -> T._float(), P -> P.name("whatever"))
                                                   .body(B -> B._if(E -> E.minusminus(V -> V.select(S -> S.ident("foo"), "bar")),
-                                                                   Statements::skip,
-                                                                   Statements::skip
+                                                                   Statement::skip,
+                                                                   Statement::skip
                                                                   )
                                                        )
                                             )));
--- a/src/jdk.compiler/share/classes/com/sun/tools/javac/api/TreeBuilderImpl.java	Fri Mar 29 10:12:18 2019 +0100
+++ b/src/jdk.compiler/share/classes/com/sun/tools/javac/api/TreeBuilderImpl.java	Mon Apr 01 11:44:31 2019 +0200
@@ -29,14 +29,18 @@
 import com.sun.source.tree.CompilationUnitTree;
 import com.sun.source.util.TreeBuilder;
 import com.sun.tools.javac.code.TypeTag;
+import com.sun.tools.javac.tree.JCTree;
 import com.sun.tools.javac.tree.JCTree.JCClassDecl;
 import com.sun.tools.javac.tree.JCTree.JCCompilationUnit;
 import com.sun.tools.javac.tree.JCTree.JCExpression;
+import com.sun.tools.javac.tree.JCTree.JCMethodDecl;
+import com.sun.tools.javac.tree.JCTree.JCStatement;
 import com.sun.tools.javac.tree.JCTree.JCVariableDecl;
 import com.sun.tools.javac.tree.JCTree.Tag;
 
 import com.sun.tools.javac.tree.TreeMaker;
 import com.sun.tools.javac.util.List;
+import com.sun.tools.javac.util.Name;
 import com.sun.tools.javac.util.Names;
 
 /**
@@ -124,7 +128,15 @@
 
         @Override
         public Class method(String name, Consumer<Type> restype, Consumer<Method> method) {
-            throw new UnsupportedOperationException("Not supported yet.");
+            TypeImpl ti = new TypeImpl();
+            restype.accept(ti);
+            if (ti.type == null) {
+                throw new IllegalStateException("Type not provided!");
+            }
+            MethodImpl vi = new MethodImpl(ti.type, name);
+            method.accept(vi);
+            result.defs = result.defs.append(vi.result);
+            return this;
         }
 
         @Override
@@ -225,11 +237,155 @@
         
     }
 
+    private final class MethodImpl implements Method {
+
+        private final JCMethodDecl result;
+
+        public MethodImpl(JCExpression restype, String name) {
+            result = make.MethodDef(make.Modifiers(0), names.fromString(name), restype, List.nil(), List.nil(), List.nil(), null, null);
+        }
+
+        @Override
+        public Method parameter(Consumer<Type> type, Consumer<Parameter> parameter) {
+            ParameterImpl paramImpl = new ParameterImpl(visitType(type));
+            parameter.accept(paramImpl);
+            result.params = result.params.append(paramImpl.result);
+            return this;
+        }
+
+        @Override
+        public Method body(Consumer<Block> statements) {
+            BlockImpl block = new BlockImpl();
+            statements.accept(block);
+            result.body = make.Block(0, block.statements);
+            return this;
+        }
+
+        @Override
+        public Method modifiers(Consumer<Modifiers> modifiers) {
+            throw new UnsupportedOperationException("Not supported yet.");
+        }
+
+        @Override
+        public Method javadoc(DocTree doc) {
+            throw new UnsupportedOperationException("Not supported yet.");
+        }
+
+        @Override
+        public Method javadoc(String doc) {
+            throw new UnsupportedOperationException("Not supported yet.");
+        }
+        
+    }
+
+    private final class ParameterImpl implements Parameter {
+
+        private final JCVariableDecl result;
+
+        public ParameterImpl(JCExpression type) {
+            //TODO: infer name
+            result = make.VarDef(make.Modifiers(0), null, type, null);
+        }
+
+        @Override
+        public Parameter modifiers(Consumer<Modifiers> modifiers) {
+            throw new UnsupportedOperationException("Not supported yet.");
+        }
+
+        @Override
+        public Parameter name(String name) {
+            result.name = names.fromString(name); //XXX: check not set yet.
+            return this;
+        }
+        
+    }
+
+    private final class BlockImpl extends StatementBaseImpl<Block> implements Block {
+
+        private List<JCStatement> statements = List.nil();
+
+        @Override
+        protected Block addStatement(JCStatement stat) {
+            statements = statements.append(stat);
+            return this;
+        }
+        
+    }
+
+    private final class StatementImpl extends StatementBaseImpl<Void> implements Statement {
+        private JCStatement result;
+
+        @Override
+        protected Void addStatement(JCStatement stat) {
+            if (result != null) {
+                throw new IllegalStateException();
+            }
+            result = stat;
+            return null;
+        }
+    }
+
+    private abstract class StatementBaseImpl<S> implements StatementBase<S> {
+
+        @Override
+        public S _if(Consumer<Expression> cond, Consumer<Statement> ifPart) {
+            JCExpression expr = visitExpression(cond);
+            //TODO: should this automatic wrapping with parenthesized be here?
+            expr = make.Parens(expr);
+            StatementImpl ifStatement = new StatementImpl();
+            ifPart.accept(ifStatement);
+            //TODO: check ifPart filled!
+            return addStatement(make.If(expr, ifStatement.result, null));
+        }
+
+        @Override
+        public S _if(Consumer<Expression> cond, Consumer<Statement> ifPart, Consumer<Statement> elsePart) {
+            JCExpression expr = visitExpression(cond);
+            //TODO: should this automatic wrapping with parenthesized be here?
+            expr = make.Parens(expr);
+            StatementImpl ifStatement = new StatementImpl();
+            ifPart.accept(ifStatement);
+            //TODO: check ifPart filled!
+            StatementImpl elseStatement = new StatementImpl();
+            elsePart.accept(elseStatement);
+            return addStatement(make.If(expr, ifStatement.result, elseStatement.result));
+        }
+
+        @Override
+        public S _return() {
+            throw new UnsupportedOperationException("Not supported yet.");
+        }
+
+        @Override
+        public S _return(Consumer<Expression> expr) {
+            return addStatement(make.Return(visitExpression(expr)));
+        }
+
+        @Override
+        public S expr(Consumer<Expression> expr) {
+            throw new UnsupportedOperationException("Not supported yet.");
+        }
+
+        @Override
+        public S skip() {
+            throw new UnsupportedOperationException("Not supported yet.");
+        }
+
+        protected abstract S addStatement(JCStatement stat);
+    }
+    
     private final class ExpressionImpl implements Expression {
 
         private JCExpression expr;
 
         @Override
+        public void equal_to(Consumer<Expression> lhs, Consumer<Expression> rhs) {
+            expr = make.Binary(Tag.EQ,
+                               visitExpression(lhs),
+                               visitExpression(rhs));
+        }
+
+        @Override
         public void minusminus(Consumer<Expression> expr) {
             throw new UnsupportedOperationException("Not supported yet.");
         }
--- a/test/langtools/tools/javac/api/ast/ASTBuilder.java	Fri Mar 29 10:12:18 2019 +0100
+++ b/test/langtools/tools/javac/api/ast/ASTBuilder.java	Mon Apr 01 11:44:31 2019 +0200
@@ -70,6 +70,14 @@
         runTest("class Test extends Exception implements java.util.List<Map<String, String>>, CharSequence {" +
                 "    int x1 = 2;" +
                 "    int x2 = 2 + x1;" +
+                "    int test(int param) {" +
+                "        if (param == 0) return 0;" +
+                "        else return 1;" +
+                "    }" +
+                "    int test2(int param) {" +
+                "        if (param == 0) return 0;" +
+                "        return 1;" +
+                "    }" +
                 "}");
     }
 
--- a/test/langtools/tools/javac/api/ast/CodeBuilder.java	Fri Mar 29 10:12:18 2019 +0100
+++ b/test/langtools/tools/javac/api/ast/CodeBuilder.java	Mon Apr 01 11:44:31 2019 +0200
@@ -37,8 +37,13 @@
 import java.util.Locale;
 import java.util.Set;
 
+import com.sun.source.tree.BlockTree;
+import com.sun.source.tree.IfTree;
 import com.sun.source.tree.MemberSelectTree;
+import com.sun.source.tree.MethodTree;
 import com.sun.source.tree.ParameterizedTypeTree;
+import com.sun.source.tree.ReturnTree;
+import com.sun.source.tree.StatementTree;
 
 public class CodeBuilder {
 
@@ -68,6 +73,7 @@
                 result.append(")");
                 return null;
             }
+
             @Override
             public Void visitCompilationUnit(CompilationUnitTree node, Void p) {
                 result.append(currentBuilder() + ".createCompilationUnitTree(");
@@ -75,6 +81,7 @@
                 result.append(")");
                 return null;
             }
+
             @Override
             public Void visitVariable(VariableTree node, Void p) {
                 result.append(currentBuilder() + ".field(\"" + node.getName() + "\", "); //XXX: field/vs local variable!
@@ -93,6 +100,33 @@
             }
 
             @Override
+            public Void visitMethod(MethodTree node, Void p) {
+                result.append(currentBuilder() + ".method(\"" + node.getName() + "\", ");
+                doScan("T", node.getReturnType());
+                result.append(", ");
+                doScan("M", () -> {
+                    //TODO: other attributes!
+                    for (VariableTree param : node.getParameters()) {
+                        result.append(currentBuilder() + ".parameter(");
+                        doScan("T", param.getType());
+                        result.append(", ");
+                        doScan("P", () -> {
+                            result.append(currentBuilder() + ".name(\"" + param.getName() + "\")");
+                        });
+                        //TODO: other attributes!
+                        result.append(")");
+                    }
+                    if (node.getBody() != null) {//TODO: test no/null body!
+                        result.append(currentBuilder() + ".body(");
+                        doScan("B", node.getBody());
+                        result.append(")");
+                    }
+                });
+                result.append(")");
+                return null;
+            }
+
+            @Override
             public Void visitPrimitiveType(PrimitiveTypeTree node, Void p) {
                 result.append(currentBuilder() + "._" + node.getPrimitiveTypeKind().name().toLowerCase(Locale.ROOT) + "()");
                 return null;
@@ -120,16 +154,19 @@
 
             @Override
             public Void visitBinary(BinaryTree node, Void p) {
+                String methodName;
                 switch (node.getKind()) {
                     case PLUS:
-                        result.append(currentBuilder() + ".plus(");
-                        doScan("E", node.getLeftOperand());
-                        result.append(", ");
-                        doScan("E", node.getRightOperand());
-                        result.append(")");
-                        break;
+                        methodName = "plus"; break;
+                    case EQUAL_TO:
+                        methodName = "equal_to"; break;
                     default: throw new IllegalStateException("Not handled: " + node.getKind());
                 }
+                result.append(currentBuilder() + "." + methodName + "(");
+                doScan("E", node.getLeftOperand());
+                result.append(", ");
+                doScan("E", node.getRightOperand());
+                result.append(")");
                 return null;
             }
 
@@ -150,6 +187,38 @@
                 return null;
             }
 
+//            @Override
+//            public Void visitBlock(BlockTree node, Void p) {
+////                for (StatementTree st : node.getStatements()) {
+////                    result.append(curr)
+////                }
+//                return super.visitBlock(node, p);
+//            }
+
+            @Override
+            public Void visitIf(IfTree node, Void p) {
+                result.append(currentBuilder() + "._if(");
+                doScan("E", node.getCondition());
+                result.append(", ");
+                doScan("S", node.getThenStatement());
+                if (node.getElseStatement() != null) {
+                    result.append(", ");
+                    doScan("S", node.getElseStatement());
+                }
+                result.append(")");
+                return null;
+            }
+
+            @Override
+            public Void visitReturn(ReturnTree node, Void p) {
+                result.append(currentBuilder() + "._return(");
+                if (node.getExpression()!= null) {
+                    doScan("E", node.getExpression());
+                }
+                result.append(")");
+                return null;
+            }
+
             private void handleDeclaredType(Tree t) {
                 doScan("T", () -> {
                     result.append(currentBuilder() + "._class(");