1 /*
   2  * Copyright (c) 2016, Oracle and/or its affiliates. All rights reserved.
   3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
   4  *
   5  * This code is free software; you can redistribute it and/or modify it
   6  * under the terms of the GNU General Public License version 2 only, as
   7  * published by the Free Software Foundation.
   8  *
   9  * This code is distributed in the hope that it will be useful, but WITHOUT
  10  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
  11  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
  12  * version 2 for more details (a copy is included in the LICENSE file that
  13  * accompanied this code).
  14  *
  15  * You should have received a copy of the GNU General Public License version
  16  * 2 along with this work; if not, write to the Free Software Foundation,
  17  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
  18  *
  19  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
  20  * or visit www.oracle.com if you need additional information or have any
  21  * questions.
  22  */
  23 package org.graalvm.compiler.replacements.test;
  24 
  25 import java.util.ArrayList;
  26 import java.util.Collection;
  27 import java.util.List;
  28 
  29 import org.junit.Assert;
  30 import org.junit.Test;
  31 import org.junit.runner.RunWith;
  32 import org.junit.runners.Parameterized;
  33 import org.junit.runners.Parameterized.Parameter;
  34 import org.junit.runners.Parameterized.Parameters;
  35 
  36 import org.graalvm.compiler.core.common.type.IntegerStamp;
  37 import org.graalvm.compiler.core.common.type.StampFactory;
  38 import org.graalvm.compiler.core.test.GraalCompilerTest;
  39 import org.graalvm.compiler.nodes.ConstantNode;
  40 import org.graalvm.compiler.nodes.ParameterNode;
  41 import org.graalvm.compiler.nodes.PiNode;
  42 import org.graalvm.compiler.nodes.StructuredGraph;
  43 import org.graalvm.compiler.nodes.ValueNode;
  44 import org.graalvm.compiler.nodes.StructuredGraph.AllowAssumptions;
  45 import org.graalvm.compiler.nodes.calc.MulNode;
  46 import org.graalvm.compiler.nodes.java.StoreFieldNode;
  47 import org.graalvm.compiler.phases.common.CanonicalizerPhase;
  48 import org.graalvm.compiler.phases.tiers.HighTierContext;
  49 import org.graalvm.compiler.replacements.nodes.arithmetic.IntegerMulExactNode;
  50 
  51 @RunWith(Parameterized.class)
  52 public class IntegerMulExactFoldTest extends GraalCompilerTest {
  53 
  54     public static int SideEffectI;
  55     public static long SideEffectL;
  56 
  57     public static void snippetInt(int a, int b) {
  58         SideEffectI = Math.multiplyExact(a, b);
  59     }
  60 
  61     public static void snippetLong(long a, long b) {
  62         SideEffectL = Math.multiplyExact(a, b);
  63     }
  64 
  65     private StructuredGraph prepareGraph(String snippet) {
  66         StructuredGraph graph = parseEager(snippet, AllowAssumptions.NO);
  67         HighTierContext context = getDefaultHighTierContext();
  68         new CanonicalizerPhase().apply(graph, context);
  69         return graph;
  70     }
  71 
  72     @Parameter(0) public long lowerBound1;
  73     @Parameter(1) public long upperBound1;
  74     @Parameter(2) public long lowerBound2;
  75     @Parameter(3) public long upperBound2;
  76     @Parameter(4) public int bits;
  77 
  78     @Test
  79     public void tryFold() {
  80         assert bits == 32 || bits == 64;
  81 
  82         IntegerStamp a = StampFactory.forInteger(bits, lowerBound1, upperBound1);
  83         IntegerStamp b = StampFactory.forInteger(bits, lowerBound2, upperBound2);
  84 
  85         // prepare the graph once for the given stamps, if the canonicalize method thinks it does
  86         // not overflow it will replace the exact mul with a normal mul
  87         StructuredGraph g = prepareGraph(bits == 32 ? "snippetInt" : "snippetLong");
  88         List<ParameterNode> params = g.getNodes(ParameterNode.TYPE).snapshot();
  89         params.get(0).replaceAtMatchingUsages((g.addOrUnique(new PiNode(params.get(0), a))), x -> x instanceof IntegerMulExactNode);
  90         params.get(1).replaceAtMatchingUsages((g.addOrUnique(new PiNode(params.get(1), b))), x -> x instanceof IntegerMulExactNode);
  91         new CanonicalizerPhase().apply(g, getDefaultHighTierContext());
  92         boolean optimized = g.getNodes().filter(IntegerMulExactNode.class).count() == 0;
  93         ValueNode leftOverMull = optimized ? g.getNodes().filter(MulNode.class).first() : g.getNodes().filter(IntegerMulExactNode.class).first();
  94         new CanonicalizerPhase().apply(g, getDefaultHighTierContext());
  95         if (leftOverMull == null) {
  96             // result may be constant if there is no mul exact or mul node left
  97             leftOverMull = g.getNodes().filter(StoreFieldNode.class).first().inputs().filter(ConstantNode.class).first();
  98         }
  99         if (leftOverMull == null) {
 100             // even mul got canonicalized so we may end up with one of the original nodes
 101             leftOverMull = g.getNodes().filter(PiNode.class).first();
 102         }
 103         IntegerStamp resultStamp = (IntegerStamp) leftOverMull.stamp();
 104 
 105         // now check for all values in the stamp whether their products overflow overflow
 106         for (long l1 = lowerBound1; l1 <= upperBound1; l1++) {
 107             for (long l2 = lowerBound2; l2 <= upperBound2; l2++) {
 108                 try {
 109                     long res = mulExact(l1, l2, bits);
 110                     Assert.assertTrue(resultStamp.contains(res));
 111                 } catch (ArithmeticException e) {
 112                     Assert.assertFalse(optimized);
 113                 }
 114                 if (l2 == Long.MAX_VALUE) {
 115                     // do not want to overflow the check loop
 116                     break;
 117                 }
 118             }
 119             if (l1 == Long.MAX_VALUE) {
 120                 // do not want to overflow the check loop
 121                 break;
 122             }
 123         }
 124 
 125     }
 126 
 127     private static long mulExact(long x, long y, int bits) {
 128         long r = x * y;
 129         if (bits == 8) {
 130             if ((byte) r != r) {
 131                 throw new ArithmeticException("overflow");
 132             }
 133         } else if (bits == 16) {
 134             if ((short) r != r) {
 135                 throw new ArithmeticException("overflow");
 136             }
 137         } else if (bits == 32) {
 138             return Math.multiplyExact((int) x, (int) y);
 139         } else {
 140             return Math.multiplyExact(x, y);
 141         }
 142         return r;
 143     }
 144 
 145     @Parameters(name = "a[{0} - {1}] b[{2} - {3}] bits=32")
 146     public static Collection<Object[]> data() {
 147         ArrayList<Object[]> tests = new ArrayList<>();
 148 
 149         // zero related
 150         addTest(tests, -2, 2, 3, 3, 32);
 151         addTest(tests, 0, 0, 1, 1, 32);
 152         addTest(tests, 1, 1, 0, 0, 32);
 153         addTest(tests, -1, 1, 0, 1, 32);
 154         addTest(tests, -1, 1, 1, 1, 32);
 155         addTest(tests, -1, 1, -1, 1, 32);
 156 
 157         addTest(tests, -2, 2, 3, 3, 64);
 158         addTest(tests, 0, 0, 1, 1, 64);
 159         addTest(tests, 1, 1, 0, 0, 64);
 160         addTest(tests, -1, 1, 0, 1, 64);
 161         addTest(tests, -1, 1, 1, 1, 64);
 162         addTest(tests, -1, 1, -1, 1, 64);
 163 
 164         addTest(tests, -2, 2, 3, 3, 32);
 165         addTest(tests, 0, 0, 1, 1, 32);
 166         addTest(tests, 1, 1, 0, 0, 32);
 167         addTest(tests, -1, 1, 0, 1, 32);
 168         addTest(tests, -1, 1, 1, 1, 32);
 169         addTest(tests, -1, 1, -1, 1, 32);
 170 
 171         addTest(tests, 0, 0, 1, 1, 64);
 172         addTest(tests, 1, 1, 0, 0, 64);
 173         addTest(tests, -1, 1, 0, 1, 64);
 174         addTest(tests, -1, 1, 1, 1, 64);
 175         addTest(tests, -1, 1, -1, 1, 64);
 176 
 177         // bounds
 178         addTest(tests, Integer.MIN_VALUE, Integer.MIN_VALUE + 0xFF, Integer.MAX_VALUE - 0xFF,
 179                         Integer.MAX_VALUE, 32);
 180         addTest(tests, Integer.MIN_VALUE, Integer.MIN_VALUE + 0xFFF, -1, -1, 32);
 181         addTest(tests, Integer.MIN_VALUE, Integer.MIN_VALUE + 0xFF, Integer.MAX_VALUE - 0xFF,
 182                         Integer.MAX_VALUE, 64);
 183         addTest(tests, Integer.MIN_VALUE, Integer.MIN_VALUE + 0xFFF, -1, -1, 64);
 184         addTest(tests, Long.MIN_VALUE, Long.MIN_VALUE + 0xFFF, -1, -1, 64);
 185 
 186         // constants
 187         addTest(tests, 2, 2, 2, 2, 32);
 188         addTest(tests, 1, 1, 2, 2, 32);
 189         addTest(tests, 2, 2, 4, 4, 32);
 190         addTest(tests, 3, 3, 3, 3, 32);
 191         addTest(tests, -4, -4, 3, 3, 32);
 192         addTest(tests, -4, -4, -3, -3, 32);
 193         addTest(tests, 4, 4, -3, -3, 32);
 194 
 195         addTest(tests, 2, 2, 2, 2, 64);
 196         addTest(tests, 1, 1, 2, 2, 64);
 197         addTest(tests, 3, 3, 3, 3, 64);
 198 
 199         addTest(tests, Long.MAX_VALUE, Long.MAX_VALUE, 1, 1, 64);
 200         addTest(tests, Long.MAX_VALUE, Long.MAX_VALUE, -1, -1, 64);
 201         addTest(tests, Long.MIN_VALUE, Long.MIN_VALUE, -1, -1, 64);
 202         addTest(tests, Long.MIN_VALUE, Long.MIN_VALUE, 1, 1, 64);
 203 
 204         return tests;
 205     }
 206 
 207     private static void addTest(ArrayList<Object[]> tests, long lowerBound1, long upperBound1, long lowerBound2, long upperBound2, int bits) {
 208         tests.add(new Object[]{lowerBound1, upperBound1, lowerBound2, upperBound2, bits});
 209     }
 210 
 211 }