--- old/src/jdk.internal.vm.compiler/share/classes/org.graalvm.compiler.nodes/src/org/graalvm/compiler/nodes/calc/ConditionalNode.java 2017-11-03 23:57:08.690442499 -0700 +++ new/src/jdk.internal.vm.compiler/share/classes/org.graalvm.compiler.nodes/src/org/graalvm/compiler/nodes/calc/ConditionalNode.java 2017-11-03 23:57:08.364427910 -0700 @@ -22,7 +22,7 @@ */ package org.graalvm.compiler.nodes.calc; -import static org.graalvm.compiler.nodeinfo.NodeCycles.CYCLES_0; +import static org.graalvm.compiler.nodeinfo.NodeCycles.CYCLES_1; import static org.graalvm.compiler.nodeinfo.NodeSize.SIZE_2; import static org.graalvm.compiler.nodes.calc.CompareNode.createCompareNode; @@ -47,10 +47,10 @@ import jdk.vm.ci.meta.JavaConstant; /** - * The {@code ConditionalNode} class represents a comparison that yields one of two values. Note - * that these nodes are not built directly from the bytecode but are introduced by canonicalization. + * The {@code ConditionalNode} class represents a comparison that yields one of two (eagerly + * evaluated) values. */ -@NodeInfo(cycles = CYCLES_0, size = SIZE_2) +@NodeInfo(cycles = CYCLES_1, size = SIZE_2) public final class ConditionalNode extends FloatingNode implements Canonicalizable, LIRLowerable { public static final NodeClass TYPE = NodeClass.create(ConditionalNode.class); @@ -116,7 +116,6 @@ valueStamp = valueStamp.join(bounds); } } - } return updateStamp(valueStamp); } @@ -145,49 +144,10 @@ } public static ValueNode canonicalizeConditional(LogicNode condition, ValueNode trueValue, ValueNode falseValue, Stamp stamp) { - // this optimizes the case where a value from the range 0 - 1 is mapped to the range 0 - 1 - if (trueValue.isConstant() && falseValue.isConstant() && trueValue.stamp() instanceof IntegerStamp && falseValue.stamp() instanceof IntegerStamp) { - long constTrueValue = trueValue.asJavaConstant().asLong(); - long constFalseValue = falseValue.asJavaConstant().asLong(); - if (condition instanceof IntegerEqualsNode) { - IntegerEqualsNode equals = (IntegerEqualsNode) condition; - if (equals.getY().isConstant() && equals.getX().stamp() instanceof IntegerStamp) { - IntegerStamp equalsXStamp = (IntegerStamp) equals.getX().stamp(); - if (equalsXStamp.upMask() == 1) { - long equalsY = equals.getY().asJavaConstant().asLong(); - if (equalsY == 0) { - if (constTrueValue == 0 && constFalseValue == 1) { - // return x when: x == 0 ? 0 : 1; - return IntegerConvertNode.convertUnsigned(equals.getX(), stamp); - } else if (constTrueValue == 1 && constFalseValue == 0) { - // negate a boolean value via xor - return IntegerConvertNode.convertUnsigned(XorNode.create(equals.getX(), ConstantNode.forIntegerStamp(equals.getX().stamp(), 1)), stamp); - } - } else if (equalsY == 1) { - if (constTrueValue == 1 && constFalseValue == 0) { - // return x when: x == 1 ? 1 : 0; - return IntegerConvertNode.convertUnsigned(equals.getX(), stamp); - } else if (constTrueValue == 0 && constFalseValue == 1) { - // negate a boolean value via xor - return IntegerConvertNode.convertUnsigned(XorNode.create(equals.getX(), ConstantNode.forIntegerStamp(equals.getX().stamp(), 1)), stamp); - } - } - } - } - } else if (condition instanceof IntegerTestNode) { - // replace IntegerTestNode with AndNode for the following patterns: - // (value & 1) == 0 ? 0 : 1 - // (value & 1) == 1 ? 1 : 0 - IntegerTestNode integerTestNode = (IntegerTestNode) condition; - if (integerTestNode.getY().isConstant()) { - assert integerTestNode.getX().stamp() instanceof IntegerStamp; - long testY = integerTestNode.getY().asJavaConstant().asLong(); - if (testY == 1 && constTrueValue == 0 && constFalseValue == 1) { - return IntegerConvertNode.convertUnsigned(AndNode.create(integerTestNode.getX(), integerTestNode.getY()), stamp); - } - } - } + if (trueValue == falseValue) { + return trueValue; } + if (condition instanceof CompareNode && ((CompareNode) condition).isIdentityComparison()) { // optimize the pattern (x == y) ? x : y CompareNode compare = (CompareNode) condition; @@ -195,25 +155,87 @@ return falseValue; } } - if (trueValue == falseValue) { - return trueValue; - } - if (condition instanceof IntegerLessThanNode && trueValue.stamp() instanceof IntegerStamp) { - /* - * Convert a conditional add ((x < 0) ? (x + y) : x) into (x + (y & (x >> (bits - 1)))) - * to avoid the test. - */ - IntegerLessThanNode lt = (IntegerLessThanNode) condition; - if (lt.getY().isConstant() && lt.getY().asConstant().isDefaultForKind()) { - if (falseValue == lt.getX()) { - if (trueValue instanceof AddNode) { - AddNode add = (AddNode) trueValue; - if (add.getX() == falseValue) { - int bits = ((IntegerStamp) trueValue.stamp()).getBits(); - ValueNode shift = new RightShiftNode(lt.getX(), ConstantNode.forIntegerBits(32, bits - 1)); - ValueNode and = new AndNode(shift, add.getY()); - return new AddNode(add.getX(), and); + if (trueValue.stamp() instanceof IntegerStamp) { + // check if the conditional is redundant + if (condition instanceof IntegerLessThanNode) { + IntegerLessThanNode lessThan = (IntegerLessThanNode) condition; + IntegerStamp falseValueStamp = (IntegerStamp) falseValue.stamp(); + IntegerStamp trueValueStamp = (IntegerStamp) trueValue.stamp(); + if (lessThan.getX() == trueValue && lessThan.getY() == falseValue) { + // return "x" for "x < y ? x : y" in case that we know "x <= y" + if (trueValueStamp.upperBound() <= falseValueStamp.lowerBound()) { + return trueValue; + } + } else if (lessThan.getX() == falseValue && lessThan.getY() == trueValue) { + // return "x" for "x < y ? y : x" in case that we know "x <= y" + if (falseValueStamp.upperBound() <= trueValueStamp.lowerBound()) { + return falseValue; + } + } + } + + // this optimizes the case where a value from the range 0 - 1 is mapped to the + // range 0 - 1 + if (trueValue.isConstant() && falseValue.isConstant()) { + long constTrueValue = trueValue.asJavaConstant().asLong(); + long constFalseValue = falseValue.asJavaConstant().asLong(); + if (condition instanceof IntegerEqualsNode) { + IntegerEqualsNode equals = (IntegerEqualsNode) condition; + if (equals.getY().isConstant() && equals.getX().stamp() instanceof IntegerStamp) { + IntegerStamp equalsXStamp = (IntegerStamp) equals.getX().stamp(); + if (equalsXStamp.upMask() == 1) { + long equalsY = equals.getY().asJavaConstant().asLong(); + if (equalsY == 0) { + if (constTrueValue == 0 && constFalseValue == 1) { + // return x when: x == 0 ? 0 : 1; + return IntegerConvertNode.convertUnsigned(equals.getX(), stamp); + } else if (constTrueValue == 1 && constFalseValue == 0) { + // negate a boolean value via xor + return IntegerConvertNode.convertUnsigned(XorNode.create(equals.getX(), ConstantNode.forIntegerStamp(equals.getX().stamp(), 1)), stamp); + } + } else if (equalsY == 1) { + if (constTrueValue == 1 && constFalseValue == 0) { + // return x when: x == 1 ? 1 : 0; + return IntegerConvertNode.convertUnsigned(equals.getX(), stamp); + } else if (constTrueValue == 0 && constFalseValue == 1) { + // negate a boolean value via xor + return IntegerConvertNode.convertUnsigned(XorNode.create(equals.getX(), ConstantNode.forIntegerStamp(equals.getX().stamp(), 1)), stamp); + } + } + } + } + } else if (condition instanceof IntegerTestNode) { + // replace IntegerTestNode with AndNode for the following patterns: + // (value & 1) == 0 ? 0 : 1 + // (value & 1) == 1 ? 1 : 0 + IntegerTestNode integerTestNode = (IntegerTestNode) condition; + if (integerTestNode.getY().isConstant()) { + assert integerTestNode.getX().stamp() instanceof IntegerStamp; + long testY = integerTestNode.getY().asJavaConstant().asLong(); + if (testY == 1 && constTrueValue == 0 && constFalseValue == 1) { + return IntegerConvertNode.convertUnsigned(AndNode.create(integerTestNode.getX(), integerTestNode.getY()), stamp); + } + } + } + } + + if (condition instanceof IntegerLessThanNode) { + /* + * Convert a conditional add ((x < 0) ? (x + y) : x) into (x + (y & (x >> (bits - + * 1)))) to avoid the test. + */ + IntegerLessThanNode lt = (IntegerLessThanNode) condition; + if (lt.getY().isConstant() && lt.getY().asConstant().isDefaultForKind()) { + if (falseValue == lt.getX()) { + if (trueValue instanceof AddNode) { + AddNode add = (AddNode) trueValue; + if (add.getX() == falseValue) { + int bits = ((IntegerStamp) trueValue.stamp()).getBits(); + ValueNode shift = new RightShiftNode(lt.getX(), ConstantNode.forIntegerBits(32, bits - 1)); + ValueNode and = new AndNode(shift, add.getY()); + return new AddNode(add.getX(), and); + } } } }