1 /*
   2  * Copyright (c) 2016, Oracle and/or its affiliates. All rights reserved.
   3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
   4  *
   5  * This code is free software; you can redistribute it and/or modify it
   6  * under the terms of the GNU General Public License version 2 only, as
   7  * published by the Free Software Foundation.
   8  *
   9  * This code is distributed in the hope that it will be useful, but WITHOUT
  10  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
  11  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
  12  * version 2 for more details (a copy is included in the LICENSE file that
  13  * accompanied this code).
  14  *
  15  * You should have received a copy of the GNU General Public License version
  16  * 2 along with this work; if not, write to the Free Software Foundation,
  17  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
  18  *
  19  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
  20  * or visit www.oracle.com if you need additional information or have any
  21  * questions.
  22  */
  23 
  24 package jdk.test;
  25 
  26 import java.lang.invoke.MethodHandle;
  27 import java.lang.invoke.MethodHandles;
  28 import java.lang.invoke.MethodHandles.Lookup;
  29 import java.lang.invoke.MethodType;
  30 import java.lang.reflect.Method;
  31 import java.security.Permission;
  32 import java.security.PermissionCollection;
  33 import java.security.Permissions;
  34 import java.security.Policy;
  35 import java.security.ProtectionDomain;
  36 import java.util.CSM.Result;
  37 import java.util.function.Supplier;
  38 
  39 /**
  40  * This test invokes StackWalker::getCallerClass via static reference,
  41  * reflection, MethodHandle, lambda.  Also verify that
  42  * StackWalker::getCallerClass can't be called from @CallerSensitive method.
  43  */
  44 public class CallerSensitiveTest {
  45     private static final String NON_CSM_CALLER_METHOD = "getCallerClass";
  46     private static final String CSM_CALLER_METHOD = "caller";
  47 
  48     public static void main(String... args) throws Throwable {
  49         boolean sm = false;
  50         if (args.length > 0 && args[0].equals("sm")) {
  51             sm = true;
  52             PermissionCollection perms = new Permissions();
  53             perms.add(new StackFramePermission("retainClassReference"));
  54             Policy.setPolicy(new Policy() {
  55                 @Override
  56                 public boolean implies(ProtectionDomain domain, Permission p) {
  57                     return perms.implies(p);
  58                 }
  59             });
  60             System.setSecurityManager(new SecurityManager());
  61         }
  62 
  63         System.err.format("Test %s security manager.%n",
  64                           sm ? "with" : "without");
  65 
  66         CallerSensitiveTest cstest = new CallerSensitiveTest();
  67         // test static call to java.util.CSM::caller and CSM::getCallerClass
  68         cstest.staticMethodCall();
  69         // test java.lang.reflect.Method call
  70         cstest.reflectMethodCall();
  71         // test java.lang.invoke.MethodHandle
  72         cstest.invokeMethodHandle(Lookup1.lookup);
  73         cstest.invokeMethodHandle(Lookup2.lookup);
  74         // test method ref
  75         cstest.lambda();
  76 
  77         LambdaTest.lambda();
  78 
  79 
  80         if (failed > 0) {
  81             throw new RuntimeException(failed + " test cases failed.");
  82         }
  83     }
  84 
  85     void staticMethodCall() {
  86         java.util.CSM.caller();
  87 
  88         Result result = java.util.CSM.getCallerClass();
  89         checkNonCSMCaller(CallerSensitiveTest.class, result);
  90     }
  91 
  92     void reflectMethodCall() throws Throwable {
  93         Method method1 = java.util.CSM.class.getMethod(CSM_CALLER_METHOD);
  94         method1.invoke(null);
  95 
  96         Method method2 = java.util.CSM.class.getMethod(NON_CSM_CALLER_METHOD);
  97         Result result = (Result) method2.invoke(null);
  98         checkNonCSMCaller(CallerSensitiveTest.class, result);
  99     }
 100 
 101     void invokeMethodHandle(Lookup lookup) throws Throwable {
 102         MethodHandle mh1 = lookup.findStatic(java.util.CSM.class, CSM_CALLER_METHOD,
 103             MethodType.methodType(Class.class));
 104         Class<?> c = (Class<?>)mh1.invokeExact();
 105 
 106         MethodHandle mh2 = lookup.findStatic(java.util.CSM.class, NON_CSM_CALLER_METHOD,
 107             MethodType.methodType(Result.class));
 108         Result result = (Result)mh2.invokeExact();
 109         checkNonCSMCaller(CallerSensitiveTest.class, result);
 110     }
 111 
 112     void lambda() {
 113         Result result = LambdaTest.getCallerClass.get();
 114         checkNonCSMCaller(CallerSensitiveTest.class, result);
 115 
 116         LambdaTest.caller.get();
 117     }
 118 
 119     static int failed = 0;
 120 
 121     static void checkNonCSMCaller(Class<?> expected, Result result) {
 122         if (result.callers.size() != 1) {
 123             throw new RuntimeException("Expected result.callers contain one element");
 124         }
 125         if (expected != result.callers.get(0)) {
 126             System.err.format("ERROR: Expected %s but got %s%n", expected,
 127                 result.callers);
 128             result.frames.stream()
 129                 .forEach(f -> System.err.println("   " + f));
 130             failed++;
 131         }
 132     }
 133 
 134     static class Lookup1 {
 135         static Lookup lookup = MethodHandles.lookup();
 136     }
 137 
 138     static class Lookup2 {
 139         static Lookup lookup = MethodHandles.lookup();
 140     }
 141 
 142     static class LambdaTest {
 143         static Supplier<Class<?>> caller = java.util.CSM::caller;
 144         static Supplier<Result> getCallerClass = java.util.CSM::getCallerClass;
 145 
 146         static void caller() {
 147             caller.get();
 148         }
 149         static Result getCallerClass() {
 150             return getCallerClass.get();
 151         }
 152 
 153         static void lambda() {
 154             Result result = LambdaTest.getCallerClass();
 155             checkNonCSMCaller(LambdaTest.class, result);
 156 
 157             LambdaTest.caller();
 158         }
 159     }
 160 }