1 /*
   2  * Copyright (c) 2018, 2018, Oracle and/or its affiliates. All rights reserved.
   3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
   4  *
   5  * This code is free software; you can redistribute it and/or modify it
   6  * under the terms of the GNU General Public License version 2 only, as
   7  * published by the Free Software Foundation.
   8  *
   9  * This code is distributed in the hope that it will be useful, but WITHOUT
  10  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
  11  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
  12  * version 2 for more details (a copy is included in the LICENSE file that
  13  * accompanied this code).
  14  *
  15  * You should have received a copy of the GNU General Public License version
  16  * 2 along with this work; if not, write to the Free Software Foundation,
  17  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
  18  *
  19  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
  20  * or visit www.oracle.com if you need additional information or have any
  21  * questions.
  22  */
  23 
  24 
  25 package org.graalvm.compiler.phases.common;
  26 
  27 import static org.graalvm.compiler.core.common.SpeculativeExecutionAttacksMitigations.NonDeoptGuardTargets;
  28 import static org.graalvm.compiler.core.common.SpeculativeExecutionAttacksMitigations.Options.MitigateSpeculativeExecutionAttacks;
  29 import static org.graalvm.compiler.core.common.SpeculativeExecutionAttacksMitigations.Options.UseIndexMasking;
  30 
  31 import java.util.ArrayList;
  32 import java.util.List;
  33 
  34 import org.graalvm.compiler.core.common.type.IntegerStamp;
  35 import org.graalvm.compiler.core.common.type.StampFactory;
  36 import org.graalvm.compiler.debug.DebugContext;
  37 import org.graalvm.compiler.graph.Node;
  38 import org.graalvm.compiler.graph.Position;
  39 import org.graalvm.compiler.nodeinfo.InputType;
  40 import org.graalvm.compiler.nodes.AbstractBeginNode;
  41 import org.graalvm.compiler.nodes.DeoptimizeNode;
  42 import org.graalvm.compiler.nodes.IfNode;
  43 import org.graalvm.compiler.nodes.LoopExitNode;
  44 import org.graalvm.compiler.nodes.NamedLocationIdentity;
  45 import org.graalvm.compiler.nodes.PiNode;
  46 import org.graalvm.compiler.nodes.StructuredGraph;
  47 import org.graalvm.compiler.nodes.extended.MultiGuardNode;
  48 import org.graalvm.compiler.nodes.memory.Access;
  49 import org.graalvm.compiler.phases.Phase;
  50 
  51 import jdk.vm.ci.meta.DeoptimizationReason;
  52 
  53 /**
  54  * This phase sets the {@linkplain AbstractBeginNode#setWithSpeculationFence() speculation fence}
  55  * flag on {@linkplain AbstractBeginNode begin nodes} in order to mitigate speculative execution
  56  * attacks.
  57  */
  58 public class InsertGuardFencesPhase extends Phase {
  59     @Override
  60     protected void run(StructuredGraph graph) {
  61         for (AbstractBeginNode beginNode : graph.getNodes(AbstractBeginNode.TYPE)) {
  62             if (hasGuardUsages(beginNode)) {
  63                 if (MitigateSpeculativeExecutionAttacks.getValue(graph.getOptions()) == NonDeoptGuardTargets) {
  64                     if (isDeoptGuard(beginNode)) {
  65                         graph.getDebug().log(DebugContext.VERBOSE_LEVEL, "Skipping deoptimizing guard speculation fence at %s", beginNode);
  66                         continue;
  67                     }
  68                 }
  69                 if (UseIndexMasking.getValue(graph.getOptions())) {
  70                     if (isBoundsCheckGuard(beginNode)) {
  71                         graph.getDebug().log(DebugContext.VERBOSE_LEVEL, "Skipping bounds-check speculation fence at %s", beginNode);
  72                         continue;
  73                     }
  74                 }
  75                 if (graph.getDebug().isLogEnabled(DebugContext.DETAILED_LEVEL)) {
  76                     graph.getDebug().log(DebugContext.DETAILED_LEVEL, "Adding speculation fence for %s at %s", guardUsages(beginNode), beginNode);
  77                 } else {
  78                     graph.getDebug().log(DebugContext.VERBOSE_LEVEL, "Adding speculation fence at %s", beginNode);
  79                 }
  80                 beginNode.setWithSpeculationFence();
  81             } else {
  82                 graph.getDebug().log(DebugContext.DETAILED_LEVEL, "No guards on %s", beginNode);
  83             }
  84         }
  85     }
  86 
  87     private static boolean isDeoptGuard(AbstractBeginNode beginNode) {
  88         if (!(beginNode.predecessor() instanceof IfNode)) {
  89             return false;
  90         }
  91         IfNode ifNode = (IfNode) beginNode.predecessor();
  92         AbstractBeginNode otherBegin;
  93         if (ifNode.trueSuccessor() == beginNode) {
  94             otherBegin = ifNode.falseSuccessor();
  95         } else {
  96             assert ifNode.falseSuccessor() == beginNode;
  97             otherBegin = ifNode.trueSuccessor();
  98         }
  99         if (!(otherBegin.next() instanceof DeoptimizeNode)) {
 100             return false;
 101         }
 102         DeoptimizeNode deopt = (DeoptimizeNode) otherBegin.next();
 103         return deopt.getAction().doesInvalidateCompilation();
 104     }
 105 
 106     public static final IntegerStamp POSITIVE_ARRAY_INDEX_STAMP = StampFactory.forInteger(32, 0, Integer.MAX_VALUE - 1);
 107 
 108     private static boolean isBoundsCheckGuard(AbstractBeginNode beginNode) {
 109         if (!(beginNode.predecessor() instanceof IfNode)) {
 110             return false;
 111         }
 112         IfNode ifNode = (IfNode) beginNode.predecessor();
 113         AbstractBeginNode otherBegin;
 114         if (ifNode.trueSuccessor() == beginNode) {
 115             otherBegin = ifNode.falseSuccessor();
 116         } else {
 117             assert ifNode.falseSuccessor() == beginNode;
 118             otherBegin = ifNode.trueSuccessor();
 119         }
 120         if (otherBegin.next() instanceof DeoptimizeNode) {
 121             DeoptimizeNode deopt = (DeoptimizeNode) otherBegin.next();
 122             if (deopt.getReason() == DeoptimizationReason.BoundsCheckException && !hasMultipleGuardUsages(beginNode)) {
 123                 return true;
 124             }
 125         } else if (otherBegin instanceof LoopExitNode && beginNode.usages().filter(MultiGuardNode.class).isNotEmpty() && !hasMultipleGuardUsages(beginNode)) {
 126             return true;
 127         }
 128 
 129         for (Node usage : beginNode.usages()) {
 130             for (Position pos : usage.inputPositions()) {
 131                 if (pos.getInputType() == InputType.Guard && pos.get(usage) == beginNode) {
 132                     if (usage instanceof PiNode) {
 133                         if (!((PiNode) usage).piStamp().equals(POSITIVE_ARRAY_INDEX_STAMP)) {
 134                             return false;
 135                         }
 136                     } else if (usage instanceof Access) {
 137                         if (!NamedLocationIdentity.isArrayLocation(((Access) usage).getLocationIdentity())) {
 138                             return false;
 139                         }
 140                     } else {
 141                         return false;
 142                     }
 143                     break;
 144                 }
 145             }
 146         }
 147 
 148         return true;
 149     }
 150 
 151     private static boolean hasGuardUsages(Node n) {
 152         for (Node usage : n.usages()) {
 153             for (Position pos : usage.inputPositions()) {
 154                 if (pos.getInputType() == InputType.Guard && pos.get(usage) == n) {
 155                     return true;
 156                 }
 157             }
 158         }
 159         return false;
 160     }
 161 
 162     private static boolean hasMultipleGuardUsages(Node n) {
 163         boolean foundOne = false;
 164         for (Node usage : n.usages()) {
 165             for (Position pos : usage.inputPositions()) {
 166                 if (pos.getInputType() == InputType.Guard && pos.get(usage) == n) {
 167                     if (foundOne) {
 168                         return true;
 169                     }
 170                     foundOne = true;
 171                 }
 172             }
 173         }
 174         return false;
 175     }
 176 
 177     private static List<Node> guardUsages(Node n) {
 178         List<Node> ret = new ArrayList<>();
 179         for (Node usage : n.usages()) {
 180             for (Position pos : usage.inputPositions()) {
 181                 if (pos.getInputType() == InputType.Guard && pos.get(usage) == n) {
 182                     ret.add(usage);
 183                 }
 184             }
 185         }
 186         return ret;
 187     }
 188 }