< prev index next >

src/jdk.internal.vm.compiler/share/classes/org.graalvm.compiler.loop.phases/src/org/graalvm/compiler/loop/phases/LoopTransformations.java

Print this page

        

@@ -20,29 +20,17 @@
  * or visit www.oracle.com if you need additional information or have any
  * questions.
  */
 package org.graalvm.compiler.loop.phases;
 
-import static org.graalvm.compiler.core.common.GraalOptions.MaximumDesiredSize;
-import static org.graalvm.compiler.loop.MathUtil.add;
-import static org.graalvm.compiler.loop.MathUtil.sub;
-
-import java.util.ArrayList;
-import java.util.Iterator;
-import java.util.List;
-
 import org.graalvm.compiler.core.common.RetryableBailoutException;
-import org.graalvm.compiler.core.common.type.Stamp;
 import org.graalvm.compiler.debug.DebugContext;
 import org.graalvm.compiler.debug.GraalError;
 import org.graalvm.compiler.graph.Graph.Mark;
 import org.graalvm.compiler.graph.Node;
-import org.graalvm.compiler.graph.NodeWorkList;
 import org.graalvm.compiler.graph.Position;
-import org.graalvm.compiler.loop.BasicInductionVariable;
 import org.graalvm.compiler.loop.CountedLoopInfo;
-import org.graalvm.compiler.loop.DerivedInductionVariable;
 import org.graalvm.compiler.loop.InductionVariable;
 import org.graalvm.compiler.loop.InductionVariable.Direction;
 import org.graalvm.compiler.loop.LoopEx;
 import org.graalvm.compiler.loop.LoopFragmentInside;
 import org.graalvm.compiler.loop.LoopFragmentWhole;

@@ -57,27 +45,31 @@
 import org.graalvm.compiler.nodes.FixedNode;
 import org.graalvm.compiler.nodes.FixedWithNextNode;
 import org.graalvm.compiler.nodes.IfNode;
 import org.graalvm.compiler.nodes.LogicNode;
 import org.graalvm.compiler.nodes.LoopBeginNode;
-import org.graalvm.compiler.nodes.LoopEndNode;
 import org.graalvm.compiler.nodes.LoopExitNode;
 import org.graalvm.compiler.nodes.PhiNode;
+import org.graalvm.compiler.nodes.SafepointNode;
 import org.graalvm.compiler.nodes.StructuredGraph;
 import org.graalvm.compiler.nodes.ValueNode;
-import org.graalvm.compiler.nodes.ValuePhiNode;
-import org.graalvm.compiler.nodes.calc.AddNode;
-import org.graalvm.compiler.nodes.calc.BinaryArithmeticNode;
 import org.graalvm.compiler.nodes.calc.CompareNode;
 import org.graalvm.compiler.nodes.calc.ConditionalNode;
 import org.graalvm.compiler.nodes.calc.IntegerLessThanNode;
-import org.graalvm.compiler.nodes.calc.SubNode;
 import org.graalvm.compiler.nodes.extended.SwitchNode;
 import org.graalvm.compiler.nodes.util.GraphUtil;
 import org.graalvm.compiler.phases.common.CanonicalizerPhase;
 import org.graalvm.compiler.phases.tiers.PhaseContext;
 
+import java.util.ArrayList;
+import java.util.Iterator;
+import java.util.List;
+
+import static org.graalvm.compiler.core.common.GraalOptions.MaximumDesiredSize;
+import static org.graalvm.compiler.loop.MathUtil.add;
+import static org.graalvm.compiler.loop.MathUtil.sub;
+
 public abstract class LoopTransformations {
 
     private LoopTransformations() {
         // does not need to be instantiated
     }

@@ -152,258 +144,17 @@
         }
 
         // TODO (gd) probabilities need some amount of fixup.. (probably also in other transforms)
     }
 
