< prev index next >

src/jdk.internal.vm.compiler/share/classes/org.graalvm.compiler.loop/src/org/graalvm/compiler/loop/LoopFragmentInside.java

Print this page
rev 52509 : [mq]: graal2

@@ -28,10 +28,11 @@
 import java.util.LinkedList;
 import java.util.List;
 
 import jdk.internal.vm.compiler.collections.EconomicMap;
 import jdk.internal.vm.compiler.collections.Equivalence;
+import org.graalvm.compiler.core.common.type.IntegerStamp;
 import org.graalvm.compiler.debug.DebugCloseable;
 import org.graalvm.compiler.debug.DebugContext;
 import org.graalvm.compiler.debug.GraalError;
 import org.graalvm.compiler.graph.Graph.DuplicationReplacement;
 import org.graalvm.compiler.graph.Node;

@@ -62,14 +63,19 @@
 import org.graalvm.compiler.nodes.ValueNode;
 import org.graalvm.compiler.nodes.ValuePhiNode;
 import org.graalvm.compiler.nodes.VirtualState.NodeClosure;
 import org.graalvm.compiler.nodes.calc.AddNode;
 import org.graalvm.compiler.nodes.calc.CompareNode;
+import org.graalvm.compiler.nodes.calc.ConditionalNode;
+import org.graalvm.compiler.nodes.calc.IntegerBelowNode;
 import org.graalvm.compiler.nodes.calc.SubNode;
+import org.graalvm.compiler.nodes.extended.OpaqueNode;
 import org.graalvm.compiler.nodes.memory.MemoryPhiNode;
 import org.graalvm.compiler.nodes.util.GraphUtil;
 
