1 /*
   2  * Copyright (c) 2009, 2019, Oracle and/or its affiliates. All rights reserved.
   3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
   4  *
   5  * This code is free software; you can redistribute it and/or modify it
   6  * under the terms of the GNU General Public License version 2 only, as
   7  * published by the Free Software Foundation.
   8  *
   9  * This code is distributed in the hope that it will be useful, but WITHOUT
  10  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
  11  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
  12  * version 2 for more details (a copy is included in the LICENSE file that
  13  * accompanied this code).
  14  *
  15  * You should have received a copy of the GNU General Public License version
  16  * 2 along with this work; if not, write to the Free Software Foundation,
  17  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
  18  *
  19  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
  20  * or visit www.oracle.com if you need additional information or have any
  21  * questions.
  22  */
  23 
  24 
  25 package org.graalvm.compiler.nodes.calc;
  26 
  27 import static org.graalvm.compiler.nodeinfo.NodeCycles.CYCLES_1;
  28 import static org.graalvm.compiler.nodeinfo.NodeSize.SIZE_2;
  29 import static org.graalvm.compiler.nodes.calc.CompareNode.createCompareNode;
  30 
  31 import org.graalvm.compiler.core.common.calc.CanonicalCondition;
  32 import org.graalvm.compiler.core.common.type.IntegerStamp;
  33 import org.graalvm.compiler.core.common.type.Stamp;
  34 import org.graalvm.compiler.core.common.type.StampFactory;
  35 import org.graalvm.compiler.graph.NodeClass;
  36 import org.graalvm.compiler.graph.spi.Canonicalizable;
  37 import org.graalvm.compiler.graph.spi.CanonicalizerTool;
  38 import org.graalvm.compiler.nodeinfo.InputType;
  39 import org.graalvm.compiler.nodeinfo.NodeInfo;
  40 import org.graalvm.compiler.nodes.ConstantNode;
  41 import org.graalvm.compiler.nodes.LogicConstantNode;
  42 import org.graalvm.compiler.nodes.LogicNegationNode;
  43 import org.graalvm.compiler.nodes.LogicNode;
  44 import org.graalvm.compiler.nodes.NodeView;
  45 import org.graalvm.compiler.nodes.StructuredGraph;
  46 import org.graalvm.compiler.nodes.ValueNode;
  47 import org.graalvm.compiler.nodes.spi.LIRLowerable;
  48 import org.graalvm.compiler.nodes.spi.NodeLIRBuilderTool;
  49 
  50 import jdk.vm.ci.meta.JavaConstant;
  51 
  52 /**
  53  * The {@code ConditionalNode} class represents a comparison that yields one of two (eagerly
  54  * evaluated) values.
  55  */
  56 @NodeInfo(cycles = CYCLES_1, size = SIZE_2)
  57 public final class ConditionalNode extends FloatingNode implements Canonicalizable, LIRLowerable {
  58 
  59     public static final NodeClass<ConditionalNode> TYPE = NodeClass.create(ConditionalNode.class);
  60     @Input(InputType.Condition) LogicNode condition;
  61     @Input(InputType.Value) ValueNode trueValue;
  62     @Input(InputType.Value) ValueNode falseValue;
  63 
  64     public LogicNode condition() {
  65         return condition;
  66     }
  67 
  68     public ConditionalNode(LogicNode condition) {
  69         this(condition, ConstantNode.forInt(1, condition.graph()), ConstantNode.forInt(0, condition.graph()));
  70     }
  71 
  72     public ConditionalNode(LogicNode condition, ValueNode trueValue, ValueNode falseValue) {
  73         super(TYPE, trueValue.stamp(NodeView.DEFAULT).meet(falseValue.stamp(NodeView.DEFAULT)));
  74         assert trueValue.stamp(NodeView.DEFAULT).isCompatible(falseValue.stamp(NodeView.DEFAULT));
  75         this.condition = condition;
  76         this.trueValue = trueValue;
  77         this.falseValue = falseValue;
  78     }
  79 
  80     public static ValueNode create(LogicNode condition, NodeView view) {
  81         return create(condition, ConstantNode.forInt(1, condition.graph()), ConstantNode.forInt(0, condition.graph()), view);
  82     }
  83 
  84     public static ValueNode create(LogicNode condition, ValueNode trueValue, ValueNode falseValue, NodeView view) {
  85         ValueNode synonym = findSynonym(condition, trueValue, falseValue, view);
  86         if (synonym != null) {
  87             return synonym;
  88         }
  89         ValueNode result = canonicalizeConditional(condition, trueValue, falseValue, trueValue.stamp(view).meet(falseValue.stamp(view)), view);
  90         if (result != null) {
  91             return result;
  92         }
  93         return new ConditionalNode(condition, trueValue, falseValue);
  94     }
  95 
  96     @Override
  97     public boolean inferStamp() {
  98         Stamp valueStamp = trueValue.stamp(NodeView.DEFAULT).meet(falseValue.stamp(NodeView.DEFAULT));
  99         if (condition instanceof IntegerLessThanNode) {
 100             IntegerLessThanNode lessThan = (IntegerLessThanNode) condition;
 101             if (lessThan.getX() == trueValue && lessThan.getY() == falseValue) {
 102                 // this encodes a min operation
 103                 JavaConstant constant = lessThan.getX().asJavaConstant();
 104                 if (constant == null) {
 105                     constant = lessThan.getY().asJavaConstant();
 106                 }
 107                 if (constant != null) {
 108                     IntegerStamp bounds = StampFactory.forInteger(constant.getJavaKind(), constant.getJavaKind().getMinValue(), constant.asLong());
 109                     valueStamp = valueStamp.join(bounds);
 110                 }
 111             } else if (lessThan.getX() == falseValue && lessThan.getY() == trueValue) {
 112                 // this encodes a max operation
 113                 JavaConstant constant = lessThan.getX().asJavaConstant();
 114                 if (constant == null) {
 115                     constant = lessThan.getY().asJavaConstant();
 116                 }
 117                 if (constant != null) {
 118                     IntegerStamp bounds = StampFactory.forInteger(constant.getJavaKind(), constant.asLong(), constant.getJavaKind().getMaxValue());
 119                     valueStamp = valueStamp.join(bounds);
 120                 }
 121             }
 122         }
 123         return updateStamp(valueStamp);
 124     }
 125 
 126     public ValueNode trueValue() {
 127         return trueValue;
 128     }
 129 
 130     public ValueNode falseValue() {
 131         return falseValue;
 132     }
 133 
 134     @Override
 135     public ValueNode canonical(CanonicalizerTool tool) {
 136         NodeView view = NodeView.from(tool);
 137         ValueNode synonym = findSynonym(condition, trueValue(), falseValue(), view);
 138         if (synonym != null) {
 139             return synonym;
 140         }
 141 
 142         ValueNode result = canonicalizeConditional(condition, trueValue(), falseValue(), stamp, view);
 143         if (result != null) {
 144             return result;
 145         }
 146 
 147         return this;
 148     }
 149 
 150     public static ValueNode canonicalizeConditional(LogicNode condition, ValueNode trueValue, ValueNode falseValue, Stamp stamp, NodeView view) {
 151         if (trueValue == falseValue) {
 152             return trueValue;
 153         }
 154 
 155         if (condition instanceof CompareNode && ((CompareNode) condition).isIdentityComparison()) {
 156             // optimize the pattern (x == y) ? x : y
 157             CompareNode compare = (CompareNode) condition;
 158             if ((compare.getX() == trueValue && compare.getY() == falseValue) || (compare.getX() == falseValue && compare.getY() == trueValue)) {
 159                 return falseValue;
 160             }
 161         }
 162 
 163         if (trueValue.stamp(view) instanceof IntegerStamp) {
 164             // check if the conditional is redundant
 165             if (condition instanceof IntegerLessThanNode) {
 166                 IntegerLessThanNode lessThan = (IntegerLessThanNode) condition;
 167                 IntegerStamp falseValueStamp = (IntegerStamp) falseValue.stamp(view);
 168                 IntegerStamp trueValueStamp = (IntegerStamp) trueValue.stamp(view);
 169                 if (lessThan.getX() == trueValue && lessThan.getY() == falseValue) {
 170                     // return "x" for "x < y ? x : y" in case that we know "x <= y"
 171                     if (trueValueStamp.upperBound() <= falseValueStamp.lowerBound()) {
 172                         return trueValue;
 173                     }
 174                 } else if (lessThan.getX() == falseValue && lessThan.getY() == trueValue) {
 175                     // return "y" for "x < y ? y : x" in case that we know "x <= y"
 176                     if (falseValueStamp.upperBound() <= trueValueStamp.lowerBound()) {
 177                         return trueValue;
 178                     }
 179                 }
 180             }
 181 
 182             // this optimizes the case where a value from the range 0 - 1 is mapped to the
 183             // range 0 - 1
 184             if (trueValue.isConstant() && falseValue.isConstant()) {
 185                 long constTrueValue = trueValue.asJavaConstant().asLong();
 186                 long constFalseValue = falseValue.asJavaConstant().asLong();
 187                 if (condition instanceof IntegerEqualsNode) {
 188                     IntegerEqualsNode equals = (IntegerEqualsNode) condition;
 189                     if (equals.getY().isConstant() && equals.getX().stamp(view) instanceof IntegerStamp) {
 190                         IntegerStamp equalsXStamp = (IntegerStamp) equals.getX().stamp(view);
 191                         if (equalsXStamp.upMask() == 1) {
 192                             long equalsY = equals.getY().asJavaConstant().asLong();
 193                             if (equalsY == 0) {
 194                                 if (constTrueValue == 0 && constFalseValue == 1) {
 195                                     // return x when: x == 0 ? 0 : 1;
 196                                     return IntegerConvertNode.convertUnsigned(equals.getX(), stamp, view);
 197                                 } else if (constTrueValue == 1 && constFalseValue == 0) {
 198                                     // negate a boolean value via xor
 199                                     return IntegerConvertNode.convertUnsigned(XorNode.create(equals.getX(), ConstantNode.forIntegerStamp(equals.getX().stamp(view), 1), view), stamp, view);
 200                                 }
 201                             } else if (equalsY == 1) {
 202                                 if (constTrueValue == 1 && constFalseValue == 0) {
 203                                     // return x when: x == 1 ? 1 : 0;
 204                                     return IntegerConvertNode.convertUnsigned(equals.getX(), stamp, view);
 205                                 } else if (constTrueValue == 0 && constFalseValue == 1) {
 206                                     // negate a boolean value via xor
 207                                     return IntegerConvertNode.convertUnsigned(XorNode.create(equals.getX(), ConstantNode.forIntegerStamp(equals.getX().stamp(view), 1), view), stamp, view);
 208                                 }
 209                             }
 210                         }
 211                     }
 212                 } else if (condition instanceof IntegerTestNode) {
 213                     // replace IntegerTestNode with AndNode for the following patterns:
 214                     // (value & 1) == 0 ? 0 : 1
 215                     // (value & 1) == 1 ? 1 : 0
 216                     IntegerTestNode integerTestNode = (IntegerTestNode) condition;
 217                     if (integerTestNode.getY().isConstant() && integerTestNode.getX().stamp(view) instanceof IntegerStamp) {
 218                         long testY = integerTestNode.getY().asJavaConstant().asLong();
 219                         if (testY == 1 && constTrueValue == 0 && constFalseValue == 1) {
 220                             return IntegerConvertNode.convertUnsigned(AndNode.create(integerTestNode.getX(), integerTestNode.getY(), view), stamp, view);
 221                         }
 222                     }
 223                 }
 224             }
 225 
 226             if (condition instanceof IntegerLessThanNode) {
 227                 /*
 228                  * Convert a conditional add ((x < 0) ? (x + y) : x) into (x + (y & (x >> (bits -
 229                  * 1)))) to avoid the test.
 230                  */
 231                 IntegerLessThanNode lt = (IntegerLessThanNode) condition;
 232                 if (lt.getY().isDefaultConstant()) {
 233                     if (falseValue == lt.getX()) {
 234                         if (trueValue instanceof AddNode) {
 235                             AddNode add = (AddNode) trueValue;
 236                             if (add.getX() == falseValue) {
 237                                 int bits = ((IntegerStamp) trueValue.stamp(NodeView.DEFAULT)).getBits();
 238                                 ValueNode shift = new RightShiftNode(lt.getX(), ConstantNode.forIntegerBits(32, bits - 1));
 239                                 ValueNode and = new AndNode(shift, add.getY());
 240                                 return new AddNode(add.getX(), and);
 241                             }
 242                         }
 243                     }
 244                 }
 245             }
 246         }
 247 
 248         return null;
 249     }
 250 
 251     private static ValueNode findSynonym(ValueNode condition, ValueNode trueValue, ValueNode falseValue, NodeView view) {
 252         if (condition instanceof LogicNegationNode) {
 253             LogicNegationNode negated = (LogicNegationNode) condition;
 254             return ConditionalNode.create(negated.getValue(), falseValue, trueValue, view);
 255         }
 256         if (condition instanceof LogicConstantNode) {
 257             LogicConstantNode c = (LogicConstantNode) condition;
 258             if (c.getValue()) {
 259                 return trueValue;
 260             } else {
 261                 return falseValue;
 262             }
 263         }
 264         return null;
 265     }
 266 
 267     @Override
 268     public void generate(NodeLIRBuilderTool generator) {
 269         generator.emitConditional(this);
 270     }
 271 
 272     public ConditionalNode(StructuredGraph graph, CanonicalCondition condition, ValueNode x, ValueNode y) {
 273         this(createCompareNode(graph, condition, x, y, null, NodeView.DEFAULT));
 274     }
 275 }