-    public static boolean partialUnroll(LoopEx loop, StructuredGraph graph) {
+    public static void partialUnroll(LoopEx loop, StructuredGraph graph) {
         assert loop.loopBegin().isMainLoop();
         graph.getDebug().log("LoopPartialUnroll %s", loop);
-        boolean changed = false;
-        CountedLoopInfo mainCounted = loop.counted();
-        LoopBeginNode mainLoopBegin = loop.loopBegin();
-        InductionVariable iv = mainCounted.getCounter();
-        IfNode mainLimit = mainCounted.getLimitTest();
-        LogicNode ifTest = mainLimit.condition();
-        CompareNode compareNode = (CompareNode) ifTest;
-        ValueNode compareBound = null;
-        ValueNode curPhi = iv.valueNode();
-        if (compareNode.getX() == curPhi) {
-            compareBound = compareNode.getY();
-        } else if (compareNode.getY() == curPhi) {
-            compareBound = compareNode.getX();
-        }
+
         LoopFragmentInside newSegment = loop.inside().duplicate();
         newSegment.insertWithinAfter(loop);
-        graph.getDebug().dump(DebugContext.VERBOSE_LEVEL, graph, "After duplication inside %s", mainLoopBegin);
-        ValueNode inductionNode = iv.valueNode();
-        Node newStrideNode = null;
-        for (PhiNode mainPhiNode : mainLoopBegin.phis()) {
-            Node segmentOrigOp = null;
-            Node replacementOp = null;
-            changed = false;
-            // Rework each phi with a loop carried dependence
-            for (Node phiUsage : mainPhiNode.usages()) {
-                if (!loop.isOutsideLoop(phiUsage)) {
-                    for (int i = 1; i < mainPhiNode.valueCount(); i++) {
-                        ValueNode v = mainPhiNode.valueAt(i);
-                        if (mainPhiNode != inductionNode) {
-                            if (closureOnPhiInputToPhiUse(v, phiUsage, loop, graph)) {
-                                segmentOrigOp = v;
-                                Node node = newSegment.getDuplicatedNode(v);
-                                replacementOp = updateUnrollSegmentValue(mainPhiNode, inductionNode, phiUsage, v, newSegment);
-
-                                // Update the induction phi with new stride node
-                                mainPhiNode.setValueAt(i, (ValueNode) node);
-                                // This is for induction variables not referenced in the loop body
-                                if (inductionNode == v) {
-                                    newStrideNode = node;
-                                }
-                                changed = true;
-                                break;
-                            }
-                        } else if (v == phiUsage) {
-                            segmentOrigOp = phiUsage;
-                            Node node = newSegment.getDuplicatedNode(phiUsage);
-                            newStrideNode = node;
-                            replacementOp = updateUnrollSegmentValue(mainPhiNode, inductionNode, phiUsage, phiUsage, newSegment);
-
-                            // Update the induction phi with new stride node
-                            mainPhiNode.setValueAt(i, (ValueNode) node);
-                            changed = true;
-                            break;
-                        }
-                    }
-                }
-                if (changed) {
-                    break;
-                }
-            }
 
-            if (changed) {
-                // Patch the new segments induction uses of replacementOp with the old stride node
-                for (Node usage : mainPhiNode.usages()) {
-                    if (usage != segmentOrigOp) {
-                        if (!loop.isOutsideLoop(usage)) {
-                            Node node = newSegment.getDuplicatedNode(usage);
-                            if (node instanceof CompareNode) {
-                                continue;
-                            }
-                            node.replaceFirstInput(replacementOp, segmentOrigOp);
-                        }
-                    }
-                }
-            }
-        }
-
-        if (changed && newStrideNode == null) {
-            throw GraalError.shouldNotReachHere("Can't find stride node");
-        }
-        if (newStrideNode != null) {
-            // If merge the duplicate code into the loop and remove redundant code
-            placeNewSegmentAndCleanup(mainCounted, mainLoopBegin, newSegment);
-            int unrollFactor = mainLoopBegin.getUnrollFactor();
-            // First restore the old pattern of the loop exit condition so we can update it one way
-            if (unrollFactor > 1) {
-                if (compareBound instanceof SubNode) {
-                    SubNode newLimit = (SubNode) compareBound;
-                    ValueNode oldcompareBound = newLimit.getX();
-                    compareNode.replaceFirstInput(newLimit, oldcompareBound);
-                    newLimit.safeDelete();
-                    compareBound = oldcompareBound;
-                } else if (compareBound instanceof AddNode) {
-                    AddNode newLimit = (AddNode) compareBound;
-                    ValueNode oldcompareBound = newLimit.getX();
-                    compareNode.replaceFirstInput(newLimit, oldcompareBound);
-                    newLimit.safeDelete();
-                    compareBound = oldcompareBound;
-                }
-            }
-            unrollFactor *= 2;
-            mainLoopBegin.setUnrollFactor(unrollFactor);
-            // Reset stride to include new segment in loop control.
-            long oldStride = iv.constantStride() * 2;
-            // Now update the induction op and the exit condition
-            if (iv instanceof BasicInductionVariable) {
-                BasicInductionVariable biv = (BasicInductionVariable) iv;
-                BinaryArithmeticNode<?> newOp = (BinaryArithmeticNode<?>) newStrideNode;
-                Stamp strideStamp = newOp.stamp();
-                ConstantNode newStrideVal = graph.unique(ConstantNode.forIntegerStamp(strideStamp, oldStride));
-                newOp.setY(newStrideVal);
-                biv.setOP(newOp);
-                // Now use the current unrollFactor to update the exit condition to power of two
-                if (unrollFactor > 1) {
-                    if (iv.direction() == Direction.Up) {
-                        int modulas = (unrollFactor - 1);
-                        ConstantNode aboveVal = graph.unique(ConstantNode.forIntegerStamp(strideStamp, modulas));
-                        ValueNode newLimit = graph.addWithoutUnique(new SubNode(compareBound, aboveVal));
-                        compareNode.replaceFirstInput(compareBound, newLimit);
-                    } else if (iv.direction() == Direction.Down) {
-                        int modulas = (unrollFactor - 1);
-                        ConstantNode aboveVal = graph.unique(ConstantNode.forIntegerStamp(strideStamp, modulas));
-                        ValueNode newLimit = graph.addWithoutUnique(new AddNode(compareBound, aboveVal));
-                        compareNode.replaceFirstInput(compareBound, newLimit);
-                    }
-                }
-                mainLoopBegin.setLoopFrequency(mainLoopBegin.loopFrequency() / 2);
-            }
-            changed = true;
-        }
-        if (changed) {
-            graph.getDebug().dump(DebugContext.DETAILED_LEVEL, graph, "LoopPartialUnroll %s", loop);
-        }
-        return changed;
-    }
-
-    private static Node updateUnrollSegmentValue(PhiNode mainPhiNode, Node inductionNode, Node phiUsage, Node patchNode, LoopFragmentInside newSegment) {
-        Node node = newSegment.getDuplicatedNode(phiUsage);
-        assert node != null : phiUsage;
-        Node replacementOp = null;
-        int inputCnt = 0;
-        for (Node input : phiUsage.inputs()) {
-            inputCnt++;
-            if (input == mainPhiNode) {
-                break;
-            }
-        }
-        int newInputCnt = 0;
-        for (Node input : node.inputs()) {
-            newInputCnt++;
-            if (newInputCnt == inputCnt) {
-                replacementOp = input;
-                if (mainPhiNode == inductionNode) {
-                    node.replaceFirstInput(input, mainPhiNode);
-                } else {
-                    node.replaceFirstInput(input, patchNode);
-                }
-                break;
-            }
-        }
-        return replacementOp;
-    }
-
-    private static boolean closureOnPhiInputToPhiUse(Node inNode, Node usage, LoopEx loop, StructuredGraph graph) {
-        NodeWorkList nodes = graph.createNodeWorkList();
-        nodes.add(inNode);
-        // Now walk from the inNode to usage if we can find it else we do not have closure
-        for (Node node : nodes) {
-            if (node == usage) {
-                return true;
-            }
-            for (Node input : node.inputs()) {
-                if (!loop.isOutsideLoop(input)) {
-                    if (input != usage) {
-                        nodes.add(input);
-                    } else {
-                        return true;
-                        // For any reason if we have completed a closure, stop processing more
-                    }
-                }
-            }
-        }
-        return false;
-    }
-
-    private static void placeNewSegmentAndCleanup(CountedLoopInfo mainCounted, LoopBeginNode mainLoopBegin, LoopFragmentInside newSegment) {
-        // Discard the segment entry and its flow, after if merging it into the loop
-        StructuredGraph graph = mainLoopBegin.graph();
-        IfNode loopTest = mainCounted.getLimitTest();
-        IfNode newSegmentTest = newSegment.getDuplicatedNode(loopTest);
-        AbstractBeginNode trueSuccessor = loopTest.trueSuccessor();
-        AbstractBeginNode falseSuccessor = loopTest.falseSuccessor();
-        FixedNode firstNode;
-        boolean codeInTrueSide = false;
-        if (trueSuccessor == mainCounted.getBody()) {
-            firstNode = trueSuccessor.next();
-            codeInTrueSide = true;
-        } else {
-            assert (falseSuccessor == mainCounted.getBody());
-            firstNode = falseSuccessor.next();
-        }
-        trueSuccessor = newSegmentTest.trueSuccessor();
-        falseSuccessor = newSegmentTest.falseSuccessor();
-        for (Node usage : falseSuccessor.anchored().snapshot()) {
-            usage.replaceFirstInput(falseSuccessor, loopTest.falseSuccessor());
-        }
-        for (Node usage : trueSuccessor.anchored().snapshot()) {
-            usage.replaceFirstInput(trueSuccessor, loopTest.trueSuccessor());
-        }
-        AbstractBeginNode startBlockNode;
-        if (codeInTrueSide) {
-            startBlockNode = trueSuccessor;
-        } else {
-            graph.getDebug().dump(DebugContext.VERBOSE_LEVEL, mainLoopBegin.graph(), "before");
-            startBlockNode = falseSuccessor;
-        }
-        FixedNode lastNode = getBlockEnd(startBlockNode);
-        LoopEndNode loopEndNode = getSingleLoopEndFromLoop(mainLoopBegin);
-        FixedNode lastCodeNode = (FixedNode) loopEndNode.predecessor();
-        FixedNode newSegmentFirstNode = newSegment.getDuplicatedNode(firstNode);
-        FixedNode newSegmentLastNode = newSegment.getDuplicatedNode(lastCodeNode);
-        graph.getDebug().dump(DebugContext.DETAILED_LEVEL, loopEndNode.graph(), "Before placing segment");
-        if (firstNode instanceof LoopEndNode) {
-            GraphUtil.killCFG(newSegment.getDuplicatedNode(mainLoopBegin));
-        } else {
-            newSegmentLastNode.clearSuccessors();
-            startBlockNode.setNext(lastNode);
-            lastCodeNode.replaceFirstSuccessor(loopEndNode, newSegmentFirstNode);
-            newSegmentLastNode.replaceFirstSuccessor(lastNode, loopEndNode);
-            FixedWithNextNode oldLastNode = (FixedWithNextNode) lastCodeNode;
-            oldLastNode.setNext(newSegmentFirstNode);
-            FixedWithNextNode newLastNode = (FixedWithNextNode) newSegmentLastNode;
-            newLastNode.setNext(loopEndNode);
-            startBlockNode.clearSuccessors();
-            lastNode.safeDelete();
-            Node newSegmentTestStart = newSegmentTest.predecessor();
-            LogicNode newSegmentIfTest = newSegmentTest.condition();
-            newSegmentTestStart.clearSuccessors();
-            newSegmentTest.safeDelete();
-            newSegmentIfTest.safeDelete();
-            trueSuccessor.safeDelete();
-            falseSuccessor.safeDelete();
-            newSegmentTestStart.safeDelete();
-        }
-        graph.getDebug().dump(DebugContext.DETAILED_LEVEL, loopEndNode.graph(), "After placing segment");
     }
 
     // This function splits candidate loops into pre, main and post loops,
     // dividing the iteration space to facilitate the majority of iterations
     // being executed in a main loop, which will have RCE implemented upon it.

@@ -473,16 +224,16 @@
     // be updated to produce vector alignment if applicable.
 
     public static void insertPrePostLoops(LoopEx loop, StructuredGraph graph) {
         graph.getDebug().log("LoopTransformations.insertPrePostLoops %s", loop);
         LoopFragmentWhole preLoop = loop.whole();
-        CountedLoopInfo preCounted = preLoop.loop().counted();
+        CountedLoopInfo preCounted = loop.counted();
         IfNode preLimit = preCounted.getLimitTest();
         if (preLimit != null) {
             LoopBeginNode preLoopBegin = loop.loopBegin();
             InductionVariable preIv = preCounted.getCounter();
-            LoopExitNode preLoopExitNode = getSingleExitFromLoop(preLoopBegin);
+            LoopExitNode preLoopExitNode = preLoopBegin.getSingleLoopExit();
             FixedNode continuationNode = preLoopExitNode.next();
 
             // Each duplication is inserted after the original, ergo create the post loop first
             LoopFragmentWhole mainLoop = preLoop.duplicate();
             LoopFragmentWhole postLoop = preLoop.duplicate();

@@ -495,11 +246,11 @@
             LoopBeginNode postLoopBegin = postLoop.getDuplicatedNode(preLoopBegin);
             postLoopBegin.setPostLoop();
 
             EndNode postEndNode = getBlockEndAfterLoopExit(postLoopBegin);
             AbstractMergeNode postMergeNode = postEndNode.merge();
-            LoopExitNode postLoopExitNode = getSingleExitFromLoop(postLoopBegin);
+            LoopExitNode postLoopExitNode = postLoopBegin.getSingleLoopExit();
 
             // Update the main loop phi initialization to carry from the pre loop
             for (PhiNode prePhiNode : preLoopBegin.phis()) {
                 PhiNode mainPhiNode = mainLoop.getDuplicatedNode(prePhiNode);
                 mainPhiNode.setValueAt(0, prePhiNode);

@@ -509,11 +260,11 @@
             AbstractMergeNode mainMergeNode = mainEndNode.merge();
             AbstractEndNode postEntryNode = postLoopBegin.forwardEnd();
 
             // In the case of no Bounds tests, we just flow right into the main loop
             AbstractBeginNode mainLandingNode = BeginNode.begin(postEntryNode);
-            LoopExitNode mainLoopExitNode = getSingleExitFromLoop(mainLoopBegin);
+            LoopExitNode mainLoopExitNode = mainLoopBegin.getSingleLoopExit();
             mainLoopExitNode.setNext(mainLandingNode);
             preLoopExitNode.setNext(mainLoopBegin.forwardEnd());
 
             // Add and update any phi edges as per merge usage as needed and update usages
             processPreLoopPhis(loop, mainLoop, postLoop);

@@ -526,10 +277,18 @@
             updateMainLoopLimit(preLimit, preIv, mainLoop);
             updatePreLoopLimit(preLimit, preIv, preCounted);
             preLoopBegin.setLoopFrequency(1);
             mainLoopBegin.setLoopFrequency(Math.max(0.0, mainLoopBegin.loopFrequency() - 2));
             postLoopBegin.setLoopFrequency(Math.max(0.0, postLoopBegin.loopFrequency() - 1));
+
+            // The pre and post loops don't require safepoints at all
+            for (SafepointNode safepoint : preLoop.nodes().filter(SafepointNode.class)) {
+                GraphUtil.removeFixedWithUnusedInputs(safepoint);
+            }
+            for (SafepointNode safepoint : postLoop.nodes().filter(SafepointNode.class)) {
+                GraphUtil.removeFixedWithUnusedInputs(safepoint);
+            }
         }
         graph.getDebug().dump(DebugContext.DETAILED_LEVEL, graph, "InsertPrePostLoops %s", loop);
     }
 
     /**

@@ -571,25 +330,15 @@
                 }
             }
         }
     }
 
-    private static LoopExitNode getSingleExitFromLoop(LoopBeginNode curLoopBegin) {
-        assert curLoopBegin.loopExits().count() == 1;
-        return curLoopBegin.loopExits().first();
-    }
-
-    private static LoopEndNode getSingleLoopEndFromLoop(LoopBeginNode curLoopBegin) {
-        assert curLoopBegin.loopEnds().count() == 1;
-        return curLoopBegin.loopEnds().first();
-    }
-
     /**
      * Find the end of the block following the LoopExit.
      */
     private static EndNode getBlockEndAfterLoopExit(LoopBeginNode curLoopBegin) {
-        FixedNode node = getSingleExitFromLoop(curLoopBegin).next();
+        FixedNode node = curLoopBegin.getSingleLoopExit().next();
         // Find the last node after the exit blocks starts
         return getBlockEnd(node);
     }
 
     private static EndNode getBlockEnd(FixedNode node) {

@@ -691,45 +440,20 @@
         }
         return controls;
     }
 
     public static boolean isUnrollableLoop(LoopEx loop) {
-        if (!loop.isCounted()) {
+        if (!loop.isCounted() || !loop.counted().getCounter().isConstantStride()) {
             return false;
         }
         LoopBeginNode loopBegin = loop.loopBegin();
-        boolean isCanonical = false;
         if (loopBegin.isMainLoop() || loopBegin.isSimpleLoop()) {
             // Flow-less loops to partial unroll for now. 3 blocks corresponds to an if that either
             // exits or continues the loop. There might be fixed and floating work within the loop
             // as well.
             if (loop.loop().getBlocks().size() < 3) {
-                isCanonical = true;
-            }
-        }
-        if (!isCanonical) {
-            return false;
-        }
-        for (ValuePhiNode phi : loopBegin.valuePhis()) {
-            if (phi.usages().filter(x -> loopBegin.isPhiAtMerge(x)).isNotEmpty()) {
-                // Filter out Phis which reference Phis at the same merge until the duplication
-                // logic handles it properly.
-                return false;
+                return true;
             }
-            InductionVariable iv = loop.getInductionVariables().get(phi);
-            if (iv == null) {
-                continue;
             }
-            if (iv instanceof DerivedInductionVariable) {
-                return false;
-            } else if (iv instanceof BasicInductionVariable) {
-                BasicInductionVariable biv = (BasicInductionVariable) iv;
-                if (!biv.isConstantStride()) {
                     return false;
                 }
-            } else {
-                return false;
-            }
-        }
-        return true;
-    }
 }
< prev index next >