1 /* 2 * Copyright (c) 2016, Oracle and/or its affiliates. All rights reserved. 3 * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. 4 * 5 * This code is free software; you can redistribute it and/or modify it 6 * under the terms of the GNU General Public License version 2 only, as 7 * published by the Free Software Foundation. 8 * 9 * This code is distributed in the hope that it will be useful, but WITHOUT 10 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or 11 * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License 12 * version 2 for more details (a copy is included in the LICENSE file that 13 * accompanied this code). 14 * 15 * You should have received a copy of the GNU General Public License version 16 * 2 along with this work; if not, write to the Free Software Foundation, 17 * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. 18 * 19 * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA 20 * or visit www.oracle.com if you need additional information or have any 21 * questions. 22 */ 23 package org.graalvm.compiler.replacements.test; 24 25 import java.util.ArrayList; 26 import java.util.Collection; 27 import java.util.List; 28 29 import org.junit.Assert; 30 import org.junit.Test; 31 import org.junit.runner.RunWith; 32 import org.junit.runners.Parameterized; 33 import org.junit.runners.Parameterized.Parameter; 34 import org.junit.runners.Parameterized.Parameters; 35 36 import org.graalvm.compiler.core.common.type.IntegerStamp; 37 import org.graalvm.compiler.core.common.type.StampFactory; 38 import org.graalvm.compiler.core.test.GraalCompilerTest; 39 import org.graalvm.compiler.nodes.ConstantNode; 40 import org.graalvm.compiler.nodes.ParameterNode; 41 import org.graalvm.compiler.nodes.PiNode; 42 import org.graalvm.compiler.nodes.StructuredGraph; 43 import org.graalvm.compiler.nodes.ValueNode; 44 import org.graalvm.compiler.nodes.StructuredGraph.AllowAssumptions; 45 import org.graalvm.compiler.nodes.calc.MulNode; 46 import org.graalvm.compiler.nodes.java.StoreFieldNode; 47 import org.graalvm.compiler.phases.common.CanonicalizerPhase; 48 import org.graalvm.compiler.phases.tiers.HighTierContext; 49 import org.graalvm.compiler.replacements.nodes.arithmetic.IntegerMulExactNode; 50 51 @RunWith(Parameterized.class) 52 public class IntegerMulExactFoldTest extends GraalCompilerTest { 53 54 public static int SideEffectI; 55 public static long SideEffectL; 56 57 public static void snippetInt(int a, int b) { 58 SideEffectI = Math.multiplyExact(a, b); 59 } 60 61 public static void snippetLong(long a, long b) { 62 SideEffectL = Math.multiplyExact(a, b); 63 } 64 65 private StructuredGraph prepareGraph(String snippet) { 66 StructuredGraph graph = parseEager(snippet, AllowAssumptions.NO); 67 HighTierContext context = getDefaultHighTierContext(); 68 new CanonicalizerPhase().apply(graph, context); 69 return graph; 70 } 71 72 @Parameter(0) public long lowerBound1; 73 @Parameter(1) public long upperBound1; 74 @Parameter(2) public long lowerBound2; 75 @Parameter(3) public long upperBound2; 76 @Parameter(4) public int bits; 77 78 @Test 79 public void tryFold() { 80 assert bits == 32 || bits == 64; 81 82 IntegerStamp a = StampFactory.forInteger(bits, lowerBound1, upperBound1); 83 IntegerStamp b = StampFactory.forInteger(bits, lowerBound2, upperBound2); 84 85 // prepare the graph once for the given stamps, if the canonicalize method thinks it does 86 // not overflow it will replace the exact mul with a normal mul 87 StructuredGraph g = prepareGraph(bits == 32 ? "snippetInt" : "snippetLong"); 88 List<ParameterNode> params = g.getNodes(ParameterNode.TYPE).snapshot(); 89 params.get(0).replaceAtMatchingUsages((g.addOrUnique(new PiNode(params.get(0), a))), x -> x instanceof IntegerMulExactNode); 90 params.get(1).replaceAtMatchingUsages((g.addOrUnique(new PiNode(params.get(1), b))), x -> x instanceof IntegerMulExactNode); 91 new CanonicalizerPhase().apply(g, getDefaultHighTierContext()); 92 boolean optimized = g.getNodes().filter(IntegerMulExactNode.class).count() == 0; 93 ValueNode leftOverMull = optimized ? g.getNodes().filter(MulNode.class).first() : g.getNodes().filter(IntegerMulExactNode.class).first(); 94 new CanonicalizerPhase().apply(g, getDefaultHighTierContext()); 95 if (leftOverMull == null) { 96 // result may be constant if there is no mul exact or mul node left 97 leftOverMull = g.getNodes().filter(StoreFieldNode.class).first().inputs().filter(ConstantNode.class).first(); 98 } 99 if (leftOverMull == null) { 100 // even mul got canonicalized so we may end up with one of the original nodes 101 leftOverMull = g.getNodes().filter(PiNode.class).first(); 102 } 103 IntegerStamp resultStamp = (IntegerStamp) leftOverMull.stamp(); 104 105 // now check for all values in the stamp whether their products overflow overflow 106 for (long l1 = lowerBound1; l1 <= upperBound1; l1++) { 107 for (long l2 = lowerBound2; l2 <= upperBound2; l2++) { 108 try { 109 long res = mulExact(l1, l2, bits); 110 Assert.assertTrue(resultStamp.contains(res)); 111 } catch (ArithmeticException e) { 112 Assert.assertFalse(optimized); 113 } 114 if (l2 == Long.MAX_VALUE) { 115 // do not want to overflow the check loop 116 break; 117 } 118 } 119 if (l1 == Long.MAX_VALUE) { 120 // do not want to overflow the check loop 121 break; 122 } 123 } 124 125 } 126 127 private static long mulExact(long x, long y, int bits) { 128 long r = x * y; 129 if (bits == 8) { 130 if ((byte) r != r) { 131 throw new ArithmeticException("overflow"); 132 } 133 } else if (bits == 16) { 134 if ((short) r != r) { 135 throw new ArithmeticException("overflow"); 136 } 137 } else if (bits == 32) { 138 return Math.multiplyExact((int) x, (int) y); 139 } else { 140 return Math.multiplyExact(x, y); 141 } 142 return r; 143 } 144 145 @Parameters(name = "a[{0} - {1}] b[{2} - {3}] bits=32") 146 public static Collection<Object[]> data() { 147 ArrayList<Object[]> tests = new ArrayList<>(); 148 149 // zero related 150 addTest(tests, -2, 2, 3, 3, 32); 151 addTest(tests, 0, 0, 1, 1, 32); 152 addTest(tests, 1, 1, 0, 0, 32); 153 addTest(tests, -1, 1, 0, 1, 32); 154 addTest(tests, -1, 1, 1, 1, 32); 155 addTest(tests, -1, 1, -1, 1, 32); 156 157 addTest(tests, -2, 2, 3, 3, 64); 158 addTest(tests, 0, 0, 1, 1, 64); 159 addTest(tests, 1, 1, 0, 0, 64); 160 addTest(tests, -1, 1, 0, 1, 64); 161 addTest(tests, -1, 1, 1, 1, 64); 162 addTest(tests, -1, 1, -1, 1, 64); 163 164 addTest(tests, -2, 2, 3, 3, 32); 165 addTest(tests, 0, 0, 1, 1, 32); 166 addTest(tests, 1, 1, 0, 0, 32); 167 addTest(tests, -1, 1, 0, 1, 32); 168 addTest(tests, -1, 1, 1, 1, 32); 169 addTest(tests, -1, 1, -1, 1, 32); 170 171 addTest(tests, 0, 0, 1, 1, 64); 172 addTest(tests, 1, 1, 0, 0, 64); 173 addTest(tests, -1, 1, 0, 1, 64); 174 addTest(tests, -1, 1, 1, 1, 64); 175 addTest(tests, -1, 1, -1, 1, 64); 176 177 // bounds 178 addTest(tests, Integer.MIN_VALUE, Integer.MIN_VALUE + 0xFF, Integer.MAX_VALUE - 0xFF, 179 Integer.MAX_VALUE, 32); 180 addTest(tests, Integer.MIN_VALUE, Integer.MIN_VALUE + 0xFFF, -1, -1, 32); 181 addTest(tests, Integer.MIN_VALUE, Integer.MIN_VALUE + 0xFF, Integer.MAX_VALUE - 0xFF, 182 Integer.MAX_VALUE, 64); 183 addTest(tests, Integer.MIN_VALUE, Integer.MIN_VALUE + 0xFFF, -1, -1, 64); 184 addTest(tests, Long.MIN_VALUE, Long.MIN_VALUE + 0xFFF, -1, -1, 64); 185 186 // constants 187 addTest(tests, 2, 2, 2, 2, 32); 188 addTest(tests, 1, 1, 2, 2, 32); 189 addTest(tests, 2, 2, 4, 4, 32); 190 addTest(tests, 3, 3, 3, 3, 32); 191 addTest(tests, -4, -4, 3, 3, 32); 192 addTest(tests, -4, -4, -3, -3, 32); 193 addTest(tests, 4, 4, -3, -3, 32); 194 195 addTest(tests, 2, 2, 2, 2, 64); 196 addTest(tests, 1, 1, 2, 2, 64); 197 addTest(tests, 3, 3, 3, 3, 64); 198 199 addTest(tests, Long.MAX_VALUE, Long.MAX_VALUE, 1, 1, 64); 200 addTest(tests, Long.MAX_VALUE, Long.MAX_VALUE, -1, -1, 64); 201 addTest(tests, Long.MIN_VALUE, Long.MIN_VALUE, -1, -1, 64); 202 addTest(tests, Long.MIN_VALUE, Long.MIN_VALUE, 1, 1, 64); 203 204 return tests; 205 } 206 207 private static void addTest(ArrayList<Object[]> tests, long lowerBound1, long upperBound1, long lowerBound2, long upperBound2, int bits) { 208 tests.add(new Object[]{lowerBound1, upperBound1, lowerBound2, upperBound2, bits}); 209 } 210 211 }