1 /*
   2  * Copyright (c) 2019, Oracle and/or its affiliates. All rights reserved.
   3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
   4  *
   5  * This code is free software; you can redistribute it and/or modify it
   6  * under the terms of the GNU General Public License version 2 only, as
   7  * published by the Free Software Foundation.  Oracle designates this
   8  * particular file as subject to the "Classpath" exception as provided
   9  * by Oracle in the LICENSE file that accompanied this code.
  10  *
  11  * This code is distributed in the hope that it will be useful, but WITHOUT
  12  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
  13  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
  14  * version 2 for more details (a copy is included in the LICENSE file that
  15  * accompanied this code).
  16  *
  17  * You should have received a copy of the GNU General Public License version
  18  * 2 along with this work; if not, write to the Free Software Foundation,
  19  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
  20  *
  21  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
  22  * or visit www.oracle.com if you need additional information or have any
  23  * questions.
  24  */
  25 
  26 package java.lang.invoke;
  27 
  28 import jdk.internal.misc.Unsafe;
  29 import jdk.internal.org.objectweb.asm.ClassReader;
  30 import jdk.internal.org.objectweb.asm.ClassWriter;
  31 import jdk.internal.org.objectweb.asm.MethodVisitor;
  32 import jdk.internal.org.objectweb.asm.Opcodes;
  33 import jdk.internal.org.objectweb.asm.Type;
  34 import jdk.internal.org.objectweb.asm.util.TraceClassVisitor;
  35 import jdk.internal.vm.annotation.ForceInline;
  36 import sun.security.action.GetBooleanAction;
  37 import sun.security.action.GetPropertyAction;
  38 
  39 import java.foreign.memory.MemoryAddress;
  40 import java.io.File;
  41 import java.io.FileOutputStream;
  42 import java.io.IOException;
  43 import java.io.PrintWriter;
  44 import java.io.StringWriter;
  45 import java.util.ArrayList;
  46 import java.util.Arrays;
  47 import java.util.HashMap;
  48 
  49 import static jdk.internal.org.objectweb.asm.Opcodes.ACC_FINAL;
  50 import static jdk.internal.org.objectweb.asm.Opcodes.ACC_PUBLIC;
  51 import static jdk.internal.org.objectweb.asm.Opcodes.ACC_STATIC;
  52 import static jdk.internal.org.objectweb.asm.Opcodes.ACC_SUPER;
  53 import static jdk.internal.org.objectweb.asm.Opcodes.ALOAD;
  54 import static jdk.internal.org.objectweb.asm.Opcodes.ARETURN;
  55 import static jdk.internal.org.objectweb.asm.Opcodes.CHECKCAST;
  56 import static jdk.internal.org.objectweb.asm.Opcodes.GETFIELD;
  57 import static jdk.internal.org.objectweb.asm.Opcodes.ILOAD;
  58 import static jdk.internal.org.objectweb.asm.Opcodes.INVOKESPECIAL;
  59 import static jdk.internal.org.objectweb.asm.Opcodes.INVOKEVIRTUAL;
  60 import static jdk.internal.org.objectweb.asm.Opcodes.LADD;
  61 import static jdk.internal.org.objectweb.asm.Opcodes.LALOAD;
  62 import static jdk.internal.org.objectweb.asm.Opcodes.LLOAD;
  63 import static jdk.internal.org.objectweb.asm.Opcodes.LMUL;
  64 import static jdk.internal.org.objectweb.asm.Opcodes.PUTFIELD;
  65 import static jdk.internal.org.objectweb.asm.Opcodes.RETURN;
  66 
  67 class AddressVarHandleGenerator {
  68     private static final String DEBUG_DUMP_CLASSES_DIR_PROPERTY = "jdk.internal.foreign.ClassGenerator.DEBUG_DUMP_CLASSES_DIR";
  69 
  70     private static final boolean DEBUG =
  71         GetBooleanAction.privilegedGetProperty("jdk.internal.foreign.ClassGenerator.DEBUG");
  72 
  73     private static final File DEBUG_DUMP_CLASSES_DIR;
  74 
  75     private static final HashMap<Class<?>, Class<?>> baseClassCache;
  76 
  77     static {
  78         baseClassCache = new HashMap<>();
  79         baseClassCache.put(byte.class, VarHandleMemoryAddressAsBytes.MemoryAddressHandleBase.class);
  80         baseClassCache.put(short.class, VarHandleMemoryAddressAsShorts.MemoryAddressHandleBase.class);
  81         baseClassCache.put(char.class, VarHandleMemoryAddressAsChars.MemoryAddressHandleBase.class);
  82         baseClassCache.put(int.class, VarHandleMemoryAddressAsInts.MemoryAddressHandleBase.class);
  83         baseClassCache.put(long.class, VarHandleMemoryAddressAsLongs.MemoryAddressHandleBase.class);
  84         baseClassCache.put(float.class, VarHandleMemoryAddressAsFloats.MemoryAddressHandleBase.class);
  85         baseClassCache.put(double.class, VarHandleMemoryAddressAsDoubles.MemoryAddressHandleBase.class);
  86     }
  87 
  88     static {
  89         String path = GetPropertyAction.privilegedGetProperty(DEBUG_DUMP_CLASSES_DIR_PROPERTY);
  90         if (path == null) {
  91             DEBUG_DUMP_CLASSES_DIR = null;
  92         } else {
  93             DEBUG_DUMP_CLASSES_DIR = new File(path);
  94         }
  95     }
  96 
  97     private static final Unsafe U = Unsafe.getUnsafe();
  98 
  99     private final Class<?> hostClass;
 100     private final String implClassName;
 101     private final int dimensions;
 102     private final Class<?> carrier;
 103     private final VarForm form;
 104 
 105     AddressVarHandleGenerator(Class<?> carrier, int dims) {
 106         this.hostClass = baseClassCache.get(carrier);
 107         this.dimensions = dims;
 108         this.carrier = carrier;
 109         Class<?>[] components = new Class<?>[dimensions];
 110         Arrays.fill(components, long.class);
 111         this.form = new VarForm(hostClass, MemoryAddress.class, carrier, components);
 112         this.implClassName = hostClass.getSimpleName() + dimensions;
 113     }
 114 
 115     /*
 116      * Generate a VarHandle memory access factory.
 117      * The factory has type (ZJJ[J)VarHandle.
 118      */
 119     MethodHandle generateHandleFactory() {
 120         Class<?> implCls = generateClass();
 121         try {
 122             Class<?>[] components = new Class<?>[dimensions];
 123             Arrays.fill(components, long.class);
 124 
 125             VarForm form = new VarForm(implCls, MemoryAddress.class, carrier, components);
 126 
 127             MethodType constrType = MethodType.methodType(void.class, VarForm.class, boolean.class, long.class, long.class, long[].class);
 128             MethodHandle constr = MethodHandles.Lookup.IMPL_LOOKUP.findConstructor(implCls, constrType);
 129             constr = MethodHandles.insertArguments(constr, 0, form);
 130             return constr;
 131         } catch (Throwable ex) {
 132             throw new AssertionError(ex);
 133         }
 134     }
 135 
 136     /*
 137      * Generate a specialized VarHandle class for given carrier
 138      * and access coordinates.
 139      */
 140     Class<?> generateClass() {
 141         BinderClassWriter cw = new BinderClassWriter();
 142 
 143         if (DEBUG) {
 144             System.out.println("Generating header implementation class");
 145         }
 146 
 147         cw.visit(52, ACC_PUBLIC | ACC_SUPER, implClassName, null, Type.getInternalName(hostClass), null);
 148 
 149         //add dimension fields
 150         for (int i = 0; i < dimensions; i++) {
 151             cw.visitField(ACC_FINAL, "dim" + i, "J", null, null);
 152         }
 153 
 154         addConstructor(cw);
 155 
 156         addAccessModeTypeMethod(cw);
 157 
 158         for (VarHandle.AccessMode mode : VarHandle.AccessMode.values()) {
 159             addAccessModeMethodIfNeeded(mode, cw);
 160         }
 161 
 162 
 163         cw.visitEnd();
 164         byte[] classBytes = cw.toByteArray();
 165         return defineClass(cw, classBytes);
 166     }
 167 
 168     void addConstructor(BinderClassWriter cw) {
 169         MethodType constrType = MethodType.methodType(void.class, VarForm.class, boolean.class, long.class, long.class, long[].class);
 170         MethodVisitor mv = cw.visitMethod(0, "<init>", constrType.toMethodDescriptorString(), null, null);
 171         mv.visitCode();
 172         //super call
 173         mv.visitVarInsn(ALOAD, 0);
 174         mv.visitVarInsn(ALOAD, 1);
 175         mv.visitTypeInsn(CHECKCAST, Type.getInternalName(VarForm.class));
 176         mv.visitVarInsn(ILOAD, 2);
 177         mv.visitVarInsn(LLOAD, 3);
 178         mv.visitVarInsn(LLOAD, 5);
 179         mv.visitMethodInsn(INVOKESPECIAL, Type.getInternalName(hostClass), "<init>",
 180                 MethodType.methodType(void.class, VarForm.class, boolean.class, long.class, long.class).toMethodDescriptorString(), false);
 181         //init dimensions
 182         for (int i = 0 ; i < dimensions ; i++) {
 183             mv.visitVarInsn(ALOAD, 0);
 184             mv.visitVarInsn(ALOAD, 7);
 185             mv.visitLdcInsn(i);
 186             mv.visitInsn(LALOAD);
 187             mv.visitFieldInsn(PUTFIELD, implClassName, "dim" + i, "J");
 188         }
 189         mv.visitInsn(RETURN);
 190         mv.visitMaxs(0, 0);
 191         mv.visitEnd();
 192     }
 193 
 194     void addAccessModeTypeMethod(BinderClassWriter cw) {
 195         MethodType modeMethType = MethodType.methodType(MethodType.class, VarHandle.AccessMode.class);
 196         MethodVisitor mv = cw.visitMethod(ACC_FINAL, "accessModeTypeUncached", modeMethType.toMethodDescriptorString(), null, null);
 197         mv.visitCode();
 198         mv.visitVarInsn(ALOAD, 1);
 199         mv.visitFieldInsn(GETFIELD, Type.getInternalName(VarHandle.AccessMode.class), "at", Type.getDescriptor(VarHandle.AccessType.class));
 200         mv.visitLdcInsn(cw.makeConstantPoolPatch(MemoryAddress.class));
 201         mv.visitTypeInsn(CHECKCAST, Type.getInternalName(Class.class));
 202         mv.visitLdcInsn(cw.makeConstantPoolPatch(carrier));
 203         mv.visitTypeInsn(CHECKCAST, Type.getInternalName(Class.class));
 204 
 205         Class<?>[] dims = new Class<?>[dimensions];
 206         Arrays.fill(dims, long.class);
 207         mv.visitLdcInsn(cw.makeConstantPoolPatch(dims));
 208         mv.visitTypeInsn(CHECKCAST, Type.getInternalName(Class[].class));
 209 
 210         mv.visitMethodInsn(INVOKEVIRTUAL, Type.getInternalName(VarHandle.AccessType.class),
 211                 "accessModeType", MethodType.methodType(MethodType.class, Class.class, Class.class, Class[].class).toMethodDescriptorString(), false);
 212 
 213         mv.visitInsn(ARETURN);
 214 
 215         mv.visitMaxs(0, 0);
 216         mv.visitEnd();
 217     }
 218 
 219     void addAccessModeMethodIfNeeded(VarHandle.AccessMode mode, BinderClassWriter cw) {
 220         String methName = mode.methodName();
 221         MethodType methType = form.getMethodType(mode.at.ordinal())
 222                 .insertParameterTypes(0, hostClass);
 223 
 224         try {
 225             MethodType helperType = methType.insertParameterTypes(2, long.class);
 226             if (dimensions > 0) {
 227                 helperType = helperType.dropParameterTypes(3, 3 + dimensions);
 228             }
 229             MethodHandle mh = MethodHandles.Lookup.IMPL_LOOKUP
 230                     .findStatic(hostClass.getDeclaringClass(),
 231                             methName + "0",
 232                             helperType);
 233 
 234 
 235             MethodVisitor mv = cw.visitMethod(ACC_STATIC, methName, methType.toMethodDescriptorString(), null, null);
 236             mv.visitAnnotation(Type.getDescriptor(ForceInline.class), true);
 237             mv.visitCode();
 238             mv.visitLdcInsn(cw.makeConstantPoolPatch(mh));
 239             mv.visitTypeInsn(CHECKCAST, Type.getInternalName(MethodHandle.class));
 240 
 241             mv.visitVarInsn(ALOAD, 0); // handle impl
 242             mv.visitVarInsn(ALOAD, 1); // receiver
 243 
 244             // offset calculation
 245             int slot = 2;
 246             mv.visitVarInsn(ALOAD, 0); // load recv
 247             mv.visitFieldInsn(GETFIELD, Type.getInternalName(hostClass), "offset", "J");
 248             for (int i = 0 ; i < dimensions ; i++) {
 249                 mv.visitVarInsn(ALOAD, 0); // load recv
 250                 mv.visitFieldInsn(GETFIELD, implClassName, "dim" + i, "J");
 251                 mv.visitVarInsn(LLOAD, slot);
 252                 mv.visitInsn(LMUL);
 253                 mv.visitInsn(LADD);
 254                 slot += 2;
 255             }
 256 
 257             for (int i = 2 + dimensions; i < methType.parameterCount() ; i++) {
 258                 Class<?> param = methType.parameterType(i);
 259                 mv.visitVarInsn(loadInsn(param), slot); // load index
 260                 slot += getSlotsForType(param);
 261             }
 262 
 263             //call MH
 264             mv.visitMethodInsn(INVOKEVIRTUAL, Type.getInternalName(MethodHandle.class), "invokeExact",
 265                     helperType.toMethodDescriptorString(), false);
 266 
 267             mv.visitInsn(returnInsn(helperType.returnType()));
 268 
 269             mv.visitMaxs(0, 0);
 270             mv.visitEnd();
 271         } catch (ReflectiveOperationException ex) {
 272             //not found, skip
 273         }
 274     }
 275 
 276     //where
 277     private Class<?> defineClass(BinderClassWriter cw, byte[] classBytes) {
 278         try {
 279             if (DEBUG_DUMP_CLASSES_DIR != null) {
 280                 debugWriteClassToFile(classBytes);
 281             }
 282             Object[] patches = cw.resolvePatches(classBytes);
 283             Class<?> c = U.defineAnonymousClass(hostClass, classBytes, patches);
 284             return c;
 285         } catch (VerifyError e) {
 286             debugPrintClass(classBytes);
 287             throw e;
 288         }
 289     }
 290 
 291     // shared code generation helpers
 292 
 293     private static int getSlotsForType(Class<?> c) {
 294         if (c == long.class || c == double.class) {
 295             return 2;
 296         }
 297         return 1;
 298     }
 299 
 300     /**
 301      * Emits an actual return instruction conforming to the given return type.
 302      */
 303     private int returnInsn(Class<?> type) {
 304         switch (LambdaForm.BasicType.basicType(type)) {
 305             case I_TYPE:  return Opcodes.IRETURN;
 306             case J_TYPE:  return Opcodes.LRETURN;
 307             case F_TYPE:  return Opcodes.FRETURN;
 308             case D_TYPE:  return Opcodes.DRETURN;
 309             case L_TYPE:  return Opcodes.ARETURN;
 310             case V_TYPE:  return RETURN;
 311             default:
 312                 throw new InternalError("unknown return type: " + type);
 313         }
 314     }
 315 
 316     private int loadInsn(Class<?> type) {
 317         switch (LambdaForm.BasicType.basicType(type)) {
 318             case I_TYPE:  return Opcodes.ILOAD;
 319             case J_TYPE:  return LLOAD;
 320             case F_TYPE:  return Opcodes.FLOAD;
 321             case D_TYPE:  return Opcodes.DLOAD;
 322             case L_TYPE:  return Opcodes.ALOAD;
 323             default:
 324                 throw new InternalError("unknown local type: " + type);
 325         }
 326     }
 327 
 328     // debug helpers
 329 
 330     private static String debugPrintClass(byte[] classFile) {
 331         ClassReader cr = new ClassReader(classFile);
 332         StringWriter sw = new StringWriter();
 333         cr.accept(new TraceClassVisitor(new PrintWriter(sw)), 0);
 334         return sw.toString();
 335     }
 336 
 337     private void debugWriteClassToFile(byte[] classFile) {
 338         File file = new File(DEBUG_DUMP_CLASSES_DIR, implClassName + ".class");
 339 
 340         if (DEBUG) {
 341             System.err.println("Dumping class " + implClassName + " to " + file);
 342         }
 343 
 344         try {
 345             debugWriteDataToFile(classFile, file);
 346         } catch (Exception e) {
 347             throw new RuntimeException("Failed to write class " + implClassName + " to file " + file);
 348         }
 349     }
 350 
 351     private void debugWriteDataToFile(byte[] data, File file) {
 352         if (file.exists()) {
 353             file.delete();
 354         }
 355         if (file.exists()) {
 356             throw new RuntimeException("Failed to remove pre-existing file " + file);
 357         }
 358 
 359         File parent = file.getParentFile();
 360         if (!parent.exists()) {
 361             parent.mkdirs();
 362         }
 363         if (!parent.exists()) {
 364             throw new RuntimeException("Failed to create " + parent);
 365         }
 366         if (!parent.isDirectory()) {
 367             throw new RuntimeException(parent + " is not a directory");
 368         }
 369 
 370         try (FileOutputStream fos = new FileOutputStream(file)) {
 371             fos.write(data);
 372         } catch (IOException e) {
 373             throw new RuntimeException("Failed to write class " + implClassName + " to file " + file);
 374         }
 375     }
 376 
 377     static class BinderClassWriter extends ClassWriter {
 378 
 379         private final ArrayList<ConstantPoolPatch> cpPatches = new ArrayList<>();
 380         private int curUniquePatchIndex = 0;
 381 
 382         BinderClassWriter() {
 383             super(ClassWriter.COMPUTE_FRAMES | ClassWriter.COMPUTE_MAXS);
 384         }
 385 
 386         public String makeConstantPoolPatch(Object o) {
 387             int myUniqueIndex = curUniquePatchIndex++;
 388             String cpPlaceholder = "CONSTANT_PLACEHOLDER_" + myUniqueIndex;
 389             int index = newConst(cpPlaceholder);
 390             cpPatches.add(new ConstantPoolPatch(index, cpPlaceholder, o));
 391             return cpPlaceholder;
 392         }
 393 
 394         public Object[] resolvePatches(byte[] classFile) {
 395             if (cpPatches.isEmpty()) {
 396                 return null;
 397             }
 398 
 399             int size = ((classFile[8] & 0xFF) << 8) | (classFile[9] & 0xFF);
 400 
 401             Object[] patches = new Object[size];
 402             for (ConstantPoolPatch p : cpPatches) {
 403                 if (p.index >= size) {
 404                     throw new InternalError("Failed to resolve constant pool patch entries");
 405                 }
 406                 patches[p.index] = p.value;
 407             }
 408 
 409             return patches;
 410         }
 411 
 412         static class ConstantPoolPatch {
 413             final int index;
 414             final String placeholder;
 415             final Object value;
 416 
 417             ConstantPoolPatch(int index, String placeholder, Object value) {
 418                 this.index = index;
 419                 this.placeholder = placeholder;
 420                 this.value = value;
 421             }
 422 
 423             @Override
 424             public String toString() {
 425                 return "CpPatch/index="+index+",placeholder="+placeholder+",value="+value;
 426             }
 427         }
 428     }
 429 }