1 /*
   2  * Copyright (c) 2012, 2018, Oracle and/or its affiliates. All rights reserved.
   3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
   4  *
   5  * This code is free software; you can redistribute it and/or modify it
   6  * under the terms of the GNU General Public License version 2 only, as
   7  * published by the Free Software Foundation.
   8  *
   9  * This code is distributed in the hope that it will be useful, but WITHOUT
  10  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
  11  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
  12  * version 2 for more details (a copy is included in the LICENSE file that
  13  * accompanied this code).
  14  *
  15  * You should have received a copy of the GNU General Public License version
  16  * 2 along with this work; if not, write to the Free Software Foundation,
  17  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
  18  *
  19  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
  20  * or visit www.oracle.com if you need additional information or have any
  21  * questions.
  22  */
  23 
  24 
  25 package org.graalvm.compiler.nodes.test;
  26 
  27 import java.util.EnumSet;
  28 import java.util.HashSet;
  29 
  30 import org.graalvm.compiler.core.common.calc.FloatConvert;
  31 import org.graalvm.compiler.core.common.calc.FloatConvertCategory;
  32 import org.graalvm.compiler.core.common.type.ArithmeticOpTable;
  33 import org.graalvm.compiler.core.common.type.ArithmeticOpTable.BinaryOp;
  34 import org.graalvm.compiler.core.common.type.ArithmeticOpTable.IntegerConvertOp;
  35 import org.graalvm.compiler.core.common.type.ArithmeticOpTable.ShiftOp;
  36 import org.graalvm.compiler.core.common.type.FloatStamp;
  37 import org.graalvm.compiler.core.common.type.IntegerStamp;
  38 import org.graalvm.compiler.core.common.type.PrimitiveStamp;
  39 import org.graalvm.compiler.core.common.type.Stamp;
  40 import org.graalvm.compiler.core.common.type.StampFactory;
  41 import org.graalvm.compiler.test.GraalTest;
  42 import org.junit.Test;
  43 
  44 import jdk.vm.ci.meta.Constant;
  45 import jdk.vm.ci.meta.JavaConstant;
  46 import jdk.vm.ci.meta.JavaKind;
  47 
  48 /**
  49  * Exercise the various stamp folding operations by generating ranges from a set of boundary values
  50  * and then ensuring that the values that produced those ranges are in the resulting stamp.
  51  */
  52 public class PrimitiveStampBoundaryTest extends GraalTest {
  53 
  54     static long[] longBoundaryValues = {Long.MIN_VALUE, Long.MIN_VALUE + 1, Integer.MIN_VALUE, Integer.MIN_VALUE + 1, -1, 0, 1, Integer.MAX_VALUE - 1, Integer.MAX_VALUE, Long.MAX_VALUE - 1,
  55                     Long.MAX_VALUE};
  56 
  57     static int[] shiftBoundaryValues = {-128, -1, 0, 1, 4, 8, 16, 31, 63, 128};
  58 
  59     static HashSet<IntegerStamp> shiftStamps;
  60     static HashSet<PrimitiveStamp> integerTestStamps;
  61     static HashSet<PrimitiveStamp> floatTestStamps;
  62 
  63     static {
  64         shiftStamps = new HashSet<>();
  65         for (long v1 : shiftBoundaryValues) {
  66             for (long v2 : shiftBoundaryValues) {
  67                 shiftStamps.add(IntegerStamp.create(32, Math.min(v1, v2), Math.max(v1, v2)));
  68             }
  69         }
  70         shiftStamps.add((IntegerStamp) StampFactory.empty(JavaKind.Int));
  71 
  72         integerTestStamps = new HashSet<>();
  73         for (long v1 : longBoundaryValues) {
  74             for (long v2 : longBoundaryValues) {
  75                 if (v2 == (int) v2 && v1 == (int) v1) {
  76                     integerTestStamps.add(IntegerStamp.create(32, Math.min(v1, v2), Math.max(v1, v2)));
  77                 }
  78                 integerTestStamps.add(IntegerStamp.create(64, Math.min(v1, v2), Math.max(v1, v2)));
  79             }
  80         }
  81         integerTestStamps.add((PrimitiveStamp) StampFactory.empty(JavaKind.Int));
  82         integerTestStamps.add((PrimitiveStamp) StampFactory.empty(JavaKind.Long));
  83     }
  84 
  85     static double[] doubleBoundaryValues = {Double.NEGATIVE_INFINITY, Double.MIN_VALUE, Float.NEGATIVE_INFINITY, Float.MIN_VALUE,
  86                     Long.MIN_VALUE, Long.MIN_VALUE + 1, Integer.MIN_VALUE, Integer.MIN_VALUE + 1, -1, -0.0, +0.0, 1,
  87                     Integer.MAX_VALUE - 1, Integer.MAX_VALUE, Long.MAX_VALUE - 1, Long.MAX_VALUE,
  88                     Float.MAX_VALUE, Float.POSITIVE_INFINITY, Double.MAX_VALUE, Double.POSITIVE_INFINITY};
  89 
  90     static double[] doubleSpecialValues = {Double.NaN, -0.0, -0.0F, Float.NaN};
  91 
  92     static {
  93         floatTestStamps = new HashSet<>();
  94 
  95         for (double d1 : doubleBoundaryValues) {
  96             for (double d2 : doubleBoundaryValues) {
  97                 float f1 = (float) d2;
  98                 float f2 = (float) d1;
  99                 if (d2 == f1 && d1 == f2) {
 100                     generateFloatingStamps(new FloatStamp(32, Math.min(f2, f1), Math.max(f2, f1), true));
 101                     generateFloatingStamps(new FloatStamp(32, Math.min(f2, f1), Math.max(f2, f1), false));
 102                 }
 103                 generateFloatingStamps(new FloatStamp(64, Math.min(d1, d2), Math.max(d1, d2), true));
 104                 generateFloatingStamps(new FloatStamp(64, Math.min(d1, d2), Math.max(d1, d2), false));
 105             }
 106         }
 107         floatTestStamps.add((PrimitiveStamp) StampFactory.empty(JavaKind.Float));
 108         floatTestStamps.add((PrimitiveStamp) StampFactory.empty(JavaKind.Double));
 109     }
 110 
 111     private static void generateFloatingStamps(FloatStamp floatStamp) {
 112         floatTestStamps.add(floatStamp);
 113         for (double d : doubleSpecialValues) {
 114             FloatStamp newStamp = (FloatStamp) floatStamp.meet(floatStampForConstant(d, floatStamp.getBits()));
 115             if (!newStamp.isUnrestricted()) {
 116                 floatTestStamps.add(newStamp);
 117             }
 118         }
 119     }
 120 
 121     @Test
 122     public void testConvertBoundaryValues() {
 123         testConvertBoundaryValues(IntegerStamp.OPS.getSignExtend(), 32, 64, integerTestStamps);
 124         testConvertBoundaryValues(IntegerStamp.OPS.getZeroExtend(), 32, 64, integerTestStamps);
 125         testConvertBoundaryValues(IntegerStamp.OPS.getNarrow(), 64, 32, integerTestStamps);
 126     }
 127 
 128     private static void testConvertBoundaryValues(IntegerConvertOp<?> op, int inputBits, int resultBits, HashSet<PrimitiveStamp> stamps) {
 129         for (PrimitiveStamp stamp : stamps) {
 130             if (inputBits == stamp.getBits()) {
 131                 Stamp lower = boundaryStamp(stamp, false);
 132                 Stamp upper = boundaryStamp(stamp, true);
 133                 checkConvertOperation(op, inputBits, resultBits, op.foldStamp(inputBits, resultBits, stamp), lower);
 134                 checkConvertOperation(op, inputBits, resultBits, op.foldStamp(inputBits, resultBits, stamp), upper);
 135             }
 136         }
 137     }
 138 
 139     private static void checkConvertOperation(IntegerConvertOp<?> op, int inputBits, int resultBits, Stamp result, Stamp v1stamp) {
 140         Stamp folded = op.foldStamp(inputBits, resultBits, v1stamp);
 141         assertTrue(folded.isEmpty() || folded.asConstant() != null, "should constant fold %s %s %s", op, v1stamp, folded);
 142         assertTrue(result.meet(folded).equals(result), "result out of range %s %s %s %s %s", op, v1stamp, folded, result, result.meet(folded));
 143     }
 144 
 145     @Test
 146     public void testFloatConvertBoundaryValues() {
 147         for (FloatConvert op : EnumSet.allOf(FloatConvert.class)) {
 148             ArithmeticOpTable.FloatConvertOp floatConvert = IntegerStamp.OPS.getFloatConvert(op);
 149             if (floatConvert == null) {
 150                 continue;
 151             }
 152             assert op.getCategory() == FloatConvertCategory.IntegerToFloatingPoint : op;
 153             testConvertBoundaryValues(floatConvert, op.getInputBits(), integerTestStamps);
 154         }
 155         for (FloatConvert op : EnumSet.allOf(FloatConvert.class)) {
 156             ArithmeticOpTable.FloatConvertOp floatConvert = FloatStamp.OPS.getFloatConvert(op);
 157             if (floatConvert == null) {
 158                 continue;
 159             }
 160             assert op.getCategory() == FloatConvertCategory.FloatingPointToInteger || op.getCategory() == FloatConvertCategory.FloatingPointToFloatingPoint : op;
 161             testConvertBoundaryValues(floatConvert, op.getInputBits(), floatTestStamps);
 162         }
 163     }
 164 
 165     private static void testConvertBoundaryValues(ArithmeticOpTable.FloatConvertOp op, int bits, HashSet<PrimitiveStamp> stamps) {
 166         for (PrimitiveStamp stamp : stamps) {
 167             if (bits == stamp.getBits()) {
 168                 Stamp lower = boundaryStamp(stamp, false);
 169                 Stamp upper = boundaryStamp(stamp, true);
 170                 checkConvertOperation(op, op.foldStamp(stamp), lower);
 171                 checkConvertOperation(op, op.foldStamp(stamp), upper);
 172             }
 173         }
 174 
 175     }
 176 
 177     static void shouldConstantFold(boolean b, Stamp folded, Object o, Stamp s1) {
 178         assertTrue(b || (folded instanceof FloatStamp && ((FloatStamp) folded).contains(0.0)), "should constant fold %s %s %s", o, s1, folded);
 179     }
 180 
 181     private static boolean constantFloatStampMayIncludeNegativeZero(Stamp s) {
 182         if (s instanceof FloatStamp) {
 183             FloatStamp f = (FloatStamp) s;
 184             return Double.compare(f.lowerBound(), f.upperBound()) == 0 && f.isNonNaN();
 185         }
 186         return false;
 187     }
 188 
 189     private static void checkConvertOperation(ArithmeticOpTable.FloatConvertOp op, Stamp result, Stamp v1stamp) {
 190         Stamp folded = op.foldStamp(v1stamp);
 191         shouldConstantFold(folded.isEmpty() || folded.asConstant() != null, folded, op, v1stamp);
 192         assertTrue(result.meet(folded).equals(result), "result out of range %s %s %s %s %s", op, v1stamp, folded, result, result.meet(folded));
 193     }
 194 
 195     @Test
 196     public void testShiftBoundaryValues() {
 197         for (ShiftOp<?> op : IntegerStamp.OPS.getShiftOps()) {
 198             testShiftBoundaryValues(op, integerTestStamps, shiftStamps);
 199         }
 200     }
 201 
 202     private static void testShiftBoundaryValues(ShiftOp<?> shiftOp, HashSet<PrimitiveStamp> stamps, HashSet<IntegerStamp> shifts) {
 203         for (PrimitiveStamp testStamp : stamps) {
 204             if (testStamp instanceof IntegerStamp) {
 205                 IntegerStamp stamp = (IntegerStamp) testStamp;
 206                 for (IntegerStamp shiftStamp : shifts) {
 207                     IntegerStamp foldedStamp = (IntegerStamp) shiftOp.foldStamp(stamp, shiftStamp);
 208                     if (foldedStamp.isEmpty()) {
 209                         assertTrue(stamp.isEmpty() || shiftStamp.isEmpty());
 210                         continue;
 211                     }
 212                     checkShiftOperation(stamp.getBits(), shiftOp, foldedStamp, stamp.lowerBound(), shiftStamp.lowerBound());
 213                     checkShiftOperation(stamp.getBits(), shiftOp, foldedStamp, stamp.lowerBound(), shiftStamp.upperBound());
 214                     checkShiftOperation(stamp.getBits(), shiftOp, foldedStamp, stamp.upperBound(), shiftStamp.lowerBound());
 215                     checkShiftOperation(stamp.getBits(), shiftOp, foldedStamp, stamp.upperBound(), shiftStamp.upperBound());
 216                 }
 217             }
 218         }
 219     }
 220 
 221     private static void checkShiftOperation(int bits, ShiftOp<?> op, IntegerStamp result, long v1, long v2) {
 222         IntegerStamp v1stamp = IntegerStamp.create(bits, v1, v1);
 223         IntegerStamp v2stamp = IntegerStamp.create(32, v2, v2);
 224         IntegerStamp folded = (IntegerStamp) op.foldStamp(v1stamp, v2stamp);
 225         Constant constant = op.foldConstant(JavaConstant.forPrimitiveInt(bits, v1), (int) v2);
 226         assertTrue(constant != null);
 227         assertTrue(folded.asConstant() != null, "should constant fold %s %s %s %s", op, v1stamp, v2stamp, folded);
 228         assertTrue(result.meet(folded).equals(result), "result out of range %s %s %s %s %s %s", op, v1stamp, v2stamp, folded, result, result.meet(folded));
 229     }
 230 
 231     private static void checkBinaryOperation(ArithmeticOpTable.BinaryOp<?> op, Stamp result, Stamp v1stamp, Stamp v2stamp) {
 232         if (constantFloatStampMayIncludeNegativeZero(v1stamp) || constantFloatStampMayIncludeNegativeZero(v2stamp)) {
 233             return;
 234         }
 235         Stamp folded = op.foldStamp(v1stamp, v2stamp);
 236         if (v1stamp.isEmpty() || v2stamp.isEmpty()) {
 237             assertTrue(folded.isEmpty());
 238             assertTrue(v1stamp.asConstant() != null || v1stamp.isEmpty());
 239             assertTrue(v2stamp.asConstant() != null || v2stamp.isEmpty());
 240             return;
 241         }
 242         Constant constant = op.foldConstant(v1stamp.asConstant(), v2stamp.asConstant());
 243         if (constant != null) {
 244             assertFalse(folded.isEmpty());
 245             Constant constant2 = folded.asConstant();
 246             if (constant2 == null && v1stamp instanceof FloatStamp) {
 247                 JavaConstant c = (JavaConstant) constant;
 248                 assertTrue((c.getJavaKind() == JavaKind.Double && Double.isNaN(c.asDouble())) ||
 249                                 (c.getJavaKind() == JavaKind.Float && Float.isNaN(c.asFloat())));
 250             } else {
 251                 assertTrue(constant2 != null, "should constant fold %s %s %s %s", op, v1stamp, v2stamp, folded);
 252                 if (!constant.equals(constant2)) {
 253                     op.foldConstant(v1stamp.asConstant(), v2stamp.asConstant());
 254                     op.foldStamp(v1stamp, v2stamp);
 255                 }
 256                 assertTrue(constant.equals(constant2), "should produce same constant %s %s %s %s %s", op, v1stamp, v2stamp, constant, constant2);
 257             }
 258             assertTrue(result.meet(folded).equals(result), "result out of range %s %s %s %s %s %s", op, v1stamp, v2stamp, folded, result, result.meet(folded));
 259         }
 260     }
 261 
 262     @Test
 263     public void testBinaryBoundaryValues() {
 264         for (BinaryOp<?> op : IntegerStamp.OPS.getBinaryOps()) {
 265             if (op != null) {
 266                 testBinaryBoundaryValues(op, integerTestStamps);
 267             }
 268         }
 269         for (BinaryOp<?> op : FloatStamp.OPS.getBinaryOps()) {
 270             if (op != null) {
 271                 testBinaryBoundaryValues(op, floatTestStamps);
 272             }
 273         }
 274     }
 275 
 276     private static Stamp boundaryStamp(Stamp v1, boolean upper) {
 277         if (v1.isEmpty()) {
 278             return v1;
 279         }
 280         if (v1 instanceof IntegerStamp) {
 281             IntegerStamp istamp = (IntegerStamp) v1;
 282             long bound = upper ? istamp.upperBound() : istamp.lowerBound();
 283             return IntegerStamp.create(istamp.getBits(), bound, bound);
 284         } else if (v1 instanceof FloatStamp) {
 285             FloatStamp floatStamp = (FloatStamp) v1;
 286             double bound = upper ? floatStamp.upperBound() : floatStamp.lowerBound();
 287             int bits = floatStamp.getBits();
 288             return floatStampForConstant(bound, bits);
 289         } else {
 290             throw new InternalError("unexpected stamp type " + v1);
 291         }
 292     }
 293 
 294     private static FloatStamp floatStampForConstant(double bound, int bits) {
 295         if (bits == 32) {
 296             float fbound = (float) bound;
 297             return new FloatStamp(bits, fbound, fbound, !Float.isNaN(fbound));
 298         } else {
 299             return new FloatStamp(bits, bound, bound, !Double.isNaN(bound));
 300         }
 301     }
 302 
 303     private static void testBinaryBoundaryValues(ArithmeticOpTable.BinaryOp<?> op, HashSet<PrimitiveStamp> stamps) {
 304         for (PrimitiveStamp v1 : stamps) {
 305             for (PrimitiveStamp v2 : stamps) {
 306                 if (v1.getBits() == v2.getBits() && v1.getClass() == v2.getClass()) {
 307                     Stamp result = op.foldStamp(v1, v2);
 308                     Stamp v1lower = boundaryStamp(v1, false);
 309                     Stamp v1upper = boundaryStamp(v1, true);
 310                     Stamp v2lower = boundaryStamp(v2, false);
 311                     Stamp v2upper = boundaryStamp(v2, true);
 312                     checkBinaryOperation(op, result, v1lower, v2lower);
 313                     checkBinaryOperation(op, result, v1lower, v2upper);
 314                     checkBinaryOperation(op, result, v1upper, v2lower);
 315                     checkBinaryOperation(op, result, v1upper, v2upper);
 316                 }
 317             }
 318         }
 319     }
 320 
 321     @Test
 322     public void testUnaryBoundaryValues() {
 323         for (ArithmeticOpTable.UnaryOp<?> op : IntegerStamp.OPS.getUnaryOps()) {
 324             if (op != null) {
 325                 testUnaryBoundaryValues(op, integerTestStamps);
 326             }
 327         }
 328         for (ArithmeticOpTable.UnaryOp<?> op : FloatStamp.OPS.getUnaryOps()) {
 329             if (op != null) {
 330                 testUnaryBoundaryValues(op, floatTestStamps);
 331             }
 332         }
 333     }
 334 
 335     private static void testUnaryBoundaryValues(ArithmeticOpTable.UnaryOp<?> op, HashSet<PrimitiveStamp> stamps) {
 336         for (PrimitiveStamp v1 : stamps) {
 337             Stamp result = op.foldStamp(v1);
 338             checkUnaryOperation(op, result, boundaryStamp(v1, false));
 339             checkUnaryOperation(op, result, boundaryStamp(v1, true));
 340         }
 341     }
 342 
 343     private static void checkUnaryOperation(ArithmeticOpTable.UnaryOp<?> op, Stamp result, Stamp v1stamp) {
 344         Stamp folded = op.foldStamp(v1stamp);
 345         Constant v1constant = v1stamp.asConstant();
 346         if (v1constant != null) {
 347             Constant constant = op.foldConstant(v1constant);
 348             if (constant != null) {
 349                 Constant constant2 = folded.asConstant();
 350                 if (constant2 == null && v1stamp instanceof FloatStamp) {
 351                     JavaConstant c = (JavaConstant) constant;
 352                     assertTrue((c.getJavaKind() == JavaKind.Double && Double.isNaN(c.asDouble())) ||
 353                                     (c.getJavaKind() == JavaKind.Float && Float.isNaN(c.asFloat())));
 354                 } else {
 355                     assertTrue(constant2 != null, "should constant fold %s %s %s", op, v1stamp, folded);
 356                     assertTrue(constant.equals(constant2), "should produce same constant %s %s %s %s", op, v1stamp, constant, constant2);
 357                 }
 358             }
 359         } else {
 360             assertTrue(v1stamp.isEmpty() || v1stamp instanceof FloatStamp);
 361         }
 362         assertTrue(result.meet(folded).equals(result), "result out of range %s %s %s %s %s", op, v1stamp, folded, result, result.meet(folded));
 363     }
 364 }