1 /* 2 * Copyright (c) 2017, 2018, Oracle and/or its affiliates. All rights reserved. 3 * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. 4 * 5 * This code is free software; you can redistribute it and/or modify it 6 * under the terms of the GNU General Public License version 2 only, as 7 * published by the Free Software Foundation. 8 * 9 * This code is distributed in the hope that it will be useful, but WITHOUT 10 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or 11 * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License 12 * version 2 for more details (a copy is included in the LICENSE file that 13 * accompanied this code). 14 * 15 * You should have received a copy of the GNU General Public License version 16 * 2 along with this work; if not, write to the Free Software Foundation, 17 * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. 18 * 19 * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA 20 * or visit www.oracle.com if you need additional information or have any 21 * questions. 22 */ 23 24 25 package org.graalvm.compiler.replacements.test; 26 27 import static org.junit.Assert.assertNotNull; 28 29 import java.util.ArrayList; 30 import java.util.Collection; 31 import java.util.List; 32 33 import org.graalvm.compiler.core.common.type.IntegerStamp; 34 import org.graalvm.compiler.core.common.type.StampFactory; 35 import org.graalvm.compiler.core.test.GraalCompilerTest; 36 import org.graalvm.compiler.graph.Node; 37 import org.graalvm.compiler.nodes.NodeView; 38 import org.graalvm.compiler.nodes.ParameterNode; 39 import org.graalvm.compiler.nodes.PiNode; 40 import org.graalvm.compiler.nodes.ReturnNode; 41 import org.graalvm.compiler.nodes.StructuredGraph; 42 import org.graalvm.compiler.nodes.StructuredGraph.AllowAssumptions; 43 import org.graalvm.compiler.nodes.ValueNode; 44 import org.graalvm.compiler.nodes.spi.LoweringTool; 45 import org.graalvm.compiler.phases.common.CanonicalizerPhase; 46 import org.graalvm.compiler.phases.common.GuardLoweringPhase; 47 import org.graalvm.compiler.phases.common.LoweringPhase; 48 import org.graalvm.compiler.phases.tiers.HighTierContext; 49 import org.graalvm.compiler.phases.tiers.MidTierContext; 50 import org.graalvm.compiler.replacements.nodes.arithmetic.IntegerExactArithmeticNode; 51 import org.graalvm.compiler.replacements.nodes.arithmetic.IntegerExactArithmeticSplitNode; 52 import org.junit.Assert; 53 import org.junit.Test; 54 import org.junit.runner.RunWith; 55 import org.junit.runners.Parameterized; 56 import org.junit.runners.Parameterized.Parameters; 57 58 @RunWith(Parameterized.class) 59 public class IntegerExactFoldTest extends GraalCompilerTest { 60 private final long lowerBoundA; 61 private final long upperBoundA; 62 private final long lowerBoundB; 63 private final long upperBoundB; 64 private final int bits; 65 private final Operation operation; 66 67 public IntegerExactFoldTest(long lowerBoundA, long upperBoundA, long lowerBoundB, long upperBoundB, int bits, Operation operation) { 68 this.lowerBoundA = lowerBoundA; 69 this.upperBoundA = upperBoundA; 70 this.lowerBoundB = lowerBoundB; 71 this.upperBoundB = upperBoundB; 72 this.bits = bits; 73 this.operation = operation; 74 75 assert bits == 32 || bits == 64; 76 assert lowerBoundA <= upperBoundA; 77 assert lowerBoundB <= upperBoundB; 78 assert bits == 64 || isInteger(lowerBoundA); 79 assert bits == 64 || isInteger(upperBoundA); 80 assert bits == 64 || isInteger(lowerBoundB); 81 assert bits == 64 || isInteger(upperBoundB); 82 } 83 84 @Test 85 public void testFolding() { 86 StructuredGraph graph = prepareGraph(); 87 IntegerStamp a = StampFactory.forInteger(bits, lowerBoundA, upperBoundA); 88 IntegerStamp b = StampFactory.forInteger(bits, lowerBoundB, upperBoundB); 89 90 List<ParameterNode> params = graph.getNodes(ParameterNode.TYPE).snapshot(); 91 params.get(0).replaceAtMatchingUsages(graph.addOrUnique(new PiNode(params.get(0), a)), x -> x instanceof IntegerExactArithmeticNode); 92 params.get(1).replaceAtMatchingUsages(graph.addOrUnique(new PiNode(params.get(1), b)), x -> x instanceof IntegerExactArithmeticNode); 93 94 Node originalNode = graph.getNodes().filter(x -> x instanceof IntegerExactArithmeticNode).first(); 95 assertNotNull("original node must be in the graph", originalNode); 96 97 new CanonicalizerPhase().apply(graph, getDefaultHighTierContext()); 98 99 ValueNode node = findNode(graph); 100 boolean overflowExpected = node instanceof IntegerExactArithmeticNode; 101 102 IntegerStamp resultStamp = (IntegerStamp) node.stamp(NodeView.DEFAULT); 103 operation.verifyOverflow(lowerBoundA, upperBoundA, lowerBoundB, upperBoundB, bits, overflowExpected, resultStamp); 104 } 105 106 @Test 107 public void testFoldingAfterLowering() { 108 StructuredGraph graph = prepareGraph(); 109 110 Node originalNode = graph.getNodes().filter(x -> x instanceof IntegerExactArithmeticNode).first(); 111 assertNotNull("original node must be in the graph", originalNode); 112 CanonicalizerPhase canonicalizer = new CanonicalizerPhase(); 113 HighTierContext highTierContext = getDefaultHighTierContext(); 114 new LoweringPhase(canonicalizer, LoweringTool.StandardLoweringStage.HIGH_TIER).apply(graph, highTierContext); 115 MidTierContext midTierContext = getDefaultMidTierContext(); 116 new GuardLoweringPhase().apply(graph, midTierContext); 117 new CanonicalizerPhase().apply(graph, midTierContext); 118 119 IntegerExactArithmeticSplitNode loweredNode = graph.getNodes().filter(IntegerExactArithmeticSplitNode.class).first(); 120 assertNotNull("the lowered node must be in the graph", loweredNode); 121 122 loweredNode.getX().setStamp(StampFactory.forInteger(bits, lowerBoundA, upperBoundA)); 123 loweredNode.getY().setStamp(StampFactory.forInteger(bits, lowerBoundB, upperBoundB)); 124 new CanonicalizerPhase().apply(graph, midTierContext); 125 126 ValueNode node = findNode(graph); 127 boolean overflowExpected = node instanceof IntegerExactArithmeticSplitNode; 128 129 IntegerStamp resultStamp = (IntegerStamp) node.stamp(NodeView.DEFAULT); 130 operation.verifyOverflow(lowerBoundA, upperBoundA, lowerBoundB, upperBoundB, bits, overflowExpected, resultStamp); 131 } 132 133 private static boolean isInteger(long value) { 134 return value >= Integer.MIN_VALUE && value <= Integer.MAX_VALUE; 135 } 136 137 private static ValueNode findNode(StructuredGraph graph) { 138 ValueNode resultNode = graph.getNodes().filter(ReturnNode.class).first().result(); 139 assertNotNull("some node must be the returned value", resultNode); 140 return resultNode; 141 } 142 143 protected StructuredGraph prepareGraph() { 144 String snippet = "snippetInt" + bits; 145 StructuredGraph graph = parseEager(getResolvedJavaMethod(operation.getClass(), snippet), AllowAssumptions.NO); 146 HighTierContext context = getDefaultHighTierContext(); 147 new CanonicalizerPhase().apply(graph, context); 148 return graph; 149 } 150 151 private static void addTest(ArrayList<Object[]> tests, long lowerBound1, long upperBound1, long lowerBound2, long upperBound2, int bits, Operation operation) { 152 tests.add(new Object[]{lowerBound1, upperBound1, lowerBound2, upperBound2, bits, operation}); 153 } 154 155 @Parameters(name = "a[{0} / {1}], b[{2} / {3}], bits={4}, operation={5}") 156 public static Collection<Object[]> data() { 157 ArrayList<Object[]> tests = new ArrayList<>(); 158 159 Operation[] operations = new Operation[]{new AddOperation(), new SubOperation(), new MulOperation()}; 160 for (Operation operation : operations) { 161 for (int bits : new int[]{32, 64}) { 162 // zero related 163 addTest(tests, 0, 0, 1, 1, bits, operation); 164 addTest(tests, 1, 1, 0, 0, bits, operation); 165 addTest(tests, -1, 1, 0, 1, bits, operation); 166 167 // bounds 168 addTest(tests, -2, 2, 3, 3, bits, operation); 169 addTest(tests, -1, 1, 1, 1, bits, operation); 170 addTest(tests, -1, 1, -1, 1, bits, operation); 171 172 addTest(tests, Integer.MIN_VALUE, Integer.MIN_VALUE + 0xF, Integer.MAX_VALUE - 0xF, Integer.MAX_VALUE, bits, operation); 173 addTest(tests, Integer.MIN_VALUE, Integer.MIN_VALUE + 0xF, -1, -1, bits, operation); 174 addTest(tests, Integer.MAX_VALUE, Integer.MAX_VALUE, -1, -1, bits, operation); 175 addTest(tests, Integer.MIN_VALUE, Integer.MIN_VALUE, -1, -1, bits, operation); 176 addTest(tests, Integer.MAX_VALUE, Integer.MAX_VALUE, 1, 1, bits, operation); 177 addTest(tests, Integer.MIN_VALUE, Integer.MIN_VALUE, 1, 1, bits, operation); 178 } 179 180 // bit-specific test cases 181 addTest(tests, Integer.MIN_VALUE, Integer.MIN_VALUE + 0xF, Integer.MAX_VALUE - 0xF, Integer.MAX_VALUE, 64, operation); 182 addTest(tests, Integer.MIN_VALUE, Integer.MIN_VALUE + 0xF, -1, -1, 64, operation); 183 } 184 185 return tests; 186 } 187 188 private abstract static class Operation { 189 abstract void verifyOverflow(long lowerBoundA, long upperBoundA, long lowerBoundB, long upperBoundB, int bits, boolean overflowExpected, IntegerStamp resultStamp); 190 } 191 192 private static final class AddOperation extends Operation { 193 @Override 194 public void verifyOverflow(long lowerBoundA, long upperBoundA, long lowerBoundB, long upperBoundB, int bits, boolean overflowExpected, IntegerStamp resultStamp) { 195 try { 196 long res = addExact(lowerBoundA, lowerBoundB, bits); 197 resultStamp.contains(res); 198 res = addExact(upperBoundA, upperBoundB, bits); 199 resultStamp.contains(res); 200 Assert.assertFalse(overflowExpected); 201 } catch (ArithmeticException e) { 202 Assert.assertTrue(overflowExpected); 203 } 204 } 205 206 private static long addExact(long x, long y, int bits) { 207 if (bits == 32) { 208 return Math.addExact((int) x, (int) y); 209 } else { 210 return Math.addExact(x, y); 211 } 212 } 213 214 @SuppressWarnings("unused") 215 public static int snippetInt32(int a, int b) { 216 return Math.addExact(a, b); 217 } 218 219 @SuppressWarnings("unused") 220 public static long snippetInt64(long a, long b) { 221 return Math.addExact(a, b); 222 } 223 } 224 225 private static final class SubOperation extends Operation { 226 @Override 227 public void verifyOverflow(long lowerBoundA, long upperBoundA, long lowerBoundB, long upperBoundB, int bits, boolean overflowExpected, IntegerStamp resultStamp) { 228 try { 229 long res = subExact(lowerBoundA, upperBoundB, bits); 230 Assert.assertTrue(resultStamp.contains(res)); 231 res = subExact(upperBoundA, lowerBoundB, bits); 232 Assert.assertTrue(resultStamp.contains(res)); 233 Assert.assertFalse(overflowExpected); 234 } catch (ArithmeticException e) { 235 Assert.assertTrue(overflowExpected); 236 } 237 } 238 239 private static long subExact(long x, long y, int bits) { 240 if (bits == 32) { 241 return Math.subtractExact((int) x, (int) y); 242 } else { 243 return Math.subtractExact(x, y); 244 } 245 } 246 247 @SuppressWarnings("unused") 248 public static int snippetInt32(int a, int b) { 249 return Math.subtractExact(a, b); 250 } 251 252 @SuppressWarnings("unused") 253 public static long snippetInt64(long a, long b) { 254 return Math.subtractExact(a, b); 255 } 256 } 257 258 private static final class MulOperation extends Operation { 259 @Override 260 public void verifyOverflow(long lowerBoundA, long upperBoundA, long lowerBoundB, long upperBoundB, int bits, boolean overflowExpected, IntegerStamp resultStamp) { 261 // now check for all values in the stamp whether their products overflow overflow 262 boolean overflowOccurred = false; 263 264 for (long l1 = lowerBoundA; l1 <= upperBoundA; l1++) { 265 for (long l2 = lowerBoundB; l2 <= upperBoundB; l2++) { 266 try { 267 long res = mulExact(l1, l2, bits); 268 Assert.assertTrue(resultStamp.contains(res)); 269 } catch (ArithmeticException e) { 270 overflowOccurred = true; 271 } 272 if (l2 == Long.MAX_VALUE) { 273 // do not want to overflow the check loop 274 break; 275 } 276 } 277 if (l1 == Long.MAX_VALUE) { 278 // do not want to overflow the check loop 279 break; 280 } 281 } 282 283 Assert.assertEquals(overflowExpected, overflowOccurred); 284 } 285 286 private static long mulExact(long x, long y, int bits) { 287 if (bits == 32) { 288 return Math.multiplyExact((int) x, (int) y); 289 } else { 290 return Math.multiplyExact(x, y); 291 } 292 } 293 294 @SuppressWarnings("unused") 295 public static int snippetInt32(int a, int b) { 296 return Math.multiplyExact(a, b); 297 } 298 299 @SuppressWarnings("unused") 300 public static long snippetInt64(long a, long b) { 301 return Math.multiplyExact(a, b); 302 } 303 } 304 }