1 /*
   2  * Copyright (c) 2017, 2018, Oracle and/or its affiliates. All rights reserved.
   3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
   4  *
   5  * This code is free software; you can redistribute it and/or modify it
   6  * under the terms of the GNU General Public License version 2 only, as
   7  * published by the Free Software Foundation.
   8  *
   9  * This code is distributed in the hope that it will be useful, but WITHOUT
  10  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
  11  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
  12  * version 2 for more details (a copy is included in the LICENSE file that
  13  * accompanied this code).
  14  *
  15  * You should have received a copy of the GNU General Public License version
  16  * 2 along with this work; if not, write to the Free Software Foundation,
  17  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
  18  *
  19  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
  20  * or visit www.oracle.com if you need additional information or have any
  21  * questions.
  22  */
  23 
  24 
  25 package org.graalvm.compiler.nodes.calc;
  26 
  27 import static jdk.vm.ci.code.CodeUtil.mask;
  28 
  29 import org.graalvm.compiler.core.common.calc.CanonicalCondition;
  30 import org.graalvm.compiler.core.common.type.IntegerStamp;
  31 import org.graalvm.compiler.core.common.type.Stamp;
  32 import org.graalvm.compiler.graph.NodeClass;
  33 import org.graalvm.compiler.nodeinfo.NodeInfo;
  34 import org.graalvm.compiler.nodes.ConstantNode;
  35 import org.graalvm.compiler.nodes.LogicConstantNode;
  36 import org.graalvm.compiler.nodes.LogicNegationNode;
  37 import org.graalvm.compiler.nodes.LogicNode;
  38 import org.graalvm.compiler.nodes.NodeView;
  39 import org.graalvm.compiler.nodes.ValueNode;
  40 import org.graalvm.compiler.nodes.util.GraphUtil;
  41 import org.graalvm.compiler.options.OptionValues;
  42 
  43 import jdk.vm.ci.code.CodeUtil;
  44 import jdk.vm.ci.meta.ConstantReflectionProvider;
  45 import jdk.vm.ci.meta.JavaConstant;
  46 import jdk.vm.ci.meta.MetaAccessProvider;
  47 import jdk.vm.ci.meta.TriState;
  48 
  49 /**
  50  * Common super-class for "a < b" comparisons both {@linkplain IntegerLowerThanNode signed} and
  51  * {@linkplain IntegerBelowNode unsigned}.
  52  */
  53 @NodeInfo()
  54 public abstract class IntegerLowerThanNode extends CompareNode {
  55     public static final NodeClass<IntegerLowerThanNode> TYPE = NodeClass.create(IntegerLowerThanNode.class);
  56     private final LowerOp op;
  57 
  58     protected IntegerLowerThanNode(NodeClass<? extends CompareNode> c, ValueNode x, ValueNode y, LowerOp op) {
  59         super(c, op.getCondition(), false, x, y);
  60         this.op = op;
  61     }
  62 
  63     protected LowerOp getOp() {
  64         return op;
  65     }
  66 
  67     @Override
  68     public Stamp getSucceedingStampForX(boolean negated, Stamp xStampGeneric, Stamp yStampGeneric) {
  69         return getSucceedingStampForX(negated, !negated, xStampGeneric, yStampGeneric, getX(), getY());
  70     }
  71 
  72     @Override
  73     public Stamp getSucceedingStampForY(boolean negated, Stamp xStampGeneric, Stamp yStampGeneric) {
  74         return getSucceedingStampForX(!negated, !negated, yStampGeneric, xStampGeneric, getY(), getX());
  75     }
  76 
  77     private Stamp getSucceedingStampForX(boolean mirror, boolean strict, Stamp xStampGeneric, Stamp yStampGeneric, ValueNode forX, ValueNode forY) {
  78         Stamp s = getSucceedingStampForX(mirror, strict, xStampGeneric, yStampGeneric);
  79         if (s != null && s.isUnrestricted()) {
  80             s = null;
  81         }
  82         if (forY instanceof AddNode && xStampGeneric instanceof IntegerStamp) {
  83             IntegerStamp xStamp = (IntegerStamp) xStampGeneric;
  84             AddNode addNode = (AddNode) forY;
  85             IntegerStamp aStamp = null;
  86             if (addNode.getX() == forX && addNode.getY().stamp(NodeView.DEFAULT) instanceof IntegerStamp) {
  87                 // x < x + a
  88                 aStamp = (IntegerStamp) addNode.getY().stamp(NodeView.DEFAULT);
  89             } else if (addNode.getY() == forX && addNode.getX().stamp(NodeView.DEFAULT) instanceof IntegerStamp) {
  90                 // x < a + x
  91                 aStamp = (IntegerStamp) addNode.getX().stamp(NodeView.DEFAULT);
  92             }
  93             if (aStamp != null) {
  94                 IntegerStamp result = getOp().getSucceedingStampForXLowerXPlusA(mirror, strict, aStamp, xStamp);
  95                 result = (IntegerStamp) xStamp.tryImproveWith(result);
  96                 if (result != null) {
  97                     if (s != null) {
  98                         s = s.improveWith(result);
  99                     } else {
 100                         s = result;
 101                     }
 102                 }
 103             }
 104         }
 105         return s;
 106     }
 107 
 108     private Stamp getSucceedingStampForX(boolean mirror, boolean strict, Stamp xStampGeneric, Stamp yStampGeneric) {
 109         if (xStampGeneric instanceof IntegerStamp) {
 110             IntegerStamp xStamp = (IntegerStamp) xStampGeneric;
 111             if (yStampGeneric instanceof IntegerStamp) {
 112                 IntegerStamp yStamp = (IntegerStamp) yStampGeneric;
 113                 assert yStamp.getBits() == xStamp.getBits();
 114                 Stamp s = getOp().getSucceedingStampForX(xStamp, yStamp, mirror, strict);
 115                 if (s != null) {
 116                     return s;
 117                 }
 118             }
 119         }
 120         return null;
 121     }
 122 
 123     @Override
 124     public TriState tryFold(Stamp xStampGeneric, Stamp yStampGeneric) {
 125         return getOp().tryFold(xStampGeneric, yStampGeneric);
 126     }
 127 
 128     public abstract static class LowerOp extends CompareOp {
 129         @Override
 130         public LogicNode canonical(ConstantReflectionProvider constantReflection, MetaAccessProvider metaAccess, OptionValues options, Integer smallestCompareWidth, CanonicalCondition condition,
 131                         boolean unorderedIsTrue, ValueNode forX, ValueNode forY, NodeView view) {
 132             LogicNode result = super.canonical(constantReflection, metaAccess, options, smallestCompareWidth, condition, unorderedIsTrue, forX, forY, view);
 133             if (result != null) {
 134                 return result;
 135             }
 136             LogicNode synonym = findSynonym(forX, forY, view);
 137             if (synonym != null) {
 138                 return synonym;
 139             }
 140             return null;
 141         }
 142 
 143         protected abstract long upperBound(IntegerStamp stamp);
 144 
 145         protected abstract long lowerBound(IntegerStamp stamp);
 146 
 147         protected abstract int compare(long a, long b);
 148 
 149         protected abstract long min(long a, long b);
 150 
 151         protected abstract long max(long a, long b);
 152 
 153         protected long min(long a, long b, int bits) {
 154             return min(cast(a, bits), cast(b, bits));
 155         }
 156 
 157         protected long max(long a, long b, int bits) {
 158             return max(cast(a, bits), cast(b, bits));
 159         }
 160 
 161         protected abstract long cast(long a, int bits);
 162 
 163         protected abstract long minValue(int bits);
 164 
 165         protected abstract long maxValue(int bits);
 166 
 167         protected abstract IntegerStamp forInteger(int bits, long min, long max);
 168 
 169         protected abstract CanonicalCondition getCondition();
 170 
 171         protected abstract IntegerLowerThanNode createNode(ValueNode x, ValueNode y);
 172 
 173         public LogicNode create(ValueNode x, ValueNode y, NodeView view) {
 174             LogicNode result = CompareNode.tryConstantFoldPrimitive(getCondition(), x, y, false, view);
 175             if (result != null) {
 176                 return result;
 177             } else {
 178                 result = findSynonym(x, y, view);
 179                 if (result != null) {
 180                     return result;
 181                 }
 182                 return createNode(x, y);
 183             }
 184         }
 185 
 186         protected LogicNode findSynonym(ValueNode forX, ValueNode forY, NodeView view) {
 187             if (GraphUtil.unproxify(forX) == GraphUtil.unproxify(forY)) {
 188                 return LogicConstantNode.contradiction();
 189             }
 190             Stamp xStampGeneric = forX.stamp(view);
 191             TriState fold = tryFold(xStampGeneric, forY.stamp(view));
 192             if (fold.isTrue()) {
 193                 return LogicConstantNode.tautology();
 194             } else if (fold.isFalse()) {
 195                 return LogicConstantNode.contradiction();
 196             }
 197             if (forY.stamp(view) instanceof IntegerStamp) {
 198                 IntegerStamp yStamp = (IntegerStamp) forY.stamp(view);
 199                 IntegerStamp xStamp = (IntegerStamp) xStampGeneric;
 200                 int bits = yStamp.getBits();
 201                 if (forX.isJavaConstant() && !forY.isConstant()) {
 202                     // bring the constant on the right
 203                     long xValue = forX.asJavaConstant().asLong();
 204                     if (xValue != maxValue(bits)) {
 205                         // c < x <=> !(c >= x) <=> !(x <= c) <=> !(x < c + 1)
 206                         return LogicNegationNode.create(create(forY, ConstantNode.forIntegerStamp(yStamp, xValue + 1), view));
 207                     }
 208                 }
 209                 if (forY.isJavaConstant()) {
 210                     long yValue = forY.asJavaConstant().asLong();
 211 
 212                     // x < MAX <=> x != MAX
 213                     if (yValue == maxValue(bits)) {
 214                         return LogicNegationNode.create(IntegerEqualsNode.create(forX, forY, view));
 215                     }
 216 
 217                     // x < MIN + 1 <=> x <= MIN <=> x == MIN
 218                     if (yValue == minValue(bits) + 1) {
 219                         return IntegerEqualsNode.create(forX, ConstantNode.forIntegerStamp(yStamp, minValue(bits)), view);
 220                     }
 221 
 222                     // (x < c && x >= c - 1) => x == c - 1
 223                     // If the constant is negative, only signed comparison is allowed.
 224                     if (yValue != minValue(bits) && xStamp.lowerBound() == yValue - 1 && (yValue > 0 || getCondition() == CanonicalCondition.LT)) {
 225                         return IntegerEqualsNode.create(forX, ConstantNode.forIntegerStamp(yStamp, yValue - 1), view);
 226                     }
 227 
 228                 } else if (forY instanceof AddNode) {
 229                     AddNode addNode = (AddNode) forY;
 230                     LogicNode canonical = canonicalizeXLowerXPlusA(forX, addNode, false, true, view);
 231                     if (canonical != null) {
 232                         return canonical;
 233                     }
 234                 }
 235                 if (forX instanceof AddNode) {
 236                     AddNode addNode = (AddNode) forX;
 237                     LogicNode canonical = canonicalizeXLowerXPlusA(forY, addNode, true, false, view);
 238                     if (canonical != null) {
 239                         return canonical;
 240                     }
 241                 }
 242             }
 243             return null;
 244         }
 245 
 246         /**
 247          * Exploit the fact that adding the (signed) MIN_VALUE on both side flips signed and
 248          * unsigned comparison.
 249          *
 250          * In particular:
 251          * <ul>
 252          * <li>{@code x + MIN_VALUE < y + MIN_VALUE <=> x |<| y}</li>
 253          * <li>{@code x + MIN_VALUE |<| y + MIN_VALUE <=> x < y}</li>
 254          * </ul>
 255          */
 256         protected static LogicNode canonicalizeRangeFlip(ValueNode forX, ValueNode forY, int bits, boolean signed, NodeView view) {
 257             long min = CodeUtil.minValue(bits);
 258             long xResidue = 0;
 259             ValueNode left = null;
 260             JavaConstant leftCst = null;
 261             if (forX instanceof AddNode) {
 262                 AddNode xAdd = (AddNode) forX;
 263                 if (xAdd.getY().isJavaConstant() && !xAdd.getY().asJavaConstant().isDefaultForKind()) {
 264                     long xCst = xAdd.getY().asJavaConstant().asLong();
 265                     xResidue = xCst - min;
 266                     left = xAdd.getX();
 267                 }
 268             } else if (forX.isJavaConstant()) {
 269                 leftCst = forX.asJavaConstant();
 270             }
 271             if (left == null && leftCst == null) {
 272                 return null;
 273             }
 274             long yResidue = 0;
 275             ValueNode right = null;
 276             JavaConstant rightCst = null;
 277             if (forY instanceof AddNode) {
 278                 AddNode yAdd = (AddNode) forY;
 279                 if (yAdd.getY().isJavaConstant() && !yAdd.getY().asJavaConstant().isDefaultForKind()) {
 280                     long yCst = yAdd.getY().asJavaConstant().asLong();
 281                     yResidue = yCst - min;
 282                     right = yAdd.getX();
 283                 }
 284             } else if (forY.isJavaConstant()) {
 285                 rightCst = forY.asJavaConstant();
 286             }
 287             if (right == null && rightCst == null) {
 288                 return null;
 289             }
 290             if ((xResidue == 0 && left != null) || (yResidue == 0 && right != null)) {
 291                 if (left == null) {
 292                     // Fortify: Suppress Null Dereference false positive
 293                     assert leftCst != null;
 294 
 295                     left = ConstantNode.forIntegerBits(bits, leftCst.asLong() - min);
 296                 } else if (xResidue != 0) {
 297                     left = AddNode.create(left, ConstantNode.forIntegerBits(bits, xResidue), view);
 298                 }
 299                 if (right == null) {
 300                     // Fortify: Suppress Null Dereference false positive
 301                     assert rightCst != null;
 302 
 303                     right = ConstantNode.forIntegerBits(bits, rightCst.asLong() - min);
 304                 } else if (yResidue != 0) {
 305                     right = AddNode.create(right, ConstantNode.forIntegerBits(bits, yResidue), view);
 306                 }
 307                 if (signed) {
 308                     return new IntegerBelowNode(left, right);
 309                 } else {
 310                     return new IntegerLessThanNode(left, right);
 311                 }
 312             }
 313             return null;
 314         }
 315 
 316         private LogicNode canonicalizeXLowerXPlusA(ValueNode forX, AddNode addNode, boolean mirrored, boolean strict, NodeView view) {
 317             // x < x + a
 318             // x |<| x + a
 319             IntegerStamp xStamp = (IntegerStamp) forX.stamp(view);
 320             IntegerStamp succeedingXStamp;
 321             boolean exact;
 322             if (addNode.getX() == forX && addNode.getY().stamp(view) instanceof IntegerStamp) {
 323                 IntegerStamp aStamp = (IntegerStamp) addNode.getY().stamp(view);
 324                 succeedingXStamp = getSucceedingStampForXLowerXPlusA(mirrored, strict, aStamp, xStamp);
 325                 exact = aStamp.lowerBound() == aStamp.upperBound();
 326             } else if (addNode.getY() == forX && addNode.getX().stamp(view) instanceof IntegerStamp) {
 327                 IntegerStamp aStamp = (IntegerStamp) addNode.getX().stamp(view);
 328                 succeedingXStamp = getSucceedingStampForXLowerXPlusA(mirrored, strict, aStamp, xStamp);
 329                 exact = aStamp.lowerBound() == aStamp.upperBound();
 330             } else {
 331                 return null;
 332             }
 333             if (succeedingXStamp.join(forX.stamp(view)).isEmpty()) {
 334                 return LogicConstantNode.contradiction();
 335             } else if (exact && !succeedingXStamp.isEmpty()) {
 336                 int bits = succeedingXStamp.getBits();
 337                 if (compare(lowerBound(succeedingXStamp), minValue(bits)) > 0) {
 338                     // x must be in [L..MAX] <=> x >= L <=> !(x < L)
 339                     return LogicNegationNode.create(create(forX, ConstantNode.forIntegerStamp(succeedingXStamp, lowerBound(succeedingXStamp)), view));
 340                 } else if (compare(upperBound(succeedingXStamp), maxValue(bits)) < 0) {
 341                     // x must be in [MIN..H] <=> x <= H <=> !(H < x)
 342                     return LogicNegationNode.create(create(ConstantNode.forIntegerStamp(succeedingXStamp, upperBound(succeedingXStamp)), forX, view));
 343                 }
 344             }
 345             return null;
 346         }
 347 
 348         protected TriState tryFold(Stamp xStampGeneric, Stamp yStampGeneric) {
 349             if (xStampGeneric instanceof IntegerStamp && yStampGeneric instanceof IntegerStamp) {
 350                 IntegerStamp xStamp = (IntegerStamp) xStampGeneric;
 351                 IntegerStamp yStamp = (IntegerStamp) yStampGeneric;
 352                 if (compare(upperBound(xStamp), lowerBound(yStamp)) < 0) {
 353                     return TriState.TRUE;
 354                 }
 355                 if (compare(lowerBound(xStamp), upperBound(yStamp)) >= 0) {
 356                     return TriState.FALSE;
 357                 }
 358             }
 359             return TriState.UNKNOWN;
 360         }
 361 
 362         protected IntegerStamp getSucceedingStampForX(IntegerStamp xStamp, IntegerStamp yStamp, boolean mirror, boolean strict) {
 363             int bits = xStamp.getBits();
 364             assert yStamp.getBits() == bits;
 365             if (mirror) {
 366                 long low = lowerBound(yStamp);
 367                 if (strict) {
 368                     if (low == maxValue(bits)) {
 369                         return null;
 370                     }
 371                     low += 1;
 372                 }
 373                 if (compare(low, lowerBound(xStamp)) > 0 || upperBound(xStamp) != (xStamp.upperBound() & mask(xStamp.getBits()))) {
 374                     return forInteger(bits, low, upperBound(xStamp));
 375                 }
 376             } else {
 377                 // x < y, i.e., x < y <= Y_UPPER_BOUND so x <= Y_UPPER_BOUND - 1
 378                 long low = upperBound(yStamp);
 379                 if (strict) {
 380                     if (low == minValue(bits)) {
 381                         return null;
 382                     }
 383                     low -= 1;
 384                 }
 385                 if (compare(low, upperBound(xStamp)) < 0 || lowerBound(xStamp) != (xStamp.lowerBound() & mask(xStamp.getBits()))) {
 386                     return forInteger(bits, lowerBound(xStamp), low);
 387                 }
 388             }
 389             return null;
 390         }
 391 
 392         protected IntegerStamp getSucceedingStampForXLowerXPlusA(boolean mirrored, boolean strict, IntegerStamp aStamp, IntegerStamp xStamp) {
 393             int bits = aStamp.getBits();
 394             long min = minValue(bits);
 395             long max = maxValue(bits);
 396 
 397             /*
 398              * if x < x + a <=> x + a didn't overflow:
 399              *
 400              * x is outside ]MAX - a, MAX], i.e., inside [MIN, MAX - a]
 401              *
 402              * if a is negative those bounds wrap around correctly.
 403              *
 404              * If a is exactly zero this gives an unbounded stamp (any integer) in the positive case
 405              * and an empty stamp in the negative case: if x |<| x is true, then either x has no
 406              * value or any value...
 407              *
 408              * This does not use upper/lowerBound from LowerOp because it's about the (signed)
 409              * addition not the comparison.
 410              */
 411             if (mirrored) {
 412                 if (aStamp.contains(0)) {
 413                     // a may be zero
 414                     return aStamp.unrestricted();
 415                 }
 416                 return forInteger(bits, min(max - aStamp.lowerBound() + 1, max - aStamp.upperBound() + 1, bits), min(max, upperBound(xStamp)));
 417             } else {
 418                 long aLower = aStamp.lowerBound();
 419                 long aUpper = aStamp.upperBound();
 420                 if (strict) {
 421                     if (aLower == 0) {
 422                         aLower = 1;
 423                     }
 424                     if (aUpper == 0) {
 425                         aUpper = -1;
 426                     }
 427                     if (aLower > aUpper) {
 428                         // impossible
 429                         return aStamp.empty();
 430                     }
 431                 }
 432                 if (aLower < 0 && aUpper > 0) {
 433                     // a may be zero
 434                     return aStamp.unrestricted();
 435                 }
 436                 return forInteger(bits, min, max(max - aLower, max - aUpper, bits));
 437             }
 438         }
 439     }
 440 }