1 /*
   2  * Copyright (c) 2011, 2016, Oracle and/or its affiliates. All rights reserved.
   3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
   4  *
   5  * This code is free software; you can redistribute it and/or modify it
   6  * under the terms of the GNU General Public License version 2 only, as
   7  * published by the Free Software Foundation.
   8  *
   9  * This code is distributed in the hope that it will be useful, but WITHOUT
  10  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
  11  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
  12  * version 2 for more details (a copy is included in the LICENSE file that
  13  * accompanied this code).
  14  *
  15  * You should have received a copy of the GNU General Public License version
  16  * 2 along with this work; if not, write to the Free Software Foundation,
  17  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
  18  *
  19  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
  20  * or visit www.oracle.com if you need additional information or have any
  21  * questions.
  22  */
  23 
  24 
  25 package org.graalvm.compiler.phases.common;
  26 
  27 import static org.graalvm.compiler.graph.Graph.NodeEvent.NODE_ADDED;
  28 import static org.graalvm.compiler.graph.Graph.NodeEvent.ZERO_USAGES;
  29 import static jdk.internal.vm.compiler.word.LocationIdentity.any;
  30 
  31 import java.util.EnumSet;
  32 import java.util.Iterator;
  33 import java.util.List;
  34 
  35 import jdk.internal.vm.compiler.collections.EconomicMap;
  36 import jdk.internal.vm.compiler.collections.EconomicSet;
  37 import jdk.internal.vm.compiler.collections.Equivalence;
  38 import jdk.internal.vm.compiler.collections.UnmodifiableMapCursor;
  39 import org.graalvm.compiler.core.common.cfg.Loop;
  40 import org.graalvm.compiler.debug.DebugCloseable;
  41 import org.graalvm.compiler.graph.Graph.NodeEventScope;
  42 import org.graalvm.compiler.graph.Node;
  43 import org.graalvm.compiler.nodes.AbstractBeginNode;
  44 import org.graalvm.compiler.nodes.AbstractMergeNode;
  45 import org.graalvm.compiler.nodes.FixedNode;
  46 import org.graalvm.compiler.nodes.InvokeWithExceptionNode;
  47 import org.graalvm.compiler.nodes.LoopBeginNode;
  48 import org.graalvm.compiler.nodes.LoopEndNode;
  49 import org.graalvm.compiler.nodes.LoopExitNode;
  50 import org.graalvm.compiler.nodes.PhiNode;
  51 import org.graalvm.compiler.nodes.ReturnNode;
  52 import org.graalvm.compiler.nodes.StartNode;
  53 import org.graalvm.compiler.nodes.StructuredGraph;
  54 import org.graalvm.compiler.nodes.ValueNodeUtil;
  55 import org.graalvm.compiler.nodes.calc.FloatingNode;
  56 import org.graalvm.compiler.nodes.cfg.Block;
  57 import org.graalvm.compiler.nodes.cfg.ControlFlowGraph;
  58 import org.graalvm.compiler.nodes.cfg.HIRLoop;
  59 import org.graalvm.compiler.nodes.memory.FloatableAccessNode;
  60 import org.graalvm.compiler.nodes.memory.FloatingAccessNode;
  61 import org.graalvm.compiler.nodes.memory.FloatingReadNode;
  62 import org.graalvm.compiler.nodes.memory.MemoryAccess;
  63 import org.graalvm.compiler.nodes.memory.MemoryAnchorNode;
  64 import org.graalvm.compiler.nodes.memory.MemoryCheckpoint;
  65 import org.graalvm.compiler.nodes.memory.MemoryMap;
  66 import org.graalvm.compiler.nodes.memory.MemoryMapNode;
  67 import org.graalvm.compiler.nodes.memory.MemoryNode;
  68 import org.graalvm.compiler.nodes.memory.MemoryPhiNode;
  69 import org.graalvm.compiler.nodes.memory.ReadNode;
  70 import org.graalvm.compiler.nodes.util.GraphUtil;
  71 import org.graalvm.compiler.phases.Phase;
  72 import org.graalvm.compiler.phases.common.util.EconomicSetNodeEventListener;
  73 import org.graalvm.compiler.phases.graph.ReentrantNodeIterator;
  74 import org.graalvm.compiler.phases.graph.ReentrantNodeIterator.LoopInfo;
  75 import org.graalvm.compiler.phases.graph.ReentrantNodeIterator.NodeIteratorClosure;
  76 import jdk.internal.vm.compiler.word.LocationIdentity;
  77 
  78 public class FloatingReadPhase extends Phase {
  79 
  80     private boolean createFloatingReads;
  81     private boolean createMemoryMapNodes;
  82 
  83     public static class MemoryMapImpl implements MemoryMap {
  84 
  85         private final EconomicMap<LocationIdentity, MemoryNode> lastMemorySnapshot;
  86 
  87         public MemoryMapImpl(MemoryMapImpl memoryMap) {
  88             lastMemorySnapshot = EconomicMap.create(Equivalence.DEFAULT, memoryMap.lastMemorySnapshot);
  89         }
  90 
  91         public MemoryMapImpl(StartNode start) {
  92             this();
  93             lastMemorySnapshot.put(any(), start);
  94         }
  95 
  96         public MemoryMapImpl() {
  97             lastMemorySnapshot = EconomicMap.create(Equivalence.DEFAULT);
  98         }
  99 
 100         @Override
 101         public MemoryNode getLastLocationAccess(LocationIdentity locationIdentity) {
 102             MemoryNode lastLocationAccess;
 103             if (locationIdentity.isImmutable()) {
 104                 return null;
 105             } else {
 106                 lastLocationAccess = lastMemorySnapshot.get(locationIdentity);
 107                 if (lastLocationAccess == null) {
 108                     lastLocationAccess = lastMemorySnapshot.get(any());
 109                     assert lastLocationAccess != null;
 110                 }
 111                 return lastLocationAccess;
 112             }
 113         }
 114 
 115         @Override
 116         public Iterable<LocationIdentity> getLocations() {
 117             return lastMemorySnapshot.getKeys();
 118         }
 119 
 120         public EconomicMap<LocationIdentity, MemoryNode> getMap() {
 121             return lastMemorySnapshot;
 122         }
 123     }
 124 
 125     public FloatingReadPhase() {
 126         this(true, false);
 127     }
 128 
 129     /**
 130      * @param createFloatingReads specifies whether {@link FloatableAccessNode}s like
 131      *            {@link ReadNode} should be converted into floating nodes (e.g.,
 132      *            {@link FloatingReadNode}s) where possible
 133      * @param createMemoryMapNodes a {@link MemoryMapNode} will be created for each return if this
 134      *            is true
 135      */
 136     public FloatingReadPhase(boolean createFloatingReads, boolean createMemoryMapNodes) {
 137         this.createFloatingReads = createFloatingReads;
 138         this.createMemoryMapNodes = createMemoryMapNodes;
 139     }
 140 
 141     @Override
 142     public float codeSizeIncrease() {
 143         return 1.25f;
 144     }
 145 
 146     /**
 147      * Removes nodes from a given set that (transitively) have a usage outside the set.
 148      */
 149     private static EconomicSet<Node> removeExternallyUsedNodes(EconomicSet<Node> set) {
 150         boolean change;
 151         do {
 152             change = false;
 153             for (Iterator<Node> iter = set.iterator(); iter.hasNext();) {
 154                 Node node = iter.next();
 155                 for (Node usage : node.usages()) {
 156                     if (!set.contains(usage)) {
 157                         change = true;
 158                         iter.remove();
 159                         break;
 160                     }
 161                 }
 162             }
 163         } while (change);
 164         return set;
 165     }
 166 
 167     protected void processNode(FixedNode node, EconomicSet<LocationIdentity> currentState) {
 168         if (node instanceof MemoryCheckpoint.Single) {
 169             processIdentity(currentState, ((MemoryCheckpoint.Single) node).getLocationIdentity());
 170         } else if (node instanceof MemoryCheckpoint.Multi) {
 171             for (LocationIdentity identity : ((MemoryCheckpoint.Multi) node).getLocationIdentities()) {
 172                 processIdentity(currentState, identity);
 173             }
 174         }
 175     }
 176 
 177     private static void processIdentity(EconomicSet<LocationIdentity> currentState, LocationIdentity identity) {
 178         if (identity.isMutable()) {
 179             currentState.add(identity);
 180         }
 181     }
 182 
 183     protected void processBlock(Block b, EconomicSet<LocationIdentity> currentState) {
 184         for (FixedNode n : b.getNodes()) {
 185             processNode(n, currentState);
 186         }
 187     }
 188 
 189     private EconomicSet<LocationIdentity> processLoop(HIRLoop loop, EconomicMap<LoopBeginNode, EconomicSet<LocationIdentity>> modifiedInLoops) {
 190         LoopBeginNode loopBegin = (LoopBeginNode) loop.getHeader().getBeginNode();
 191         EconomicSet<LocationIdentity> result = modifiedInLoops.get(loopBegin);
 192         if (result != null) {
 193             return result;
 194         }
 195 
 196         result = EconomicSet.create(Equivalence.DEFAULT);
 197         for (Loop<Block> inner : loop.getChildren()) {
 198             result.addAll(processLoop((HIRLoop) inner, modifiedInLoops));
 199         }
 200 
 201         for (Block b : loop.getBlocks()) {
 202             if (b.getLoop() == loop) {
 203                 processBlock(b, result);
 204             }
 205         }
 206 
 207         modifiedInLoops.put(loopBegin, result);
 208         return result;
 209     }
 210 
 211     @Override
 212     @SuppressWarnings("try")
 213     protected void run(StructuredGraph graph) {
 214         EconomicMap<LoopBeginNode, EconomicSet<LocationIdentity>> modifiedInLoops = null;
 215         if (graph.hasLoops()) {
 216             modifiedInLoops = EconomicMap.create(Equivalence.IDENTITY);
 217             ControlFlowGraph cfg = ControlFlowGraph.compute(graph, true, true, false, false);
 218             for (Loop<?> l : cfg.getLoops()) {
 219                 HIRLoop loop = (HIRLoop) l;
 220                 processLoop(loop, modifiedInLoops);
 221             }
 222         }
 223 
 224         EconomicSetNodeEventListener listener = new EconomicSetNodeEventListener(EnumSet.of(NODE_ADDED, ZERO_USAGES));
 225         try (NodeEventScope nes = graph.trackNodeEvents(listener)) {
 226             ReentrantNodeIterator.apply(new FloatingReadClosure(modifiedInLoops, createFloatingReads, createMemoryMapNodes), graph.start(), new MemoryMapImpl(graph.start()));
 227         }
 228 
 229         for (Node n : removeExternallyUsedNodes(listener.getNodes())) {
 230             if (n.isAlive() && n instanceof FloatingNode) {
 231                 n.replaceAtUsages(null);
 232                 GraphUtil.killWithUnusedFloatingInputs(n);
 233             }
 234         }
 235         if (createFloatingReads) {
 236             assert !graph.isAfterFloatingReadPhase();
 237             graph.setAfterFloatingReadPhase(true);
 238         }
 239     }
 240 
 241     public static MemoryMapImpl mergeMemoryMaps(AbstractMergeNode merge, List<? extends MemoryMap> states) {
 242         MemoryMapImpl newState = new MemoryMapImpl();
 243 
 244         EconomicSet<LocationIdentity> keys = EconomicSet.create(Equivalence.DEFAULT);
 245         for (MemoryMap other : states) {
 246             keys.addAll(other.getLocations());
 247         }
 248         assert checkNoImmutableLocations(keys);
 249 
 250         for (LocationIdentity key : keys) {
 251             int mergedStatesCount = 0;
 252             boolean isPhi = false;
 253             MemoryNode merged = null;
 254             for (MemoryMap state : states) {
 255                 MemoryNode last = state.getLastLocationAccess(key);
 256                 if (isPhi) {
 257                     ((MemoryPhiNode) merged).addInput(ValueNodeUtil.asNode(last));
 258                 } else {
 259                     if (merged == last) {
 260                         // nothing to do
 261                     } else if (merged == null) {
 262                         merged = last;
 263                     } else {
 264                         MemoryPhiNode phi = merge.graph().addWithoutUnique(new MemoryPhiNode(merge, key));
 265                         for (int j = 0; j < mergedStatesCount; j++) {
 266                             phi.addInput(ValueNodeUtil.asNode(merged));
 267                         }
 268                         phi.addInput(ValueNodeUtil.asNode(last));
 269                         merged = phi;
 270                         isPhi = true;
 271                     }
 272                 }
 273                 mergedStatesCount++;
 274             }
 275             newState.lastMemorySnapshot.put(key, merged);
 276         }
 277         return newState;
 278 
 279     }
 280 
 281     private static boolean checkNoImmutableLocations(EconomicSet<LocationIdentity> keys) {
 282         keys.forEach(t -> {
 283             assert t.isMutable();
 284         });
 285         return true;
 286     }
 287 
 288     public static class FloatingReadClosure extends NodeIteratorClosure<MemoryMapImpl> {
 289 
 290         private final EconomicMap<LoopBeginNode, EconomicSet<LocationIdentity>> modifiedInLoops;
 291         private boolean createFloatingReads;
 292         private boolean createMemoryMapNodes;
 293 
 294         public FloatingReadClosure(EconomicMap<LoopBeginNode, EconomicSet<LocationIdentity>> modifiedInLoops, boolean createFloatingReads, boolean createMemoryMapNodes) {
 295             this.modifiedInLoops = modifiedInLoops;
 296             this.createFloatingReads = createFloatingReads;
 297             this.createMemoryMapNodes = createMemoryMapNodes;
 298         }
 299 
 300         @Override
 301         protected MemoryMapImpl processNode(FixedNode node, MemoryMapImpl state) {
 302             if (node instanceof MemoryAnchorNode) {
 303                 processAnchor((MemoryAnchorNode) node, state);
 304                 return state;
 305             }
 306 
 307             if (node instanceof MemoryAccess) {
 308                 processAccess((MemoryAccess) node, state);
 309             }
 310 
 311             if (createFloatingReads && node instanceof FloatableAccessNode) {
 312                 processFloatable((FloatableAccessNode) node, state);
 313             } else if (node instanceof MemoryCheckpoint.Single) {
 314                 processCheckpoint((MemoryCheckpoint.Single) node, state);
 315             } else if (node instanceof MemoryCheckpoint.Multi) {
 316                 processCheckpoint((MemoryCheckpoint.Multi) node, state);
 317             }
 318             assert MemoryCheckpoint.TypeAssertion.correctType(node) : node;
 319 
 320             if (createMemoryMapNodes && node instanceof ReturnNode) {
 321                 ((ReturnNode) node).setMemoryMap(node.graph().unique(new MemoryMapNode(state.lastMemorySnapshot)));
 322             }
 323             return state;
 324         }
 325 
 326         /**
 327          * Improve the memory graph by re-wiring all usages of a {@link MemoryAnchorNode} to the
 328          * real last access location.
 329          */
 330         private static void processAnchor(MemoryAnchorNode anchor, MemoryMapImpl state) {
 331             for (Node node : anchor.usages().snapshot()) {
 332                 if (node instanceof MemoryAccess) {
 333                     MemoryAccess access = (MemoryAccess) node;
 334                     if (access.getLastLocationAccess() == anchor) {
 335                         MemoryNode lastLocationAccess = state.getLastLocationAccess(access.getLocationIdentity());
 336                         assert lastLocationAccess != null;
 337                         access.setLastLocationAccess(lastLocationAccess);
 338                     }
 339                 }
 340             }
 341 
 342             if (anchor.hasNoUsages()) {
 343                 anchor.graph().removeFixed(anchor);
 344             }
 345         }
 346 
 347         private static void processAccess(MemoryAccess access, MemoryMapImpl state) {
 348             LocationIdentity locationIdentity = access.getLocationIdentity();
 349             if (!locationIdentity.equals(LocationIdentity.any())) {
 350                 MemoryNode lastLocationAccess = state.getLastLocationAccess(locationIdentity);
 351                 access.setLastLocationAccess(lastLocationAccess);
 352             }
 353         }
 354 
 355         private static void processCheckpoint(MemoryCheckpoint.Single checkpoint, MemoryMapImpl state) {
 356             processIdentity(checkpoint.getLocationIdentity(), checkpoint, state);
 357         }
 358 
 359         private static void processCheckpoint(MemoryCheckpoint.Multi checkpoint, MemoryMapImpl state) {
 360             for (LocationIdentity identity : checkpoint.getLocationIdentities()) {
 361                 processIdentity(identity, checkpoint, state);
 362             }
 363         }
 364 
 365         private static void processIdentity(LocationIdentity identity, MemoryCheckpoint checkpoint, MemoryMapImpl state) {
 366             if (identity.isAny()) {
 367                 state.lastMemorySnapshot.clear();
 368             }
 369             if (identity.isMutable()) {
 370                 state.lastMemorySnapshot.put(identity, checkpoint);
 371             }
 372         }
 373 
 374         @SuppressWarnings("try")
 375         private static void processFloatable(FloatableAccessNode accessNode, MemoryMapImpl state) {
 376             StructuredGraph graph = accessNode.graph();
 377             LocationIdentity locationIdentity = accessNode.getLocationIdentity();
 378             if (accessNode.canFloat()) {
 379                 assert accessNode.getNullCheck() == false;
 380                 MemoryNode lastLocationAccess = state.getLastLocationAccess(locationIdentity);
 381                 try (DebugCloseable position = accessNode.withNodeSourcePosition()) {
 382                     FloatingAccessNode floatingNode = accessNode.asFloatingNode(lastLocationAccess);
 383                     graph.replaceFixedWithFloating(accessNode, floatingNode);
 384                 }
 385             }
 386         }
 387 
 388         @Override
 389         protected MemoryMapImpl merge(AbstractMergeNode merge, List<MemoryMapImpl> states) {
 390             return mergeMemoryMaps(merge, states);
 391         }
 392 
 393         @Override
 394         protected MemoryMapImpl afterSplit(AbstractBeginNode node, MemoryMapImpl oldState) {
 395             MemoryMapImpl result = new MemoryMapImpl(oldState);
 396             if (node.predecessor() instanceof InvokeWithExceptionNode) {
 397                 /*
 398                  * InvokeWithException cannot be the lastLocationAccess for a FloatingReadNode.
 399                  * Since it is both the invoke and a control flow split, the scheduler cannot
 400                  * schedule anything immediately after the invoke. It can only schedule in the
 401                  * normal or exceptional successor - and we have to tell the scheduler here which
 402                  * side it needs to choose by putting in the location identity on both successors.
 403                  */
 404                 InvokeWithExceptionNode invoke = (InvokeWithExceptionNode) node.predecessor();
 405                 result.lastMemorySnapshot.put(invoke.getLocationIdentity(), (MemoryCheckpoint) node);
 406             }
 407             return result;
 408         }
 409 
 410         @Override
 411         protected EconomicMap<LoopExitNode, MemoryMapImpl> processLoop(LoopBeginNode loop, MemoryMapImpl initialState) {
 412             EconomicSet<LocationIdentity> modifiedLocations = modifiedInLoops.get(loop);
 413             EconomicMap<LocationIdentity, MemoryPhiNode> phis = EconomicMap.create(Equivalence.DEFAULT);
 414             if (modifiedLocations.contains(LocationIdentity.any())) {
 415                 // create phis for all locations if ANY is modified in the loop
 416                 modifiedLocations = EconomicSet.create(Equivalence.DEFAULT, modifiedLocations);
 417                 modifiedLocations.addAll(initialState.lastMemorySnapshot.getKeys());
 418             }
 419 
 420             for (LocationIdentity location : modifiedLocations) {
 421                 createMemoryPhi(loop, initialState, phis, location);
 422             }
 423             initialState.lastMemorySnapshot.putAll(phis);
 424 
 425             LoopInfo<MemoryMapImpl> loopInfo = ReentrantNodeIterator.processLoop(this, loop, initialState);
 426 
 427             UnmodifiableMapCursor<LoopEndNode, MemoryMapImpl> endStateCursor = loopInfo.endStates.getEntries();
 428             while (endStateCursor.advance()) {
 429                 int endIndex = loop.phiPredecessorIndex(endStateCursor.getKey());
 430                 UnmodifiableMapCursor<LocationIdentity, MemoryPhiNode> phiCursor = phis.getEntries();
 431                 while (phiCursor.advance()) {
 432                     LocationIdentity key = phiCursor.getKey();
 433                     PhiNode phi = phiCursor.getValue();
 434                     phi.initializeValueAt(endIndex, ValueNodeUtil.asNode(endStateCursor.getValue().getLastLocationAccess(key)));
 435                 }
 436             }
 437             return loopInfo.exitStates;
 438         }
 439 
 440         private static void createMemoryPhi(LoopBeginNode loop, MemoryMapImpl initialState, EconomicMap<LocationIdentity, MemoryPhiNode> phis, LocationIdentity location) {
 441             MemoryPhiNode phi = loop.graph().addWithoutUnique(new MemoryPhiNode(loop, location));
 442             phi.addInput(ValueNodeUtil.asNode(initialState.getLastLocationAccess(location)));
 443             phis.put(location, phi);
 444         }
 445     }
 446 }