1 /*
   2  * Copyright (c) 2016, 2016, Oracle and/or its affiliates. All rights reserved.
   3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
   4  *
   5  * This code is free software; you can redistribute it and/or modify it
   6  * under the terms of the GNU General Public License version 2 only, as
   7  * published by the Free Software Foundation.
   8  *
   9  * This code is distributed in the hope that it will be useful, but WITHOUT
  10  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
  11  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
  12  * version 2 for more details (a copy is included in the LICENSE file that
  13  * accompanied this code).
  14  *
  15  * You should have received a copy of the GNU General Public License version
  16  * 2 along with this work; if not, write to the Free Software Foundation,
  17  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
  18  *
  19  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
  20  * or visit www.oracle.com if you need additional information or have any
  21  * questions.
  22  */
  23 package org.graalvm.compiler.api.directives.test;
  24 
  25 import java.io.IOException;
  26 import java.util.ArrayList;
  27 import java.util.HashMap;
  28 import java.util.Iterator;
  29 import java.util.List;
  30 import java.util.Map;
  31 
  32 import org.graalvm.compiler.test.ExportingClassLoader;
  33 
  34 import jdk.internal.org.objectweb.asm.ClassReader;
  35 import jdk.internal.org.objectweb.asm.ClassWriter;
  36 import jdk.internal.org.objectweb.asm.Label;
  37 import jdk.internal.org.objectweb.asm.MethodVisitor;
  38 import jdk.internal.org.objectweb.asm.Opcodes;
  39 import jdk.internal.org.objectweb.asm.tree.AbstractInsnNode;
  40 import jdk.internal.org.objectweb.asm.tree.ClassNode;
  41 import jdk.internal.org.objectweb.asm.tree.IincInsnNode;
  42 import jdk.internal.org.objectweb.asm.tree.InsnList;
  43 import jdk.internal.org.objectweb.asm.tree.JumpInsnNode;
  44 import jdk.internal.org.objectweb.asm.tree.LabelNode;
  45 import jdk.internal.org.objectweb.asm.tree.LineNumberNode;
  46 import jdk.internal.org.objectweb.asm.tree.MethodNode;
  47 import jdk.internal.org.objectweb.asm.tree.VarInsnNode;
  48 
  49 /**
  50  * The {@code TinyInstrumentor} is a bytecode instrumentor using ASM bytecode manipulation
  51  * framework. It injects given code snippet into a target method and creates a temporary class as
  52  * the container. Because the target method is cloned into the temporary class, it is required that
  53  * the target method is public static. Any referred method/field in the target method or the
  54  * instrumentation snippet should be made public as well.
  55  */
  56 public class TinyInstrumentor implements Opcodes {
  57 
  58     private InsnList instrumentationInstructions;
  59     private int instrumentationMaxLocal;
  60 
  61     /**
  62      * Create a instrumentor with a instrumentation snippet. The snippet is specified with the given
  63      * class {@code instrumentationClass} and the given method name {@code methodName}.
  64      */
  65     public TinyInstrumentor(Class<?> instrumentationClass, String methodName) throws IOException {
  66         MethodNode instrumentationMethod = getMethodNode(instrumentationClass, methodName);
  67         assert instrumentationMethod != null;
  68         assert (instrumentationMethod.access | ACC_STATIC) != 0;
  69         assert "()V".equals(instrumentationMethod.desc);
  70         instrumentationInstructions = cloneInstructions(instrumentationMethod.instructions);
  71         instrumentationMaxLocal = instrumentationMethod.maxLocals;
  72         // replace return instructions with a goto unless there is a single return at the end. In
  73         // that case, simply remove the return.
  74         List<AbstractInsnNode> returnInstructions = new ArrayList<>();
  75         for (AbstractInsnNode instruction : selectAll(instrumentationInstructions)) {
  76             if (instruction instanceof LineNumberNode) {
  77                 instrumentationInstructions.remove(instruction);
  78             } else if (instruction.getOpcode() == RETURN) {
  79                 returnInstructions.add(instruction);
  80             }
  81         }
  82         LabelNode exit = new LabelNode();
  83         if (returnInstructions.size() == 1) {
  84             AbstractInsnNode returnInstruction = returnInstructions.get(0);
  85             if (instrumentationInstructions.getLast() != returnInstruction) {
  86                 instrumentationInstructions.insertBefore(returnInstruction, new JumpInsnNode(GOTO, exit));
  87             }
  88             instrumentationInstructions.remove(returnInstruction);
  89         } else {
  90             for (AbstractInsnNode returnInstruction : returnInstructions) {
  91                 instrumentationInstructions.insertBefore(returnInstruction, new JumpInsnNode(GOTO, exit));
  92                 instrumentationInstructions.remove(returnInstruction);
  93             }
  94         }
  95         instrumentationInstructions.add(exit);
  96     }
  97 
  98     /**
  99      * @return a {@link MethodNode} called {@code methodName} in the given class.
 100      */
 101     private static MethodNode getMethodNode(Class<?> clazz, String methodName) throws IOException {
 102         ClassReader classReader = new ClassReader(clazz.getName());
 103         ClassNode classNode = new ClassNode();
 104         classReader.accept(classNode, ClassReader.SKIP_FRAMES);
 105 
 106         for (MethodNode methodNode : classNode.methods) {
 107             if (methodNode.name.equals(methodName)) {
 108                 return methodNode;
 109             }
 110         }
 111         return null;
 112     }
 113 
 114     /**
 115      * Create a {@link ClassNode} with empty constructor.
 116      */
 117     private static ClassNode emptyClass(String name) {
 118         ClassNode classNode = new ClassNode();
 119         classNode.visit(52, ACC_SUPER | ACC_PUBLIC, name.replace('.', '/'), null, "java/lang/Object", new String[]{});
 120 
 121         MethodVisitor mv = classNode.visitMethod(ACC_PUBLIC, "<init>", "()V", null, null);
 122         mv.visitCode();
 123         Label l0 = new Label();
 124         mv.visitLabel(l0);
 125         mv.visitVarInsn(ALOAD, 0);
 126         mv.visitMethodInsn(INVOKESPECIAL, "java/lang/Object", "<init>", "()V", false);
 127         mv.visitInsn(RETURN);
 128         Label l1 = new Label();
 129         mv.visitLabel(l1);
 130         mv.visitMaxs(1, 1);
 131         mv.visitEnd();
 132 
 133         return classNode;
 134     }
 135 
 136     /**
 137      * Helper method for iterating the given {@link InsnList}.
 138      */
 139     private static Iterable<AbstractInsnNode> selectAll(InsnList instructions) {
 140         return new Iterable<AbstractInsnNode>() {
 141             @Override
 142             public Iterator<AbstractInsnNode> iterator() {
 143                 return instructions.iterator();
 144             }
 145         };
 146     }
 147 
 148     /**
 149      * Make a clone of the given {@link InsnList}.
 150      */
 151     private static InsnList cloneInstructions(InsnList instructions) {
 152         Map<LabelNode, LabelNode> labelMap = new HashMap<>();
 153         for (AbstractInsnNode instruction : selectAll(instructions)) {
 154             if (instruction instanceof LabelNode) {
 155                 LabelNode clone = new LabelNode(new Label());
 156                 LabelNode original = (LabelNode) instruction;
 157                 labelMap.put(original, clone);
 158             }
 159         }
 160         InsnList clone = new InsnList();
 161         for (AbstractInsnNode insn : selectAll(instructions)) {
 162             clone.add(insn.clone(labelMap));
 163         }
 164         return clone;
 165     }
 166 
 167     /**
 168      * Shifts all local variable slot references by a specified amount.
 169      */
 170     private static void shiftLocalSlots(InsnList instructions, int offset) {
 171         for (AbstractInsnNode insn : selectAll(instructions)) {
 172             if (insn instanceof VarInsnNode) {
 173                 VarInsnNode varInsn = (VarInsnNode) insn;
 174                 varInsn.var += offset;
 175 
 176             } else if (insn instanceof IincInsnNode) {
 177                 IincInsnNode iincInsn = (IincInsnNode) insn;
 178                 iincInsn.var += offset;
 179             }
 180         }
 181     }
 182 
 183     /**
 184      * Instrument the target method specified by the class {@code targetClass} and the method name
 185      * {@code methodName}. For each occurrence of the {@code opcode}, the instrumentor injects a
 186      * copy of the instrumentation snippet.
 187      */
 188     public Class<?> instrument(Class<?> targetClass, String methodName, int opcode) throws IOException, ClassNotFoundException {
 189         return instrument(targetClass, methodName, opcode, true);
 190     }
 191 
 192     public Class<?> instrument(Class<?> targetClass, String methodName, int opcode, boolean insertAfter) throws IOException, ClassNotFoundException {
 193         // create a container class
 194         String className = targetClass.getName() + "$$" + methodName;
 195         ClassNode classNode = emptyClass(className);
 196         // duplicate the target method and add to the container class
 197         MethodNode methodNode = getMethodNode(targetClass, methodName);
 198         MethodNode newMethodNode = new MethodNode(methodNode.access, methodNode.name, methodNode.desc, methodNode.signature, methodNode.exceptions.toArray(new String[methodNode.exceptions.size()]));
 199         methodNode.accept(newMethodNode);
 200         classNode.methods.add(newMethodNode);
 201         // perform bytecode instrumentation
 202         for (AbstractInsnNode instruction : selectAll(newMethodNode.instructions)) {
 203             if (instruction.getOpcode() == opcode) {
 204                 InsnList instrumentation = cloneInstructions(instrumentationInstructions);
 205                 shiftLocalSlots(instrumentation, newMethodNode.maxLocals);
 206                 newMethodNode.maxLocals += instrumentationMaxLocal;
 207                 if (insertAfter) {
 208                     newMethodNode.instructions.insert(instruction, instrumentation);
 209                 } else {
 210                     newMethodNode.instructions.insertBefore(instruction, instrumentation);
 211                 }
 212             }
 213         }
 214         // dump a byte array and load the class with a dedicated loader to separate the namespace
 215         ClassWriter classWriter = new ClassWriter(ClassWriter.COMPUTE_FRAMES);
 216         classNode.accept(classWriter);
 217         byte[] bytes = classWriter.toByteArray();
 218         return new Loader(className, bytes).findClass(className);
 219     }
 220 
 221     private static class Loader extends ExportingClassLoader {
 222 
 223         private String className;
 224         private byte[] bytes;
 225 
 226         Loader(String className, byte[] bytes) {
 227             super(TinyInstrumentor.class.getClassLoader());
 228             this.className = className;
 229             this.bytes = bytes;
 230         }
 231 
 232         @Override
 233         protected Class<?> findClass(String name) throws ClassNotFoundException {
 234             if (name.equals(className)) {
 235                 return defineClass(name, bytes, 0, bytes.length);
 236             } else {
 237                 return super.findClass(name);
 238             }
 239         }
 240     }
 241 
 242 }