1 /*
   2  * Copyright (c) 2017, Oracle and/or its affiliates. All rights reserved.
   3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
   4  *
   5  * This code is free software; you can redistribute it and/or modify it
   6  * under the terms of the GNU General Public License version 2 only, as
   7  * published by the Free Software Foundation.
   8  *
   9  * This code is distributed in the hope that it will be useful, but WITHOUT
  10  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
  11  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
  12  * version 2 for more details (a copy is included in the LICENSE file that
  13  * accompanied this code).
  14  *
  15  * You should have received a copy of the GNU General Public License version
  16  * 2 along with this work; if not, write to the Free Software Foundation,
  17  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
  18  *
  19  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
  20  * or visit www.oracle.com if you need additional information or have any
  21  * questions.
  22  */
  23 package org.graalvm.compiler.loop.test;
  24 
  25 import java.util.ListIterator;
  26 
  27 import org.graalvm.compiler.core.common.CompilationIdentifier;
  28 import org.graalvm.compiler.core.test.GraalCompilerTest;
  29 import org.graalvm.compiler.debug.DebugContext;
  30 import org.graalvm.compiler.graph.iterators.NodeIterable;
  31 import org.graalvm.compiler.java.ComputeLoopFrequenciesClosure;
  32 import org.graalvm.compiler.loop.DefaultLoopPolicies;
  33 import org.graalvm.compiler.loop.LoopEx;
  34 import org.graalvm.compiler.loop.LoopFragmentInside;
  35 import org.graalvm.compiler.loop.LoopsData;
  36 import org.graalvm.compiler.loop.phases.LoopPartialUnrollPhase;
  37 import org.graalvm.compiler.nodes.LoopBeginNode;
  38 import org.graalvm.compiler.nodes.StructuredGraph;
  39 import org.graalvm.compiler.nodes.spi.LoweringTool;
  40 import org.graalvm.compiler.options.OptionValues;
  41 import org.graalvm.compiler.phases.BasePhase;
  42 import org.graalvm.compiler.phases.OptimisticOptimizations;
  43 import org.graalvm.compiler.phases.PhaseSuite;
  44 import org.graalvm.compiler.phases.common.CanonicalizerPhase;
  45 import org.graalvm.compiler.phases.common.ConditionalEliminationPhase;
  46 import org.graalvm.compiler.phases.common.DeadCodeEliminationPhase;
  47 import org.graalvm.compiler.phases.common.DeoptimizationGroupingPhase;
  48 import org.graalvm.compiler.phases.common.FloatingReadPhase;
  49 import org.graalvm.compiler.phases.common.FrameStateAssignmentPhase;
  50 import org.graalvm.compiler.phases.common.GuardLoweringPhase;
  51 import org.graalvm.compiler.phases.common.LoweringPhase;
  52 import org.graalvm.compiler.phases.common.RemoveValueProxyPhase;
  53 import org.graalvm.compiler.phases.tiers.MidTierContext;
  54 import org.graalvm.compiler.phases.tiers.Suites;
  55 import org.junit.Ignore;
  56 import org.junit.Test;
  57 
  58 import jdk.vm.ci.meta.ResolvedJavaMethod;
  59 
  60 public class LoopPartialUnrollTest extends GraalCompilerTest {
  61 
  62     @Override
  63     protected boolean checkMidTierGraph(StructuredGraph graph) {
  64         NodeIterable<LoopBeginNode> loops = graph.getNodes().filter(LoopBeginNode.class);
  65         for (LoopBeginNode loop : loops) {
  66             if (loop.isMainLoop()) {
  67                 return true;
  68             }
  69         }
  70         return false;
  71     }
  72 
  73     public static long sumWithEqualityLimit(int[] text) {
  74         long sum = 0;
  75         for (int i = 0; branchProbability(0.99, i != text.length); ++i) {
  76             sum += volatileInt;
  77         }
  78         return sum;
  79     }
  80 
  81     @Ignore("equality limits aren't working properly")
  82     @Test
  83     public void testSumWithEqualityLimit() {
  84         for (int i = 0; i < 128; i++) {
  85             int[] data = new int[i];
  86             test("sumWithEqualityLimit", data);
  87         }
  88     }
  89 
  90     @Test
  91     public void testLoopCarried() {
  92         for (int i = 0; i < 64; i++) {
  93             test("testLoopCarriedSnippet", i);
  94         }
  95     }
  96 
  97     @Test
  98     public void testLoopCarriedDuplication() {
  99         testDuplicateBody("testLoopCarriedReference", "testLoopCarriedSnippet");
 100     }
 101 
 102     static volatile int volatileInt = 3;
 103 
 104     public int testLoopCarriedSnippet(int iterations) {
 105         int a = 0;
 106         int b = 0;
 107         int c = 0;
 108 
 109         for (int i = 0; branchProbability(0.99, i < iterations); i++) {
 110             int t1 = volatileInt;
 111             int t2 = a + b;
 112             c = b;
 113             b = a;
 114             a = t1 + t2;
 115         }
 116 
 117         return c;
 118     }
 119 
 120     public int testLoopCarriedReference(int iterations) {
 121         int a = 0;
 122         int b = 0;
 123         int c = 0;
 124 
 125         for (int i = 0; branchProbability(0.99, i < iterations); i += 2) {
 126             int t1 = volatileInt;
 127             int t2 = a + b;
 128             c = b;
 129             b = a;
 130             a = t1 + t2;
 131             t1 = volatileInt;
 132             t2 = a + b;
 133             c = b;
 134             b = a;
 135             a = t1 + t2;
 136         }
 137 
 138         return c;
 139     }
 140 
 141     public static long init = Runtime.getRuntime().totalMemory();
 142     private int x;
 143     private int z;
 144 
 145     public int[] testComplexSnippet(int d) {
 146         x = 3;
 147         int y = 5;
 148         z = 7;
 149         for (int i = 0; i < d; i++) {
 150             for (int j = 0; branchProbability(0.99, j < i); j++) {
 151                 z += x;
 152             }
 153             y = x ^ z;
 154             if ((i & 4) == 0) {
 155                 z--;
 156             } else if ((i & 8) == 0) {
 157                 Runtime.getRuntime().totalMemory();
 158             }
 159         }
 160         return new int[]{x, y, z};
 161     }
 162 
 163     @Test
 164     public void testComplex() {
 165         for (int i = 0; i < 10; i++) {
 166             test("testComplexSnippet", i);
 167         }
 168         test("testComplexSnippet", 10);
 169         test("testComplexSnippet", 100);
 170         test("testComplexSnippet", 1000);
 171     }
 172 
 173     public static long testSignExtensionSnippet(long arg) {
 174         long r = 1;
 175         for (int i = 0; branchProbability(0.99, i < arg); i++) {
 176             r *= i;
 177         }
 178         return r;
 179     }
 180 
 181     @Test
 182     public void testSignExtension() {
 183         test("testSignExtensionSnippet", 9L);
 184     }
 185 
 186     @Override
 187     protected Suites createSuites(OptionValues opts) {
 188         Suites suites = super.createSuites(opts).copy();
 189         PhaseSuite<MidTierContext> mid = suites.getMidTier();
 190         ListIterator<BasePhase<? super MidTierContext>> iter = mid.findPhase(LoopPartialUnrollPhase.class);
 191         BasePhase<? super MidTierContext> partialUnoll = iter.previous();
 192         if (iter.previous().getClass() != FrameStateAssignmentPhase.class) {
 193             // Ensure LoopPartialUnrollPhase runs immediately after FrameStateAssignment, so it gets
 194             // priority over other optimizations in these tests.
 195             mid.findPhase(LoopPartialUnrollPhase.class).remove();
 196             ListIterator<BasePhase<? super MidTierContext>> fsa = mid.findPhase(FrameStateAssignmentPhase.class);
 197             fsa.add(partialUnoll);
 198         }
 199         return suites;
 200     }
 201 
 202     public void testGraph(String reference, String test) {
 203         StructuredGraph referenceGraph = buildGraph(reference, false);
 204         StructuredGraph testGraph = buildGraph(test, true);
 205         assertEquals(referenceGraph, testGraph, false, false);
 206     }
 207 
 208     @SuppressWarnings("try")
 209     public StructuredGraph buildGraph(String name, boolean partialUnroll) {
 210         CompilationIdentifier id = new CompilationIdentifier() {
 211             @Override
 212             public String toString(Verbosity verbosity) {
 213                 return name;
 214             }
 215         };
 216         ResolvedJavaMethod method = getResolvedJavaMethod(name);
 217         OptionValues options = new OptionValues(getInitialOptions(), DefaultLoopPolicies.Options.UnrollMaxIterations, 2);
 218         StructuredGraph graph = parse(builder(method, StructuredGraph.AllowAssumptions.YES, id, options), getEagerGraphBuilderSuite());
 219         try (DebugContext.Scope buildScope = graph.getDebug().scope(name, method, graph)) {
 220             MidTierContext context = new MidTierContext(getProviders(), getTargetProvider(), OptimisticOptimizations.ALL, null);
 221 
 222             CanonicalizerPhase canonicalizer = new CanonicalizerPhase();
 223             canonicalizer.apply(graph, context);
 224             new RemoveValueProxyPhase().apply(graph);
 225             new LoweringPhase(canonicalizer, LoweringTool.StandardLoweringStage.HIGH_TIER).apply(graph, context);
 226             new FloatingReadPhase().apply(graph);
 227             new DeadCodeEliminationPhase().apply(graph);
 228             new ConditionalEliminationPhase(true).apply(graph, context);
 229             ComputeLoopFrequenciesClosure.compute(graph);
 230             new GuardLoweringPhase().apply(graph, context);
 231             new LoweringPhase(canonicalizer, LoweringTool.StandardLoweringStage.MID_TIER).apply(graph, context);
 232             new FrameStateAssignmentPhase().apply(graph);
 233             new DeoptimizationGroupingPhase().apply(graph, context);
 234             canonicalizer.apply(graph, context);
 235             new ConditionalEliminationPhase(true).apply(graph, context);
 236             if (partialUnroll) {
 237                 LoopsData dataCounted = new LoopsData(graph);
 238                 dataCounted.detectedCountedLoops();
 239                 for (LoopEx loop : dataCounted.countedLoops()) {
 240                     LoopFragmentInside newSegment = loop.inside().duplicate();
 241                     newSegment.insertWithinAfter(loop, false);
 242                 }
 243                 canonicalizer.apply(graph, getDefaultMidTierContext());
 244             }
 245             new DeadCodeEliminationPhase().apply(graph);
 246             canonicalizer.apply(graph, context);
 247             graph.getDebug().dump(DebugContext.BASIC_LEVEL, graph, "before compare");
 248             return graph;
 249         } catch (Throwable e) {
 250             throw getDebugContext().handle(e);
 251         }
 252     }
 253 
 254     public void testDuplicateBody(String reference, String test) {
 255 
 256         StructuredGraph referenceGraph = buildGraph(reference, false);
 257         StructuredGraph testGraph = buildGraph(test, true);
 258         CanonicalizerPhase canonicalizer = new CanonicalizerPhase();
 259         canonicalizer.apply(testGraph, getDefaultMidTierContext());
 260         canonicalizer.apply(referenceGraph, getDefaultMidTierContext());
 261         assertEquals(referenceGraph, testGraph);
 262     }
 263 }