/* * Copyright (c) 2017, 2018, Oracle and/or its affiliates. All rights reserved. * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. * * This code is free software; you can redistribute it and/or modify it * under the terms of the GNU General Public License version 2 only, as * published by the Free Software Foundation. * * This code is distributed in the hope that it will be useful, but WITHOUT * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License * version 2 for more details (a copy is included in the LICENSE file that * accompanied this code). * * You should have received a copy of the GNU General Public License version * 2 along with this work; if not, write to the Free Software Foundation, * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. * * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA * or visit www.oracle.com if you need additional information or have any * questions. */ package org.graalvm.compiler.replacements.test; import static org.junit.Assert.assertNotNull; import java.util.ArrayList; import java.util.Collection; import java.util.List; import org.graalvm.compiler.core.common.type.IntegerStamp; import org.graalvm.compiler.core.common.type.StampFactory; import org.graalvm.compiler.core.test.GraalCompilerTest; import org.graalvm.compiler.graph.Node; import org.graalvm.compiler.nodes.NodeView; import org.graalvm.compiler.nodes.ParameterNode; import org.graalvm.compiler.nodes.PiNode; import org.graalvm.compiler.nodes.ReturnNode; import org.graalvm.compiler.nodes.StructuredGraph; import org.graalvm.compiler.nodes.StructuredGraph.AllowAssumptions; import org.graalvm.compiler.nodes.ValueNode; import org.graalvm.compiler.nodes.spi.LoweringTool; import org.graalvm.compiler.phases.common.CanonicalizerPhase; import org.graalvm.compiler.phases.common.GuardLoweringPhase; import org.graalvm.compiler.phases.common.LoweringPhase; import org.graalvm.compiler.phases.tiers.HighTierContext; import org.graalvm.compiler.phases.tiers.MidTierContext; import org.graalvm.compiler.replacements.nodes.arithmetic.IntegerExactArithmeticNode; import org.graalvm.compiler.replacements.nodes.arithmetic.IntegerExactArithmeticSplitNode; import org.junit.Assert; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.Parameterized; import org.junit.runners.Parameterized.Parameters; @RunWith(Parameterized.class) public class IntegerExactFoldTest extends GraalCompilerTest { private final long lowerBoundA; private final long upperBoundA; private final long lowerBoundB; private final long upperBoundB; private final int bits; private final Operation operation; public IntegerExactFoldTest(long lowerBoundA, long upperBoundA, long lowerBoundB, long upperBoundB, int bits, Operation operation) { this.lowerBoundA = lowerBoundA; this.upperBoundA = upperBoundA; this.lowerBoundB = lowerBoundB; this.upperBoundB = upperBoundB; this.bits = bits; this.operation = operation; assert bits == 32 || bits == 64; assert lowerBoundA <= upperBoundA; assert lowerBoundB <= upperBoundB; assert bits == 64 || isInteger(lowerBoundA); assert bits == 64 || isInteger(upperBoundA); assert bits == 64 || isInteger(lowerBoundB); assert bits == 64 || isInteger(upperBoundB); } @Test public void testFolding() { StructuredGraph graph = prepareGraph(); IntegerStamp a = StampFactory.forInteger(bits, lowerBoundA, upperBoundA); IntegerStamp b = StampFactory.forInteger(bits, lowerBoundB, upperBoundB); List params = graph.getNodes(ParameterNode.TYPE).snapshot(); params.get(0).replaceAtMatchingUsages(graph.addOrUnique(new PiNode(params.get(0), a)), x -> x instanceof IntegerExactArithmeticNode); params.get(1).replaceAtMatchingUsages(graph.addOrUnique(new PiNode(params.get(1), b)), x -> x instanceof IntegerExactArithmeticNode); Node originalNode = graph.getNodes().filter(x -> x instanceof IntegerExactArithmeticNode).first(); assertNotNull("original node must be in the graph", originalNode); new CanonicalizerPhase().apply(graph, getDefaultHighTierContext()); ValueNode node = findNode(graph); boolean overflowExpected = node instanceof IntegerExactArithmeticNode; IntegerStamp resultStamp = (IntegerStamp) node.stamp(NodeView.DEFAULT); operation.verifyOverflow(lowerBoundA, upperBoundA, lowerBoundB, upperBoundB, bits, overflowExpected, resultStamp); } @Test public void testFoldingAfterLowering() { StructuredGraph graph = prepareGraph(); Node originalNode = graph.getNodes().filter(x -> x instanceof IntegerExactArithmeticNode).first(); assertNotNull("original node must be in the graph", originalNode); CanonicalizerPhase canonicalizer = new CanonicalizerPhase(); HighTierContext highTierContext = getDefaultHighTierContext(); new LoweringPhase(canonicalizer, LoweringTool.StandardLoweringStage.HIGH_TIER).apply(graph, highTierContext); MidTierContext midTierContext = getDefaultMidTierContext(); new GuardLoweringPhase().apply(graph, midTierContext); new CanonicalizerPhase().apply(graph, midTierContext); IntegerExactArithmeticSplitNode loweredNode = graph.getNodes().filter(IntegerExactArithmeticSplitNode.class).first(); assertNotNull("the lowered node must be in the graph", loweredNode); loweredNode.getX().setStamp(StampFactory.forInteger(bits, lowerBoundA, upperBoundA)); loweredNode.getY().setStamp(StampFactory.forInteger(bits, lowerBoundB, upperBoundB)); new CanonicalizerPhase().apply(graph, midTierContext); ValueNode node = findNode(graph); boolean overflowExpected = node instanceof IntegerExactArithmeticSplitNode; IntegerStamp resultStamp = (IntegerStamp) node.stamp(NodeView.DEFAULT); operation.verifyOverflow(lowerBoundA, upperBoundA, lowerBoundB, upperBoundB, bits, overflowExpected, resultStamp); } private static boolean isInteger(long value) { return value >= Integer.MIN_VALUE && value <= Integer.MAX_VALUE; } private static ValueNode findNode(StructuredGraph graph) { ValueNode resultNode = graph.getNodes().filter(ReturnNode.class).first().result(); assertNotNull("some node must be the returned value", resultNode); return resultNode; } protected StructuredGraph prepareGraph() { String snippet = "snippetInt" + bits; StructuredGraph graph = parseEager(getResolvedJavaMethod(operation.getClass(), snippet), AllowAssumptions.NO); HighTierContext context = getDefaultHighTierContext(); new CanonicalizerPhase().apply(graph, context); return graph; } private static void addTest(ArrayList tests, long lowerBound1, long upperBound1, long lowerBound2, long upperBound2, int bits, Operation operation) { tests.add(new Object[]{lowerBound1, upperBound1, lowerBound2, upperBound2, bits, operation}); } @Parameters(name = "a[{0} / {1}], b[{2} / {3}], bits={4}, operation={5}") public static Collection data() { ArrayList tests = new ArrayList<>(); Operation[] operations = new Operation[]{new AddOperation(), new SubOperation(), new MulOperation()}; for (Operation operation : operations) { for (int bits : new int[]{32, 64}) { // zero related addTest(tests, 0, 0, 1, 1, bits, operation); addTest(tests, 1, 1, 0, 0, bits, operation); addTest(tests, -1, 1, 0, 1, bits, operation); // bounds addTest(tests, -2, 2, 3, 3, bits, operation); addTest(tests, -1, 1, 1, 1, bits, operation); addTest(tests, -1, 1, -1, 1, bits, operation); addTest(tests, Integer.MIN_VALUE, Integer.MIN_VALUE + 0xF, Integer.MAX_VALUE - 0xF, Integer.MAX_VALUE, bits, operation); addTest(tests, Integer.MIN_VALUE, Integer.MIN_VALUE + 0xF, -1, -1, bits, operation); addTest(tests, Integer.MAX_VALUE, Integer.MAX_VALUE, -1, -1, bits, operation); addTest(tests, Integer.MIN_VALUE, Integer.MIN_VALUE, -1, -1, bits, operation); addTest(tests, Integer.MAX_VALUE, Integer.MAX_VALUE, 1, 1, bits, operation); addTest(tests, Integer.MIN_VALUE, Integer.MIN_VALUE, 1, 1, bits, operation); } // bit-specific test cases addTest(tests, Integer.MIN_VALUE, Integer.MIN_VALUE + 0xF, Integer.MAX_VALUE - 0xF, Integer.MAX_VALUE, 64, operation); addTest(tests, Integer.MIN_VALUE, Integer.MIN_VALUE + 0xF, -1, -1, 64, operation); } return tests; } private abstract static class Operation { abstract void verifyOverflow(long lowerBoundA, long upperBoundA, long lowerBoundB, long upperBoundB, int bits, boolean overflowExpected, IntegerStamp resultStamp); } private static final class AddOperation extends Operation { @Override public void verifyOverflow(long lowerBoundA, long upperBoundA, long lowerBoundB, long upperBoundB, int bits, boolean overflowExpected, IntegerStamp resultStamp) { try { long res = addExact(lowerBoundA, lowerBoundB, bits); resultStamp.contains(res); res = addExact(upperBoundA, upperBoundB, bits); resultStamp.contains(res); Assert.assertFalse(overflowExpected); } catch (ArithmeticException e) { Assert.assertTrue(overflowExpected); } } private static long addExact(long x, long y, int bits) { if (bits == 32) { return Math.addExact((int) x, (int) y); } else { return Math.addExact(x, y); } } @SuppressWarnings("unused") public static int snippetInt32(int a, int b) { return Math.addExact(a, b); } @SuppressWarnings("unused") public static long snippetInt64(long a, long b) { return Math.addExact(a, b); } } private static final class SubOperation extends Operation { @Override public void verifyOverflow(long lowerBoundA, long upperBoundA, long lowerBoundB, long upperBoundB, int bits, boolean overflowExpected, IntegerStamp resultStamp) { try { long res = subExact(lowerBoundA, upperBoundB, bits); Assert.assertTrue(resultStamp.contains(res)); res = subExact(upperBoundA, lowerBoundB, bits); Assert.assertTrue(resultStamp.contains(res)); Assert.assertFalse(overflowExpected); } catch (ArithmeticException e) { Assert.assertTrue(overflowExpected); } } private static long subExact(long x, long y, int bits) { if (bits == 32) { return Math.subtractExact((int) x, (int) y); } else { return Math.subtractExact(x, y); } } @SuppressWarnings("unused") public static int snippetInt32(int a, int b) { return Math.subtractExact(a, b); } @SuppressWarnings("unused") public static long snippetInt64(long a, long b) { return Math.subtractExact(a, b); } } private static final class MulOperation extends Operation { @Override public void verifyOverflow(long lowerBoundA, long upperBoundA, long lowerBoundB, long upperBoundB, int bits, boolean overflowExpected, IntegerStamp resultStamp) { // now check for all values in the stamp whether their products overflow overflow boolean overflowOccurred = false; for (long l1 = lowerBoundA; l1 <= upperBoundA; l1++) { for (long l2 = lowerBoundB; l2 <= upperBoundB; l2++) { try { long res = mulExact(l1, l2, bits); Assert.assertTrue(resultStamp.contains(res)); } catch (ArithmeticException e) { overflowOccurred = true; } if (l2 == Long.MAX_VALUE) { // do not want to overflow the check loop break; } } if (l1 == Long.MAX_VALUE) { // do not want to overflow the check loop break; } } Assert.assertEquals(overflowExpected, overflowOccurred); } private static long mulExact(long x, long y, int bits) { if (bits == 32) { return Math.multiplyExact((int) x, (int) y); } else { return Math.multiplyExact(x, y); } } @SuppressWarnings("unused") public static int snippetInt32(int a, int b) { return Math.multiplyExact(a, b); } @SuppressWarnings("unused") public static long snippetInt64(long a, long b) { return Math.multiplyExact(a, b); } } }