1 /*
   2  * Copyright (c) 2017, 2018, Oracle and/or its affiliates. All rights reserved.
   3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
   4  *
   5  * This code is free software; you can redistribute it and/or modify it
   6  * under the terms of the GNU General Public License version 2 only, as
   7  * published by the Free Software Foundation.
   8  *
   9  * This code is distributed in the hope that it will be useful, but WITHOUT
  10  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
  11  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
  12  * version 2 for more details (a copy is included in the LICENSE file that
  13  * accompanied this code).
  14  *
  15  * You should have received a copy of the GNU General Public License version
  16  * 2 along with this work; if not, write to the Free Software Foundation,
  17  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
  18  *
  19  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
  20  * or visit www.oracle.com if you need additional information or have any
  21  * questions.
  22  */
  23 
  24 
  25 package org.graalvm.compiler.replacements.test;
  26 
  27 import static org.junit.Assert.assertNotNull;
  28 
  29 import java.util.ArrayList;
  30 import java.util.Collection;
  31 import java.util.List;
  32 
  33 import org.graalvm.compiler.core.common.type.IntegerStamp;
  34 import org.graalvm.compiler.core.common.type.StampFactory;
  35 import org.graalvm.compiler.core.test.GraalCompilerTest;
  36 import org.graalvm.compiler.graph.Node;
  37 import org.graalvm.compiler.nodes.NodeView;
  38 import org.graalvm.compiler.nodes.ParameterNode;
  39 import org.graalvm.compiler.nodes.PiNode;
  40 import org.graalvm.compiler.nodes.ReturnNode;
  41 import org.graalvm.compiler.nodes.StructuredGraph;
  42 import org.graalvm.compiler.nodes.StructuredGraph.AllowAssumptions;
  43 import org.graalvm.compiler.nodes.ValueNode;
  44 import org.graalvm.compiler.nodes.spi.LoweringTool;
  45 import org.graalvm.compiler.phases.common.CanonicalizerPhase;
  46 import org.graalvm.compiler.phases.common.GuardLoweringPhase;
  47 import org.graalvm.compiler.phases.common.LoweringPhase;
  48 import org.graalvm.compiler.phases.tiers.HighTierContext;
  49 import org.graalvm.compiler.phases.tiers.MidTierContext;
  50 import org.graalvm.compiler.replacements.nodes.arithmetic.IntegerExactArithmeticNode;
  51 import org.graalvm.compiler.replacements.nodes.arithmetic.IntegerExactArithmeticSplitNode;
  52 import org.junit.Assert;
  53 import org.junit.Test;
  54 import org.junit.runner.RunWith;
  55 import org.junit.runners.Parameterized;
  56 import org.junit.runners.Parameterized.Parameters;
  57 
  58 @RunWith(Parameterized.class)
  59 public class IntegerExactFoldTest extends GraalCompilerTest {
  60     private final long lowerBoundA;
  61     private final long upperBoundA;
  62     private final long lowerBoundB;
  63     private final long upperBoundB;
  64     private final int bits;
  65     private final Operation operation;
  66 
  67     public IntegerExactFoldTest(long lowerBoundA, long upperBoundA, long lowerBoundB, long upperBoundB, int bits, Operation operation) {
  68         this.lowerBoundA = lowerBoundA;
  69         this.upperBoundA = upperBoundA;
  70         this.lowerBoundB = lowerBoundB;
  71         this.upperBoundB = upperBoundB;
  72         this.bits = bits;
  73         this.operation = operation;
  74 
  75         assert bits == 32 || bits == 64;
  76         assert lowerBoundA <= upperBoundA;
  77         assert lowerBoundB <= upperBoundB;
  78         assert bits == 64 || isInteger(lowerBoundA);
  79         assert bits == 64 || isInteger(upperBoundA);
  80         assert bits == 64 || isInteger(lowerBoundB);
  81         assert bits == 64 || isInteger(upperBoundB);
  82     }
  83 
  84     @Test
  85     public void testFolding() {
  86         StructuredGraph graph = prepareGraph();
  87         IntegerStamp a = StampFactory.forInteger(bits, lowerBoundA, upperBoundA);
  88         IntegerStamp b = StampFactory.forInteger(bits, lowerBoundB, upperBoundB);
  89 
  90         List<ParameterNode> params = graph.getNodes(ParameterNode.TYPE).snapshot();
  91         params.get(0).replaceAtMatchingUsages(graph.addOrUnique(new PiNode(params.get(0), a)), x -> x instanceof IntegerExactArithmeticNode);
  92         params.get(1).replaceAtMatchingUsages(graph.addOrUnique(new PiNode(params.get(1), b)), x -> x instanceof IntegerExactArithmeticNode);
  93 
  94         Node originalNode = graph.getNodes().filter(x -> x instanceof IntegerExactArithmeticNode).first();
  95         assertNotNull("original node must be in the graph", originalNode);
  96 
  97         new CanonicalizerPhase().apply(graph, getDefaultHighTierContext());
  98 
  99         ValueNode node = findNode(graph);
 100         boolean overflowExpected = node instanceof IntegerExactArithmeticNode;
 101 
 102         IntegerStamp resultStamp = (IntegerStamp) node.stamp(NodeView.DEFAULT);
 103         operation.verifyOverflow(lowerBoundA, upperBoundA, lowerBoundB, upperBoundB, bits, overflowExpected, resultStamp);
 104     }
 105 
 106     @Test
 107     public void testFoldingAfterLowering() {
 108         StructuredGraph graph = prepareGraph();
 109 
 110         Node originalNode = graph.getNodes().filter(x -> x instanceof IntegerExactArithmeticNode).first();
 111         assertNotNull("original node must be in the graph", originalNode);
 112         CanonicalizerPhase canonicalizer = new CanonicalizerPhase();
 113         HighTierContext highTierContext = getDefaultHighTierContext();
 114         new LoweringPhase(canonicalizer, LoweringTool.StandardLoweringStage.HIGH_TIER).apply(graph, highTierContext);
 115         MidTierContext midTierContext = getDefaultMidTierContext();
 116         new GuardLoweringPhase().apply(graph, midTierContext);
 117         new CanonicalizerPhase().apply(graph, midTierContext);
 118 
 119         IntegerExactArithmeticSplitNode loweredNode = graph.getNodes().filter(IntegerExactArithmeticSplitNode.class).first();
 120         assertNotNull("the lowered node must be in the graph", loweredNode);
 121 
 122         loweredNode.getX().setStamp(StampFactory.forInteger(bits, lowerBoundA, upperBoundA));
 123         loweredNode.getY().setStamp(StampFactory.forInteger(bits, lowerBoundB, upperBoundB));
 124         new CanonicalizerPhase().apply(graph, midTierContext);
 125 
 126         ValueNode node = findNode(graph);
 127         boolean overflowExpected = node instanceof IntegerExactArithmeticSplitNode;
 128 
 129         IntegerStamp resultStamp = (IntegerStamp) node.stamp(NodeView.DEFAULT);
 130         operation.verifyOverflow(lowerBoundA, upperBoundA, lowerBoundB, upperBoundB, bits, overflowExpected, resultStamp);
 131     }
 132 
 133     private static boolean isInteger(long value) {
 134         return value >= Integer.MIN_VALUE && value <= Integer.MAX_VALUE;
 135     }
 136 
 137     private static ValueNode findNode(StructuredGraph graph) {
 138         ValueNode resultNode = graph.getNodes().filter(ReturnNode.class).first().result();
 139         assertNotNull("some node must be the returned value", resultNode);
 140         return resultNode;
 141     }
 142 
 143     protected StructuredGraph prepareGraph() {
 144         String snippet = "snippetInt" + bits;
 145         StructuredGraph graph = parseEager(getResolvedJavaMethod(operation.getClass(), snippet), AllowAssumptions.NO);
 146         HighTierContext context = getDefaultHighTierContext();
 147         new CanonicalizerPhase().apply(graph, context);
 148         return graph;
 149     }
 150 
 151     private static void addTest(ArrayList<Object[]> tests, long lowerBound1, long upperBound1, long lowerBound2, long upperBound2, int bits, Operation operation) {
 152         tests.add(new Object[]{lowerBound1, upperBound1, lowerBound2, upperBound2, bits, operation});
 153     }
 154 
 155     @Parameters(name = "a[{0} / {1}], b[{2} / {3}], bits={4}, operation={5}")
 156     public static Collection<Object[]> data() {
 157         ArrayList<Object[]> tests = new ArrayList<>();
 158 
 159         Operation[] operations = new Operation[]{new AddOperation(), new SubOperation(), new MulOperation()};
 160         for (Operation operation : operations) {
 161             for (int bits : new int[]{32, 64}) {
 162                 // zero related
 163                 addTest(tests, 0, 0, 1, 1, bits, operation);
 164                 addTest(tests, 1, 1, 0, 0, bits, operation);
 165                 addTest(tests, -1, 1, 0, 1, bits, operation);
 166 
 167                 // bounds
 168                 addTest(tests, -2, 2, 3, 3, bits, operation);
 169                 addTest(tests, -1, 1, 1, 1, bits, operation);
 170                 addTest(tests, -1, 1, -1, 1, bits, operation);
 171 
 172                 addTest(tests, Integer.MIN_VALUE, Integer.MIN_VALUE + 0xF, Integer.MAX_VALUE - 0xF, Integer.MAX_VALUE, bits, operation);
 173                 addTest(tests, Integer.MIN_VALUE, Integer.MIN_VALUE + 0xF, -1, -1, bits, operation);
 174                 addTest(tests, Integer.MAX_VALUE, Integer.MAX_VALUE, -1, -1, bits, operation);
 175                 addTest(tests, Integer.MIN_VALUE, Integer.MIN_VALUE, -1, -1, bits, operation);
 176                 addTest(tests, Integer.MAX_VALUE, Integer.MAX_VALUE, 1, 1, bits, operation);
 177                 addTest(tests, Integer.MIN_VALUE, Integer.MIN_VALUE, 1, 1, bits, operation);
 178             }
 179 
 180             // bit-specific test cases
 181             addTest(tests, Integer.MIN_VALUE, Integer.MIN_VALUE + 0xF, Integer.MAX_VALUE - 0xF, Integer.MAX_VALUE, 64, operation);
 182             addTest(tests, Integer.MIN_VALUE, Integer.MIN_VALUE + 0xF, -1, -1, 64, operation);
 183         }
 184 
 185         return tests;
 186     }
 187 
 188     private abstract static class Operation {
 189         abstract void verifyOverflow(long lowerBoundA, long upperBoundA, long lowerBoundB, long upperBoundB, int bits, boolean overflowExpected, IntegerStamp resultStamp);
 190     }
 191 
 192     private static final class AddOperation extends Operation {
 193         @Override
 194         public void verifyOverflow(long lowerBoundA, long upperBoundA, long lowerBoundB, long upperBoundB, int bits, boolean overflowExpected, IntegerStamp resultStamp) {
 195             try {
 196                 long res = addExact(lowerBoundA, lowerBoundB, bits);
 197                 resultStamp.contains(res);
 198                 res = addExact(upperBoundA, upperBoundB, bits);
 199                 resultStamp.contains(res);
 200                 Assert.assertFalse(overflowExpected);
 201             } catch (ArithmeticException e) {
 202                 Assert.assertTrue(overflowExpected);
 203             }
 204         }
 205 
 206         private static long addExact(long x, long y, int bits) {
 207             if (bits == 32) {
 208                 return Math.addExact((int) x, (int) y);
 209             } else {
 210                 return Math.addExact(x, y);
 211             }
 212         }
 213 
 214         @SuppressWarnings("unused")
 215         public static int snippetInt32(int a, int b) {
 216             return Math.addExact(a, b);
 217         }
 218 
 219         @SuppressWarnings("unused")
 220         public static long snippetInt64(long a, long b) {
 221             return Math.addExact(a, b);
 222         }
 223     }
 224 
 225     private static final class SubOperation extends Operation {
 226         @Override
 227         public void verifyOverflow(long lowerBoundA, long upperBoundA, long lowerBoundB, long upperBoundB, int bits, boolean overflowExpected, IntegerStamp resultStamp) {
 228             try {
 229                 long res = subExact(lowerBoundA, upperBoundB, bits);
 230                 Assert.assertTrue(resultStamp.contains(res));
 231                 res = subExact(upperBoundA, lowerBoundB, bits);
 232                 Assert.assertTrue(resultStamp.contains(res));
 233                 Assert.assertFalse(overflowExpected);
 234             } catch (ArithmeticException e) {
 235                 Assert.assertTrue(overflowExpected);
 236             }
 237         }
 238 
 239         private static long subExact(long x, long y, int bits) {
 240             if (bits == 32) {
 241                 return Math.subtractExact((int) x, (int) y);
 242             } else {
 243                 return Math.subtractExact(x, y);
 244             }
 245         }
 246 
 247         @SuppressWarnings("unused")
 248         public static int snippetInt32(int a, int b) {
 249             return Math.subtractExact(a, b);
 250         }
 251 
 252         @SuppressWarnings("unused")
 253         public static long snippetInt64(long a, long b) {
 254             return Math.subtractExact(a, b);
 255         }
 256     }
 257 
 258     private static final class MulOperation extends Operation {
 259         @Override
 260         public void verifyOverflow(long lowerBoundA, long upperBoundA, long lowerBoundB, long upperBoundB, int bits, boolean overflowExpected, IntegerStamp resultStamp) {
 261             // now check for all values in the stamp whether their products overflow overflow
 262             boolean overflowOccurred = false;
 263 
 264             for (long l1 = lowerBoundA; l1 <= upperBoundA; l1++) {
 265                 for (long l2 = lowerBoundB; l2 <= upperBoundB; l2++) {
 266                     try {
 267                         long res = mulExact(l1, l2, bits);
 268                         Assert.assertTrue(resultStamp.contains(res));
 269                     } catch (ArithmeticException e) {
 270                         overflowOccurred = true;
 271                     }
 272                     if (l2 == Long.MAX_VALUE) {
 273                         // do not want to overflow the check loop
 274                         break;
 275                     }
 276                 }
 277                 if (l1 == Long.MAX_VALUE) {
 278                     // do not want to overflow the check loop
 279                     break;
 280                 }
 281             }
 282 
 283             Assert.assertEquals(overflowExpected, overflowOccurred);
 284         }
 285 
 286         private static long mulExact(long x, long y, int bits) {
 287             if (bits == 32) {
 288                 return Math.multiplyExact((int) x, (int) y);
 289             } else {
 290                 return Math.multiplyExact(x, y);
 291             }
 292         }
 293 
 294         @SuppressWarnings("unused")
 295         public static int snippetInt32(int a, int b) {
 296             return Math.multiplyExact(a, b);
 297         }
 298 
 299         @SuppressWarnings("unused")
 300         public static long snippetInt64(long a, long b) {
 301             return Math.multiplyExact(a, b);
 302         }
 303     }
 304 }