1 /*
   2  * Copyright (c) 2017, 2018, Oracle and/or its affiliates. All rights reserved.
   3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
   4  *
   5  * This code is free software; you can redistribute it and/or modify it
   6  * under the terms of the GNU General Public License version 2 only, as
   7  * published by the Free Software Foundation.
   8  *
   9  * This code is distributed in the hope that it will be useful, but WITHOUT
  10  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
  11  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
  12  * version 2 for more details (a copy is included in the LICENSE file that
  13  * accompanied this code).
  14  *
  15  * You should have received a copy of the GNU General Public License version
  16  * 2 along with this work; if not, write to the Free Software Foundation,
  17  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
  18  *
  19  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
  20  * or visit www.oracle.com if you need additional information or have any
  21  * questions.
  22  */
  23 
  24 
  25 package org.graalvm.compiler.replacements.test;
  26 
  27 import static org.junit.Assert.assertNotNull;
  28 
  29 import java.util.ArrayList;
  30 import java.util.Collection;
  31 import java.util.List;
  32 
  33 import org.graalvm.compiler.core.common.type.IntegerStamp;
  34 import org.graalvm.compiler.core.common.type.StampFactory;
  35 import org.graalvm.compiler.core.test.GraalCompilerTest;
  36 import org.graalvm.compiler.graph.Node;
  37 import org.graalvm.compiler.nodes.NodeView;
  38 import org.graalvm.compiler.nodes.ParameterNode;
  39 import org.graalvm.compiler.nodes.PiNode;
  40 import org.graalvm.compiler.nodes.ReturnNode;
  41 import org.graalvm.compiler.nodes.StructuredGraph;
  42 import org.graalvm.compiler.nodes.StructuredGraph.AllowAssumptions;
  43 import org.graalvm.compiler.nodes.StructuredGraph.GuardsStage;
  44 import org.graalvm.compiler.nodes.ValueNode;
  45 import org.graalvm.compiler.nodes.spi.LoweringTool;
  46 import org.graalvm.compiler.phases.common.CanonicalizerPhase;
  47 import org.graalvm.compiler.phases.common.LoweringPhase;
  48 import org.graalvm.compiler.phases.tiers.HighTierContext;
  49 import org.graalvm.compiler.phases.tiers.PhaseContext;
  50 import org.graalvm.compiler.replacements.nodes.arithmetic.IntegerExactArithmeticNode;
  51 import org.graalvm.compiler.replacements.nodes.arithmetic.IntegerExactArithmeticSplitNode;
  52 import org.junit.Assert;
  53 import org.junit.Ignore;
  54 import org.junit.Test;
  55 import org.junit.runner.RunWith;
  56 import org.junit.runners.Parameterized;
  57 import org.junit.runners.Parameterized.Parameters;
  58 
  59 @Ignore
  60 @RunWith(Parameterized.class)
  61 public class IntegerExactFoldTest extends GraalCompilerTest {
  62     private final long lowerBoundA;
  63     private final long upperBoundA;
  64     private final long lowerBoundB;
  65     private final long upperBoundB;
  66     private final int bits;
  67     private final Operation operation;
  68 
  69     public IntegerExactFoldTest(long lowerBoundA, long upperBoundA, long lowerBoundB, long upperBoundB, int bits, Operation operation) {
  70         this.lowerBoundA = lowerBoundA;
  71         this.upperBoundA = upperBoundA;
  72         this.lowerBoundB = lowerBoundB;
  73         this.upperBoundB = upperBoundB;
  74         this.bits = bits;
  75         this.operation = operation;
  76 
  77         assert bits == 32 || bits == 64;
  78         assert lowerBoundA <= upperBoundA;
  79         assert lowerBoundB <= upperBoundB;
  80         assert bits == 64 || isInteger(lowerBoundA);
  81         assert bits == 64 || isInteger(upperBoundA);
  82         assert bits == 64 || isInteger(lowerBoundB);
  83         assert bits == 64 || isInteger(upperBoundB);
  84     }
  85 
  86     @Test
  87     public void testFolding() {
  88         StructuredGraph graph = prepareGraph();
  89         IntegerStamp a = StampFactory.forInteger(bits, lowerBoundA, upperBoundA);
  90         IntegerStamp b = StampFactory.forInteger(bits, lowerBoundB, upperBoundB);
  91 
  92         List<ParameterNode> params = graph.getNodes(ParameterNode.TYPE).snapshot();
  93         params.get(0).replaceAtMatchingUsages(graph.addOrUnique(new PiNode(params.get(0), a)), x -> x instanceof IntegerExactArithmeticNode);
  94         params.get(1).replaceAtMatchingUsages(graph.addOrUnique(new PiNode(params.get(1), b)), x -> x instanceof IntegerExactArithmeticNode);
  95 
  96         Node originalNode = graph.getNodes().filter(x -> x instanceof IntegerExactArithmeticNode).first();
  97         assertNotNull("original node must be in the graph", originalNode);
  98 
  99         new CanonicalizerPhase().apply(graph, getDefaultHighTierContext());
 100         ValueNode node = findNode(graph);
 101         boolean overflowExpected = node instanceof IntegerExactArithmeticNode;
 102 
 103         IntegerStamp resultStamp = (IntegerStamp) node.stamp(NodeView.DEFAULT);
 104         operation.verifyOverflow(lowerBoundA, upperBoundA, lowerBoundB, upperBoundB, bits, overflowExpected, resultStamp);
 105     }
 106 
 107     @Test
 108     public void testFoldingAfterLowering() {
 109         StructuredGraph graph = prepareGraph();
 110 
 111         Node originalNode = graph.getNodes().filter(x -> x instanceof IntegerExactArithmeticNode).first();
 112         assertNotNull("original node must be in the graph", originalNode);
 113 
 114         graph.setGuardsStage(GuardsStage.FIXED_DEOPTS);
 115         CanonicalizerPhase canonicalizer = new CanonicalizerPhase();
 116         PhaseContext context = new PhaseContext(getProviders());
 117         new LoweringPhase(canonicalizer, LoweringTool.StandardLoweringStage.HIGH_TIER).apply(graph, context);
 118         IntegerExactArithmeticSplitNode loweredNode = graph.getNodes().filter(IntegerExactArithmeticSplitNode.class).first();
 119         assertNotNull("the lowered node must be in the graph", loweredNode);
 120 
 121         loweredNode.getX().setStamp(StampFactory.forInteger(bits, lowerBoundA, upperBoundA));
 122         loweredNode.getY().setStamp(StampFactory.forInteger(bits, lowerBoundB, upperBoundB));
 123         new CanonicalizerPhase().apply(graph, context);
 124 
 125         ValueNode node = findNode(graph);
 126         boolean overflowExpected = node instanceof IntegerExactArithmeticSplitNode;
 127 
 128         IntegerStamp resultStamp = (IntegerStamp) node.stamp(NodeView.DEFAULT);
 129         operation.verifyOverflow(lowerBoundA, upperBoundA, lowerBoundB, upperBoundB, bits, overflowExpected, resultStamp);
 130     }
 131 
 132     private static boolean isInteger(long value) {
 133         return value >= Integer.MIN_VALUE && value <= Integer.MAX_VALUE;
 134     }
 135 
 136     private static ValueNode findNode(StructuredGraph graph) {
 137         ValueNode resultNode = graph.getNodes().filter(ReturnNode.class).first().result();
 138         assertNotNull("some node must be the returned value", resultNode);
 139         return resultNode;
 140     }
 141 
 142     protected StructuredGraph prepareGraph() {
 143         String snippet = "snippetInt" + bits;
 144         StructuredGraph graph = parseEager(getResolvedJavaMethod(operation.getClass(), snippet), AllowAssumptions.NO);
 145         HighTierContext context = getDefaultHighTierContext();
 146         new CanonicalizerPhase().apply(graph, context);
 147         return graph;
 148     }
 149 
 150     private static void addTest(ArrayList<Object[]> tests, long lowerBound1, long upperBound1, long lowerBound2, long upperBound2, int bits, Operation operation) {
 151         tests.add(new Object[]{lowerBound1, upperBound1, lowerBound2, upperBound2, bits, operation});
 152     }
 153 
 154     @Parameters(name = "a[{0} / {1}], b[{2} / {3}], bits={4}, operation={5}")
 155     public static Collection<Object[]> data() {
 156         ArrayList<Object[]> tests = new ArrayList<>();
 157 
 158         Operation[] operations = new Operation[]{new AddOperation(), new SubOperation(), new MulOperation()};
 159         for (Operation operation : operations) {
 160             for (int bits : new int[]{32, 64}) {
 161                 // zero related
 162                 addTest(tests, 0, 0, 1, 1, bits, operation);
 163                 addTest(tests, 1, 1, 0, 0, bits, operation);
 164                 addTest(tests, -1, 1, 0, 1, bits, operation);
 165 
 166                 // bounds
 167                 addTest(tests, -2, 2, 3, 3, bits, operation);
 168                 addTest(tests, -1, 1, 1, 1, bits, operation);
 169                 addTest(tests, -1, 1, -1, 1, bits, operation);
 170 
 171                 addTest(tests, Integer.MIN_VALUE, Integer.MIN_VALUE + 0xF, Integer.MAX_VALUE - 0xF, Integer.MAX_VALUE, bits, operation);
 172                 addTest(tests, Integer.MIN_VALUE, Integer.MIN_VALUE + 0xF, -1, -1, bits, operation);
 173                 addTest(tests, Integer.MAX_VALUE, Integer.MAX_VALUE, -1, -1, bits, operation);
 174                 addTest(tests, Integer.MIN_VALUE, Integer.MIN_VALUE, -1, -1, bits, operation);
 175                 addTest(tests, Integer.MAX_VALUE, Integer.MAX_VALUE, 1, 1, bits, operation);
 176                 addTest(tests, Integer.MIN_VALUE, Integer.MIN_VALUE, 1, 1, bits, operation);
 177             }
 178 
 179             // bit-specific test cases
 180             addTest(tests, Integer.MIN_VALUE, Integer.MIN_VALUE + 0xF, Integer.MAX_VALUE - 0xF, Integer.MAX_VALUE, 64, operation);
 181             addTest(tests, Integer.MIN_VALUE, Integer.MIN_VALUE + 0xF, -1, -1, 64, operation);
 182         }
 183 
 184         return tests;
 185     }
 186 
 187     private abstract static class Operation {
 188         abstract void verifyOverflow(long lowerBoundA, long upperBoundA, long lowerBoundB, long upperBoundB, int bits, boolean overflowExpected, IntegerStamp resultStamp);
 189     }
 190 
 191     private static final class AddOperation extends Operation {
 192         @Override
 193         public void verifyOverflow(long lowerBoundA, long upperBoundA, long lowerBoundB, long upperBoundB, int bits, boolean overflowExpected, IntegerStamp resultStamp) {
 194             try {
 195                 long res = addExact(lowerBoundA, lowerBoundB, bits);
 196                 resultStamp.contains(res);
 197                 res = addExact(upperBoundA, upperBoundB, bits);
 198                 resultStamp.contains(res);
 199                 Assert.assertFalse(overflowExpected);
 200             } catch (ArithmeticException e) {
 201                 Assert.assertTrue(overflowExpected);
 202             }
 203         }
 204 
 205         private static long addExact(long x, long y, int bits) {
 206             if (bits == 32) {
 207                 return Math.addExact((int) x, (int) y);
 208             } else {
 209                 return Math.addExact(x, y);
 210             }
 211         }
 212 
 213         @SuppressWarnings("unused")
 214         public static int snippetInt32(int a, int b) {
 215             return Math.addExact(a, b);
 216         }
 217 
 218         @SuppressWarnings("unused")
 219         public static long snippetInt64(long a, long b) {
 220             return Math.addExact(a, b);
 221         }
 222     }
 223 
 224     private static final class SubOperation extends Operation {
 225         @Override
 226         public void verifyOverflow(long lowerBoundA, long upperBoundA, long lowerBoundB, long upperBoundB, int bits, boolean overflowExpected, IntegerStamp resultStamp) {
 227             try {
 228                 long res = subExact(lowerBoundA, upperBoundB, bits);
 229                 Assert.assertTrue(resultStamp.contains(res));
 230                 res = subExact(upperBoundA, lowerBoundB, bits);
 231                 Assert.assertTrue(resultStamp.contains(res));
 232                 Assert.assertFalse(overflowExpected);
 233             } catch (ArithmeticException e) {
 234                 Assert.assertTrue(overflowExpected);
 235             }
 236         }
 237 
 238         private static long subExact(long x, long y, int bits) {
 239             if (bits == 32) {
 240                 return Math.subtractExact((int) x, (int) y);
 241             } else {
 242                 return Math.subtractExact(x, y);
 243             }
 244         }
 245 
 246         @SuppressWarnings("unused")
 247         public static int snippetInt32(int a, int b) {
 248             return Math.subtractExact(a, b);
 249         }
 250 
 251         @SuppressWarnings("unused")
 252         public static long snippetInt64(long a, long b) {
 253             return Math.subtractExact(a, b);
 254         }
 255     }
 256 
 257     private static final class MulOperation extends Operation {
 258         @Override
 259         public void verifyOverflow(long lowerBoundA, long upperBoundA, long lowerBoundB, long upperBoundB, int bits, boolean overflowExpected, IntegerStamp resultStamp) {
 260             // now check for all values in the stamp whether their products overflow overflow
 261             boolean overflowOccurred = false;
 262 
 263             for (long l1 = lowerBoundA; l1 <= upperBoundA; l1++) {
 264                 for (long l2 = lowerBoundB; l2 <= upperBoundB; l2++) {
 265                     try {
 266                         long res = mulExact(l1, l2, bits);
 267                         Assert.assertTrue(resultStamp.contains(res));
 268                     } catch (ArithmeticException e) {
 269                         overflowOccurred = true;
 270                     }
 271                     if (l2 == Long.MAX_VALUE) {
 272                         // do not want to overflow the check loop
 273                         break;
 274                     }
 275                 }
 276                 if (l1 == Long.MAX_VALUE) {
 277                     // do not want to overflow the check loop
 278                     break;
 279                 }
 280             }
 281 
 282             Assert.assertEquals(overflowExpected, overflowOccurred);
 283         }
 284 
 285         private static long mulExact(long x, long y, int bits) {
 286             if (bits == 32) {
 287                 return Math.multiplyExact((int) x, (int) y);
 288             } else {
 289                 return Math.multiplyExact(x, y);
 290             }
 291         }
 292 
 293         @SuppressWarnings("unused")
 294         public static int snippetInt32(int a, int b) {
 295             return Math.multiplyExact(a, b);
 296         }
 297 
 298         @SuppressWarnings("unused")
 299         public static long snippetInt64(long a, long b) {
 300             return Math.multiplyExact(a, b);
 301         }
 302     }
 303 }