< prev index next >

src/jdk.internal.vm.compiler/share/classes/org.graalvm.compiler.nodes/src/org/graalvm/compiler/nodes/calc/IntegerLowerThanNode.java

Print this page
rev 52509 : [mq]: graal

@@ -1,7 +1,7 @@
 /*
- * Copyright (c) 2017, 2017, Oracle and/or its affiliates. All rights reserved.
+ * Copyright (c) 2017, 2018, Oracle and/or its affiliates. All rights reserved.
  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
  *
  * This code is free software; you can redistribute it and/or modify it
  * under the terms of the GNU General Public License version 2 only, as
  * published by the Free Software Foundation.

@@ -38,11 +38,13 @@
 import org.graalvm.compiler.nodes.NodeView;
 import org.graalvm.compiler.nodes.ValueNode;
 import org.graalvm.compiler.nodes.util.GraphUtil;
 import org.graalvm.compiler.options.OptionValues;
 
+import jdk.vm.ci.code.CodeUtil;
 import jdk.vm.ci.meta.ConstantReflectionProvider;
+import jdk.vm.ci.meta.JavaConstant;
 import jdk.vm.ci.meta.MetaAccessProvider;
 import jdk.vm.ci.meta.TriState;
 
 /**
  * Common super-class for "a < b" comparisons both {@linkplain IntegerLowerThanNode signed} and

@@ -87,11 +89,11 @@
             } else if (addNode.getY() == forX && addNode.getX().stamp(NodeView.DEFAULT) instanceof IntegerStamp) {
                 // x < a + x
                 aStamp = (IntegerStamp) addNode.getX().stamp(NodeView.DEFAULT);
             }
             if (aStamp != null) {
-                IntegerStamp result = getOp().getSucceedingStampForXLowerXPlusA(mirror, strict, aStamp);
+                IntegerStamp result = getOp().getSucceedingStampForXLowerXPlusA(mirror, strict, aStamp, xStamp);
                 result = (IntegerStamp) xStamp.tryImproveWith(result);
                 if (result != null) {
                     if (s != null) {
                         s = s.improveWith(result);
                     } else {

@@ -183,18 +185,20 @@
 
         protected LogicNode findSynonym(ValueNode forX, ValueNode forY, NodeView view) {
             if (GraphUtil.unproxify(forX) == GraphUtil.unproxify(forY)) {
                 return LogicConstantNode.contradiction();
             }
-            TriState fold = tryFold(forX.stamp(view), forY.stamp(view));
+            Stamp xStampGeneric = forX.stamp(view);
+            TriState fold = tryFold(xStampGeneric, forY.stamp(view));
             if (fold.isTrue()) {
                 return LogicConstantNode.tautology();
             } else if (fold.isFalse()) {
                 return LogicConstantNode.contradiction();
             }
             if (forY.stamp(view) instanceof IntegerStamp) {
                 IntegerStamp yStamp = (IntegerStamp) forY.stamp(view);
+                IntegerStamp xStamp = (IntegerStamp) xStampGeneric;
                 int bits = yStamp.getBits();
                 if (forX.isJavaConstant() && !forY.isConstant()) {
                     // bring the constant on the right
                     long xValue = forX.asJavaConstant().asLong();
                     if (xValue != maxValue(bits)) {

@@ -202,18 +206,27 @@
                         return LogicNegationNode.create(create(forY, ConstantNode.forIntegerStamp(yStamp, xValue + 1), view));
                     }
                 }
                 if (forY.isJavaConstant()) {
                     long yValue = forY.asJavaConstant().asLong();
-                    if (yValue == maxValue(bits)) {
+
                         // x < MAX <=> x != MAX
+                    if (yValue == maxValue(bits)) {
                         return LogicNegationNode.create(IntegerEqualsNode.create(forX, forY, view));
                     }
-                    if (yValue == minValue(bits) + 1) {
+
                         // x < MIN + 1 <=> x <= MIN <=> x == MIN
+                    if (yValue == minValue(bits) + 1) {
                         return IntegerEqualsNode.create(forX, ConstantNode.forIntegerStamp(yStamp, minValue(bits)), view);
                     }
+
+                    // (x < c && x >= c - 1) => x == c - 1
+                    // If the constant is negative, only signed comparison is allowed.
+                    if (yValue != minValue(bits) && xStamp.lowerBound() == yValue - 1 && (yValue > 0 || getCondition() == CanonicalCondition.LT)) {
+                        return IntegerEqualsNode.create(forX, ConstantNode.forIntegerStamp(yStamp, yValue - 1), view);
+                    }
+
                 } else if (forY instanceof AddNode) {
                     AddNode addNode = (AddNode) forY;
                     LogicNode canonical = canonicalizeXLowerXPlusA(forX, addNode, false, true, view);
                     if (canonical != null) {
                         return canonical;

@@ -228,31 +241,96 @@
                 }
             }
             return null;
         }
 
+        /**
+         * Exploit the fact that adding the (signed) MIN_VALUE on both side flips signed and
+         * unsigned comparison.
+         *
+         * In particular:
+         * <ul>
+         * <li>{@code x + MIN_VALUE < y + MIN_VALUE <=> x |<| y}</li>
+         * <li>{@code x + MIN_VALUE |<| y + MIN_VALUE <=> x < y}</li>
+         * </ul>
+         */
+        protected static LogicNode canonicalizeRangeFlip(ValueNode forX, ValueNode forY, int bits, boolean signed, NodeView view) {
+            long min = CodeUtil.minValue(bits);
+            long xResidue = 0;
+            ValueNode left = null;
+            JavaConstant leftCst = null;
+            if (forX instanceof AddNode) {
+                AddNode xAdd = (AddNode) forX;
+                if (xAdd.getY().isJavaConstant() && !xAdd.getY().asJavaConstant().isDefaultForKind()) {
+                    long xCst = xAdd.getY().asJavaConstant().asLong();
+                    xResidue = xCst - min;
+                    left = xAdd.getX();
+                }
+            } else if (forX.isJavaConstant()) {
+                leftCst = forX.asJavaConstant();
+            }
+            if (left == null && leftCst == null) {
+                return null;
+            }
+            long yResidue = 0;
+            ValueNode right = null;
+            JavaConstant rightCst = null;
+            if (forY instanceof AddNode) {
+                AddNode yAdd = (AddNode) forY;
+                if (yAdd.getY().isJavaConstant() && !yAdd.getY().asJavaConstant().isDefaultForKind()) {
+                    long yCst = yAdd.getY().asJavaConstant().asLong();
+                    yResidue = yCst - min;
+                    right = yAdd.getX();
+                }
+            } else if (forY.isJavaConstant()) {
+                rightCst = forY.asJavaConstant();
+            }
+            if (right == null && rightCst == null) {
+                return null;
+            }
+            if ((xResidue == 0 && left != null) || (yResidue == 0 && right != null)) {
+                if (left == null) {
+                    left = ConstantNode.forIntegerBits(bits, leftCst.asLong() - min);
+                } else if (xResidue != 0) {
+                    left = AddNode.create(left, ConstantNode.forIntegerBits(bits, xResidue), view);
+                }
+                if (right == null) {
+                    right = ConstantNode.forIntegerBits(bits, rightCst.asLong() - min);
+                } else if (yResidue != 0) {
+                    right = AddNode.create(right, ConstantNode.forIntegerBits(bits, yResidue), view);
+                }
+                if (signed) {
+                    return new IntegerBelowNode(left, right);
+                } else {
+                    return new IntegerLessThanNode(left, right);
+                }
+            }
+            return null;
+        }
+
         private LogicNode canonicalizeXLowerXPlusA(ValueNode forX, AddNode addNode, boolean mirrored, boolean strict, NodeView view) {
             // x < x + a
+            // x |<| x + a
+            IntegerStamp xStamp = (IntegerStamp) forX.stamp(view);
             IntegerStamp succeedingXStamp;
             boolean exact;
             if (addNode.getX() == forX && addNode.getY().stamp(view) instanceof IntegerStamp) {
                 IntegerStamp aStamp = (IntegerStamp) addNode.getY().stamp(view);
-                succeedingXStamp = getSucceedingStampForXLowerXPlusA(mirrored, strict, aStamp);
+                succeedingXStamp = getSucceedingStampForXLowerXPlusA(mirrored, strict, aStamp, xStamp);
                 exact = aStamp.lowerBound() == aStamp.upperBound();
             } else if (addNode.getY() == forX && addNode.getX().stamp(view) instanceof IntegerStamp) {
                 IntegerStamp aStamp = (IntegerStamp) addNode.getX().stamp(view);
-                succeedingXStamp = getSucceedingStampForXLowerXPlusA(mirrored, strict, aStamp);
+                succeedingXStamp = getSucceedingStampForXLowerXPlusA(mirrored, strict, aStamp, xStamp);
                 exact = aStamp.lowerBound() == aStamp.upperBound();
             } else {
                 return null;
             }
             if (succeedingXStamp.join(forX.stamp(view)).isEmpty()) {
                 return LogicConstantNode.contradiction();
             } else if (exact && !succeedingXStamp.isEmpty()) {
                 int bits = succeedingXStamp.getBits();
                 if (compare(lowerBound(succeedingXStamp), minValue(bits)) > 0) {
-                    assert upperBound(succeedingXStamp) == maxValue(bits);
                     // x must be in [L..MAX] <=> x >= L <=> !(x < L)
                     return LogicNegationNode.create(create(forX, ConstantNode.forIntegerStamp(succeedingXStamp, lowerBound(succeedingXStamp)), view));
                 } else if (compare(upperBound(succeedingXStamp), maxValue(bits)) < 0) {
                     // x must be in [MIN..H] <=> x <= H <=> !(H < x)
                     return LogicNegationNode.create(create(ConstantNode.forIntegerStamp(succeedingXStamp, upperBound(succeedingXStamp)), forX, view));

@@ -303,14 +381,15 @@
                 }
             }
             return null;
         }
 
