1 /*
   2  * Copyright (c) 2015, 2016, Oracle and/or its affiliates. All rights reserved.
   3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
   4  *
   5  * This code is free software; you can redistribute it and/or modify it
   6  * under the terms of the GNU General Public License version 2 only, as
   7  * published by the Free Software Foundation.
   8  *
   9  * This code is distributed in the hope that it will be useful, but WITHOUT
  10  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
  11  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
  12  * version 2 for more details (a copy is included in the LICENSE file that
  13  * accompanied this code).
  14  *
  15  * You should have received a copy of the GNU General Public License version
  16  * 2 along with this work; if not, write to the Free Software Foundation,
  17  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
  18  *
  19  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
  20  * or visit www.oracle.com if you need additional information or have any
  21  * questions.
  22  */
  23 
  24 /*
  25  * @test
  26  * @bug 8140450 8152893
  27  * @summary Basic test for StackWalker.getCallerClass()
  28  * @run main/othervm GetCallerClassTest
  29  * @run main/othervm GetCallerClassTest sm
  30  */
  31 
  32 import static java.lang.StackWalker.Option.*;
  33 import java.lang.invoke.MethodHandle;
  34 import java.lang.invoke.MethodHandles;
  35 import java.lang.invoke.MethodType;
  36 import java.lang.reflect.InvocationTargetException;
  37 import java.lang.reflect.Method;
  38 import java.security.Permission;
  39 import java.security.PermissionCollection;
  40 import java.security.Permissions;
  41 import java.security.Policy;
  42 import java.security.ProtectionDomain;
  43 import java.util.Arrays;
  44 import java.util.EnumSet;
  45 import java.util.List;
  46 
  47 public class GetCallerClassTest {
  48     private final StackWalker walker;
  49     private final boolean expectUOE;
  50 
  51     public GetCallerClassTest(StackWalker sw, boolean expect) {
  52         this.walker = sw;
  53         this.expectUOE = expect;
  54     }
  55     public static void main(String... args) throws Exception {
  56         if (args.length > 0 && args[0].equals("sm")) {
  57             PermissionCollection perms = new Permissions();
  58             perms.add(new StackFramePermission("retainClassReference"));
  59             Policy.setPolicy(new Policy() {
  60                 @Override
  61                 public boolean implies(ProtectionDomain domain, Permission p) {
  62                     return perms.implies(p);
  63                 }
  64             });
  65             System.setSecurityManager(new SecurityManager());
  66         }
  67         new GetCallerClassTest(StackWalker.getInstance(), true).test();
  68         new GetCallerClassTest(StackWalker.getInstance(RETAIN_CLASS_REFERENCE), false).test();
  69         new GetCallerClassTest(StackWalker.getInstance(EnumSet.of(RETAIN_CLASS_REFERENCE,
  70                                                                   SHOW_HIDDEN_FRAMES)), false).test();
  71     }
  72 
  73     public void test() {
  74         new TopLevelCaller().run();
  75         new LambdaTest().run();
  76         new Nested().createNestedCaller().run();
  77         new InnerClassCaller().run();
  78         new ReflectionTest().run();
  79 
  80         List<Thread> threads = Arrays.asList(
  81                 new Thread(new TopLevelCaller()),
  82                 new Thread(new LambdaTest()),
  83                 new Thread(new Nested().createNestedCaller()),
  84                 new Thread(new InnerClassCaller()),
  85                 new Thread(new ReflectionTest())
  86         );
  87         threads.stream().forEach(Thread::start);
  88         threads.stream().forEach(t -> {
  89             try {
  90                 t.join();
  91             } catch (InterruptedException e) {
  92                 throw new RuntimeException(e);
  93             }
  94         });
  95     }
  96 
  97     public static void staticGetCallerClass(StackWalker stackWalker,
  98                                             Class<?> expected,
  99                                             boolean expectUOE) {
 100         try {
 101             Class<?> c = stackWalker.getCallerClass();
 102             assertEquals(c, expected);
 103             if (expectUOE) { // Should have thrown
 104                 throw new RuntimeException("Didn't get expected exception");
 105             }
 106         } catch (RuntimeException e) { // also catches UOE
 107             if (expectUOE && causeIsUOE(e)) {
 108                 return; /* expected */
 109             }
 110             System.err.println("Unexpected exception:");
 111             throw e;
 112         }
 113     }
 114 
 115     public static void reflectiveGetCallerClass(StackWalker stackWalker,
 116                                                 Class<?> expected,
 117                                                 boolean expectUOE) {
 118         try {
 119             Method m = StackWalker.class.getMethod("getCallerClass");
 120             Class<?> c = (Class<?>) m.invoke(stackWalker);
 121             assertEquals(c, expected);
 122             if (expectUOE) { // Should have thrown
 123                 throw new RuntimeException("Didn't get expected exception");
 124             }
 125         } catch (Throwable e) {
 126             if (expectUOE && causeIsUOE(e)) {
 127                 return; /* expected */
 128             }
 129             System.err.println("Unexpected exception:");
 130             throw new RuntimeException(e);
 131         }
 132     }
 133 
 134     public static void methodHandleGetCallerClass(StackWalker stackWalker,
 135                                                   Class<?> expected,
 136                                                   boolean expectUOE) {
 137         MethodHandles.Lookup lookup = MethodHandles.lookup();
 138         try {
 139             MethodHandle mh = lookup.findVirtual(StackWalker.class, "getCallerClass",
 140                                                  MethodType.methodType(Class.class));
 141             Class<?> c = (Class<?>) mh.invokeExact(stackWalker);
 142             assertEquals(c, expected);
 143             if (expectUOE) { // Should have thrown
 144                 throw new RuntimeException("Didn't get expected exception");
 145             }
 146         } catch (Throwable e) {
 147             if (expectUOE && causeIsUOE(e)) {
 148                 return; /* expected */
 149             }
 150             System.err.println("Unexpected exception:");
 151             throw new RuntimeException(e);
 152         }
 153     }
 154 
 155     public static void assertEquals(Class<?> c, Class<?> expected) {
 156         if (expected != c) {
 157             throw new RuntimeException("Got " + c + ", but expected " + expected);
 158         }
 159     }
 160 
 161     /** Is there an UnsupportedOperationException in there? */
 162     public static boolean causeIsUOE(Throwable t) {
 163         while (t != null) {
 164             if (t instanceof UnsupportedOperationException) {
 165                 return true;
 166             }
 167             t = t.getCause();
 168         }
 169         return false;
 170     }
 171 
 172     class TopLevelCaller implements Runnable {
 173         public void run() {
 174             GetCallerClassTest.staticGetCallerClass(walker, this.getClass(), expectUOE);
 175             GetCallerClassTest.reflectiveGetCallerClass(walker, this.getClass(), expectUOE);
 176             GetCallerClassTest.methodHandleGetCallerClass(walker, this.getClass(), expectUOE);
 177         }
 178     }
 179 
 180     class LambdaTest implements Runnable {
 181         public void run() {
 182             Runnable lambdaRunnable = () -> {
 183                 try {
 184                     Class<?> c = walker.getCallerClass();
 185 
 186                     assertEquals(c, LambdaTest.class);
 187                     if (expectUOE) { // Should have thrown
 188                         throw new RuntimeException("Didn't get expected exception");
 189                     }
 190                 } catch (Throwable e) {
 191                     if (expectUOE && causeIsUOE(e)) {
 192                         return; /* expected */
 193                     }
 194                     System.err.println("Unexpected exception:");
 195                     throw new RuntimeException(e);
 196                 }
 197             };
 198             lambdaRunnable.run();
 199         }
 200     }
 201 
 202     class Nested {
 203         NestedClassCaller createNestedCaller() { return new NestedClassCaller(); }
 204         class NestedClassCaller implements Runnable {
 205             public void run() {
 206                 GetCallerClassTest.staticGetCallerClass(walker, this.getClass(), expectUOE);
 207                 GetCallerClassTest.reflectiveGetCallerClass(walker, this.getClass(), expectUOE);
 208                 GetCallerClassTest.methodHandleGetCallerClass(walker, this.getClass(), expectUOE);
 209             }
 210         }
 211     }
 212 
 213     class InnerClassCaller implements Runnable {
 214         public void run() {
 215             new Inner().test();
 216         }
 217         class Inner {
 218             void test() {
 219                 GetCallerClassTest.staticGetCallerClass(walker, this.getClass(), expectUOE);
 220                 GetCallerClassTest.reflectiveGetCallerClass(walker, this.getClass(), expectUOE);
 221                 GetCallerClassTest.methodHandleGetCallerClass(walker, this.getClass(), expectUOE);
 222             }
 223         }
 224     }
 225 
 226     class ReflectionTest implements Runnable {
 227         final MethodType methodType =
 228             MethodType.methodType(void.class, StackWalker.class, Class.class, boolean.class);
 229 
 230         public void run() {
 231             callMethodHandle();
 232             callMethodHandleRefl();
 233             callMethodInvoke();
 234             callMethodInvokeRefl();
 235         }
 236         void callMethodHandle() {
 237             MethodHandles.Lookup lookup = MethodHandles.publicLookup();
 238             try {
 239                 MethodHandle mh = lookup.findStatic(GetCallerClassTest.class,
 240                                                     "staticGetCallerClass",
 241                                                     methodType);
 242                 mh.invokeExact(walker, ReflectionTest.class, expectUOE);
 243             } catch (Throwable e) {
 244                 throw new RuntimeException(e);
 245             }
 246         }
 247         void callMethodHandleRefl() {
 248             MethodHandles.Lookup lookup = MethodHandles.publicLookup();
 249             try {
 250                 MethodHandle mh = lookup.findStatic(GetCallerClassTest.class,
 251                                                     "reflectiveGetCallerClass",
 252                                                     methodType);
 253                 mh.invokeExact(walker, ReflectionTest.class, expectUOE);
 254             } catch (Throwable e) {
 255                 throw new RuntimeException(e);
 256             }
 257         }
 258         void callMethodInvoke() {
 259             try {
 260                 Method m = GetCallerClassTest.class.getMethod("staticGetCallerClass",
 261                                StackWalker.class, Class.class, boolean.class);
 262                 m.invoke(null, walker, ReflectionTest.class, expectUOE);
 263             } catch (NoSuchMethodException|IllegalAccessException|InvocationTargetException e) {
 264                 throw new RuntimeException(e);
 265             }
 266         }
 267         void callMethodInvokeRefl() {
 268             try {
 269                 Method m = GetCallerClassTest.class.getMethod("reflectiveGetCallerClass",
 270                                StackWalker.class, Class.class, boolean.class);
 271                 m.invoke(null, walker, ReflectionTest.class, expectUOE);
 272             } catch (UnsupportedOperationException e) {
 273                 throw e;
 274             } catch (NoSuchMethodException|IllegalAccessException|InvocationTargetException e) {
 275                 throw new RuntimeException(e);
 276             }
 277         }
 278     }
 279 }