1 /*
   2  * Copyright (c) 2011, 2016, Oracle and/or its affiliates. All rights reserved.
   3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
   4  *
   5  * This code is free software; you can redistribute it and/or modify it
   6  * under the terms of the GNU General Public License version 2 only, as
   7  * published by the Free Software Foundation.
   8  *
   9  * This code is distributed in the hope that it will be useful, but WITHOUT
  10  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
  11  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
  12  * version 2 for more details (a copy is included in the LICENSE file that
  13  * accompanied this code).
  14  *
  15  * You should have received a copy of the GNU General Public License version
  16  * 2 along with this work; if not, write to the Free Software Foundation,
  17  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
  18  *
  19  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
  20  * or visit www.oracle.com if you need additional information or have any
  21  * questions.
  22  */
  23 package org.graalvm.compiler.phases.common;
  24 
  25 import static org.graalvm.compiler.core.common.LocationIdentity.any;
  26 import static org.graalvm.compiler.graph.Graph.NodeEvent.NODE_ADDED;
  27 import static org.graalvm.compiler.graph.Graph.NodeEvent.ZERO_USAGES;
  28 
  29 import java.util.Collection;
  30 import java.util.EnumSet;
  31 import java.util.HashMap;
  32 import java.util.Iterator;
  33 import java.util.List;
  34 import java.util.Map;
  35 import java.util.Set;
  36 
  37 import org.graalvm.compiler.core.common.CollectionsFactory;
  38 import org.graalvm.compiler.core.common.LocationIdentity;
  39 import org.graalvm.compiler.core.common.cfg.Loop;
  40 import org.graalvm.compiler.debug.DebugCloseable;
  41 import org.graalvm.compiler.graph.Graph.NodeEventScope;
  42 import org.graalvm.compiler.graph.Node;
  43 import org.graalvm.compiler.nodes.AbstractBeginNode;
  44 import org.graalvm.compiler.nodes.AbstractMergeNode;
  45 import org.graalvm.compiler.nodes.FixedNode;
  46 import org.graalvm.compiler.nodes.InvokeWithExceptionNode;
  47 import org.graalvm.compiler.nodes.LoopBeginNode;
  48 import org.graalvm.compiler.nodes.LoopEndNode;
  49 import org.graalvm.compiler.nodes.LoopExitNode;
  50 import org.graalvm.compiler.nodes.PhiNode;
  51 import org.graalvm.compiler.nodes.ReturnNode;
  52 import org.graalvm.compiler.nodes.StartNode;
  53 import org.graalvm.compiler.nodes.StructuredGraph;
  54 import org.graalvm.compiler.nodes.ValueNodeUtil;
  55 import org.graalvm.compiler.nodes.calc.FloatingNode;
  56 import org.graalvm.compiler.nodes.cfg.Block;
  57 import org.graalvm.compiler.nodes.cfg.ControlFlowGraph;
  58 import org.graalvm.compiler.nodes.cfg.HIRLoop;
  59 import org.graalvm.compiler.nodes.extended.GuardingNode;
  60 import org.graalvm.compiler.nodes.extended.ValueAnchorNode;
  61 import org.graalvm.compiler.nodes.memory.FloatableAccessNode;
  62 import org.graalvm.compiler.nodes.memory.FloatingAccessNode;
  63 import org.graalvm.compiler.nodes.memory.FloatingReadNode;
  64 import org.graalvm.compiler.nodes.memory.MemoryAccess;
  65 import org.graalvm.compiler.nodes.memory.MemoryAnchorNode;
  66 import org.graalvm.compiler.nodes.memory.MemoryCheckpoint;
  67 import org.graalvm.compiler.nodes.memory.MemoryMap;
  68 import org.graalvm.compiler.nodes.memory.MemoryMapNode;
  69 import org.graalvm.compiler.nodes.memory.MemoryNode;
  70 import org.graalvm.compiler.nodes.memory.MemoryPhiNode;
  71 import org.graalvm.compiler.nodes.memory.ReadNode;
  72 import org.graalvm.compiler.nodes.util.GraphUtil;
  73 import org.graalvm.compiler.phases.Phase;
  74 import org.graalvm.compiler.phases.common.util.HashSetNodeEventListener;
  75 import org.graalvm.compiler.phases.graph.ReentrantNodeIterator;
  76 import org.graalvm.compiler.phases.graph.ReentrantNodeIterator.LoopInfo;
  77 import org.graalvm.compiler.phases.graph.ReentrantNodeIterator.NodeIteratorClosure;
  78 
  79 public class FloatingReadPhase extends Phase {
  80 
  81     private boolean createFloatingReads;
  82     private boolean createMemoryMapNodes;
  83 
  84     public static class MemoryMapImpl implements MemoryMap {
  85 
  86         private final Map<LocationIdentity, MemoryNode> lastMemorySnapshot;
  87 
  88         public MemoryMapImpl(MemoryMapImpl memoryMap) {
  89             lastMemorySnapshot = CollectionsFactory.newMap(memoryMap.lastMemorySnapshot);
  90         }
  91 
  92         public MemoryMapImpl(StartNode start) {
  93             lastMemorySnapshot = CollectionsFactory.newMap();
  94             lastMemorySnapshot.put(any(), start);
  95         }
  96 
  97         public MemoryMapImpl() {
  98             lastMemorySnapshot = CollectionsFactory.newMap();
  99         }
 100 
 101         @Override
 102         public MemoryNode getLastLocationAccess(LocationIdentity locationIdentity) {
 103             MemoryNode lastLocationAccess;
 104             if (locationIdentity.isImmutable()) {
 105                 return null;
 106             } else {
 107                 lastLocationAccess = lastMemorySnapshot.get(locationIdentity);
 108                 if (lastLocationAccess == null) {
 109                     lastLocationAccess = lastMemorySnapshot.get(any());
 110                     assert lastLocationAccess != null;
 111                 }
 112                 return lastLocationAccess;
 113             }
 114         }
 115 
 116         @Override
 117         public Collection<LocationIdentity> getLocations() {
 118             return lastMemorySnapshot.keySet();
 119         }
 120 
 121         public Map<LocationIdentity, MemoryNode> getMap() {
 122             return lastMemorySnapshot;
 123         }
 124     }
 125 
 126     public FloatingReadPhase() {
 127         this(true, false);
 128     }
 129 
 130     /**
 131      * @param createFloatingReads specifies whether {@link FloatableAccessNode}s like
 132      *            {@link ReadNode} should be converted into floating nodes (e.g.,
 133      *            {@link FloatingReadNode}s) where possible
 134      * @param createMemoryMapNodes a {@link MemoryMapNode} will be created for each return if this
 135      *            is true
 136      */
 137     public FloatingReadPhase(boolean createFloatingReads, boolean createMemoryMapNodes) {
 138         this.createFloatingReads = createFloatingReads;
 139         this.createMemoryMapNodes = createMemoryMapNodes;
 140     }
 141 
 142     @Override
 143     public float codeSizeIncrease() {
 144         return 1.25f;
 145     }
 146 
 147     /**
 148      * Removes nodes from a given set that (transitively) have a usage outside the set.
 149      */
 150     private static Set<Node> removeExternallyUsedNodes(Set<Node> set) {
 151         boolean change;
 152         do {
 153             change = false;
 154             for (Iterator<Node> iter = set.iterator(); iter.hasNext();) {
 155                 Node node = iter.next();
 156                 for (Node usage : node.usages()) {
 157                     if (!set.contains(usage)) {
 158                         change = true;
 159                         iter.remove();
 160                         break;
 161                     }
 162                 }
 163             }
 164         } while (change);
 165         return set;
 166     }
 167 
 168     protected void processNode(FixedNode node, Set<LocationIdentity> currentState) {
 169         if (node instanceof MemoryCheckpoint.Single) {
 170             currentState.add(((MemoryCheckpoint.Single) node).getLocationIdentity());
 171         } else if (node instanceof MemoryCheckpoint.Multi) {
 172             for (LocationIdentity identity : ((MemoryCheckpoint.Multi) node).getLocationIdentities()) {
 173                 currentState.add(identity);
 174             }
 175         }
 176     }
 177 
 178     protected void processBlock(Block b, Set<LocationIdentity> currentState) {
 179         for (FixedNode n : b.getNodes()) {
 180             processNode(n, currentState);
 181         }
 182     }
 183 
 184     private Set<LocationIdentity> processLoop(HIRLoop loop, Map<LoopBeginNode, Set<LocationIdentity>> modifiedInLoops) {
 185         LoopBeginNode loopBegin = (LoopBeginNode) loop.getHeader().getBeginNode();
 186         Set<LocationIdentity> result = modifiedInLoops.get(loopBegin);
 187         if (result != null) {
 188             return result;
 189         }
 190 
 191         result = CollectionsFactory.newSet();
 192         for (Loop<Block> inner : loop.getChildren()) {
 193             result.addAll(processLoop((HIRLoop) inner, modifiedInLoops));
 194         }
 195 
 196         for (Block b : loop.getBlocks()) {
 197             if (b.getLoop() == loop) {
 198                 processBlock(b, result);
 199             }
 200         }
 201 
 202         modifiedInLoops.put(loopBegin, result);
 203         return result;
 204     }
 205 
 206     @Override
 207     @SuppressWarnings("try")
 208     protected void run(StructuredGraph graph) {
 209         Map<LoopBeginNode, Set<LocationIdentity>> modifiedInLoops = null;
 210         if (graph.hasLoops()) {
 211             modifiedInLoops = new HashMap<>();
 212             ControlFlowGraph cfg = ControlFlowGraph.compute(graph, true, true, false, false);
 213             for (Loop<?> l : cfg.getLoops()) {
 214                 HIRLoop loop = (HIRLoop) l;
 215                 processLoop(loop, modifiedInLoops);
 216             }
 217         }
 218 
 219         HashSetNodeEventListener listener = new HashSetNodeEventListener(EnumSet.of(NODE_ADDED, ZERO_USAGES));
 220         try (NodeEventScope nes = graph.trackNodeEvents(listener)) {
 221             ReentrantNodeIterator.apply(new FloatingReadClosure(modifiedInLoops, createFloatingReads, createMemoryMapNodes), graph.start(), new MemoryMapImpl(graph.start()));
 222         }
 223 
 224         for (Node n : removeExternallyUsedNodes(listener.getNodes())) {
 225             if (n.isAlive() && n instanceof FloatingNode) {
 226                 n.replaceAtUsages(null);
 227                 GraphUtil.killWithUnusedFloatingInputs(n);
 228             }
 229         }
 230         if (createFloatingReads) {
 231             assert !graph.isAfterFloatingReadPhase();
 232             graph.setAfterFloatingReadPhase(true);
 233         }
 234     }
 235 
 236     public static MemoryMapImpl mergeMemoryMaps(AbstractMergeNode merge, List<? extends MemoryMap> states) {
 237         MemoryMapImpl newState = new MemoryMapImpl();
 238 
 239         Set<LocationIdentity> keys = CollectionsFactory.newSet();
 240         for (MemoryMap other : states) {
 241             keys.addAll(other.getLocations());
 242         }
 243         assert checkNoImmutableLocations(keys);
 244 
 245         for (LocationIdentity key : keys) {
 246             int mergedStatesCount = 0;
 247             boolean isPhi = false;
 248             MemoryNode merged = null;
 249             for (MemoryMap state : states) {
 250                 MemoryNode last = state.getLastLocationAccess(key);
 251                 if (isPhi) {
 252                     ((MemoryPhiNode) merged).addInput(ValueNodeUtil.asNode(last));
 253                 } else {
 254                     if (merged == last) {
 255                         // nothing to do
 256                     } else if (merged == null) {
 257                         merged = last;
 258                     } else {
 259                         MemoryPhiNode phi = merge.graph().addWithoutUnique(new MemoryPhiNode(merge, key));
 260                         for (int j = 0; j < mergedStatesCount; j++) {
 261                             phi.addInput(ValueNodeUtil.asNode(merged));
 262                         }
 263                         phi.addInput(ValueNodeUtil.asNode(last));
 264                         merged = phi;
 265                         isPhi = true;
 266                     }
 267                 }
 268                 mergedStatesCount++;
 269             }
 270             newState.lastMemorySnapshot.put(key, merged);
 271         }
 272         return newState;
 273 
 274     }
 275 
 276     private static boolean checkNoImmutableLocations(Set<LocationIdentity> keys) {
 277         keys.forEach(t -> {
 278             assert !t.isImmutable();
 279         });
 280         return true;
 281     }
 282 
 283     public static class FloatingReadClosure extends NodeIteratorClosure<MemoryMapImpl> {
 284 
 285         private final Map<LoopBeginNode, Set<LocationIdentity>> modifiedInLoops;
 286         private boolean createFloatingReads;
 287         private boolean createMemoryMapNodes;
 288 
 289         public FloatingReadClosure(Map<LoopBeginNode, Set<LocationIdentity>> modifiedInLoops, boolean createFloatingReads, boolean createMemoryMapNodes) {
 290             this.modifiedInLoops = modifiedInLoops;
 291             this.createFloatingReads = createFloatingReads;
 292             this.createMemoryMapNodes = createMemoryMapNodes;
 293         }
 294 
 295         @Override
 296         protected MemoryMapImpl processNode(FixedNode node, MemoryMapImpl state) {
 297             if (node instanceof MemoryAnchorNode) {
 298                 processAnchor((MemoryAnchorNode) node, state);
 299                 return state;
 300             }
 301 
 302             if (node instanceof MemoryAccess) {
 303                 processAccess((MemoryAccess) node, state);
 304             }
 305 
 306             if (createFloatingReads & node instanceof FloatableAccessNode) {
 307                 processFloatable((FloatableAccessNode) node, state);
 308             } else if (node instanceof MemoryCheckpoint.Single) {
 309                 processCheckpoint((MemoryCheckpoint.Single) node, state);
 310             } else if (node instanceof MemoryCheckpoint.Multi) {
 311                 processCheckpoint((MemoryCheckpoint.Multi) node, state);
 312             }
 313             assert MemoryCheckpoint.TypeAssertion.correctType(node) : node;
 314 
 315             if (createMemoryMapNodes && node instanceof ReturnNode) {
 316                 ((ReturnNode) node).setMemoryMap(node.graph().unique(new MemoryMapNode(state.lastMemorySnapshot)));
 317             }
 318             return state;
 319         }
 320 
 321         /**
 322          * Improve the memory graph by re-wiring all usages of a {@link MemoryAnchorNode} to the
 323          * real last access location.
 324          */
 325         private static void processAnchor(MemoryAnchorNode anchor, MemoryMapImpl state) {
 326             for (Node node : anchor.usages().snapshot()) {
 327                 if (node instanceof MemoryAccess) {
 328                     MemoryAccess access = (MemoryAccess) node;
 329                     if (access.getLastLocationAccess() == anchor) {
 330                         MemoryNode lastLocationAccess = state.getLastLocationAccess(access.getLocationIdentity());
 331                         assert lastLocationAccess != null;
 332                         access.setLastLocationAccess(lastLocationAccess);
 333                     }
 334                 }
 335             }
 336 
 337             if (anchor.hasNoUsages()) {
 338                 anchor.graph().removeFixed(anchor);
 339             }
 340         }
 341 
 342         private static void processAccess(MemoryAccess access, MemoryMapImpl state) {
 343             LocationIdentity locationIdentity = access.getLocationIdentity();
 344             if (!locationIdentity.equals(LocationIdentity.any())) {
 345                 MemoryNode lastLocationAccess = state.getLastLocationAccess(locationIdentity);
 346                 access.setLastLocationAccess(lastLocationAccess);
 347             }
 348         }
 349 
 350         private static void processCheckpoint(MemoryCheckpoint.Single checkpoint, MemoryMapImpl state) {
 351             processIdentity(checkpoint.getLocationIdentity(), checkpoint, state);
 352         }
 353 
 354         private static void processCheckpoint(MemoryCheckpoint.Multi checkpoint, MemoryMapImpl state) {
 355             for (LocationIdentity identity : checkpoint.getLocationIdentities()) {
 356                 processIdentity(identity, checkpoint, state);
 357             }
 358         }
 359 
 360         private static void processIdentity(LocationIdentity identity, MemoryCheckpoint checkpoint, MemoryMapImpl state) {
 361             if (identity.isAny()) {
 362                 state.lastMemorySnapshot.clear();
 363             }
 364             state.lastMemorySnapshot.put(identity, checkpoint);
 365         }
 366 
 367         @SuppressWarnings("try")
 368         private static void processFloatable(FloatableAccessNode accessNode, MemoryMapImpl state) {
 369             StructuredGraph graph = accessNode.graph();
 370             LocationIdentity locationIdentity = accessNode.getLocationIdentity();
 371             if (accessNode.canFloat()) {
 372                 assert accessNode.getNullCheck() == false;
 373                 MemoryNode lastLocationAccess = state.getLastLocationAccess(locationIdentity);
 374                 try (DebugCloseable position = accessNode.withNodeSourcePosition()) {
 375                     FloatingAccessNode floatingNode = accessNode.asFloatingNode(lastLocationAccess);
 376                     ValueAnchorNode anchor = null;
 377                     GuardingNode guard = accessNode.getGuard();
 378                     if (guard != null) {
 379                         anchor = graph.add(new ValueAnchorNode(guard.asNode()));
 380                         graph.addAfterFixed(accessNode, anchor);
 381                     }
 382                     graph.replaceFixedWithFloating(accessNode, floatingNode);
 383                 }
 384             }
 385         }
 386 
 387         @Override
 388         protected MemoryMapImpl merge(AbstractMergeNode merge, List<MemoryMapImpl> states) {
 389             return mergeMemoryMaps(merge, states);
 390         }
 391 
 392         @Override
 393         protected MemoryMapImpl afterSplit(AbstractBeginNode node, MemoryMapImpl oldState) {
 394             MemoryMapImpl result = new MemoryMapImpl(oldState);
 395             if (node.predecessor() instanceof InvokeWithExceptionNode) {
 396                 /*
 397                  * InvokeWithException cannot be the lastLocationAccess for a FloatingReadNode.
 398                  * Since it is both the invoke and a control flow split, the scheduler cannot
 399                  * schedule anything immediately after the invoke. It can only schedule in the
 400                  * normal or exceptional successor - and we have to tell the scheduler here which
 401                  * side it needs to choose by putting in the location identity on both successors.
 402                  */
 403                 InvokeWithExceptionNode invoke = (InvokeWithExceptionNode) node.predecessor();
 404                 result.lastMemorySnapshot.put(invoke.getLocationIdentity(), (MemoryCheckpoint) node);
 405             }
 406             return result;
 407         }
 408 
 409         @Override
 410         protected Map<LoopExitNode, MemoryMapImpl> processLoop(LoopBeginNode loop, MemoryMapImpl initialState) {
 411             Set<LocationIdentity> modifiedLocations = modifiedInLoops.get(loop);
 412             Map<LocationIdentity, MemoryPhiNode> phis = CollectionsFactory.newMap();
 413             if (modifiedLocations.contains(LocationIdentity.any())) {
 414                 // create phis for all locations if ANY is modified in the loop
 415                 modifiedLocations = CollectionsFactory.newSet(modifiedLocations);
 416                 modifiedLocations.addAll(initialState.lastMemorySnapshot.keySet());
 417             }
 418 
 419             for (LocationIdentity location : modifiedLocations) {
 420                 createMemoryPhi(loop, initialState, phis, location);
 421             }
 422             for (Map.Entry<LocationIdentity, MemoryPhiNode> entry : phis.entrySet()) {
 423                 initialState.lastMemorySnapshot.put(entry.getKey(), entry.getValue());
 424             }
 425 
 426             LoopInfo<MemoryMapImpl> loopInfo = ReentrantNodeIterator.processLoop(this, loop, initialState);
 427 
 428             for (Map.Entry<LoopEndNode, MemoryMapImpl> entry : loopInfo.endStates.entrySet()) {
 429                 int endIndex = loop.phiPredecessorIndex(entry.getKey());
 430                 for (Map.Entry<LocationIdentity, MemoryPhiNode> phiEntry : phis.entrySet()) {
 431                     LocationIdentity key = phiEntry.getKey();
 432                     PhiNode phi = phiEntry.getValue();
 433                     phi.initializeValueAt(endIndex, ValueNodeUtil.asNode(entry.getValue().getLastLocationAccess(key)));
 434                 }
 435             }
 436             return loopInfo.exitStates;
 437         }
 438 
 439         private static void createMemoryPhi(LoopBeginNode loop, MemoryMapImpl initialState, Map<LocationIdentity, MemoryPhiNode> phis, LocationIdentity location) {
 440             MemoryPhiNode phi = loop.graph().addWithoutUnique(new MemoryPhiNode(loop, location));
 441             phi.addInput(ValueNodeUtil.asNode(initialState.getLastLocationAccess(location)));
 442             phis.put(location, phi);
 443         }
 444     }
 445 }