+import jdk.vm.ci.code.CodeUtil;
+
 public class LoopFragmentInside extends LoopFragment {
 
     /**
      * mergedInitializers. When an inside fragment's (loop)ends are merged to create a unique exit
      * point, some phis must be created : they phis together all the back-values of the loop-phis

@@ -148,24 +154,12 @@
     }
 
     /**
      * Duplicate the body within the loop after the current copy copy of the body, updating the
      * iteration limit to account for the duplication.
-     *
-     * @param loop
      */
-    public void insertWithinAfter(LoopEx loop) {
-        insertWithinAfter(loop, true);
-    }
-
-    /**
-     * Duplicate the body within the loop after the current copy copy of the body.
-     *
-     * @param loop
-     * @param updateLimit true if the iteration limit should be adjusted.
-     */
-    public void insertWithinAfter(LoopEx loop, boolean updateLimit) {
+    public void insertWithinAfter(LoopEx loop, EconomicMap<LoopBeginNode, OpaqueNode> opaqueUnrolledStrides) {
         assert isDuplicate() && original().loop() == loop;
 
         patchNodes(dataFixWithinAfter);
 
         /*

@@ -199,36 +193,47 @@
         assert loop.whole().nodes().filter(SafepointNode.class).count() == nodes().filter(SafepointNode.class).count();
         for (SafepointNode safepoint : loop.whole().nodes().filter(SafepointNode.class)) {
             graph().removeFixed(safepoint);
         }
 
-        int unrollFactor = mainLoopBegin.getUnrollFactor();
         StructuredGraph graph = mainLoopBegin.graph();
-        if (updateLimit) {
-            // Now use the previous unrollFactor to update the exit condition to power of two
-            InductionVariable iv = loop.counted().getCounter();
-            CompareNode compareNode = (CompareNode) loop.counted().getLimitTest().condition();
-            ValueNode compareBound;
-            if (compareNode.getX() == iv.valueNode()) {
-                compareBound = compareNode.getY();
-            } else if (compareNode.getY() == iv.valueNode()) {
-                compareBound = compareNode.getX();
+        if (opaqueUnrolledStrides != null) {
+            OpaqueNode opaque = opaqueUnrolledStrides.get(loop.loopBegin());
+            CountedLoopInfo counted = loop.counted();
+            ValueNode counterStride = counted.getCounter().strideNode();
+            if (opaque == null) {
+                opaque = new OpaqueNode(AddNode.add(counterStride, counterStride, NodeView.DEFAULT));
+                ValueNode limit = counted.getLimit();
+                int bits = ((IntegerStamp) limit.stamp(NodeView.DEFAULT)).getBits();
+                ValueNode newLimit = SubNode.create(limit, opaque, NodeView.DEFAULT);
+                LogicNode overflowCheck;
+                ConstantNode extremum;
+                if (counted.getDirection() == InductionVariable.Direction.Up) {
+                    // limit - counterStride could overflow negatively if limit - min <
+                    // counterStride
+                    extremum = ConstantNode.forIntegerBits(bits, CodeUtil.minValue(bits));
+                    overflowCheck = IntegerBelowNode.create(SubNode.create(limit, extremum, NodeView.DEFAULT), opaque, NodeView.DEFAULT);
             } else {
-                throw GraalError.shouldNotReachHere();
-            }
-            long originalStride = unrollFactor == 1 ? iv.constantStride() : iv.constantStride() / unrollFactor;
-            if (iv.direction() == InductionVariable.Direction.Up) {
-                ConstantNode aboveVal = graph.unique(ConstantNode.forIntegerStamp(iv.initNode().stamp(NodeView.DEFAULT), unrollFactor * originalStride));
-                ValueNode newLimit = graph.addWithoutUnique(new SubNode(compareBound, aboveVal));
-                compareNode.replaceFirstInput(compareBound, newLimit);
-            } else if (iv.direction() == InductionVariable.Direction.Down) {
-                ConstantNode aboveVal = graph.unique(ConstantNode.forIntegerStamp(iv.initNode().stamp(NodeView.DEFAULT), unrollFactor * -originalStride));
-                ValueNode newLimit = graph.addWithoutUnique(new AddNode(compareBound, aboveVal));
-                compareNode.replaceFirstInput(compareBound, newLimit);
+                    assert counted.getDirection() == InductionVariable.Direction.Down;
+                    // limit - counterStride could overflow if max - limit < -counterStride
+                    // i.e., counterStride < limit - max
+                    extremum = ConstantNode.forIntegerBits(bits, CodeUtil.maxValue(bits));
+                    overflowCheck = IntegerBelowNode.create(opaque, SubNode.create(limit, extremum, NodeView.DEFAULT), NodeView.DEFAULT);
+                }
+                newLimit = ConditionalNode.create(overflowCheck, extremum, newLimit, NodeView.DEFAULT);
+                CompareNode compareNode = (CompareNode) counted.getLimitTest().condition();
+                compareNode.replaceFirstInput(limit, graph.addOrUniqueWithInputs(newLimit));
+                opaqueUnrolledStrides.put(loop.loopBegin(), opaque);
+            } else {
+                assert counted.getCounter().isConstantStride();
+                assert Math.addExact(counted.getCounter().constantStride(), counted.getCounter().constantStride()) == counted.getCounter().constantStride() * 2;
+                ValueNode previousValue = opaque.getValue();
+                opaque.setValue(graph.addOrUniqueWithInputs(AddNode.add(counterStride, previousValue, NodeView.DEFAULT)));
+                GraphUtil.tryKillUnused(previousValue);
             }
         }
-        mainLoopBegin.setUnrollFactor(unrollFactor * 2);
+        mainLoopBegin.setUnrollFactor(mainLoopBegin.getUnrollFactor() * 2);
         mainLoopBegin.setLoopFrequency(mainLoopBegin.loopFrequency() / 2);
         graph.getDebug().dump(DebugContext.DETAILED_LEVEL, graph, "LoopPartialUnroll %s", loop);
 
         mainLoopBegin.getDebug().dump(DebugContext.VERBOSE_LEVEL, mainLoopBegin.graph(), "After insertWithinAfter %s", mainLoopBegin);
     }
< prev index next >