1 /*
   2  * Copyright (c) 2017, 2018, Oracle and/or its affiliates. All rights reserved.
   3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
   4  *
   5  * This code is free software; you can redistribute it and/or modify it
   6  * under the terms of the GNU General Public License version 2 only, as
   7  * published by the Free Software Foundation.
   8  *
   9  * This code is distributed in the hope that it will be useful, but WITHOUT
  10  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
  11  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
  12  * version 2 for more details (a copy is included in the LICENSE file that
  13  * accompanied this code).
  14  *
  15  * You should have received a copy of the GNU General Public License version
  16  * 2 along with this work; if not, write to the Free Software Foundation,
  17  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
  18  *
  19  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
  20  * or visit www.oracle.com if you need additional information or have any
  21  * questions.
  22  */
  23 
  24 
  25 package org.graalvm.compiler.loop.test;
  26 
  27 import java.util.ListIterator;
  28 
  29 import org.graalvm.compiler.api.directives.GraalDirectives;
  30 import org.graalvm.compiler.core.common.CompilationIdentifier;
  31 import org.graalvm.compiler.core.common.GraalOptions;
  32 import org.graalvm.compiler.core.test.GraalCompilerTest;
  33 import org.graalvm.compiler.debug.DebugContext;
  34 import org.graalvm.compiler.graph.iterators.NodeIterable;
  35 import org.graalvm.compiler.java.ComputeLoopFrequenciesClosure;
  36 import org.graalvm.compiler.loop.DefaultLoopPolicies;
  37 import org.graalvm.compiler.loop.LoopEx;
  38 import org.graalvm.compiler.loop.LoopFragmentInside;
  39 import org.graalvm.compiler.loop.LoopsData;
  40 import org.graalvm.compiler.loop.phases.LoopPartialUnrollPhase;
  41 import org.graalvm.compiler.nodes.LoopBeginNode;
  42 import org.graalvm.compiler.nodes.StructuredGraph;
  43 import org.graalvm.compiler.nodes.spi.LoweringTool;
  44 import org.graalvm.compiler.options.OptionValues;
  45 import org.graalvm.compiler.phases.BasePhase;
  46 import org.graalvm.compiler.phases.OptimisticOptimizations;
  47 import org.graalvm.compiler.phases.PhaseSuite;
  48 import org.graalvm.compiler.phases.common.CanonicalizerPhase;
  49 import org.graalvm.compiler.phases.common.ConditionalEliminationPhase;
  50 import org.graalvm.compiler.phases.common.DeadCodeEliminationPhase;
  51 import org.graalvm.compiler.phases.common.DeoptimizationGroupingPhase;
  52 import org.graalvm.compiler.phases.common.FloatingReadPhase;
  53 import org.graalvm.compiler.phases.common.FrameStateAssignmentPhase;
  54 import org.graalvm.compiler.phases.common.GuardLoweringPhase;
  55 import org.graalvm.compiler.phases.common.LoweringPhase;
  56 import org.graalvm.compiler.phases.common.RemoveValueProxyPhase;
  57 import org.graalvm.compiler.phases.tiers.MidTierContext;
  58 import org.graalvm.compiler.phases.tiers.Suites;
  59 import org.junit.Ignore;
  60 import org.junit.Test;
  61 
  62 import jdk.vm.ci.meta.ResolvedJavaMethod;
  63 
  64 public class LoopPartialUnrollTest extends GraalCompilerTest {
  65 
  66     @Override
  67     protected void checkMidTierGraph(StructuredGraph graph) {
  68         NodeIterable<LoopBeginNode> loops = graph.getNodes().filter(LoopBeginNode.class);
  69         for (LoopBeginNode loop : loops) {
  70             if (loop.isMainLoop()) {
  71                 return;
  72             }
  73         }
  74         fail("expected a main loop");
  75     }
  76 
  77     public static long sumWithEqualityLimit(int[] text) {
  78         long sum = 0;
  79         for (int i = 0; branchProbability(0.99, i != text.length); ++i) {
  80             sum += volatileInt;
  81         }
  82         return sum;
  83     }
  84 
  85     @Ignore("equality limits aren't working properly")
  86     @Test
  87     public void testSumWithEqualityLimit() {
  88         for (int i = -1; i < 128; i++) {
  89             int[] data = new int[i];
  90             test("sumWithEqualityLimit", data);
  91         }
  92     }
  93 
  94     @Test
  95     public void testLoopCarried() {
  96         for (int i = -1; i < 64; i++) {
  97             test("testLoopCarriedSnippet", i);
  98         }
  99     }
 100 
 101     @Test
 102     public void testLoopCarriedDuplication() {
 103         testDuplicateBody("testLoopCarriedReference", "testLoopCarriedSnippet");
 104     }
 105 
 106     static volatile int volatileInt = 3;
 107 
 108     public static int testLoopCarriedSnippet(int iterations) {
 109         int a = 0;
 110         int b = 0;
 111         int c = 0;
 112 
 113         for (int i = 0; branchProbability(0.99, i < iterations); i++) {
 114             int t1 = volatileInt;
 115             int t2 = a + b;
 116             c = b;
 117             b = a;
 118             a = t1 + t2;
 119         }
 120 
 121         return c;
 122     }
 123 
 124     public static int testLoopCarriedReference(int iterations) {
 125         int a = 0;
 126         int b = 0;
 127         int c = 0;
 128 
 129         for (int i = 0; branchProbability(0.99, i < iterations); i += 2) {
 130             int t1 = volatileInt;
 131             int t2 = a + b;
 132             c = b;
 133             b = a;
 134             a = t1 + t2;
 135             t1 = volatileInt;
 136             t2 = a + b;
 137             c = b;
 138             b = a;
 139             a = t1 + t2;
 140         }
 141 
 142         return c;
 143     }
 144 
 145     @Test
 146     @Ignore
 147     public void testUnsignedLoopCarried() {
 148         for (int i = -1; i < 64; i++) {
 149             for (int j = 0; j < 64; j++) {
 150                 test("testUnsignedLoopCarriedSnippet", i, j);
 151             }
 152         }
 153         test("testUnsignedLoopCarriedSnippet", -1 - 32, -1);
 154         test("testUnsignedLoopCarriedSnippet", -1 - 4, -1);
 155         test("testUnsignedLoopCarriedSnippet", -1 - 32, 0);
 156     }
 157 
 158     public static int testUnsignedLoopCarriedSnippet(int start, int end) {
 159         int a = 0;
 160         int b = 0;
 161         int c = 0;
 162 
 163         for (int i = start; branchProbability(0.99, Integer.compareUnsigned(i, end) < 0); i++) {
 164             int t1 = volatileInt;
 165             int t2 = a + b;
 166             c = b;
 167             b = a;
 168             a = t1 + t2;
 169         }
 170 
 171         return c;
 172     }
 173 
 174     @Test
 175     public void testLoopCarried2() {
 176         for (int i = -1; i < 64; i++) {
 177             for (int j = -1; j < 64; j++) {
 178                 test("testLoopCarried2Snippet", i, j);
 179             }
 180         }
 181         test("testLoopCarried2Snippet", Integer.MAX_VALUE - 32, Integer.MAX_VALUE);
 182         test("testLoopCarried2Snippet", Integer.MAX_VALUE - 4, Integer.MAX_VALUE);
 183         test("testLoopCarried2Snippet", Integer.MAX_VALUE, 0);
 184         test("testLoopCarried2Snippet", Integer.MIN_VALUE, Integer.MIN_VALUE + 32);
 185         test("testLoopCarried2Snippet", Integer.MIN_VALUE, Integer.MIN_VALUE + 4);
 186         test("testLoopCarried2Snippet", 0, Integer.MIN_VALUE);
 187     }
 188 
 189     public static int testLoopCarried2Snippet(int start, int end) {
 190         int a = 0;
 191         int b = 0;
 192         int c = 0;
 193 
 194         for (int i = start; branchProbability(0.99, i < end); i++) {
 195             int t1 = volatileInt;
 196             int t2 = a + b;
 197             c = b;
 198             b = a;
 199             a = t1 + t2;
 200         }
 201 
 202         return c;
 203     }
 204 
 205     public static long init = Runtime.getRuntime().totalMemory();
 206     private int x;
 207     private int z;
 208 
 209     public int[] testComplexSnippet(int d) {
 210         x = 3;
 211         int y = 5;
 212         z = 7;
 213         for (int i = 0; i < d; i++) {
 214             for (int j = 0; branchProbability(0.99, j < i); j++) {
 215                 z += x;
 216             }
 217             y = x ^ z;
 218             if ((i & 4) == 0) {
 219                 z--;
 220             } else if ((i & 8) == 0) {
 221                 Runtime.getRuntime().totalMemory();
 222             }
 223         }
 224         return new int[]{x, y, z};
 225     }
 226 
 227     @Test
 228     public void testComplex() {
 229         for (int i = -1; i < 10; i++) {
 230             test("testComplexSnippet", i);
 231         }
 232         test("testComplexSnippet", 10);
 233         test("testComplexSnippet", 100);
 234         test("testComplexSnippet", 1000);
 235     }
 236 
 237     public static long testSignExtensionSnippet(long arg) {
 238         long r = 1;
 239         for (int i = 0; branchProbability(0.99, i < arg); i++) {
 240             r *= i;
 241         }
 242         return r;
 243     }
 244 
 245     @Test
 246     public void testSignExtension() {
 247         test("testSignExtensionSnippet", 9L);
 248     }
 249 
 250     public static Object objectPhi(int n) {
 251         Integer v = Integer.valueOf(200);
 252         GraalDirectives.blackhole(v); // Prevents PEA
 253         Integer r = 1;
 254 
 255         for (int i = 0; iterationCount(100, i < n); i++) {
 256             GraalDirectives.blackhole(r); // Create a phi of two loop invariants
 257             r = v;
 258         }
 259 
 260         return r;
 261     }
 262 
 263     @Test
 264     public void testObjectPhi() {
 265         OptionValues options = new OptionValues(getInitialOptions(), GraalOptions.LoopPeeling, false);
 266         test(options, "objectPhi", 1);
 267     }
 268 
 269     @Override
 270     protected Suites createSuites(OptionValues opts) {
 271         Suites suites = super.createSuites(opts).copy();
 272         PhaseSuite<MidTierContext> mid = suites.getMidTier();
 273         ListIterator<BasePhase<? super MidTierContext>> iter = mid.findPhase(LoopPartialUnrollPhase.class);
 274         BasePhase<? super MidTierContext> partialUnoll = iter.previous();
 275         if (iter.previous().getClass() != FrameStateAssignmentPhase.class) {
 276             // Ensure LoopPartialUnrollPhase runs immediately after FrameStateAssignment, so it gets
 277             // priority over other optimizations in these tests.
 278             mid.findPhase(LoopPartialUnrollPhase.class).remove();
 279             ListIterator<BasePhase<? super MidTierContext>> fsa = mid.findPhase(FrameStateAssignmentPhase.class);
 280             fsa.add(partialUnoll);
 281         }
 282         return suites;
 283     }
 284 
 285     public void testGraph(String reference, String test) {
 286         StructuredGraph referenceGraph = buildGraph(reference, false);
 287         StructuredGraph testGraph = buildGraph(test, true);
 288         assertEquals(referenceGraph, testGraph, false, false);
 289     }
 290 
 291     @SuppressWarnings("try")
 292     public StructuredGraph buildGraph(String name, boolean partialUnroll) {
 293         CompilationIdentifier id = new CompilationIdentifier() {
 294             @Override
 295             public String toString(Verbosity verbosity) {
 296                 return name;
 297             }
 298         };
 299         ResolvedJavaMethod method = getResolvedJavaMethod(name);
 300         OptionValues options = new OptionValues(getInitialOptions(), DefaultLoopPolicies.Options.UnrollMaxIterations, 2);
 301         StructuredGraph graph = parse(builder(method, StructuredGraph.AllowAssumptions.YES, id, options), getEagerGraphBuilderSuite());
 302         try (DebugContext.Scope buildScope = graph.getDebug().scope(name, method, graph)) {
 303             MidTierContext context = new MidTierContext(getProviders(), getTargetProvider(), OptimisticOptimizations.ALL, null);
 304 
 305             CanonicalizerPhase canonicalizer = new CanonicalizerPhase();
 306             canonicalizer.apply(graph, context);
 307             new RemoveValueProxyPhase().apply(graph);
 308             new LoweringPhase(canonicalizer, LoweringTool.StandardLoweringStage.HIGH_TIER).apply(graph, context);
 309             new FloatingReadPhase().apply(graph);
 310             new DeadCodeEliminationPhase().apply(graph);
 311             new ConditionalEliminationPhase(true).apply(graph, context);
 312             ComputeLoopFrequenciesClosure.compute(graph);
 313             new GuardLoweringPhase().apply(graph, context);
 314             new LoweringPhase(canonicalizer, LoweringTool.StandardLoweringStage.MID_TIER).apply(graph, context);
 315             new FrameStateAssignmentPhase().apply(graph);
 316             new DeoptimizationGroupingPhase().apply(graph, context);
 317             canonicalizer.apply(graph, context);
 318             new ConditionalEliminationPhase(true).apply(graph, context);
 319             if (partialUnroll) {
 320                 LoopsData dataCounted = new LoopsData(graph);
 321                 dataCounted.detectedCountedLoops();
 322                 for (LoopEx loop : dataCounted.countedLoops()) {
 323                     LoopFragmentInside newSegment = loop.inside().duplicate();
 324                     newSegment.insertWithinAfter(loop, null);
 325                 }
 326                 canonicalizer.apply(graph, getDefaultMidTierContext());
 327             }
 328             new DeadCodeEliminationPhase().apply(graph);
 329             canonicalizer.apply(graph, context);
 330             graph.getDebug().dump(DebugContext.BASIC_LEVEL, graph, "before compare");
 331             return graph;
 332         } catch (Throwable e) {
 333             throw getDebugContext().handle(e);
 334         }
 335     }
 336 
 337     public void testDuplicateBody(String reference, String test) {
 338 
 339         StructuredGraph referenceGraph = buildGraph(reference, false);
 340         StructuredGraph testGraph = buildGraph(test, true);
 341         CanonicalizerPhase canonicalizer = new CanonicalizerPhase();
 342         canonicalizer.apply(testGraph, getDefaultMidTierContext());
 343         canonicalizer.apply(referenceGraph, getDefaultMidTierContext());
 344         assertEquals(referenceGraph, testGraph);
 345     }
 346 }