1 /*
   2  * Copyright (c) 2019, Oracle and/or its affiliates. All rights reserved.
   3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
   4  *
   5  * This code is free software; you can redistribute it and/or modify it
   6  * under the terms of the GNU General Public License version 2 only, as
   7  * published by the Free Software Foundation.  Oracle designates this
   8  * particular file as subject to the "Classpath" exception as provided
   9  * by Oracle in the LICENSE file that accompanied this code.
  10  *
  11  * This code is distributed in the hope that it will be useful, but WITHOUT
  12  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
  13  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
  14  * version 2 for more details (a copy is included in the LICENSE file that
  15  * accompanied this code).
  16  *
  17  * You should have received a copy of the GNU General Public License version
  18  * 2 along with this work; if not, write to the Free Software Foundation,
  19  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
  20  *
  21  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
  22  * or visit www.oracle.com if you need additional information or have any
  23  * questions.
  24  */
  25 
  26 package java.lang.invoke;
  27 
  28 import jdk.internal.access.foreign.MemoryAddressProxy;
  29 import jdk.internal.misc.Unsafe;
  30 import jdk.internal.org.objectweb.asm.ClassReader;
  31 import jdk.internal.org.objectweb.asm.ClassWriter;
  32 import jdk.internal.org.objectweb.asm.MethodVisitor;
  33 import jdk.internal.org.objectweb.asm.Opcodes;
  34 import jdk.internal.org.objectweb.asm.Type;
  35 import jdk.internal.org.objectweb.asm.util.TraceClassVisitor;
  36 import jdk.internal.vm.annotation.ForceInline;
  37 import sun.security.action.GetBooleanAction;
  38 import sun.security.action.GetPropertyAction;
  39 
  40 import java.io.File;
  41 import java.io.FileOutputStream;
  42 import java.io.IOException;
  43 import java.io.PrintWriter;
  44 import java.io.StringWriter;
  45 import java.util.ArrayList;
  46 import java.util.Arrays;
  47 import java.util.HashMap;
  48 
  49 import static jdk.internal.org.objectweb.asm.Opcodes.ACC_FINAL;
  50 import static jdk.internal.org.objectweb.asm.Opcodes.ACC_PRIVATE;
  51 import static jdk.internal.org.objectweb.asm.Opcodes.ACC_PUBLIC;
  52 import static jdk.internal.org.objectweb.asm.Opcodes.ACC_STATIC;
  53 import static jdk.internal.org.objectweb.asm.Opcodes.ACC_SUPER;
  54 import static jdk.internal.org.objectweb.asm.Opcodes.ALOAD;
  55 import static jdk.internal.org.objectweb.asm.Opcodes.ARETURN;
  56 import static jdk.internal.org.objectweb.asm.Opcodes.BIPUSH;
  57 import static jdk.internal.org.objectweb.asm.Opcodes.CHECKCAST;
  58 import static jdk.internal.org.objectweb.asm.Opcodes.GETFIELD;
  59 import static jdk.internal.org.objectweb.asm.Opcodes.ICONST_0;
  60 import static jdk.internal.org.objectweb.asm.Opcodes.ICONST_1;
  61 import static jdk.internal.org.objectweb.asm.Opcodes.ICONST_2;
  62 import static jdk.internal.org.objectweb.asm.Opcodes.ICONST_3;
  63 import static jdk.internal.org.objectweb.asm.Opcodes.ICONST_4;
  64 import static jdk.internal.org.objectweb.asm.Opcodes.ICONST_5;
  65 import static jdk.internal.org.objectweb.asm.Opcodes.ICONST_M1;
  66 import static jdk.internal.org.objectweb.asm.Opcodes.ILOAD;
  67 import static jdk.internal.org.objectweb.asm.Opcodes.INVOKESPECIAL;
  68 import static jdk.internal.org.objectweb.asm.Opcodes.INVOKESTATIC;
  69 import static jdk.internal.org.objectweb.asm.Opcodes.INVOKEVIRTUAL;
  70 import static jdk.internal.org.objectweb.asm.Opcodes.LADD;
  71 import static jdk.internal.org.objectweb.asm.Opcodes.LALOAD;
  72 import static jdk.internal.org.objectweb.asm.Opcodes.LASTORE;
  73 import static jdk.internal.org.objectweb.asm.Opcodes.LLOAD;
  74 import static jdk.internal.org.objectweb.asm.Opcodes.LMUL;
  75 import static jdk.internal.org.objectweb.asm.Opcodes.NEWARRAY;
  76 import static jdk.internal.org.objectweb.asm.Opcodes.PUTFIELD;
  77 import static jdk.internal.org.objectweb.asm.Opcodes.RETURN;
  78 import static jdk.internal.org.objectweb.asm.Opcodes.DUP;
  79 import static jdk.internal.org.objectweb.asm.Opcodes.SIPUSH;
  80 import static jdk.internal.org.objectweb.asm.Opcodes.T_LONG;
  81 
  82 class AddressVarHandleGenerator {
  83     private static final String DEBUG_DUMP_CLASSES_DIR_PROPERTY = "jdk.internal.foreign.ClassGenerator.DEBUG_DUMP_CLASSES_DIR";
  84 
  85     private static final boolean DEBUG =
  86         GetBooleanAction.privilegedGetProperty("jdk.internal.foreign.ClassGenerator.DEBUG");
  87 
  88     private static final Class<?> BASE_CLASS = VarHandleMemoryAddressBase.class;
  89 
  90     private static final HashMap<Class<?>, Class<?>> helperClassCache;
  91 
  92     static {
  93         helperClassCache = new HashMap<>();
  94         helperClassCache.put(byte.class, VarHandleMemoryAddressAsBytes.class);
  95         helperClassCache.put(short.class, VarHandleMemoryAddressAsShorts.class);
  96         helperClassCache.put(char.class, VarHandleMemoryAddressAsChars.class);
  97         helperClassCache.put(int.class, VarHandleMemoryAddressAsInts.class);
  98         helperClassCache.put(long.class, VarHandleMemoryAddressAsLongs.class);
  99         helperClassCache.put(float.class, VarHandleMemoryAddressAsFloats.class);
 100         helperClassCache.put(double.class, VarHandleMemoryAddressAsDoubles.class);
 101     }
 102 
 103     private static final File DEBUG_DUMP_CLASSES_DIR;
 104 
 105     static {
 106         String path = GetPropertyAction.privilegedGetProperty(DEBUG_DUMP_CLASSES_DIR_PROPERTY);
 107         if (path == null) {
 108             DEBUG_DUMP_CLASSES_DIR = null;
 109         } else {
 110             DEBUG_DUMP_CLASSES_DIR = new File(path);
 111         }
 112     }
 113 
 114     private static final Unsafe U = Unsafe.getUnsafe();
 115 
 116     private final String implClassName;
 117     private final int dimensions;
 118     private final Class<?> carrier;
 119     private final Class<?> helperClass;
 120     private final VarForm form;
 121 
 122     AddressVarHandleGenerator(Class<?> carrier, int dims) {
 123         this.dimensions = dims;
 124         this.carrier = carrier;
 125         Class<?>[] components = new Class<?>[dimensions];
 126         Arrays.fill(components, long.class);
 127         this.form = new VarForm(BASE_CLASS, MemoryAddressProxy.class, carrier, components);
 128         this.helperClass = helperClassCache.get(carrier);
 129         this.implClassName = helperClass.getName().replace('.', '/') + dimensions;
 130     }
 131 
 132     /*
 133      * Generate a VarHandle memory access factory.
 134      * The factory has type (ZJJ[J)VarHandle.
 135      */
 136     MethodHandle generateHandleFactory() {
 137         Class<?> implCls = generateClass();
 138         try {
 139             Class<?>[] components = new Class<?>[dimensions];
 140             Arrays.fill(components, long.class);
 141 
 142             VarForm form = new VarForm(implCls, MemoryAddressProxy.class, carrier, components);
 143 
 144             MethodType constrType = MethodType.methodType(void.class, VarForm.class, boolean.class, long.class, long.class, long.class, long[].class);
 145             MethodHandle constr = MethodHandles.Lookup.IMPL_LOOKUP.findConstructor(implCls, constrType);
 146             constr = MethodHandles.insertArguments(constr, 0, form);
 147             return constr;
 148         } catch (Throwable ex) {
 149             throw new AssertionError(ex);
 150         }
 151     }
 152 
 153     /*
 154      * Generate a specialized VarHandle class for given carrier
 155      * and access coordinates.
 156      */
 157     Class<?> generateClass() {
 158         BinderClassWriter cw = new BinderClassWriter();
 159 
 160         if (DEBUG) {
 161             System.out.println("Generating header implementation class");
 162         }
 163 
 164         cw.visit(52, ACC_PUBLIC | ACC_SUPER, implClassName, null, Type.getInternalName(BASE_CLASS), null);
 165 
 166         //add dimension fields
 167         for (int i = 0; i < dimensions; i++) {
 168             cw.visitField(ACC_PRIVATE | ACC_FINAL, "dim" + i, "J", null, null);
 169         }
 170 
 171         addConstructor(cw);
 172 
 173         addAccessModeTypeMethod(cw);
 174 
 175         addStridesAccessor(cw);
 176 
 177         addCarrierAccessor(cw);
 178 
 179         for (VarHandle.AccessMode mode : VarHandle.AccessMode.values()) {
 180             addAccessModeMethodIfNeeded(mode, cw);
 181         }
 182 
 183 
 184         cw.visitEnd();
 185         byte[] classBytes = cw.toByteArray();
 186         return defineClass(cw, classBytes);
 187     }
 188 
 189     void addConstructor(BinderClassWriter cw) {
 190         MethodType constrType = MethodType.methodType(void.class, VarForm.class, boolean.class, long.class, long.class, long.class, long[].class);
 191         MethodVisitor mv = cw.visitMethod(0, "<init>", constrType.toMethodDescriptorString(), null, null);
 192         mv.visitCode();
 193         //super call
 194         mv.visitVarInsn(ALOAD, 0);
 195         mv.visitVarInsn(ALOAD, 1);
 196         mv.visitTypeInsn(CHECKCAST, Type.getInternalName(VarForm.class));
 197         mv.visitVarInsn(ILOAD, 2);
 198         mv.visitVarInsn(LLOAD, 3);
 199         mv.visitVarInsn(LLOAD, 5);
 200         mv.visitVarInsn(LLOAD, 7);
 201         mv.visitMethodInsn(INVOKESPECIAL, Type.getInternalName(BASE_CLASS), "<init>",
 202                 MethodType.methodType(void.class, VarForm.class, boolean.class, long.class, long.class, long.class).toMethodDescriptorString(), false);
 203         //init dimensions
 204         for (int i = 0 ; i < dimensions ; i++) {
 205             mv.visitVarInsn(ALOAD, 0);
 206             mv.visitVarInsn(ALOAD, 9);
 207             mv.visitLdcInsn(i);
 208             mv.visitInsn(LALOAD);
 209             mv.visitFieldInsn(PUTFIELD, implClassName, "dim" + i, "J");
 210         }
 211         mv.visitInsn(RETURN);
 212         mv.visitMaxs(0, 0);
 213         mv.visitEnd();
 214     }
 215 
 216     void addAccessModeTypeMethod(BinderClassWriter cw) {
 217         MethodType modeMethType = MethodType.methodType(MethodType.class, VarHandle.AccessMode.class);
 218         MethodVisitor mv = cw.visitMethod(ACC_FINAL, "accessModeTypeUncached", modeMethType.toMethodDescriptorString(), null, null);
 219         mv.visitCode();
 220         mv.visitVarInsn(ALOAD, 1);
 221         mv.visitFieldInsn(GETFIELD, Type.getInternalName(VarHandle.AccessMode.class), "at", Type.getDescriptor(VarHandle.AccessType.class));
 222         mv.visitLdcInsn(cw.makeConstantPoolPatch(MemoryAddressProxy.class));
 223         mv.visitTypeInsn(CHECKCAST, Type.getInternalName(Class.class));
 224         mv.visitLdcInsn(cw.makeConstantPoolPatch(carrier));
 225         mv.visitTypeInsn(CHECKCAST, Type.getInternalName(Class.class));
 226 
 227         Class<?>[] dims = new Class<?>[dimensions];
 228         Arrays.fill(dims, long.class);
 229         mv.visitLdcInsn(cw.makeConstantPoolPatch(dims));
 230         mv.visitTypeInsn(CHECKCAST, Type.getInternalName(Class[].class));
 231 
 232         mv.visitMethodInsn(INVOKEVIRTUAL, Type.getInternalName(VarHandle.AccessType.class),
 233                 "accessModeType", MethodType.methodType(MethodType.class, Class.class, Class.class, Class[].class).toMethodDescriptorString(), false);
 234 
 235         mv.visitInsn(ARETURN);
 236 
 237         mv.visitMaxs(0, 0);
 238         mv.visitEnd();
 239     }
 240 
 241     void addAccessModeMethodIfNeeded(VarHandle.AccessMode mode, BinderClassWriter cw) {
 242         String methName = mode.methodName();
 243         MethodType methType = form.getMethodType(mode.at.ordinal())
 244                 .insertParameterTypes(0, BASE_CLASS);
 245 
 246         try {
 247             MethodType helperType = methType.insertParameterTypes(2, long.class);
 248             if (dimensions > 0) {
 249                 helperType = helperType.dropParameterTypes(3, 3 + dimensions);
 250             }
 251             //try to resolve...
 252             String helperMethodName = methName + "0";
 253             MethodHandles.Lookup.IMPL_LOOKUP
 254                     .findStatic(helperClass,
 255                             helperMethodName,
 256                             helperType);
 257 
 258 
 259             MethodVisitor mv = cw.visitMethod(ACC_STATIC, methName, methType.toMethodDescriptorString(), null, null);
 260             mv.visitAnnotation(Type.getDescriptor(ForceInline.class), true);
 261             mv.visitCode();
 262 
 263             mv.visitVarInsn(ALOAD, 0); // handle impl
 264             mv.visitVarInsn(ALOAD, 1); // receiver
 265 
 266             // offset calculation
 267             int slot = 2;
 268             mv.visitVarInsn(ALOAD, 0); // load recv
 269             mv.visitFieldInsn(GETFIELD, Type.getInternalName(BASE_CLASS), "offset", "J");
 270             for (int i = 0 ; i < dimensions ; i++) {
 271                 mv.visitVarInsn(ALOAD, 0); // load recv
 272                 mv.visitTypeInsn(CHECKCAST, implClassName);
 273                 mv.visitFieldInsn(GETFIELD, implClassName, "dim" + i, "J");
 274                 mv.visitVarInsn(LLOAD, slot);
 275                 mv.visitInsn(LMUL);
 276                 mv.visitInsn(LADD);
 277                 slot += 2;
 278             }
 279 
 280             for (int i = 2 + dimensions; i < methType.parameterCount() ; i++) {
 281                 Class<?> param = methType.parameterType(i);
 282                 mv.visitVarInsn(loadInsn(param), slot); // load index
 283                 slot += getSlotsForType(param);
 284             }
 285 
 286             //call helper
 287             mv.visitMethodInsn(INVOKESTATIC, Type.getInternalName(helperClass), helperMethodName,
 288                     helperType.toMethodDescriptorString(), false);
 289 
 290             mv.visitInsn(returnInsn(helperType.returnType()));
 291 
 292             mv.visitMaxs(0, 0);
 293             mv.visitEnd();
 294         } catch (ReflectiveOperationException ex) {
 295             //not found, skip
 296         }
 297     }
 298 
 299     void addStridesAccessor(BinderClassWriter cw) {
 300         MethodVisitor mv = cw.visitMethod(ACC_FINAL, "strides", "()[J", null, null);
 301         mv.visitCode();
 302         iConstInsn(mv, dimensions);
 303         mv.visitIntInsn(NEWARRAY, T_LONG);
 304 
 305         for (int i = 0 ; i < dimensions ; i++) {
 306             mv.visitInsn(DUP);
 307             iConstInsn(mv, i);
 308             mv.visitVarInsn(ALOAD, 0);
 309             mv.visitFieldInsn(GETFIELD, implClassName, "dim" + i, "J");
 310             mv.visitInsn(LASTORE);
 311         }
 312 
 313         mv.visitInsn(ARETURN);
 314         mv.visitMaxs(0, 0);
 315         mv.visitEnd();
 316     }
 317 
 318     void addCarrierAccessor(BinderClassWriter cw) {
 319         MethodVisitor mv = cw.visitMethod(ACC_FINAL, "carrier", "()Ljava/lang/Class;", null, null);
 320         mv.visitCode();
 321         mv.visitLdcInsn(cw.makeConstantPoolPatch(carrier));
 322         mv.visitTypeInsn(CHECKCAST, Type.getInternalName(Class.class));
 323         mv.visitInsn(ARETURN);
 324         mv.visitMaxs(0, 0);
 325         mv.visitEnd();
 326     }
 327 
 328     //where
 329     private Class<?> defineClass(BinderClassWriter cw, byte[] classBytes) {
 330         try {
 331             if (DEBUG_DUMP_CLASSES_DIR != null) {
 332                 debugWriteClassToFile(classBytes);
 333             }
 334             Object[] patches = cw.resolvePatches(classBytes);
 335             Class<?> c = U.defineAnonymousClass(BASE_CLASS, classBytes, patches);
 336             return c;
 337         } catch (Throwable e) {
 338             debugPrintClass(classBytes);
 339             throw e;
 340         }
 341     }
 342 
 343     // shared code generation helpers
 344 
 345     private static int getSlotsForType(Class<?> c) {
 346         if (c == long.class || c == double.class) {
 347             return 2;
 348         }
 349         return 1;
 350     }
 351 
 352     /**
 353      * Emits an actual return instruction conforming to the given return type.
 354      */
 355     private int returnInsn(Class<?> type) {
 356         return switch (LambdaForm.BasicType.basicType(type)) {
 357             case I_TYPE -> Opcodes.IRETURN;
 358             case J_TYPE -> Opcodes.LRETURN;
 359             case F_TYPE -> Opcodes.FRETURN;
 360             case D_TYPE -> Opcodes.DRETURN;
 361             case L_TYPE -> Opcodes.ARETURN;
 362             case V_TYPE -> RETURN;
 363         };
 364     }
 365 
 366     private int loadInsn(Class<?> type) {
 367         return switch (LambdaForm.BasicType.basicType(type)) {
 368             case I_TYPE -> Opcodes.ILOAD;
 369             case J_TYPE -> LLOAD;
 370             case F_TYPE -> Opcodes.FLOAD;
 371             case D_TYPE -> Opcodes.DLOAD;
 372             case L_TYPE -> Opcodes.ALOAD;
 373             case V_TYPE -> throw new IllegalStateException("Cannot load void");
 374         };
 375     }
 376 
 377     private static void iConstInsn(MethodVisitor mv, int i) {
 378         switch (i) {
 379             case -1, 0, 1, 2, 3, 4, 5:
 380                 mv.visitInsn(ICONST_0 + i);
 381                 break;
 382             default:
 383                 if(i >= Byte.MIN_VALUE && i <= Byte.MAX_VALUE) {
 384                     mv.visitIntInsn(BIPUSH, i);
 385                 } else if (i >= Short.MIN_VALUE && i <= Short.MAX_VALUE) {
 386                     mv.visitIntInsn(SIPUSH, i);
 387                 } else {
 388                     mv.visitLdcInsn(i);
 389                 }
 390         }
 391     }
 392 
 393     // debug helpers
 394 
 395     private static String debugPrintClass(byte[] classFile) {
 396         ClassReader cr = new ClassReader(classFile);
 397         StringWriter sw = new StringWriter();
 398         cr.accept(new TraceClassVisitor(new PrintWriter(sw)), 0);
 399         return sw.toString();
 400     }
 401 
 402     private void debugWriteClassToFile(byte[] classFile) {
 403         File file = new File(DEBUG_DUMP_CLASSES_DIR, implClassName + ".class");
 404 
 405         if (DEBUG) {
 406             System.err.println("Dumping class " + implClassName + " to " + file);
 407         }
 408 
 409         try {
 410             debugWriteDataToFile(classFile, file);
 411         } catch (Exception e) {
 412             throw new RuntimeException("Failed to write class " + implClassName + " to file " + file);
 413         }
 414     }
 415 
 416     private void debugWriteDataToFile(byte[] data, File file) {
 417         if (file.exists()) {
 418             file.delete();
 419         }
 420         if (file.exists()) {
 421             throw new RuntimeException("Failed to remove pre-existing file " + file);
 422         }
 423 
 424         File parent = file.getParentFile();
 425         if (!parent.exists()) {
 426             parent.mkdirs();
 427         }
 428         if (!parent.exists()) {
 429             throw new RuntimeException("Failed to create " + parent);
 430         }
 431         if (!parent.isDirectory()) {
 432             throw new RuntimeException(parent + " is not a directory");
 433         }
 434 
 435         try (FileOutputStream fos = new FileOutputStream(file)) {
 436             fos.write(data);
 437         } catch (IOException e) {
 438             throw new RuntimeException("Failed to write class " + implClassName + " to file " + file);
 439         }
 440     }
 441 
 442     static class BinderClassWriter extends ClassWriter {
 443 
 444         private final ArrayList<ConstantPoolPatch> cpPatches = new ArrayList<>();
 445         private int curUniquePatchIndex = 0;
 446 
 447         BinderClassWriter() {
 448             super(ClassWriter.COMPUTE_FRAMES | ClassWriter.COMPUTE_MAXS);
 449         }
 450 
 451         public String makeConstantPoolPatch(Object o) {
 452             int myUniqueIndex = curUniquePatchIndex++;
 453             String cpPlaceholder = "CONSTANT_PLACEHOLDER_" + myUniqueIndex;
 454             int index = newConst(cpPlaceholder);
 455             cpPatches.add(new ConstantPoolPatch(index, cpPlaceholder, o));
 456             return cpPlaceholder;
 457         }
 458 
 459         public Object[] resolvePatches(byte[] classFile) {
 460             if (cpPatches.isEmpty()) {
 461                 return null;
 462             }
 463 
 464             int size = ((classFile[8] & 0xFF) << 8) | (classFile[9] & 0xFF);
 465 
 466             Object[] patches = new Object[size];
 467             for (ConstantPoolPatch p : cpPatches) {
 468                 if (p.index >= size) {
 469                     throw new InternalError("Failed to resolve constant pool patch entries");
 470                 }
 471                 patches[p.index] = p.value;
 472             }
 473 
 474             return patches;
 475         }
 476 
 477         static class ConstantPoolPatch {
 478             final int index;
 479             final String placeholder;
 480             final Object value;
 481 
 482             ConstantPoolPatch(int index, String placeholder, Object value) {
 483                 this.index = index;
 484                 this.placeholder = placeholder;
 485                 this.value = value;
 486             }
 487 
 488             @Override
 489             public String toString() {
 490                 return "CpPatch/index="+index+",placeholder="+placeholder+",value="+value;
 491             }
 492         }
 493     }
 494 }