src/jdk.internal.vm.compiler/share/classes/org.graalvm.compiler.nodes/src/org/graalvm/compiler/nodes/calc/MulNode.java
changeset 48190 25cfedf27edc
parent 47216 71c04702a3d5
child 50858 2d3e99a72541
--- a/src/jdk.internal.vm.compiler/share/classes/org.graalvm.compiler.nodes/src/org/graalvm/compiler/nodes/calc/MulNode.java	Fri Dec 01 14:19:16 2017 -0500
+++ b/src/jdk.internal.vm.compiler/share/classes/org.graalvm.compiler.nodes/src/org/graalvm/compiler/nodes/calc/MulNode.java	Fri Dec 01 11:17:45 2017 -0800
@@ -35,6 +35,7 @@
 import org.graalvm.compiler.lir.gen.ArithmeticLIRGeneratorTool;
 import org.graalvm.compiler.nodeinfo.NodeInfo;
 import org.graalvm.compiler.nodes.ConstantNode;
+import org.graalvm.compiler.nodes.NodeView;
 import org.graalvm.compiler.nodes.ValueNode;
 import org.graalvm.compiler.nodes.spi.NodeLIRBuilderTool;
 
@@ -56,14 +57,14 @@
         super(c, ArithmeticOpTable::getMul, x, y);
     }
 
-    public static ValueNode create(ValueNode x, ValueNode y) {
-        BinaryOp<Mul> op = ArithmeticOpTable.forStamp(x.stamp()).getMul();
-        Stamp stamp = op.foldStamp(x.stamp(), y.stamp());
-        ConstantNode tryConstantFold = tryConstantFold(op, x, y, stamp);
+    public static ValueNode create(ValueNode x, ValueNode y, NodeView view) {
+        BinaryOp<Mul> op = ArithmeticOpTable.forStamp(x.stamp(view)).getMul();
+        Stamp stamp = op.foldStamp(x.stamp(view), y.stamp(view));
+        ConstantNode tryConstantFold = tryConstantFold(op, x, y, stamp, view);
         if (tryConstantFold != null) {
             return tryConstantFold;
         }
-        return canonical(null, op, stamp, x, y);
+        return canonical(null, op, stamp, x, y, view);
     }
 
     @Override
@@ -83,10 +84,11 @@
             return new MulNode(forY, forX);
         }
         BinaryOp<Mul> op = getOp(forX, forY);
-        return canonical(this, op, stamp(), forX, forY);
+        NodeView view = NodeView.from(tool);
+        return canonical(this, op, stamp(view), forX, forY, view);
     }
 
