1 /*
   2  * Copyright (c) 2009, 2018, Oracle and/or its affiliates. All rights reserved.
   3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
   4  *
   5  * This code is free software; you can redistribute it and/or modify it
   6  * under the terms of the GNU General Public License version 2 only, as
   7  * published by the Free Software Foundation.
   8  *
   9  * This code is distributed in the hope that it will be useful, but WITHOUT
  10  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
  11  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
  12  * version 2 for more details (a copy is included in the LICENSE file that
  13  * accompanied this code).
  14  *
  15  * You should have received a copy of the GNU General Public License version
  16  * 2 along with this work; if not, write to the Free Software Foundation,
  17  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
  18  *
  19  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
  20  * or visit www.oracle.com if you need additional information or have any
  21  * questions.
  22  */
  23 
  24 
  25 package org.graalvm.compiler.nodes.extended;
  26 
  27 import java.util.ArrayList;
  28 import java.util.Arrays;
  29 import java.util.Comparator;
  30 import java.util.HashMap;
  31 import java.util.List;
  32 import java.util.Map;
  33 
  34 import org.graalvm.compiler.core.common.spi.ConstantFieldProvider;
  35 import org.graalvm.compiler.core.common.type.IntegerStamp;
  36 import org.graalvm.compiler.core.common.type.PrimitiveStamp;
  37 import org.graalvm.compiler.core.common.type.Stamp;
  38 import org.graalvm.compiler.core.common.type.StampFactory;
  39 import org.graalvm.compiler.graph.Node;
  40 import org.graalvm.compiler.graph.NodeClass;
  41 import org.graalvm.compiler.graph.spi.Simplifiable;
  42 import org.graalvm.compiler.graph.spi.SimplifierTool;
  43 import org.graalvm.compiler.nodeinfo.NodeInfo;
  44 import org.graalvm.compiler.nodes.AbstractBeginNode;
  45 import org.graalvm.compiler.nodes.ConstantNode;
  46 import org.graalvm.compiler.nodes.FixedGuardNode;
  47 import org.graalvm.compiler.nodes.FixedWithNextNode;
  48 import org.graalvm.compiler.nodes.LogicNode;
  49 import org.graalvm.compiler.nodes.NodeView;
  50 import org.graalvm.compiler.nodes.ValueNode;
  51 import org.graalvm.compiler.nodes.calc.IntegerBelowNode;
  52 import org.graalvm.compiler.nodes.java.LoadIndexedNode;
  53 import org.graalvm.compiler.nodes.spi.LIRLowerable;
  54 import org.graalvm.compiler.nodes.spi.NodeLIRBuilderTool;
  55 import org.graalvm.compiler.nodes.spi.SwitchFoldable;
  56 import org.graalvm.compiler.nodes.util.GraphUtil;
  57 
  58 import jdk.vm.ci.meta.DeoptimizationAction;
  59 import jdk.vm.ci.meta.DeoptimizationReason;
  60 import jdk.vm.ci.meta.JavaConstant;
  61 import jdk.vm.ci.meta.JavaKind;
  62 
  63 /**
  64  * The {@code IntegerSwitchNode} represents a switch on integer keys, with a sorted array of key
  65  * values. The actual implementation of the switch will be decided by the backend.
  66  */
  67 @NodeInfo
  68 public final class IntegerSwitchNode extends SwitchNode implements LIRLowerable, Simplifiable, SwitchFoldable {
  69     public static final NodeClass<IntegerSwitchNode> TYPE = NodeClass.create(IntegerSwitchNode.class);
  70 
  71     protected final int[] keys;
  72 
  73     public IntegerSwitchNode(ValueNode value, AbstractBeginNode[] successors, int[] keys, double[] keyProbabilities, int[] keySuccessors) {
  74         super(TYPE, value, successors, keySuccessors, keyProbabilities);
  75         assert keySuccessors.length == keys.length + 1;
  76         assert keySuccessors.length == keyProbabilities.length;
  77         this.keys = keys;
  78         assert value.stamp(NodeView.DEFAULT) instanceof PrimitiveStamp && value.stamp(NodeView.DEFAULT).getStackKind().isNumericInteger();
  79         assert assertSorted();
  80         assert assertNoUntargettedSuccessor();
  81     }
  82 
  83     private boolean assertSorted() {
  84         for (int i = 1; i < keys.length; i++) {
  85             assert keys[i - 1] < keys[i];
  86         }
  87         return true;
  88     }
  89 
  90     private boolean assertNoUntargettedSuccessor() {
  91         boolean[] checker = new boolean[successors.size()];
  92         for (int successorIndex : keySuccessors) {
  93             checker[successorIndex] = true;
  94         }
  95         checker[defaultSuccessorIndex()] = true;
  96         for (boolean b : checker) {
  97             assert b;
  98         }
  99         return true;
 100     }
 101 
 102     public IntegerSwitchNode(ValueNode value, int successorCount, int[] keys, double[] keyProbabilities, int[] keySuccessors) {
 103         this(value, new AbstractBeginNode[successorCount], keys, keyProbabilities, keySuccessors);
 104     }
 105 
 106     @Override
 107     public boolean isSorted() {
 108         return true;
 109     }
 110 
 111     /**
 112      * Gets the key at the specified index.
 113      *
 114      * @param i the index
 115      * @return the key at that index
 116      */
 117     @Override
 118     public JavaConstant keyAt(int i) {
 119         return JavaConstant.forInt(keys[i]);
 120     }
 121 
 122     /**
 123      * Gets the key at the specified index, as a java int.
 124      */
 125     @Override
 126     public int intKeyAt(int i) {
 127         return keys[i];
 128     }
 129 
 130     @Override
 131     public int keyCount() {
 132         return keys.length;
 133     }
 134 
 135     @Override
 136     public boolean equalKeys(SwitchNode switchNode) {
 137         if (!(switchNode instanceof IntegerSwitchNode)) {
 138             return false;
 139         }
 140         IntegerSwitchNode other = (IntegerSwitchNode) switchNode;
 141         return Arrays.equals(keys, other.keys);
 142     }
 143 
 144     @Override
 145     public void generate(NodeLIRBuilderTool gen) {
 146         gen.emitSwitch(this);
 147     }
 148 
 149     public AbstractBeginNode successorAtKey(int key) {
 150         return blockSuccessor(successorIndexAtKey(key));
 151     }
 152 
 153     public int successorIndexAtKey(int key) {
 154         for (int i = 0; i < keyCount(); i++) {
 155             if (keys[i] == key) {
 156                 return keySuccessorIndex(i);
 157             }
 158         }
 159         return keySuccessorIndex(keyCount());
 160     }
 161 
 162     @Override
 163     public void simplify(SimplifierTool tool) {
 164         NodeView view = NodeView.from(tool);
 165         if (blockSuccessorCount() == 1) {
 166             tool.addToWorkList(defaultSuccessor());
 167             graph().removeSplitPropagate(this, defaultSuccessor());
 168         } else if (value() instanceof ConstantNode) {
 169             killOtherSuccessors(tool, successorIndexAtKey(value().asJavaConstant().asInt()));
 170         } else if (tryOptimizeEnumSwitch(tool)) {
 171             return;
 172         } else if (tryRemoveUnreachableKeys(tool, value().stamp(view))) {
 173             return;
 174         } else if (switchTransformationOptimization(tool)) {
 175             return;
 176         }
 177     }
 178 
 179     private void addSuccessorForDeletion(AbstractBeginNode defaultNode) {
 180         successors.add(defaultNode);
 181     }
 182 
 183     @Override
 184     public Node getNextSwitchFoldableBranch() {
 185         return defaultSuccessor();
 186     }
 187 
 188     @Override
 189     public boolean isInSwitch(ValueNode switchValue) {
 190         return value == switchValue;
 191     }
 192 
 193     @Override
 194     public void cutOffCascadeNode() {
 195         AbstractBeginNode toKill = defaultSuccessor();
 196         clearSuccessors();
 197         addSuccessorForDeletion(toKill);
 198     }
 199 
 200     @Override
 201     public void cutOffLowestCascadeNode() {
 202         clearSuccessors();
 203     }
 204 
 205     @Override
 206     public AbstractBeginNode getDefault() {
 207         return defaultSuccessor();
 208     }
 209 
 210     @Override
 211     public ValueNode switchValue() {
 212         return value();
 213     }
 214 
 215     @Override
 216     public boolean isNonInitializedProfile() {
 217         int nbSuccessors = getSuccessorCount();
 218         double prob = 0.0d;
 219         for (int i = 0; i < nbSuccessors; i++) {
 220             if (keyProbabilities[i] > 0.0d) {
 221                 if (prob == 0.0d) {
 222                     prob = keyProbabilities[i];
 223                 } else if (keyProbabilities[i] != prob) {
 224                     return false;
 225                 }
 226             }
 227         }
 228         return true;
 229     }
 230 
 231     static final class KeyData {
 232         final int key;
 233         final double keyProbability;
 234         final int keySuccessor;
 235 
 236         KeyData(int key, double keyProbability, int keySuccessor) {
 237             this.key = key;
 238             this.keyProbability = keyProbability;
 239             this.keySuccessor = keySuccessor;
 240         }
 241     }
 242 
 243     /**
 244      * Remove unreachable keys from the switch based on the stamp of the value, i.e., based on the
 245      * known range of the switch value.
 246      */
 247     public boolean tryRemoveUnreachableKeys(SimplifierTool tool, Stamp valueStamp) {
 248         if (!(valueStamp instanceof IntegerStamp)) {
 249             return false;
 250         }
 251         IntegerStamp integerStamp = (IntegerStamp) valueStamp;
 252         if (integerStamp.isUnrestricted()) {
 253             return false;
 254         }
 255 
 256         List<KeyData> newKeyDatas = new ArrayList<>(keys.length);
 257         ArrayList<AbstractBeginNode> newSuccessors = new ArrayList<>(blockSuccessorCount());
 258         for (int i = 0; i < keys.length; i++) {
 259             if (integerStamp.contains(keys[i]) && keySuccessor(i) != defaultSuccessor()) {
 260                 newKeyDatas.add(new KeyData(keys[i], keyProbabilities[i], addNewSuccessor(keySuccessor(i), newSuccessors)));
 261             }
 262         }
 263 
 264         if (newKeyDatas.size() == keys.length) {
 265             /* All keys are reachable. */
 266             return false;
 267 
 268         } else if (newKeyDatas.size() == 0) {
 269             if (tool != null) {
 270                 tool.addToWorkList(defaultSuccessor());
 271             }
 272             graph().removeSplitPropagate(this, defaultSuccessor());
 273             return true;
 274 
 275         } else {
 276             int newDefaultSuccessor = addNewSuccessor(defaultSuccessor(), newSuccessors);
 277             double newDefaultProbability = keyProbabilities[keyProbabilities.length - 1];
 278             doReplace(value(), newKeyDatas, newSuccessors, newDefaultSuccessor, newDefaultProbability);
 279             return true;
 280         }
 281     }
 282 
 283     /**
 284      * For switch statements on enum values, the Java compiler has to generate complicated code:
 285      * because {@link Enum#ordinal()} can change when recompiling an enum, it cannot be used
 286      * directly as the value that is switched on. An intermediate int[] array, which is initialized
 287      * once at run time based on the actual {@link Enum#ordinal()} values, is used.
 288      * <p>
 289      * The {@link ConstantFieldProvider} of Graal already detects the int[] arrays and marks them as
 290      * {@link ConstantNode#isDefaultStable() stable}, i.e., the array elements are constant. The
 291      * code in this method detects array loads from such a stable array and re-wires the switch to
 292      * use the keys from the array elements, so that the array load is unnecessary.
 293      */
 294     private boolean tryOptimizeEnumSwitch(SimplifierTool tool) {
 295         if (!(value() instanceof LoadIndexedNode)) {
 296             /* Not the switch pattern we are looking for. */
 297             return false;
 298         }
 299         LoadIndexedNode loadIndexed = (LoadIndexedNode) value();
 300         if (loadIndexed.hasMoreThanOneUsage()) {
 301             /*
 302              * The array load is necessary for other reasons too, so there is no benefit optimizing
 303              * the switch.
 304              */
 305             return false;
 306         }
 307         assert loadIndexed.usages().first() == this;
 308 
 309         ValueNode newValue = loadIndexed.index();
 310         JavaConstant arrayConstant = loadIndexed.array().asJavaConstant();
 311         if (arrayConstant == null || ((ConstantNode) loadIndexed.array()).getStableDimension() != 1 || !((ConstantNode) loadIndexed.array()).isDefaultStable()) {
 312             /*
 313              * The array is a constant that we can optimize. We require the array elements to be
 314              * constant too, since we put them as literal constants into the switch keys.
 315              */
 316             return false;
 317         }
 318 
 319         Integer optionalArrayLength = tool.getConstantReflection().readArrayLength(arrayConstant);
 320         if (optionalArrayLength == null) {
 321             /* Loading a constant value can be denied by the VM. */
 322             return false;
 323         }
 324         int arrayLength = optionalArrayLength;
 325 
 326         Map<Integer, List<Integer>> reverseArrayMapping = new HashMap<>();
 327         for (int i = 0; i < arrayLength; i++) {
 328             JavaConstant elementConstant = tool.getConstantReflection().readArrayElement(arrayConstant, i);
 329             if (elementConstant == null || elementConstant.getJavaKind() != JavaKind.Int) {
 330                 /* Loading a constant value can be denied by the VM. */
 331                 return false;
 332             }
 333             int element = elementConstant.asInt();
 334 
 335             /*
 336              * The value loaded from the array is the old switch key, the index into the array is
 337              * the new switch key. We build a mapping from the old switch key to new keys.
 338              */
 339             reverseArrayMapping.computeIfAbsent(element, e -> new ArrayList<>()).add(i);
 340         }
 341 
 342         /* Build high-level representation of new switch keys. */
 343         List<KeyData> newKeyDatas = new ArrayList<>(arrayLength);
 344         ArrayList<AbstractBeginNode> newSuccessors = new ArrayList<>(blockSuccessorCount());
 345         for (int i = 0; i < keys.length; i++) {
 346             List<Integer> newKeys = reverseArrayMapping.get(keys[i]);
 347             if (newKeys == null || newKeys.size() == 0) {
 348                 /* The switch case is unreachable, we can ignore it. */
 349                 continue;
 350             }
 351 
 352             /*
 353              * We do not have detailed profiling information about the individual new keys, so we
 354              * have to assume they split the probability of the old key.
 355              */
 356             double newKeyProbability = keyProbabilities[i] / newKeys.size();
 357             int newKeySuccessor = addNewSuccessor(keySuccessor(i), newSuccessors);
 358 
 359             for (int newKey : newKeys) {
 360                 newKeyDatas.add(new KeyData(newKey, newKeyProbability, newKeySuccessor));
 361             }
 362         }
 363 
 364         int newDefaultSuccessor = addNewSuccessor(defaultSuccessor(), newSuccessors);
 365         double newDefaultProbability = keyProbabilities[keyProbabilities.length - 1];
 366 
 367         /*
 368          * We remove the array load, but we still need to preserve exception semantics by keeping
 369          * the bounds check. Fortunately the array length is a constant.
 370          */
 371         LogicNode boundsCheck = graph().unique(new IntegerBelowNode(newValue, ConstantNode.forInt(arrayLength, graph())));
 372         graph().addBeforeFixed(this, graph().add(new FixedGuardNode(boundsCheck, DeoptimizationReason.BoundsCheckException, DeoptimizationAction.InvalidateReprofile)));
 373 
 374         /*
 375          * Build the low-level representation of the new switch keys and replace ourself with a new
 376          * node.
 377          */
 378         doReplace(newValue, newKeyDatas, newSuccessors, newDefaultSuccessor, newDefaultProbability);
 379 
 380         /* The array load is now unnecessary. */
 381         assert loadIndexed.hasNoUsages();
 382         GraphUtil.removeFixedWithUnusedInputs(loadIndexed);
 383 
 384         return true;
 385     }
 386 
 387     private static int addNewSuccessor(AbstractBeginNode newSuccessor, ArrayList<AbstractBeginNode> newSuccessors) {
 388         int index = newSuccessors.indexOf(newSuccessor);
 389         if (index == -1) {
 390             index = newSuccessors.size();
 391             newSuccessors.add(newSuccessor);
 392         }
 393         return index;
 394     }
 395 
 396     private void doReplace(ValueNode newValue, List<KeyData> newKeyDatas, ArrayList<AbstractBeginNode> newSuccessors, int newDefaultSuccessor, double newDefaultProbability) {
 397         /* Sort the new keys (invariant of the IntegerSwitchNode). */
 398         newKeyDatas.sort(Comparator.comparingInt(k -> k.key));
 399 
 400         /* Create the final data arrays. */
 401         int newKeyCount = newKeyDatas.size();
 402         int[] newKeys = new int[newKeyCount];
 403         double[] newKeyProbabilities = new double[newKeyCount + 1];
 404         int[] newKeySuccessors = new int[newKeyCount + 1];
 405 
 406         for (int i = 0; i < newKeyCount; i++) {
 407             KeyData keyData = newKeyDatas.get(i);
 408             newKeys[i] = keyData.key;
 409             newKeyProbabilities[i] = keyData.keyProbability;
 410             newKeySuccessors[i] = keyData.keySuccessor;
 411         }
 412 
 413         newKeySuccessors[newKeyCount] = newDefaultSuccessor;
 414         newKeyProbabilities[newKeyCount] = newDefaultProbability;
 415 
 416         /* Normalize new probabilities so that they sum up to 1. */
 417         double totalProbability = 0;
 418         for (double probability : newKeyProbabilities) {
 419             totalProbability += probability;
 420         }
 421         if (totalProbability > 0) {
 422             for (int i = 0; i < newKeyProbabilities.length; i++) {
 423                 newKeyProbabilities[i] /= totalProbability;
 424             }
 425         } else {
 426             for (int i = 0; i < newKeyProbabilities.length; i++) {
 427                 newKeyProbabilities[i] = 1.0 / newKeyProbabilities.length;
 428             }
 429         }
 430 
 431         /*
 432          * Collect dead successors. Successors have to be cleaned before adding the new node to the
 433          * graph.
 434          */
 435         List<AbstractBeginNode> deadSuccessors = successors.filter(s -> !newSuccessors.contains(s)).snapshot();
 436         successors.clear();
 437 
 438         /*
 439          * Create the new switch node. This is done before removing dead successors as `killCFG`
 440          * could edit some of the inputs (e.g., if `newValue` is a loop-phi of the loop that dies
 441          * while removing successors).
 442          */
 443         AbstractBeginNode[] successorsArray = newSuccessors.toArray(new AbstractBeginNode[newSuccessors.size()]);
 444         SwitchNode newSwitch = graph().add(new IntegerSwitchNode(newValue, successorsArray, newKeys, newKeyProbabilities, newKeySuccessors));
 445 
 446         /* Remove dead successors. */
 447         for (AbstractBeginNode successor : deadSuccessors) {
 448             GraphUtil.killCFG(successor);
 449         }
 450 
 451         /* Replace ourselves with the new switch */
 452         ((FixedWithNextNode) predecessor()).setNext(newSwitch);
 453         GraphUtil.killWithUnusedFloatingInputs(this);
 454     }
 455 
 456     @Override
 457     public Stamp getValueStampForSuccessor(AbstractBeginNode beginNode) {
 458         Stamp result = null;
 459         if (beginNode != this.defaultSuccessor()) {
 460             for (int i = 0; i < keyCount(); i++) {
 461                 if (keySuccessor(i) == beginNode) {
 462                     if (result == null) {
 463                         result = StampFactory.forConstant(keyAt(i));
 464                     } else {
 465                         result = result.meet(StampFactory.forConstant(keyAt(i)));
 466                     }
 467                 }
 468             }
 469         }
 470         return result;
 471     }
 472 }