< prev index next >

src/jdk.internal.vm.compiler/share/classes/org.graalvm.compiler.loop.phases/src/org/graalvm/compiler/loop/phases/LoopTransformations.java

Print this page

        

@@ -52,11 +52,10 @@
 import org.graalvm.compiler.nodes.FixedNode;
 import org.graalvm.compiler.nodes.FixedWithNextNode;
 import org.graalvm.compiler.nodes.IfNode;
 import org.graalvm.compiler.nodes.LogicNode;
 import org.graalvm.compiler.nodes.LoopBeginNode;
-import org.graalvm.compiler.nodes.LoopExitNode;
 import org.graalvm.compiler.nodes.NodeView;
 import org.graalvm.compiler.nodes.PhiNode;
 import org.graalvm.compiler.nodes.SafepointNode;
 import org.graalvm.compiler.nodes.StructuredGraph;
 import org.graalvm.compiler.nodes.ValueNode;

@@ -229,15 +228,15 @@
     public static LoopBeginNode insertPrePostLoops(LoopEx loop) {
         StructuredGraph graph = loop.loopBegin().graph();
         graph.getDebug().log("LoopTransformations.insertPrePostLoops %s", loop);
         LoopFragmentWhole preLoop = loop.whole();
         CountedLoopInfo preCounted = loop.counted();
-        IfNode preLimit = preCounted.getLimitTest();
-        assert preLimit != null;
         LoopBeginNode preLoopBegin = loop.loopBegin();
-        LoopExitNode preLoopExitNode = preLoopBegin.getSingleLoopExit();
-        FixedNode continuationNode = preLoopExitNode.next();
+        AbstractBeginNode preLoopExitNode = preCounted.getCountedExit();
+
+        assert preLoop.nodes().contains(preLoopBegin);
+        assert preLoop.nodes().contains(preLoopExitNode);
 
         // Each duplication is inserted after the original, ergo create the post loop first
         LoopFragmentWhole mainLoop = preLoop.duplicate();
         LoopFragmentWhole postLoop = preLoop.duplicate();
         preLoopBegin.incrementSplits();

@@ -247,27 +246,30 @@
         LoopBeginNode mainLoopBegin = mainLoop.getDuplicatedNode(preLoopBegin);
         mainLoopBegin.setMainLoop();
         LoopBeginNode postLoopBegin = postLoop.getDuplicatedNode(preLoopBegin);
         postLoopBegin.setPostLoop();
 
-        EndNode postEndNode = getBlockEndAfterLoopExit(postLoopBegin);
+        AbstractBeginNode postLoopExitNode = postLoop.getDuplicatedNode(preLoopExitNode);
+        EndNode postEndNode = getBlockEndAfterLoopExit(postLoopExitNode);
         AbstractMergeNode postMergeNode = postEndNode.merge();
-        LoopExitNode postLoopExitNode = postLoopBegin.getSingleLoopExit();
 
         // Update the main loop phi initialization to carry from the pre loop
         for (PhiNode prePhiNode : preLoopBegin.phis()) {
             PhiNode mainPhiNode = mainLoop.getDuplicatedNode(prePhiNode);
             mainPhiNode.setValueAt(0, prePhiNode);
         }
 
-        EndNode mainEndNode = getBlockEndAfterLoopExit(mainLoopBegin);
+        AbstractBeginNode mainLoopExitNode = mainLoop.getDuplicatedNode(preLoopExitNode);
+        EndNode mainEndNode = getBlockEndAfterLoopExit(mainLoopExitNode);
         AbstractMergeNode mainMergeNode = mainEndNode.merge();
         AbstractEndNode postEntryNode = postLoopBegin.forwardEnd();
 
+        // Exits have been merged, find the continuation below the merge
+        FixedNode continuationNode = mainMergeNode.next();
+
         // In the case of no Bounds tests, we just flow right into the main loop
         AbstractBeginNode mainLandingNode = BeginNode.begin(postEntryNode);
-        LoopExitNode mainLoopExitNode = mainLoopBegin.getSingleLoopExit();
         mainLoopExitNode.setNext(mainLandingNode);
         preLoopExitNode.setNext(mainLoopBegin.forwardEnd());
 
         // Add and update any phi edges as per merge usage as needed and update usages
         processPreLoopPhis(loop, mainLoop, postLoop);

@@ -335,12 +337,12 @@
     }
 
     /**
      * Find the end of the block following the LoopExit.
      */
-    private static EndNode getBlockEndAfterLoopExit(LoopBeginNode curLoopBegin) {
-        FixedNode node = curLoopBegin.getSingleLoopExit().next();
+    private static EndNode getBlockEndAfterLoopExit(AbstractBeginNode exit) {
+        FixedNode node = exit.next();
         // Find the last node after the exit blocks starts
         return getBlockEnd(node);
     }
 
     private static EndNode getBlockEnd(FixedNode node) {

@@ -421,10 +423,13 @@
             Math.addExact(stride, stride);
         } catch (ArithmeticException ae) {
             condition.getDebug().log(DebugContext.VERBOSE_LEVEL, "isUnrollableLoop %s doubling the stride overflows %d", loopBegin, stride);
             return false;
         }
+        if (!loop.canDuplicateLoop()) {
+            return false;
+        }
         if (loopBegin.isMainLoop() || loopBegin.isSimpleLoop()) {
             // Flow-less loops to partial unroll for now. 3 blocks corresponds to an if that either
             // exits or continues the loop. There might be fixed and floating work within the loop
             // as well.
             if (loop.loop().getBlocks().size() < 3) {
< prev index next >