-    private static ValueNode canonical(MulNode self, BinaryOp<Mul> op, Stamp stamp, ValueNode forX, ValueNode forY) {
+    private static ValueNode canonical(MulNode self, BinaryOp<Mul> op, Stamp stamp, ValueNode forX, ValueNode forY, NodeView view) {
         if (forY.isConstant()) {
             Constant c = forY.asConstant();
             if (op.isNeutral(c)) {
@@ -95,57 +97,64 @@
 
             if (c instanceof PrimitiveConstant && ((PrimitiveConstant) c).getJavaKind().isNumericInteger()) {
                 long i = ((PrimitiveConstant) c).asLong();
-
-                if (i == 0) {
-                    return ConstantNode.forIntegerStamp(stamp, 0);
-                } else if (i == 1) {
-                    return forX;
-                } else if (i == -1) {
-                    return NegateNode.create(forX);
-                } else if (i > 0) {
-                    if (CodeUtil.isPowerOf2(i)) {
-                        return new LeftShiftNode(forX, ConstantNode.forInt(CodeUtil.log2(i)));
-                    } else if (CodeUtil.isPowerOf2(i - 1)) {
-                        return AddNode.create(new LeftShiftNode(forX, ConstantNode.forInt(CodeUtil.log2(i - 1))), forX);
-                    } else if (CodeUtil.isPowerOf2(i + 1)) {
-                        return SubNode.create(new LeftShiftNode(forX, ConstantNode.forInt(CodeUtil.log2(i + 1))), forX);
-                    } else {
-                        int bitCount = Long.bitCount(i);
-                        long highestBitValue = Long.highestOneBit(i);
-                        if (bitCount == 2) {
-                            // e.g., 0b1000_0010
-                            long lowerBitValue = i - highestBitValue;
-                            assert highestBitValue > 0 && lowerBitValue > 0;
-                            ValueNode left = new LeftShiftNode(forX, ConstantNode.forInt(CodeUtil.log2(highestBitValue)));
-                            ValueNode right = lowerBitValue == 1 ? forX : new LeftShiftNode(forX, ConstantNode.forInt(CodeUtil.log2(lowerBitValue)));
-                            return AddNode.create(left, right);
-                        } else {
-                            // e.g., 0b1111_1101
-                            int shiftToRoundUpToPowerOf2 = CodeUtil.log2(highestBitValue) + 1;
-                            long subValue = (1 << shiftToRoundUpToPowerOf2) - i;
-                            if (CodeUtil.isPowerOf2(subValue) && shiftToRoundUpToPowerOf2 < ((IntegerStamp) stamp).getBits()) {
-                                assert CodeUtil.log2(subValue) >= 1;
-                                ValueNode left = new LeftShiftNode(forX, ConstantNode.forInt(shiftToRoundUpToPowerOf2));
-                                ValueNode right = new LeftShiftNode(forX, ConstantNode.forInt(CodeUtil.log2(subValue)));
-                                return SubNode.create(left, right);
-                            }
-                        }
-                    }
-                } else if (i < 0) {
-                    if (CodeUtil.isPowerOf2(-i)) {
-                        return NegateNode.create(LeftShiftNode.create(forX, ConstantNode.forInt(CodeUtil.log2(-i))));
-                    }
+                ValueNode result = canonical(stamp, forX, i, view);
+                if (result != null) {
+                    return result;
                 }
             }
 
             if (op.isAssociative()) {
                 // canonicalize expressions like "(a * 1) * 2"
-                return reassociate(self != null ? self : (MulNode) new MulNode(forX, forY).maybeCommuteInputs(), ValueNode.isConstantPredicate(), forX, forY);
+                return reassociate(self != null ? self : (MulNode) new MulNode(forX, forY).maybeCommuteInputs(), ValueNode.isConstantPredicate(), forX, forY, view);
             }
         }
         return self != null ? self : new MulNode(forX, forY).maybeCommuteInputs();
     }
 
+    public static ValueNode canonical(Stamp stamp, ValueNode forX, long i, NodeView view) {
+        if (i == 0) {
+            return ConstantNode.forIntegerStamp(stamp, 0);
+        } else if (i == 1) {
+            return forX;
+        } else if (i == -1) {
+            return NegateNode.create(forX, view);
+        } else if (i > 0) {
+            if (CodeUtil.isPowerOf2(i)) {
+                return new LeftShiftNode(forX, ConstantNode.forInt(CodeUtil.log2(i)));
+            } else if (CodeUtil.isPowerOf2(i - 1)) {
+                return AddNode.create(new LeftShiftNode(forX, ConstantNode.forInt(CodeUtil.log2(i - 1))), forX, view);
+            } else if (CodeUtil.isPowerOf2(i + 1)) {
+                return SubNode.create(new LeftShiftNode(forX, ConstantNode.forInt(CodeUtil.log2(i + 1))), forX, view);
+            } else {
+                int bitCount = Long.bitCount(i);
+                long highestBitValue = Long.highestOneBit(i);
+                if (bitCount == 2) {
+                    // e.g., 0b1000_0010
+                    long lowerBitValue = i - highestBitValue;
+                    assert highestBitValue > 0 && lowerBitValue > 0;
+                    ValueNode left = new LeftShiftNode(forX, ConstantNode.forInt(CodeUtil.log2(highestBitValue)));
+                    ValueNode right = lowerBitValue == 1 ? forX : new LeftShiftNode(forX, ConstantNode.forInt(CodeUtil.log2(lowerBitValue)));
+                    return AddNode.create(left, right, view);
+                } else {
+                    // e.g., 0b1111_1101
+                    int shiftToRoundUpToPowerOf2 = CodeUtil.log2(highestBitValue) + 1;
+                    long subValue = (1 << shiftToRoundUpToPowerOf2) - i;
+                    if (CodeUtil.isPowerOf2(subValue) && shiftToRoundUpToPowerOf2 < ((IntegerStamp) stamp).getBits()) {
+                        assert CodeUtil.log2(subValue) >= 1;
+                        ValueNode left = new LeftShiftNode(forX, ConstantNode.forInt(shiftToRoundUpToPowerOf2));
+                        ValueNode right = new LeftShiftNode(forX, ConstantNode.forInt(CodeUtil.log2(subValue)));
+                        return SubNode.create(left, right, view);
+                    }
+                }
+            }
+        } else if (i < 0) {
+            if (CodeUtil.isPowerOf2(-i)) {
+                return NegateNode.create(LeftShiftNode.create(forX, ConstantNode.forInt(CodeUtil.log2(-i)), view), view);
+            }
+        }
+        return null;
+    }
+
     @Override
     public void generate(NodeLIRBuilderTool nodeValueMap, ArithmeticLIRGeneratorTool gen) {
         Value op1 = nodeValueMap.operand(getX());