1 /*
   2  * Copyright (c) 2012, 2012, Oracle and/or its affiliates. All rights reserved.
   3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
   4  *
   5  * This code is free software; you can redistribute it and/or modify it
   6  * under the terms of the GNU General Public License version 2 only, as
   7  * published by the Free Software Foundation.
   8  *
   9  * This code is distributed in the hope that it will be useful, but WITHOUT
  10  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
  11  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
  12  * version 2 for more details (a copy is included in the LICENSE file that
  13  * accompanied this code).
  14  *
  15  * You should have received a copy of the GNU General Public License version
  16  * 2 along with this work; if not, write to the Free Software Foundation,
  17  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
  18  *
  19  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
  20  * or visit www.oracle.com if you need additional information or have any
  21  * questions.
  22  */
  23 package org.graalvm.compiler.loop;
  24 
  25 import jdk.vm.ci.meta.TriState;
  26 import org.graalvm.compiler.debug.GraalError;
  27 import org.graalvm.compiler.graph.Graph;
  28 import org.graalvm.compiler.graph.Graph.DuplicationReplacement;
  29 import org.graalvm.compiler.graph.Node;
  30 import org.graalvm.compiler.graph.NodeBitMap;
  31 import org.graalvm.compiler.graph.iterators.NodeIterable;
  32 import org.graalvm.compiler.nodes.AbstractBeginNode;
  33 import org.graalvm.compiler.nodes.EndNode;
  34 import org.graalvm.compiler.nodes.FixedNode;
  35 import org.graalvm.compiler.nodes.FrameState;
  36 import org.graalvm.compiler.nodes.GuardNode;
  37 import org.graalvm.compiler.nodes.GuardPhiNode;
  38 import org.graalvm.compiler.nodes.GuardProxyNode;
  39 import org.graalvm.compiler.nodes.Invoke;
  40 import org.graalvm.compiler.nodes.LoopExitNode;
  41 import org.graalvm.compiler.nodes.MergeNode;
  42 import org.graalvm.compiler.nodes.PhiNode;
  43 import org.graalvm.compiler.nodes.ProxyNode;
  44 import org.graalvm.compiler.nodes.StructuredGraph;
  45 import org.graalvm.compiler.nodes.ValueNode;
  46 import org.graalvm.compiler.nodes.ValuePhiNode;
  47 import org.graalvm.compiler.nodes.ValueProxyNode;
  48 import org.graalvm.compiler.nodes.VirtualState;
  49 import org.graalvm.compiler.nodes.cfg.Block;
  50 import org.graalvm.compiler.nodes.java.MonitorEnterNode;
  51 import org.graalvm.compiler.nodes.spi.NodeWithState;
  52 import org.graalvm.compiler.nodes.virtual.CommitAllocationNode;
  53 import org.graalvm.compiler.nodes.virtual.VirtualObjectNode;
  54 import org.graalvm.util.EconomicMap;
  55 
  56 import java.util.ArrayDeque;
  57 import java.util.Collections;
  58 import java.util.Deque;
  59 import java.util.Iterator;
  60 
  61 public abstract class LoopFragment {
  62 
  63     private final LoopEx loop;
  64     private final LoopFragment original;
  65     protected NodeBitMap nodes;
  66     protected boolean nodesReady;
  67     private EconomicMap<Node, Node> duplicationMap;
  68 
  69     public LoopFragment(LoopEx loop) {
  70         this(loop, null);
  71         this.nodesReady = true;
  72     }
  73 
  74     public LoopFragment(LoopEx loop, LoopFragment original) {
  75         this.loop = loop;
  76         this.original = original;
  77         this.nodesReady = false;
  78     }
  79 
  80     /**
  81      * Return the original LoopEx for this fragment. For duplicated fragments this returns null.
  82      */
  83     protected LoopEx loop() {
  84         return loop;
  85     }
  86 
  87     public abstract LoopFragment duplicate();
  88 
  89     public abstract void insertBefore(LoopEx l);
  90 
  91     public void disconnect() {
  92         // TODO (gd) possibly abstract
  93     }
  94 
  95     public boolean contains(Node n) {
  96         return nodes().isMarkedAndGrow(n);
  97     }
  98 
  99     @SuppressWarnings("unchecked")
 100     public <New extends Node, Old extends New> New getDuplicatedNode(Old n) {
 101         assert isDuplicate();
 102         return (New) duplicationMap.get(n);
 103     }
 104 
 105     protected <New extends Node, Old extends New> void putDuplicatedNode(Old oldNode, New newNode) {
 106         duplicationMap.put(oldNode, newNode);
 107     }
 108 
 109     /**
 110      * Gets the corresponding value in this fragment. Should be called on duplicate fragments with a
 111      * node from the original fragment as argument.
 112      *
 113      * @param b original value
 114      * @return corresponding value in the peel
 115      */
 116     protected abstract ValueNode prim(ValueNode b);
 117 
 118     public boolean isDuplicate() {
 119         return original != null;
 120     }
 121 
 122     public LoopFragment original() {
 123         return original;
 124     }
 125 
 126     public abstract NodeBitMap nodes();
 127 
 128     public StructuredGraph graph() {
 129         LoopEx l;
 130         if (isDuplicate()) {
 131             l = original().loop();
 132         } else {
 133             l = loop();
 134         }
 135         return l.loopBegin().graph();
 136     }
 137 
 138     protected abstract DuplicationReplacement getDuplicationReplacement();
 139 
 140     protected abstract void beforeDuplication();
 141 
 142     protected abstract void finishDuplication();
 143 
 144     protected void patchNodes(final DuplicationReplacement dataFix) {
 145         if (isDuplicate() && !nodesReady) {
 146             assert !original.isDuplicate();
 147             final DuplicationReplacement cfgFix = original().getDuplicationReplacement();
 148             DuplicationReplacement dr;
 149             if (cfgFix == null && dataFix != null) {
 150                 dr = dataFix;
 151             } else if (cfgFix != null && dataFix == null) {
 152                 dr = cfgFix;
 153             } else if (cfgFix != null && dataFix != null) {
 154                 dr = new DuplicationReplacement() {
 155 
 156                     @Override
 157                     public Node replacement(Node o) {
 158                         Node r1 = dataFix.replacement(o);
 159                         if (r1 != o) {
 160                             assert cfgFix.replacement(o) == o;
 161                             return r1;
 162                         }
 163                         Node r2 = cfgFix.replacement(o);
 164                         if (r2 != o) {
 165                             return r2;
 166                         }
 167                         return o;
 168                     }
 169                 };
 170             } else {
 171                 dr = null;
 172             }
 173             beforeDuplication();
 174             NodeIterable<Node> nodesIterable = original().nodes();
 175             duplicationMap = graph().addDuplicates(nodesIterable, graph(), nodesIterable.count(), dr);
 176             finishDuplication();
 177             nodes = new NodeBitMap(graph());
 178             nodes.markAll(duplicationMap.getValues());
 179             nodesReady = true;
 180         } else {
 181             // TODO (gd) apply fix ?
 182         }
 183     }
 184 
 185     protected static NodeBitMap computeNodes(Graph graph, Iterable<AbstractBeginNode> blocks) {
 186         return computeNodes(graph, blocks, Collections.emptyList());
 187     }
 188 
 189     protected static NodeBitMap computeNodes(Graph graph, Iterable<AbstractBeginNode> blocks, Iterable<AbstractBeginNode> earlyExits) {
 190         final NodeBitMap nodes = graph.createNodeBitMap();
 191         computeNodes(nodes, graph, blocks, earlyExits);
 192         return nodes;
 193     }
 194 
 195     protected static void computeNodes(NodeBitMap nodes, Graph graph, Iterable<AbstractBeginNode> blocks, Iterable<AbstractBeginNode> earlyExits) {
 196         for (AbstractBeginNode b : blocks) {
 197             if (b.isDeleted()) {
 198                 continue;
 199             }
 200 
 201             for (Node n : b.getBlockNodes()) {
 202                 if (n instanceof Invoke) {
 203                     nodes.mark(((Invoke) n).callTarget());
 204                 }
 205                 if (n instanceof NodeWithState) {
 206                     NodeWithState withState = (NodeWithState) n;
 207                     withState.states().forEach(state -> state.applyToVirtual(node -> nodes.mark(node)));
 208                 }
 209                 nodes.mark(n);
 210             }
 211         }
 212         for (AbstractBeginNode earlyExit : earlyExits) {
 213             if (earlyExit.isDeleted()) {
 214                 continue;
 215             }
 216 
 217             nodes.mark(earlyExit);
 218 
 219             if (earlyExit instanceof LoopExitNode) {
 220                 LoopExitNode loopExit = (LoopExitNode) earlyExit;
 221                 FrameState stateAfter = loopExit.stateAfter();
 222                 if (stateAfter != null) {
 223                     stateAfter.applyToVirtual(node -> nodes.mark(node));
 224                 }
 225                 for (ProxyNode proxy : loopExit.proxies()) {
 226                     nodes.mark(proxy);
 227                 }
 228             }
 229         }
 230 
 231         final NodeBitMap nonLoopNodes = graph.createNodeBitMap();
 232         Deque<WorkListEntry> worklist = new ArrayDeque<>();
 233         for (AbstractBeginNode b : blocks) {
 234             if (b.isDeleted()) {
 235                 continue;
 236             }
 237 
 238             for (Node n : b.getBlockNodes()) {
 239                 if (n instanceof CommitAllocationNode) {
 240                     for (VirtualObjectNode obj : ((CommitAllocationNode) n).getVirtualObjects()) {
 241                         markFloating(worklist, obj, nodes, nonLoopNodes);
 242                     }
 243                 }
 244                 if (n instanceof MonitorEnterNode) {
 245                     markFloating(worklist, ((MonitorEnterNode) n).getMonitorId(), nodes, nonLoopNodes);
 246                 }
 247                 for (Node usage : n.usages()) {
 248                     markFloating(worklist, usage, nodes, nonLoopNodes);
 249                 }
 250             }
 251         }
 252     }
 253 
 254     static class WorkListEntry {
 255         final Iterator<Node> usages;
 256         final Node n;
 257         boolean isLoopNode;
 258 
 259         WorkListEntry(Node n, NodeBitMap loopNodes) {
 260             this.n = n;
 261             this.usages = n.usages().iterator();
 262             this.isLoopNode = loopNodes.isMarked(n);
 263         }
 264     }
 265 
 266     static TriState isLoopNode(Node n, NodeBitMap loopNodes, NodeBitMap nonLoopNodes) {
 267         if (loopNodes.isMarked(n)) {
 268             return TriState.TRUE;
 269         }
 270         if (nonLoopNodes.isMarked(n)) {
 271             return TriState.FALSE;
 272         }
 273         if (n instanceof FixedNode) {
 274             return TriState.FALSE;
 275         }
 276         boolean mark = false;
 277         if (n instanceof PhiNode) {
 278             PhiNode phi = (PhiNode) n;
 279             mark = loopNodes.isMarked(phi.merge());
 280             if (mark) {
 281                 /*
 282                  * This Phi is a loop node but the inputs might not be so they must be processed by
 283                  * the caller.
 284                  */
 285                 loopNodes.mark(n);
 286             } else {
 287                 nonLoopNodes.mark(n);
 288                 return TriState.FALSE;
 289             }
 290         }
 291         return TriState.UNKNOWN;
 292     }
 293 
 294     private static void markFloating(Deque<WorkListEntry> workList, Node start, NodeBitMap loopNodes, NodeBitMap nonLoopNodes) {
 295         if (isLoopNode(start, loopNodes, nonLoopNodes).isKnown()) {
 296             return;
 297         }
 298         workList.push(new WorkListEntry(start, loopNodes));
 299         while (!workList.isEmpty()) {
 300             WorkListEntry currentEntry = workList.peek();
 301             if (currentEntry.usages.hasNext()) {
 302                 Node current = currentEntry.usages.next();
 303                 TriState result = isLoopNode(current, loopNodes, nonLoopNodes);
 304                 if (result.isKnown()) {
 305                     if (result.toBoolean()) {
 306                         currentEntry.isLoopNode = true;
 307                     }
 308                 } else {
 309                     workList.push(new WorkListEntry(current, loopNodes));
 310                 }
 311             } else {
 312                 workList.pop();
 313                 boolean isLoopNode = currentEntry.isLoopNode;
 314                 Node current = currentEntry.n;
 315                 if (!isLoopNode && current instanceof GuardNode) {
 316                     /*
 317                      * (gd) this is only OK if we are not going to make loop transforms based on
 318                      * this
 319                      */
 320                     assert !((GuardNode) current).graph().hasValueProxies();
 321                     isLoopNode = true;
 322                 }
 323                 if (isLoopNode) {
 324                     loopNodes.mark(current);
 325                     for (WorkListEntry e : workList) {
 326                         e.isLoopNode = true;
 327                     }
 328                 } else {
 329                     nonLoopNodes.mark(current);
 330                 }
 331             }
 332         }
 333     }
 334 
 335     public static NodeIterable<AbstractBeginNode> toHirBlocks(final Iterable<Block> blocks) {
 336         return new NodeIterable<AbstractBeginNode>() {
 337 
 338             @Override
 339             public Iterator<AbstractBeginNode> iterator() {
 340                 final Iterator<Block> it = blocks.iterator();
 341                 return new Iterator<AbstractBeginNode>() {
 342 
 343                     @Override
 344                     public void remove() {
 345                         throw new UnsupportedOperationException();
 346                     }
 347 
 348                     @Override
 349                     public AbstractBeginNode next() {
 350                         return it.next().getBeginNode();
 351                     }
 352 
 353                     @Override
 354                     public boolean hasNext() {
 355                         return it.hasNext();
 356                     }
 357                 };
 358             }
 359 
 360         };
 361     }
 362 
 363     public static NodeIterable<AbstractBeginNode> toHirExits(final Iterable<Block> blocks) {
 364         return new NodeIterable<AbstractBeginNode>() {
 365 
 366             @Override
 367             public Iterator<AbstractBeginNode> iterator() {
 368                 final Iterator<Block> it = blocks.iterator();
 369                 return new Iterator<AbstractBeginNode>() {
 370 
 371                     @Override
 372                     public void remove() {
 373                         throw new UnsupportedOperationException();
 374                     }
 375 
 376                     /**
 377                      * Return the true LoopExitNode for this loop or the BeginNode for the block.
 378                      */
 379                     @Override
 380                     public AbstractBeginNode next() {
 381                         Block next = it.next();
 382                         LoopExitNode exit = next.getLoopExit();
 383                         if (exit != null) {
 384                             return exit;
 385                         }
 386                         return next.getBeginNode();
 387                     }
 388 
 389                     @Override
 390                     public boolean hasNext() {
 391                         return it.hasNext();
 392                     }
 393                 };
 394             }
 395 
 396         };
 397     }
 398 
 399     /**
 400      * Merges the early exits (i.e. loop exits) that were duplicated as part of this fragment, with
 401      * the original fragment's exits.
 402      */
 403     protected void mergeEarlyExits() {
 404         assert isDuplicate();
 405         StructuredGraph graph = graph();
 406         for (AbstractBeginNode earlyExit : LoopFragment.toHirBlocks(original().loop().loop().getExits())) {
 407             LoopExitNode loopEarlyExit = (LoopExitNode) earlyExit;
 408             FixedNode next = loopEarlyExit.next();
 409             if (loopEarlyExit.isDeleted() || !this.original().contains(loopEarlyExit)) {
 410                 continue;
 411             }
 412             AbstractBeginNode newEarlyExit = getDuplicatedNode(loopEarlyExit);
 413             if (newEarlyExit == null) {
 414                 continue;
 415             }
 416             MergeNode merge = graph.add(new MergeNode());
 417             EndNode originalEnd = graph.add(new EndNode());
 418             EndNode newEnd = graph.add(new EndNode());
 419             merge.addForwardEnd(originalEnd);
 420             merge.addForwardEnd(newEnd);
 421             loopEarlyExit.setNext(originalEnd);
 422             newEarlyExit.setNext(newEnd);
 423             merge.setNext(next);
 424 
 425             FrameState exitState = loopEarlyExit.stateAfter();
 426             if (exitState != null) {
 427                 FrameState originalExitState = exitState;
 428                 exitState = exitState.duplicateWithVirtualState();
 429                 loopEarlyExit.setStateAfter(exitState);
 430                 merge.setStateAfter(originalExitState);
 431                 /*
 432                  * Using the old exit's state as the merge's state is necessary because some of the
 433                  * VirtualState nodes contained in the old exit's state may be shared by other
 434                  * dominated VirtualStates. Those dominated virtual states need to see the
 435                  * proxy->phi update that are applied below.
 436                  *
 437                  * We now update the original fragment's nodes accordingly:
 438                  */
 439                 originalExitState.applyToVirtual(node -> original.nodes.clearAndGrow(node));
 440                 exitState.applyToVirtual(node -> original.nodes.markAndGrow(node));
 441             }
 442             FrameState finalExitState = exitState;
 443 
 444             for (Node anchored : loopEarlyExit.anchored().snapshot()) {
 445                 anchored.replaceFirstInput(loopEarlyExit, merge);
 446             }
 447 
 448             boolean newEarlyExitIsLoopExit = newEarlyExit instanceof LoopExitNode;
 449             for (ProxyNode vpn : loopEarlyExit.proxies().snapshot()) {
 450                 if (vpn.hasNoUsages()) {
 451                     continue;
 452                 }
 453                 if (vpn.value() == null) {
 454                     assert vpn instanceof GuardProxyNode;
 455                     vpn.replaceAtUsages(null);
 456                     continue;
 457                 }
 458                 final ValueNode replaceWith;
 459                 ValueNode newVpn = prim(newEarlyExitIsLoopExit ? vpn : vpn.value());
 460                 if (newVpn != null) {
 461                     PhiNode phi;
 462                     if (vpn instanceof ValueProxyNode) {
 463                         phi = graph.addWithoutUnique(new ValuePhiNode(vpn.stamp(), merge));
 464                     } else if (vpn instanceof GuardProxyNode) {
 465                         phi = graph.addWithoutUnique(new GuardPhiNode(merge));
 466                     } else {
 467                         throw GraalError.shouldNotReachHere();
 468                     }
 469                     phi.addInput(vpn);
 470                     phi.addInput(newVpn);
 471                     replaceWith = phi;
 472                 } else {
 473                     replaceWith = vpn.value();
 474                 }
 475                 vpn.replaceAtMatchingUsages(replaceWith, usage -> {
 476                     if (merge.isPhiAtMerge(usage)) {
 477                         return false;
 478                     }
 479                     if (usage instanceof VirtualState) {
 480                         VirtualState stateUsage = (VirtualState) usage;
 481                         if (finalExitState != null && finalExitState.isPartOfThisState(stateUsage)) {
 482                             return false;
 483                         }
 484                     }
 485                     return true;
 486                 });
 487             }
 488         }
 489     }
 490 }