1 /*
   2  * Copyright (c) 2016, Oracle and/or its affiliates. All rights reserved.
   3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
   4  *
   5  * This code is free software; you can redistribute it and/or modify it
   6  * under the terms of the GNU General Public License version 2 only, as
   7  * published by the Free Software Foundation.
   8  *
   9  * This code is distributed in the hope that it will be useful, but WITHOUT
  10  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
  11  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
  12  * version 2 for more details (a copy is included in the LICENSE file that
  13  * accompanied this code).
  14  *
  15  * You should have received a copy of the GNU General Public License version
  16  * 2 along with this work; if not, write to the Free Software Foundation,
  17  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
  18  *
  19  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
  20  * or visit www.oracle.com if you need additional information or have any
  21  * questions.
  22  */
  23 
  24 package jdk.test;
  25 
  26 import java.lang.invoke.MethodHandle;
  27 import java.lang.invoke.MethodHandles;
  28 import java.lang.invoke.MethodHandles.Lookup;
  29 import java.lang.invoke.MethodType;
  30 import java.lang.reflect.Method;
  31 import java.security.Permission;
  32 import java.security.PermissionCollection;
  33 import java.security.Permissions;
  34 import java.security.Policy;
  35 import java.security.ProtectionDomain;
  36 import java.util.CSM.Result;
  37 import java.util.function.Supplier;
  38 
  39 /**
  40  * This test invokes StackWalker::getCallerClass via static reference,
  41  * reflection, MethodHandle, lambda.  Also verify that
  42  * StackWalker::getCallerClass can't be called from @CallerSensitive method.
  43  */
  44 public class CallerSensitiveTest {
  45     private static final String NON_CSM_CALLER_METHOD = "getCallerClass";
  46     private static final String CSM_CALLER_METHOD = "caller";
  47 
  48     public static void main(String... args) throws Throwable {
  49         boolean sm = false;
  50         if (args.length > 0 && args[0].equals("sm")) {
  51             sm = true;
  52             PermissionCollection perms = new Permissions();
  53             perms.add(new RuntimePermission("getStackWalkerWithClassReference"));
  54             Policy.setPolicy(new Policy() {
  55                 @Override
  56                 public boolean implies(ProtectionDomain domain, Permission p) {
  57                     return perms.implies(p);
  58                 }
  59             });
  60             System.setSecurityManager(new SecurityManager());
  61         }
  62 
  63         System.err.format("Test %s security manager.%n",
  64                           sm ? "with" : "without");
  65 
  66         CallerSensitiveTest cstest = new CallerSensitiveTest();
  67         // test static call to java.util.CSM::caller and CSM::getCallerClass
  68         cstest.staticMethodCall();
  69         // test java.lang.reflect.Method call
  70         cstest.reflectMethodCall();
  71         // test java.lang.invoke.MethodHandle
  72         cstest.invokeMethodHandle(Lookup1.lookup);
  73         cstest.invokeMethodHandle(Lookup2.lookup);
  74         // test method ref
  75         cstest.lambda();
  76 
  77         LambdaTest.lambda();
  78 
  79         if (failed > 0) {
  80             throw new RuntimeException(failed + " test cases failed.");
  81         }
  82     }
  83 
  84     void staticMethodCall() {
  85         java.util.CSM.caller();
  86 
  87         Result result = java.util.CSM.getCallerClass();
  88         checkNonCSMCaller(CallerSensitiveTest.class, result);
  89     }
  90 
  91     void reflectMethodCall() throws Throwable {
  92         Method method1 = java.util.CSM.class.getMethod(CSM_CALLER_METHOD);
  93         method1.invoke(null);
  94 
  95         Method method2 = java.util.CSM.class.getMethod(NON_CSM_CALLER_METHOD);
  96         Result result = (Result) method2.invoke(null);
  97         checkNonCSMCaller(CallerSensitiveTest.class, result);
  98     }
  99 
 100     void invokeMethodHandle(Lookup lookup) throws Throwable {
 101         MethodHandle mh1 = lookup.findStatic(java.util.CSM.class, CSM_CALLER_METHOD,
 102             MethodType.methodType(Class.class));
 103         Class<?> c = (Class<?>)mh1.invokeExact();
 104 
 105         MethodHandle mh2 = lookup.findStatic(java.util.CSM.class, NON_CSM_CALLER_METHOD,
 106             MethodType.methodType(Result.class));
 107         Result result = (Result)mh2.invokeExact();
 108         checkNonCSMCaller(CallerSensitiveTest.class, result);
 109     }
 110 
 111     void lambda() {
 112         Result result = LambdaTest.getCallerClass.get();
 113         checkNonCSMCaller(CallerSensitiveTest.class, result);
 114 
 115         LambdaTest.caller.get();
 116     }
 117 
 118     static int failed = 0;
 119 
 120     static void checkNonCSMCaller(Class<?> expected, Result result) {
 121         if (result.callers.size() != 1) {
 122             throw new RuntimeException("Expected result.callers contain one element");
 123         }
 124         if (expected != result.callers.get(0)) {
 125             System.err.format("ERROR: Expected %s but got %s%n", expected,
 126                 result.callers);
 127             result.frames.stream()
 128                 .forEach(f -> System.err.println("   " + f));
 129             failed++;
 130         }
 131     }
 132 
 133     static class Lookup1 {
 134         static Lookup lookup = MethodHandles.lookup();
 135     }
 136 
 137     static class Lookup2 {
 138         static Lookup lookup = MethodHandles.lookup();
 139     }
 140 
 141     static class LambdaTest {
 142         static Supplier<Class<?>> caller = java.util.CSM::caller;
 143         static Supplier<Result> getCallerClass = java.util.CSM::getCallerClass;
 144 
 145         static void caller() {
 146             caller.get();
 147         }
 148         static Result getCallerClass() {
 149             return getCallerClass.get();
 150         }
 151 
 152         static void lambda() {
 153             Result result = LambdaTest.getCallerClass();
 154             checkNonCSMCaller(LambdaTest.class, result);
 155 
 156             LambdaTest.caller();
 157         }
 158     }
 159 }