-        protected IntegerStamp getSucceedingStampForXLowerXPlusA(boolean mirrored, boolean strict, IntegerStamp a) {
-            int bits = a.getBits();
+        protected IntegerStamp getSucceedingStampForXLowerXPlusA(boolean mirrored, boolean strict, IntegerStamp aStamp, IntegerStamp xStamp) {
+            int bits = aStamp.getBits();
             long min = minValue(bits);
             long max = maxValue(bits);
+
             /*
              * if x < x + a <=> x + a didn't overflow:
              *
              * x is outside ]MAX - a, MAX], i.e., inside [MIN, MAX - a]
              *

@@ -322,33 +401,33 @@
              *
              * This does not use upper/lowerBound from LowerOp because it's about the (signed)
              * addition not the comparison.
              */
             if (mirrored) {
-                if (a.contains(0)) {
+                if (aStamp.contains(0)) {
                     // a may be zero
-                    return a.unrestricted();
+                    return aStamp.unrestricted();
                 }
-                return forInteger(bits, min(max - a.lowerBound() + 1, max - a.upperBound() + 1, bits), max);
+                return forInteger(bits, min(max - aStamp.lowerBound() + 1, max - aStamp.upperBound() + 1, bits), min(max, upperBound(xStamp)));
             } else {
-                long aLower = a.lowerBound();
-                long aUpper = a.upperBound();
+                long aLower = aStamp.lowerBound();
+                long aUpper = aStamp.upperBound();
                 if (strict) {
                     if (aLower == 0) {
                         aLower = 1;
                     }
                     if (aUpper == 0) {
                         aUpper = -1;
                     }
                     if (aLower > aUpper) {
                         // impossible
-                        return a.empty();
+                        return aStamp.empty();
                     }
                 }
                 if (aLower < 0 && aUpper > 0) {
                     // a may be zero
-                    return a.unrestricted();
+                    return aStamp.unrestricted();
                 }
                 return forInteger(bits, min, max(max - aLower, max - aUpper, bits));
             }
         }
     }
< prev index next >