--- a/src/jdk.internal.vm.compiler/share/classes/org.graalvm.compiler.nodes/src/org/graalvm/compiler/nodes/calc/MulNode.java Fri Dec 01 14:19:16 2017 -0500
+++ b/src/jdk.internal.vm.compiler/share/classes/org.graalvm.compiler.nodes/src/org/graalvm/compiler/nodes/calc/MulNode.java Fri Dec 01 11:17:45 2017 -0800
@@ -35,6 +35,7 @@
import org.graalvm.compiler.lir.gen.ArithmeticLIRGeneratorTool;
import org.graalvm.compiler.nodeinfo.NodeInfo;
import org.graalvm.compiler.nodes.ConstantNode;
+import org.graalvm.compiler.nodes.NodeView;
import org.graalvm.compiler.nodes.ValueNode;
import org.graalvm.compiler.nodes.spi.NodeLIRBuilderTool;
@@ -56,14 +57,14 @@
super(c, ArithmeticOpTable::getMul, x, y);
}
- public static ValueNode create(ValueNode x, ValueNode y) {
- BinaryOp<Mul> op = ArithmeticOpTable.forStamp(x.stamp()).getMul();
- Stamp stamp = op.foldStamp(x.stamp(), y.stamp());
- ConstantNode tryConstantFold = tryConstantFold(op, x, y, stamp);
+ public static ValueNode create(ValueNode x, ValueNode y, NodeView view) {
+ BinaryOp<Mul> op = ArithmeticOpTable.forStamp(x.stamp(view)).getMul();
+ Stamp stamp = op.foldStamp(x.stamp(view), y.stamp(view));
+ ConstantNode tryConstantFold = tryConstantFold(op, x, y, stamp, view);
if (tryConstantFold != null) {
return tryConstantFold;
}
- return canonical(null, op, stamp, x, y);
+ return canonical(null, op, stamp, x, y, view);
}
@Override
@@ -83,10 +84,11 @@
return new MulNode(forY, forX);
}
BinaryOp<Mul> op = getOp(forX, forY);
- return canonical(this, op, stamp(), forX, forY);
+ NodeView view = NodeView.from(tool);
+ return canonical(this, op, stamp(view), forX, forY, view);
}
- private static ValueNode canonical(MulNode self, BinaryOp<Mul> op, Stamp stamp, ValueNode forX, ValueNode forY) {
+ private static ValueNode canonical(MulNode self, BinaryOp<Mul> op, Stamp stamp, ValueNode forX, ValueNode forY, NodeView view) {
if (forY.isConstant()) {
Constant c = forY.asConstant();
if (op.isNeutral(c)) {
@@ -95,57 +97,64 @@
if (c instanceof PrimitiveConstant && ((PrimitiveConstant) c).getJavaKind().isNumericInteger()) {
long i = ((PrimitiveConstant) c).asLong();
-
- if (i == 0) {
- return ConstantNode.forIntegerStamp(stamp, 0);
- } else if (i == 1) {
- return forX;
- } else if (i == -1) {
- return NegateNode.create(forX);
- } else if (i > 0) {
- if (CodeUtil.isPowerOf2(i)) {
- return new LeftShiftNode(forX, ConstantNode.forInt(CodeUtil.log2(i)));
- } else if (CodeUtil.isPowerOf2(i - 1)) {
- return AddNode.create(new LeftShiftNode(forX, ConstantNode.forInt(CodeUtil.log2(i - 1))), forX);
- } else if (CodeUtil.isPowerOf2(i + 1)) {
- return SubNode.create(new LeftShiftNode(forX, ConstantNode.forInt(CodeUtil.log2(i + 1))), forX);
- } else {
- int bitCount = Long.bitCount(i);
- long highestBitValue = Long.highestOneBit(i);
- if (bitCount == 2) {
- // e.g., 0b1000_0010
- long lowerBitValue = i - highestBitValue;
- assert highestBitValue > 0 && lowerBitValue > 0;
- ValueNode left = new LeftShiftNode(forX, ConstantNode.forInt(CodeUtil.log2(highestBitValue)));
- ValueNode right = lowerBitValue == 1 ? forX : new LeftShiftNode(forX, ConstantNode.forInt(CodeUtil.log2(lowerBitValue)));
- return AddNode.create(left, right);
- } else {
- // e.g., 0b1111_1101
- int shiftToRoundUpToPowerOf2 = CodeUtil.log2(highestBitValue) + 1;
- long subValue = (1 << shiftToRoundUpToPowerOf2) - i;
- if (CodeUtil.isPowerOf2(subValue) && shiftToRoundUpToPowerOf2 < ((IntegerStamp) stamp).getBits()) {
- assert CodeUtil.log2(subValue) >= 1;
- ValueNode left = new LeftShiftNode(forX, ConstantNode.forInt(shiftToRoundUpToPowerOf2));
- ValueNode right = new LeftShiftNode(forX, ConstantNode.forInt(CodeUtil.log2(subValue)));
- return SubNode.create(left, right);
- }
- }
- }
- } else if (i < 0) {
- if (CodeUtil.isPowerOf2(-i)) {
- return NegateNode.create(LeftShiftNode.create(forX, ConstantNode.forInt(CodeUtil.log2(-i))));
- }
+ ValueNode result = canonical(stamp, forX, i, view);
+ if (result != null) {
+ return result;
}
}
if (op.isAssociative()) {
// canonicalize expressions like "(a * 1) * 2"
- return reassociate(self != null ? self : (MulNode) new MulNode(forX, forY).maybeCommuteInputs(), ValueNode.isConstantPredicate(), forX, forY);
+ return reassociate(self != null ? self : (MulNode) new MulNode(forX, forY).maybeCommuteInputs(), ValueNode.isConstantPredicate(), forX, forY, view);
}
}
return self != null ? self : new MulNode(forX, forY).maybeCommuteInputs();
}
+ public static ValueNode canonical(Stamp stamp, ValueNode forX, long i, NodeView view) {
+ if (i == 0) {
+ return ConstantNode.forIntegerStamp(stamp, 0);
+ } else if (i == 1) {
+ return forX;
+ } else if (i == -1) {
+ return NegateNode.create(forX, view);
+ } else if (i > 0) {
+ if (CodeUtil.isPowerOf2(i)) {
+ return new LeftShiftNode(forX, ConstantNode.forInt(CodeUtil.log2(i)));
+ } else if (CodeUtil.isPowerOf2(i - 1)) {
+ return AddNode.create(new LeftShiftNode(forX, ConstantNode.forInt(CodeUtil.log2(i - 1))), forX, view);
+ } else if (CodeUtil.isPowerOf2(i + 1)) {
+ return SubNode.create(new LeftShiftNode(forX, ConstantNode.forInt(CodeUtil.log2(i + 1))), forX, view);
+ } else {
+ int bitCount = Long.bitCount(i);
+ long highestBitValue = Long.highestOneBit(i);
+ if (bitCount == 2) {
+ // e.g., 0b1000_0010
+ long lowerBitValue = i - highestBitValue;
+ assert highestBitValue > 0 && lowerBitValue > 0;
+ ValueNode left = new LeftShiftNode(forX, ConstantNode.forInt(CodeUtil.log2(highestBitValue)));
+ ValueNode right = lowerBitValue == 1 ? forX : new LeftShiftNode(forX, ConstantNode.forInt(CodeUtil.log2(lowerBitValue)));
+ return AddNode.create(left, right, view);
+ } else {
+ // e.g., 0b1111_1101
+ int shiftToRoundUpToPowerOf2 = CodeUtil.log2(highestBitValue) + 1;
+ long subValue = (1 << shiftToRoundUpToPowerOf2) - i;
+ if (CodeUtil.isPowerOf2(subValue) && shiftToRoundUpToPowerOf2 < ((IntegerStamp) stamp).getBits()) {
+ assert CodeUtil.log2(subValue) >= 1;
+ ValueNode left = new LeftShiftNode(forX, ConstantNode.forInt(shiftToRoundUpToPowerOf2));
+ ValueNode right = new LeftShiftNode(forX, ConstantNode.forInt(CodeUtil.log2(subValue)));
+ return SubNode.create(left, right, view);
+ }
+ }
+ }
+ } else if (i < 0) {
+ if (CodeUtil.isPowerOf2(-i)) {
+ return NegateNode.create(LeftShiftNode.create(forX, ConstantNode.forInt(CodeUtil.log2(-i)), view), view);
+ }
+ }
+ return null;
+ }
+
@Override
public void generate(NodeLIRBuilderTool nodeValueMap, ArithmeticLIRGeneratorTool gen) {
Value op1 = nodeValueMap.operand(getX());