1 /* 2 * Copyright (c) 2017, 2018, Oracle and/or its affiliates. All rights reserved. 3 * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. 4 * 5 * This code is free software; you can redistribute it and/or modify it 6 * under the terms of the GNU General Public License version 2 only, as 7 * published by the Free Software Foundation. 8 * 9 * This code is distributed in the hope that it will be useful, but WITHOUT 10 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or 11 * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License 12 * version 2 for more details (a copy is included in the LICENSE file that 13 * accompanied this code). 14 * 15 * You should have received a copy of the GNU General Public License version 16 * 2 along with this work; if not, write to the Free Software Foundation, 17 * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. 18 * 19 * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA 20 * or visit www.oracle.com if you need additional information or have any 21 * questions. 22 */ 23 24 25 package org.graalvm.compiler.replacements.test; 26 27 import static org.junit.Assert.assertNotNull; 28 29 import java.util.ArrayList; 30 import java.util.Collection; 31 import java.util.List; 32 33 import org.graalvm.compiler.core.common.type.IntegerStamp; 34 import org.graalvm.compiler.core.common.type.StampFactory; 35 import org.graalvm.compiler.core.test.GraalCompilerTest; 36 import org.graalvm.compiler.graph.Node; 37 import org.graalvm.compiler.nodes.NodeView; 38 import org.graalvm.compiler.nodes.ParameterNode; 39 import org.graalvm.compiler.nodes.PiNode; 40 import org.graalvm.compiler.nodes.ReturnNode; 41 import org.graalvm.compiler.nodes.StructuredGraph; 42 import org.graalvm.compiler.nodes.StructuredGraph.AllowAssumptions; 43 import org.graalvm.compiler.nodes.StructuredGraph.GuardsStage; 44 import org.graalvm.compiler.nodes.ValueNode; 45 import org.graalvm.compiler.nodes.spi.LoweringTool; 46 import org.graalvm.compiler.phases.common.CanonicalizerPhase; 47 import org.graalvm.compiler.phases.common.LoweringPhase; 48 import org.graalvm.compiler.phases.tiers.HighTierContext; 49 import org.graalvm.compiler.phases.tiers.PhaseContext; 50 import org.graalvm.compiler.replacements.nodes.arithmetic.IntegerExactArithmeticNode; 51 import org.graalvm.compiler.replacements.nodes.arithmetic.IntegerExactArithmeticSplitNode; 52 import org.junit.Assert; 53 import org.junit.Ignore; 54 import org.junit.Test; 55 import org.junit.runner.RunWith; 56 import org.junit.runners.Parameterized; 57 import org.junit.runners.Parameterized.Parameters; 58 59 @Ignore 60 @RunWith(Parameterized.class) 61 public class IntegerExactFoldTest extends GraalCompilerTest { 62 private final long lowerBoundA; 63 private final long upperBoundA; 64 private final long lowerBoundB; 65 private final long upperBoundB; 66 private final int bits; 67 private final Operation operation; 68 69 public IntegerExactFoldTest(long lowerBoundA, long upperBoundA, long lowerBoundB, long upperBoundB, int bits, Operation operation) { 70 this.lowerBoundA = lowerBoundA; 71 this.upperBoundA = upperBoundA; 72 this.lowerBoundB = lowerBoundB; 73 this.upperBoundB = upperBoundB; 74 this.bits = bits; 75 this.operation = operation; 76 77 assert bits == 32 || bits == 64; 78 assert lowerBoundA <= upperBoundA; 79 assert lowerBoundB <= upperBoundB; 80 assert bits == 64 || isInteger(lowerBoundA); 81 assert bits == 64 || isInteger(upperBoundA); 82 assert bits == 64 || isInteger(lowerBoundB); 83 assert bits == 64 || isInteger(upperBoundB); 84 } 85 86 @Test 87 public void testFolding() { 88 StructuredGraph graph = prepareGraph(); 89 IntegerStamp a = StampFactory.forInteger(bits, lowerBoundA, upperBoundA); 90 IntegerStamp b = StampFactory.forInteger(bits, lowerBoundB, upperBoundB); 91 92 List<ParameterNode> params = graph.getNodes(ParameterNode.TYPE).snapshot(); 93 params.get(0).replaceAtMatchingUsages(graph.addOrUnique(new PiNode(params.get(0), a)), x -> x instanceof IntegerExactArithmeticNode); 94 params.get(1).replaceAtMatchingUsages(graph.addOrUnique(new PiNode(params.get(1), b)), x -> x instanceof IntegerExactArithmeticNode); 95 96 Node originalNode = graph.getNodes().filter(x -> x instanceof IntegerExactArithmeticNode).first(); 97 assertNotNull("original node must be in the graph", originalNode); 98 99 new CanonicalizerPhase().apply(graph, getDefaultHighTierContext()); 100 ValueNode node = findNode(graph); 101 boolean overflowExpected = node instanceof IntegerExactArithmeticNode; 102 103 IntegerStamp resultStamp = (IntegerStamp) node.stamp(NodeView.DEFAULT); 104 operation.verifyOverflow(lowerBoundA, upperBoundA, lowerBoundB, upperBoundB, bits, overflowExpected, resultStamp); 105 } 106 107 @Test 108 public void testFoldingAfterLowering() { 109 StructuredGraph graph = prepareGraph(); 110 111 Node originalNode = graph.getNodes().filter(x -> x instanceof IntegerExactArithmeticNode).first(); 112 assertNotNull("original node must be in the graph", originalNode); 113 114 graph.setGuardsStage(GuardsStage.FIXED_DEOPTS); 115 CanonicalizerPhase canonicalizer = new CanonicalizerPhase(); 116 PhaseContext context = new PhaseContext(getProviders()); 117 new LoweringPhase(canonicalizer, LoweringTool.StandardLoweringStage.HIGH_TIER).apply(graph, context); 118 IntegerExactArithmeticSplitNode loweredNode = graph.getNodes().filter(IntegerExactArithmeticSplitNode.class).first(); 119 assertNotNull("the lowered node must be in the graph", loweredNode); 120 121 loweredNode.getX().setStamp(StampFactory.forInteger(bits, lowerBoundA, upperBoundA)); 122 loweredNode.getY().setStamp(StampFactory.forInteger(bits, lowerBoundB, upperBoundB)); 123 new CanonicalizerPhase().apply(graph, context); 124 125 ValueNode node = findNode(graph); 126 boolean overflowExpected = node instanceof IntegerExactArithmeticSplitNode; 127 128 IntegerStamp resultStamp = (IntegerStamp) node.stamp(NodeView.DEFAULT); 129 operation.verifyOverflow(lowerBoundA, upperBoundA, lowerBoundB, upperBoundB, bits, overflowExpected, resultStamp); 130 } 131 132 private static boolean isInteger(long value) { 133 return value >= Integer.MIN_VALUE && value <= Integer.MAX_VALUE; 134 } 135 136 private static ValueNode findNode(StructuredGraph graph) { 137 ValueNode resultNode = graph.getNodes().filter(ReturnNode.class).first().result(); 138 assertNotNull("some node must be the returned value", resultNode); 139 return resultNode; 140 } 141 142 protected StructuredGraph prepareGraph() { 143 String snippet = "snippetInt" + bits; 144 StructuredGraph graph = parseEager(getResolvedJavaMethod(operation.getClass(), snippet), AllowAssumptions.NO); 145 HighTierContext context = getDefaultHighTierContext(); 146 new CanonicalizerPhase().apply(graph, context); 147 return graph; 148 } 149 150 private static void addTest(ArrayList<Object[]> tests, long lowerBound1, long upperBound1, long lowerBound2, long upperBound2, int bits, Operation operation) { 151 tests.add(new Object[]{lowerBound1, upperBound1, lowerBound2, upperBound2, bits, operation}); 152 } 153 154 @Parameters(name = "a[{0} / {1}], b[{2} / {3}], bits={4}, operation={5}") 155 public static Collection<Object[]> data() { 156 ArrayList<Object[]> tests = new ArrayList<>(); 157 158 Operation[] operations = new Operation[]{new AddOperation(), new SubOperation(), new MulOperation()}; 159 for (Operation operation : operations) { 160 for (int bits : new int[]{32, 64}) { 161 // zero related 162 addTest(tests, 0, 0, 1, 1, bits, operation); 163 addTest(tests, 1, 1, 0, 0, bits, operation); 164 addTest(tests, -1, 1, 0, 1, bits, operation); 165 166 // bounds 167 addTest(tests, -2, 2, 3, 3, bits, operation); 168 addTest(tests, -1, 1, 1, 1, bits, operation); 169 addTest(tests, -1, 1, -1, 1, bits, operation); 170 171 addTest(tests, Integer.MIN_VALUE, Integer.MIN_VALUE + 0xF, Integer.MAX_VALUE - 0xF, Integer.MAX_VALUE, bits, operation); 172 addTest(tests, Integer.MIN_VALUE, Integer.MIN_VALUE + 0xF, -1, -1, bits, operation); 173 addTest(tests, Integer.MAX_VALUE, Integer.MAX_VALUE, -1, -1, bits, operation); 174 addTest(tests, Integer.MIN_VALUE, Integer.MIN_VALUE, -1, -1, bits, operation); 175 addTest(tests, Integer.MAX_VALUE, Integer.MAX_VALUE, 1, 1, bits, operation); 176 addTest(tests, Integer.MIN_VALUE, Integer.MIN_VALUE, 1, 1, bits, operation); 177 } 178 179 // bit-specific test cases 180 addTest(tests, Integer.MIN_VALUE, Integer.MIN_VALUE + 0xF, Integer.MAX_VALUE - 0xF, Integer.MAX_VALUE, 64, operation); 181 addTest(tests, Integer.MIN_VALUE, Integer.MIN_VALUE + 0xF, -1, -1, 64, operation); 182 } 183 184 return tests; 185 } 186 187 private abstract static class Operation { 188 abstract void verifyOverflow(long lowerBoundA, long upperBoundA, long lowerBoundB, long upperBoundB, int bits, boolean overflowExpected, IntegerStamp resultStamp); 189 } 190 191 private static final class AddOperation extends Operation { 192 @Override 193 public void verifyOverflow(long lowerBoundA, long upperBoundA, long lowerBoundB, long upperBoundB, int bits, boolean overflowExpected, IntegerStamp resultStamp) { 194 try { 195 long res = addExact(lowerBoundA, lowerBoundB, bits); 196 resultStamp.contains(res); 197 res = addExact(upperBoundA, upperBoundB, bits); 198 resultStamp.contains(res); 199 Assert.assertFalse(overflowExpected); 200 } catch (ArithmeticException e) { 201 Assert.assertTrue(overflowExpected); 202 } 203 } 204 205 private static long addExact(long x, long y, int bits) { 206 if (bits == 32) { 207 return Math.addExact((int) x, (int) y); 208 } else { 209 return Math.addExact(x, y); 210 } 211 } 212 213 @SuppressWarnings("unused") 214 public static int snippetInt32(int a, int b) { 215 return Math.addExact(a, b); 216 } 217 218 @SuppressWarnings("unused") 219 public static long snippetInt64(long a, long b) { 220 return Math.addExact(a, b); 221 } 222 } 223 224 private static final class SubOperation extends Operation { 225 @Override 226 public void verifyOverflow(long lowerBoundA, long upperBoundA, long lowerBoundB, long upperBoundB, int bits, boolean overflowExpected, IntegerStamp resultStamp) { 227 try { 228 long res = subExact(lowerBoundA, upperBoundB, bits); 229 Assert.assertTrue(resultStamp.contains(res)); 230 res = subExact(upperBoundA, lowerBoundB, bits); 231 Assert.assertTrue(resultStamp.contains(res)); 232 Assert.assertFalse(overflowExpected); 233 } catch (ArithmeticException e) { 234 Assert.assertTrue(overflowExpected); 235 } 236 } 237 238 private static long subExact(long x, long y, int bits) { 239 if (bits == 32) { 240 return Math.subtractExact((int) x, (int) y); 241 } else { 242 return Math.subtractExact(x, y); 243 } 244 } 245 246 @SuppressWarnings("unused") 247 public static int snippetInt32(int a, int b) { 248 return Math.subtractExact(a, b); 249 } 250 251 @SuppressWarnings("unused") 252 public static long snippetInt64(long a, long b) { 253 return Math.subtractExact(a, b); 254 } 255 } 256 257 private static final class MulOperation extends Operation { 258 @Override 259 public void verifyOverflow(long lowerBoundA, long upperBoundA, long lowerBoundB, long upperBoundB, int bits, boolean overflowExpected, IntegerStamp resultStamp) { 260 // now check for all values in the stamp whether their products overflow overflow 261 boolean overflowOccurred = false; 262 263 for (long l1 = lowerBoundA; l1 <= upperBoundA; l1++) { 264 for (long l2 = lowerBoundB; l2 <= upperBoundB; l2++) { 265 try { 266 long res = mulExact(l1, l2, bits); 267 Assert.assertTrue(resultStamp.contains(res)); 268 } catch (ArithmeticException e) { 269 overflowOccurred = true; 270 } 271 if (l2 == Long.MAX_VALUE) { 272 // do not want to overflow the check loop 273 break; 274 } 275 } 276 if (l1 == Long.MAX_VALUE) { 277 // do not want to overflow the check loop 278 break; 279 } 280 } 281 282 Assert.assertEquals(overflowExpected, overflowOccurred); 283 } 284 285 private static long mulExact(long x, long y, int bits) { 286 if (bits == 32) { 287 return Math.multiplyExact((int) x, (int) y); 288 } else { 289 return Math.multiplyExact(x, y); 290 } 291 } 292 293 @SuppressWarnings("unused") 294 public static int snippetInt32(int a, int b) { 295 return Math.multiplyExact(a, b); 296 } 297 298 @SuppressWarnings("unused") 299 public static long snippetInt64(long a, long b) { 300 return Math.multiplyExact(a, b); 301 } 302 } 303 }