src/jdk.internal.vm.compiler/share/classes/org.graalvm.compiler.nodes/src/org/graalvm/compiler/nodes/calc/IntegerLowerThanNode.java
changeset 52578 7dd81e82d083
parent 50858 2d3e99a72541
child 54601 c40b2a190173
--- a/src/jdk.internal.vm.compiler/share/classes/org.graalvm.compiler.nodes/src/org/graalvm/compiler/nodes/calc/IntegerLowerThanNode.java	Thu Nov 15 21:05:47 2018 +0100
+++ b/src/jdk.internal.vm.compiler/share/classes/org.graalvm.compiler.nodes/src/org/graalvm/compiler/nodes/calc/IntegerLowerThanNode.java	Thu Nov 15 09:04:07 2018 -0800
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2017, 2017, Oracle and/or its affiliates. All rights reserved.
+ * Copyright (c) 2017, 2018, Oracle and/or its affiliates. All rights reserved.
  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
  *
  * This code is free software; you can redistribute it and/or modify it
@@ -40,7 +40,9 @@
 import org.graalvm.compiler.nodes.util.GraphUtil;
 import org.graalvm.compiler.options.OptionValues;
 
+import jdk.vm.ci.code.CodeUtil;
 import jdk.vm.ci.meta.ConstantReflectionProvider;
+import jdk.vm.ci.meta.JavaConstant;
 import jdk.vm.ci.meta.MetaAccessProvider;
 import jdk.vm.ci.meta.TriState;
 
