1 /*
   2  * Copyright (c) 2018, Oracle and/or its affiliates. All rights reserved.
   3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
   4  *
   5  * This code is free software; you can redistribute it and/or modify it
   6  * under the terms of the GNU General Public License version 2 only, as
   7  * published by the Free Software Foundation.
   8  *
   9  * This code is distributed in the hope that it will be useful, but WITHOUT
  10  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
  11  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
  12  * version 2 for more details (a copy is included in the LICENSE file that
  13  * accompanied this code).
  14  *
  15  * You should have received a copy of the GNU General Public License version
  16  * 2 along with this work; if not, write to the Free Software Foundation,
  17  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
  18  *
  19  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
  20  * or visit www.oracle.com if you need additional information or have any
  21  * questions.
  22  */
  23 
  24 /*
  25  * @test
  26  * @library /test/lib
  27  * @modules java.base/jdk.internal.org.objectweb.asm
  28  * @build  DefineClassWithClassData
  29  * @run testng/othervm DefineClassWithClassData
  30  */
  31 
  32 import java.lang.invoke.MethodHandles;
  33 import java.lang.invoke.MethodHandles.Lookup;
  34 import java.lang.reflect.InvocationTargetException;
  35 import java.lang.reflect.Method;
  36 import java.util.List;
  37 import java.util.stream.Stream;
  38 
  39 import jdk.internal.org.objectweb.asm.*;
  40 import org.testng.annotations.Test;
  41 
  42 import static java.lang.invoke.MethodHandles.Lookup.ClassOption.*;
  43 import static java.lang.invoke.MethodHandles.Lookup.PRIVATE;
  44 import static jdk.internal.org.objectweb.asm.Opcodes.*;
  45 import static org.testng.Assert.*;
  46 
  47 public class DefineClassWithClassData {
  48     private static final byte[] T_CLASS_BYTES = ClassByteBuilder.classBytes("T");
  49     private static final byte[] T2_CLASS_BYTES = ClassByteBuilder.classBytes("T2");
  50 
  51     private int privMethod() { return 1234; }
  52 
  53     /*
  54      * invoke int test(DefineClassWithClassData o) method defined in the injected class
  55      */
  56     private int testInjectedClass(Class<?> c) throws Throwable {
  57         try {
  58             Method m = c.getMethod("test", DefineClassWithClassData.class);
  59             return (int) m.invoke(c.newInstance(), this);
  60         } catch (InvocationTargetException e) {
  61             throw e.getCause();
  62         }
  63     }
  64 
  65     /*
  66      * Returns the value of the static final "data" field in the injected class
  67      */
  68     private Object injectedData(Class<?> c) throws Throwable {
  69         return c.getDeclaredField("data").get(null);
  70     }
  71 
  72     private static final List<String> classData = List.of("nestmate", "classdata");
  73 
  74     @Test
  75     public void defineNestMate() throws Throwable {
  76         // define a nestmate
  77         Lookup lookup = MethodHandles.lookup().defineHiddenClassWithClassData(T_CLASS_BYTES, classData, NESTMATE);
  78         Class<?> c = lookup.lookupClass();
  79         assertTrue(c.getNestHost() == DefineClassWithClassData.class);
  80         assertEquals(classData, injectedData(c));
  81 
  82         // invoke int test(DefineClassWithClassData o)
  83         int x = testInjectedClass(c);
  84         assertTrue(x == privMethod());
  85 
  86         // dynamic nestmate is not listed in the return array of getNestMembers
  87         assertTrue(Stream.of(c.getNestHost().getNestMembers()).noneMatch(k -> k == c));
  88         assertTrue(c.isNestmateOf(DefineClassWithClassData.class));
  89     }
  90 
  91     @Test
  92     public void defineHiddenClass() throws Throwable {
  93         // define a hidden class
  94         Lookup lookup = MethodHandles.lookup().defineHiddenClassWithClassData(T_CLASS_BYTES, classData, NESTMATE);
  95         Class<?> c = lookup.lookupClass();
  96         assertTrue(c.getNestHost() == DefineClassWithClassData.class);
  97         assertTrue(c.isHiddenClass());
  98         assertEquals(classData, injectedData(c));
  99 
 100         // invoke int test(DefineClassWithClassData o)
 101         int x = testInjectedClass(c);
 102         assertTrue(x == privMethod());
 103 
 104         // dynamic nestmate is not listed in the return array of getNestMembers
 105         assertTrue(Stream.of(c.getNestHost().getNestMembers()).noneMatch(k -> k == c));
 106         assertTrue(c.isNestmateOf(DefineClassWithClassData.class));
 107     }
 108 
 109     @Test
 110     public void defineWeakClass() throws Throwable {
 111         // define a weak class
 112         Lookup lookup = MethodHandles.lookup().defineHiddenClassWithClassData(T_CLASS_BYTES, classData, WEAK);
 113         Class<?> c = lookup.lookupClass();
 114         assertTrue(c.getNestHost() == c);
 115         assertTrue(c.isHiddenClass());
 116     }
 117 
 118     @Test(expectedExceptions = IllegalAccessException.class)
 119     public void noPrivateLookupAccess() throws Throwable {
 120         Lookup lookup = MethodHandles.lookup().dropLookupMode(Lookup.PRIVATE);
 121         lookup.defineHiddenClassWithClassData(T2_CLASS_BYTES, classData, NESTMATE);
 122     }
 123 
 124     @Test
 125     public void teleportToNestmate() throws Throwable {
 126         Lookup lookup = MethodHandles.lookup()
 127             .defineHiddenClassWithClassData(T_CLASS_BYTES, classData, NESTMATE);
 128         Class<?> c = lookup.lookupClass();
 129         assertTrue(c.getNestHost() == DefineClassWithClassData.class);
 130         assertEquals(classData, injectedData(c));
 131         assertTrue(c.isHiddenClass());
 132 
 133         // Teleport to a nestmate
 134         Lookup lookup2 =  MethodHandles.lookup().in(DefineClassWithClassData.class);
 135         assertTrue((lookup2.lookupModes() & PRIVATE) != 0);
 136         Lookup lc = lookup2.defineHiddenClassWithClassData(T2_CLASS_BYTES, classData, NESTMATE);
 137         assertTrue(lc.lookupClass().getNestHost() == DefineClassWithClassData.class);
 138         assertTrue(lc.lookupClass().isHiddenClass());
 139     }
 140 
 141     static class ClassByteBuilder {
 142         static final String OBJECT_CLS = "java/lang/Object";
 143         static final String STRING_CLS = "java/lang/String";
 144         static final String LIST_CLS = "java/util/List";
 145         static final String MH_CLS = "java/lang/invoke/MethodHandles";
 146         static final String LOOKUP_CLS = "java/lang/invoke/MethodHandles$Lookup";
 147         static final String LOOKUP_SIG = "Ljava/lang/invoke/MethodHandles$Lookup;";
 148         static final String LIST_SIG = "Ljava/util/List;";
 149 
 150         static byte[] classBytes(String classname) {
 151             ClassWriter cw = new ClassWriter(ClassWriter.COMPUTE_FRAMES);
 152             MethodVisitor mv;
 153             FieldVisitor fv;
 154 
 155             String hostClassName = DefineClassWithClassData.class.getName();
 156 
 157             cw.visit(V11, ACC_FINAL, classname, null, OBJECT_CLS, null);
 158             {
 159                 fv = cw.visitField(ACC_STATIC | ACC_FINAL, "data", LIST_SIG, null, null);
 160                 fv.visitEnd();
 161             }
 162             {
 163                 mv = cw.visitMethod(ACC_STATIC, "<clinit>", "()V", null, null);
 164                 mv.visitCode();
 165 
 166                 // set up try block
 167                 Label lTryBlockStart =   new Label();
 168                 Label lTryBlockEnd =     new Label();
 169                 Label lCatchBlockStart = new Label();
 170                 Label lCatchBlockEnd =   new Label();
 171                 mv.visitTryCatchBlock(lTryBlockStart, lTryBlockEnd, lCatchBlockStart, "java/lang/IllegalAccessException");
 172 
 173                 mv.visitLabel(lTryBlockStart);
 174                 mv.visitMethodInsn(INVOKESTATIC, MH_CLS, "lookup", "()" + LOOKUP_SIG, false);
 175                 mv.visitLdcInsn(Type.getType(List.class));
 176                 mv.visitMethodInsn(INVOKEVIRTUAL, LOOKUP_CLS, "classData", "(Ljava/lang/Class;)Ljava/lang/Object;", false);
 177                 mv.visitTypeInsn(CHECKCAST, LIST_CLS);
 178                 mv.visitFieldInsn(PUTSTATIC, classname, "data", LIST_SIG);
 179                 mv.visitLabel(lTryBlockEnd);
 180                 mv.visitJumpInsn(GOTO, lCatchBlockEnd);
 181 
 182                 mv.visitLabel(lCatchBlockStart);
 183                 mv.visitVarInsn(ASTORE, 0);
 184                 mv.visitTypeInsn(NEW, "java/lang/Error");
 185                 mv.visitInsn(DUP);
 186                 mv.visitVarInsn(ALOAD, 0);
 187                 mv.visitMethodInsn(INVOKESPECIAL, "java/lang/Error", "<init>", "(Ljava/lang/Throwable;)V", false);
 188                 mv.visitInsn(ATHROW);
 189                 mv.visitLabel(lCatchBlockEnd);
 190                 mv.visitInsn(RETURN);
 191                 mv.visitMaxs(0, 0);
 192                 mv.visitEnd();
 193             }
 194 
 195             {
 196                 mv = cw.visitMethod(ACC_PUBLIC, "<init>", "()V", null, null);
 197                 mv.visitCode();
 198                 mv.visitVarInsn(ALOAD, 0);
 199                 mv.visitMethodInsn(INVOKESPECIAL, OBJECT_CLS, "<init>", "()V", false);
 200                 mv.visitInsn(RETURN);
 201                 mv.visitMaxs(0, 0);
 202                 mv.visitEnd();
 203             }
 204             {
 205                 mv = cw.visitMethod(ACC_PUBLIC, "test", "(L" + hostClassName + ";)I", null, null);
 206                 mv.visitCode();
 207                 mv.visitVarInsn(ALOAD, 0);
 208                 mv.visitVarInsn(ALOAD, 1);
 209                 mv.visitMethodInsn(INVOKEVIRTUAL, hostClassName, "privMethod", "()I", false);
 210                 mv.visitInsn(IRETURN);
 211                 mv.visitMaxs(0, 0);
 212                 mv.visitEnd();
 213             }
 214 
 215             {
 216                 mv = cw.visitMethod(ACC_PUBLIC | ACC_STATIC, "printData", "()V", null, null);
 217                 mv.visitCode();
 218                 mv.visitFieldInsn(GETSTATIC, "java/lang/System", "out", "Ljava/io/PrintStream;");
 219                 mv.visitFieldInsn(GETSTATIC, classname, "data", LIST_SIG);
 220                 mv.visitMethodInsn(INVOKEVIRTUAL, "java/io/PrintStream", "println", "(Ljava/lang/Object;)V", false);
 221                 mv.visitInsn(RETURN);
 222                 mv.visitMaxs(0, 0);
 223                 mv.visitEnd();
 224             }
 225             cw.visitEnd();
 226             return cw.toByteArray();
 227         }
 228     }
 229 }
 230 
 231