@@ -89,7 +91,7 @@
                 aStamp = (IntegerStamp) addNode.getX().stamp(NodeView.DEFAULT);
             }
             if (aStamp != null) {
-                IntegerStamp result = getOp().getSucceedingStampForXLowerXPlusA(mirror, strict, aStamp);
+                IntegerStamp result = getOp().getSucceedingStampForXLowerXPlusA(mirror, strict, aStamp, xStamp);
                 result = (IntegerStamp) xStamp.tryImproveWith(result);
                 if (result != null) {
                     if (s != null) {
@@ -185,7 +187,8 @@
             if (GraphUtil.unproxify(forX) == GraphUtil.unproxify(forY)) {
                 return LogicConstantNode.contradiction();
             }
-            TriState fold = tryFold(forX.stamp(view), forY.stamp(view));
+            Stamp xStampGeneric = forX.stamp(view);
+            TriState fold = tryFold(xStampGeneric, forY.stamp(view));
             if (fold.isTrue()) {
                 return LogicConstantNode.tautology();
             } else if (fold.isFalse()) {
@@ -193,6 +196,7 @@
             }
             if (forY.stamp(view) instanceof IntegerStamp) {
                 IntegerStamp yStamp = (IntegerStamp) forY.stamp(view);
+                IntegerStamp xStamp = (IntegerStamp) xStampGeneric;
                 int bits = yStamp.getBits();
                 if (forX.isJavaConstant() && !forY.isConstant()) {
                     // bring the constant on the right
@@ -204,14 +208,23 @@
                 }
                 if (forY.isJavaConstant()) {
                     long yValue = forY.asJavaConstant().asLong();
+
+                    // x < MAX <=> x != MAX
                     if (yValue == maxValue(bits)) {
-                        // x < MAX <=> x != MAX
                         return LogicNegationNode.create(IntegerEqualsNode.create(forX, forY, view));
                     }
+
+                    // x < MIN + 1 <=> x <= MIN <=> x == MIN
                     if (yValue == minValue(bits) + 1) {
-                        // x < MIN + 1 <=> x <= MIN <=> x == MIN
                         return IntegerEqualsNode.create(forX, ConstantNode.forIntegerStamp(yStamp, minValue(bits)), view);
                     }
+
+                    // (x < c && x >= c - 1) => x == c - 1
+                    // If the constant is negative, only signed comparison is allowed.
+                    if (yValue != minValue(bits) && xStamp.lowerBound() == yValue - 1 && (yValue > 0 || getCondition() == CanonicalCondition.LT)) {
+                        return IntegerEqualsNode.create(forX, ConstantNode.forIntegerStamp(yStamp, yValue - 1), view);
+                    }
+
                 } else if (forY instanceof AddNode) {
                     AddNode addNode = (AddNode) forY;
                     LogicNode canonical = canonicalizeXLowerXPlusA(forX, addNode, false, true, view);
@@ -230,17 +243,83 @@
             return null;
         }
 
+        /**
+         * Exploit the fact that adding the (signed) MIN_VALUE on both side flips signed and
+         * unsigned comparison.
+         *
+         * In particular:
+         * <ul>
+         * <li>{@code x + MIN_VALUE < y + MIN_VALUE <=> x |<| y}</li>
+         * <li>{@code x + MIN_VALUE |<| y + MIN_VALUE <=> x < y}</li>
+         * </ul>
+         */
+        protected static LogicNode canonicalizeRangeFlip(ValueNode forX, ValueNode forY, int bits, boolean signed, NodeView view) {
+            long min = CodeUtil.minValue(bits);
+            long xResidue = 0;
+            ValueNode left = null;
+            JavaConstant leftCst = null;
+            if (forX instanceof AddNode) {
+                AddNode xAdd = (AddNode) forX;
+                if (xAdd.getY().isJavaConstant() && !xAdd.getY().asJavaConstant().isDefaultForKind()) {
+                    long xCst = xAdd.getY().asJavaConstant().asLong();
+                    xResidue = xCst - min;
+                    left = xAdd.getX();
+                }
+            } else if (forX.isJavaConstant()) {
+                leftCst = forX.asJavaConstant();
+            }
+            if (left == null && leftCst == null) {
+                return null;
+            }
+            long yResidue = 0;
+            ValueNode right = null;
+            JavaConstant rightCst = null;
+            if (forY instanceof AddNode) {
+                AddNode yAdd = (AddNode) forY;
+                if (yAdd.getY().isJavaConstant() && !yAdd.getY().asJavaConstant().isDefaultForKind()) {
+                    long yCst = yAdd.getY().asJavaConstant().asLong();
+                    yResidue = yCst - min;
+                    right = yAdd.getX();
+                }
+            } else if (forY.isJavaConstant()) {
+                rightCst = forY.asJavaConstant();
+            }
+            if (right == null && rightCst == null) {
+                return null;
+            }
+            if ((xResidue == 0 && left != null) || (yResidue == 0 && right != null)) {
+                if (left == null) {
+                    left = ConstantNode.forIntegerBits(bits, leftCst.asLong() - min);
+                } else if (xResidue != 0) {
+                    left = AddNode.create(left, ConstantNode.forIntegerBits(bits, xResidue), view);
+                }
+                if (right == null) {
+                    right = ConstantNode.forIntegerBits(bits, rightCst.asLong() - min);
+                } else if (yResidue != 0) {
+                    right = AddNode.create(right, ConstantNode.forIntegerBits(bits, yResidue), view);
+                }
+                if (signed) {
+                    return new IntegerBelowNode(left, right);
+                } else {
+                    return new IntegerLessThanNode(left, right);
+                }
+            }
+            return null;
+        }
+
         private LogicNode canonicalizeXLowerXPlusA(ValueNode forX, AddNode addNode, boolean mirrored, boolean strict, NodeView view) {
             // x < x + a
+            // x |<| x + a
+            IntegerStamp xStamp = (IntegerStamp) forX.stamp(view);
             IntegerStamp succeedingXStamp;
             boolean exact;
             if (addNode.getX() == forX && addNode.getY().stamp(view) instanceof IntegerStamp) {
                 IntegerStamp aStamp = (IntegerStamp) addNode.getY().stamp(view);
-                succeedingXStamp = getSucceedingStampForXLowerXPlusA(mirrored, strict, aStamp);
+                succeedingXStamp = getSucceedingStampForXLowerXPlusA(mirrored, strict, aStamp, xStamp);
                 exact = aStamp.lowerBound() == aStamp.upperBound();
             } else if (addNode.getY() == forX && addNode.getX().stamp(view) instanceof IntegerStamp) {
                 IntegerStamp aStamp = (IntegerStamp) addNode.getX().stamp(view);
-                succeedingXStamp = getSucceedingStampForXLowerXPlusA(mirrored, strict, aStamp);
+                succeedingXStamp = getSucceedingStampForXLowerXPlusA(mirrored, strict, aStamp, xStamp);
                 exact = aStamp.lowerBound() == aStamp.upperBound();
             } else {
                 return null;
@@ -250,7 +329,6 @@
             } else if (exact && !succeedingXStamp.isEmpty()) {
                 int bits = succeedingXStamp.getBits();
                 if (compare(lowerBound(succeedingXStamp), minValue(bits)) > 0) {
-                    assert upperBound(succeedingXStamp) == maxValue(bits);
                     // x must be in [L..MAX] <=> x >= L <=> !(x < L)
                     return LogicNegationNode.create(create(forX, ConstantNode.forIntegerStamp(succeedingXStamp, lowerBound(succeedingXStamp)), view));
                 } else if (compare(upperBound(succeedingXStamp), maxValue(bits)) < 0) {
@@ -305,10 +383,11 @@
             return null;
         }
 
-        protected IntegerStamp getSucceedingStampForXLowerXPlusA(boolean mirrored, boolean strict, IntegerStamp a) {
-            int bits = a.getBits();
+        protected IntegerStamp getSucceedingStampForXLowerXPlusA(boolean mirrored, boolean strict, IntegerStamp aStamp, IntegerStamp xStamp) {
+            int bits = aStamp.getBits();
             long min = minValue(bits);
             long max = maxValue(bits);
+
             /*
              * if x < x + a <=> x + a didn't overflow:
              *
@@ -324,14 +403,14 @@
              * addition not the comparison.
              */
             if (mirrored) {
-                if (a.contains(0)) {
+                if (aStamp.contains(0)) {
                     // a may be zero
-                    return a.unrestricted();
+                    return aStamp.unrestricted();
                 }
-                return forInteger(bits, min(max - a.lowerBound() + 1, max - a.upperBound() + 1, bits), max);
+                return forInteger(bits, min(max - aStamp.lowerBound() + 1, max - aStamp.upperBound() + 1, bits), min(max, upperBound(xStamp)));
             } else {
-                long aLower = a.lowerBound();
-                long aUpper = a.upperBound();
+                long aLower = aStamp.lowerBound();
+                long aUpper = aStamp.upperBound();
                 if (strict) {
                     if (aLower == 0) {
                         aLower = 1;
@@ -341,12 +420,12 @@
                     }
                     if (aLower > aUpper) {
                         // impossible
-                        return a.empty();
+                        return aStamp.empty();
                     }
                 }
                 if (aLower < 0 && aUpper > 0) {
                     // a may be zero
-                    return a.unrestricted();
+                    return aStamp.unrestricted();
                 }
                 return forInteger(bits, min, max(max - aLower, max - aUpper, bits));
             }