jdk/src/share/classes/javax/management/modelmbean/RequiredModelMBean.java

Print this page
rev 5696 : 8000537: Contextualize RequiredModelMBean class
Reviewed-by: ahgross, dsamersoff, skoivu
Contributed-by: Jaroslav Bachorik <jaroslav.bachorik@oracle.com>

@@ -1,7 +1,7 @@
 /*
- * Copyright (c) 2000, 2008, Oracle and/or its affiliates. All rights reserved.
+ * Copyright (c) 2000, 2012, Oracle and/or its affiliates. All rights reserved.
  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
  *
  * This code is free software; you can redistribute it and/or modify it
  * under the terms of the GNU General Public License version 2 only, as
  * published by the Free Software Foundation.  Oracle designates this

@@ -37,15 +37,17 @@
 import java.io.FileOutputStream;
 import java.io.PrintStream;
 import java.lang.reflect.InvocationTargetException;
 
 import java.lang.reflect.Method;
+import java.security.AccessControlContext;
+import java.security.AccessController;
+import java.security.PrivilegedAction;
 
 import java.util.Date;
 import java.util.HashMap;
 import java.util.HashSet;
-import java.util.Iterator;
 import java.util.logging.Level;
 import java.util.Map;
 import java.util.Set;
 
 import java.util.Vector;

@@ -76,10 +78,12 @@
 import javax.management.ReflectionException;
 import javax.management.RuntimeErrorException;
 import javax.management.RuntimeOperationsException;
 import javax.management.ServiceNotFoundException;
 import javax.management.loading.ClassLoaderRepository;
+import sun.misc.JavaSecurityAccess;
+import sun.misc.SharedSecrets;
 
 import sun.reflect.misc.MethodUtil;
 import sun.reflect.misc.ReflectUtil;
 
 /**

@@ -136,10 +140,13 @@
 
     /* records the registering in MBeanServer */
     private boolean registered = false;
     private transient MBeanServer server = null;
 
+    private final static JavaSecurityAccess javaSecurityAccess = SharedSecrets.getJavaSecurityAccess();
+    final private AccessControlContext acc = AccessController.getContext();
+
     /*************************************/
     /* constructors                      */
     /*************************************/
 
     /**

@@ -1023,15 +1030,36 @@
 
             final Class<?> targetClass;
 
             if (opClassName != null) {
                 try {
+                    AccessControlContext stack = AccessController.getContext();
+                    final Object obj = targetObject;
+                    final String className = opClassName;
+                    final ClassNotFoundException[] caughtException = new ClassNotFoundException[1];
+
+                    targetClass = javaSecurityAccess.doIntersectionPrivilege(new PrivilegedAction<Class<?>>() {
+
+                        @Override
+                        public Class<?> run() {
+                            try {
+                                ReflectUtil.checkPackageAccess(className);
                     final ClassLoader targetClassLoader =
-                        targetObject.getClass().getClassLoader();
-                    targetClass = Class.forName(opClassName, false,
+                                    obj.getClass().getClassLoader();
+                                return Class.forName(className, false,
                                                 targetClassLoader);
                 } catch (ClassNotFoundException e) {
+                                caughtException[0] = e;
+                            }
+                            return null;
+                        }
+                    }, stack, acc);
+
+                    if (caughtException[0] != null) {
+                        throw caughtException[0];
+                    }
+                } catch (ClassNotFoundException e) {
                     final String msg =
                         "class for invoke " + opName + " not found";
                     throw new ReflectionException(e, msg);
                 }
             } else

@@ -1059,13 +1087,13 @@
             cacheResult(opInfo, opDescr, result);
 
         return result;
     }
 
-    private static Method resolveMethod(Class<?> targetClass,
+    private Method resolveMethod(Class<?> targetClass,
                                         String opMethodName,
-                                        String[] sig)
+                                        final String[] sig)
             throws ReflectionException {
         final boolean tracing = MODELMBEAN_LOGGER.isLoggable(Level.FINER);
 
         if (tracing) {
             MODELMBEAN_LOGGER.logp(Level.FINER,

@@ -1076,34 +1104,49 @@
         final Class<?>[] argClasses;
 
         if (sig == null)
             argClasses = null;
         else {
+            final AccessControlContext stack = AccessController.getContext();
+            final ReflectionException[] caughtException = new ReflectionException[1];
             final ClassLoader targetClassLoader = targetClass.getClassLoader();
             argClasses = new Class<?>[sig.length];
+
+            javaSecurityAccess.doIntersectionPrivilege(new PrivilegedAction<Void>() {
+
+                @Override
+                public Void run() {
             for (int i = 0; i < sig.length; i++) {
                 if (tracing) {
                     MODELMBEAN_LOGGER.logp(Level.FINER,
                         RequiredModelMBean.class.getName(),"resolveMethod",
                             "resolve type " + sig[i]);
                 }
                 argClasses[i] = (Class<?>) primitiveClassMap.get(sig[i]);
                 if (argClasses[i] == null) {
                     try {
+                                ReflectUtil.checkPackageAccess(sig[i]);
                         argClasses[i] =
                             Class.forName(sig[i], false, targetClassLoader);
                     } catch (ClassNotFoundException e) {
                         if (tracing) {
                             MODELMBEAN_LOGGER.logp(Level.FINER,
                                     RequiredModelMBean.class.getName(),
                                     "resolveMethod",
                                     "class not found");
                         }
                         final String msg = "Parameter class not found";
-                        throw new ReflectionException(e, msg);
+                                caughtException[0] = new ReflectionException(e, msg);
+                            }
+                        }
                     }
+                    return null;
                 }
+            }, stack, acc);
+
+            if (caughtException[0] != null) {
+                throw caughtException[0];
             }
         }
 
         try {
             return targetClass.getMethod(opMethodName, argClasses);

@@ -1131,11 +1174,11 @@
     }
 
     /* Find a method in RequiredModelMBean as determined by the given
        parameters.  Return null if there is none, or if the parameters
        exclude using it.  Called from invoke. */
-    private static Method findRMMBMethod(String opMethodName,
+    private Method findRMMBMethod(String opMethodName,
                                          Object targetObjectField,
                                          String opClassName,
                                          String[] sig) {
         final boolean tracing = MODELMBEAN_LOGGER.isLoggable(Level.FINER);
 

@@ -1153,38 +1196,71 @@
         final Class<RequiredModelMBean> rmmbClass = RequiredModelMBean.class;
         final Class<?> targetClass;
         if (opClassName == null)
             targetClass = rmmbClass;
         else {
+            AccessControlContext stack = AccessController.getContext();
+            final String className = opClassName;
+            targetClass = javaSecurityAccess.doIntersectionPrivilege(new PrivilegedAction<Class<?>>() {
+
+                @Override
+                public Class<?> run() {
             try {
+                        ReflectUtil.checkPackageAccess(className);
                 final ClassLoader targetClassLoader =
                     rmmbClass.getClassLoader();
-                targetClass = Class.forName(opClassName, false,
+                        Class clz = Class.forName(className, false,
                                             targetClassLoader);
-                if (!rmmbClass.isAssignableFrom(targetClass))
+                        if (!rmmbClass.isAssignableFrom(clz))
                     return null;
+                        return clz;
             } catch (ClassNotFoundException e) {
                 return null;
             }
         }
+            }, stack, acc);
+        }
         try {
-            return resolveMethod(targetClass, opMethodName, sig);
+            return targetClass != null ? resolveMethod(targetClass, opMethodName, sig) : null;
         } catch (ReflectionException e) {
             return null;
         }
     }
 
     /*
      * Invoke the given method, and throw the somewhat unpredictable
      * appropriate exception if the method itself gets an exception.
      */
-    private Object invokeMethod(String opName, Method method,
-                                Object targetObject, Object[] opArgs)
+    private Object invokeMethod(String opName, final Method method,
+                                final Object targetObject, final Object[] opArgs)
             throws MBeanException, ReflectionException {
         try {
+            final Throwable[] caughtException = new Throwable[1];
+            AccessControlContext stack = AccessController.getContext();
+            Object rslt = javaSecurityAccess.doIntersectionPrivilege(new PrivilegedAction<Object>() {
+
+                @Override
+                public Object run() {
+                    try {
             ReflectUtil.checkPackageAccess(method.getDeclaringClass());
             return MethodUtil.invoke(method, targetObject, opArgs);
+                    } catch (InvocationTargetException e) {
+                        caughtException[0] = e;
+                    } catch (IllegalAccessException e) {
+                        caughtException[0] = e;
+                    }
+                    return null;
+                }
+            }, stack, acc);
+            if (caughtException[0] != null) {
+                if (caughtException[0] instanceof Exception) {
+                    throw (Exception)caughtException[0];
+                } else if(caughtException[0] instanceof Error) {
+                    throw (Error)caughtException[0];
+                }
+            }
+            return rslt;
         } catch (RuntimeErrorException ree) {
             throw new RuntimeOperationsException(ree,
                       "RuntimeException occurred in RequiredModelMBean "+
                       "while trying to invoke operation " + opName);
         } catch (RuntimeException re) {

@@ -1565,11 +1641,11 @@
                         // !! cast response to right class
                     }
                 }
 
                 // make sure response class matches type field
-                String respType = attrInfo.getType();
+                final String respType = attrInfo.getType();
                 if (response != null) {
                     String responseClass = response.getClass().getName();
                     if (!respType.equals(responseClass)) {
                         boolean wrongType = false;
                         boolean primitiveType = false;

@@ -1588,13 +1664,35 @@
                                 wrongType = true;
                         } else {
                             // inequality may come from type subclassing
                             boolean subtype;
                             try {
+                                final Class respClass = response.getClass();
+                                final Exception[] caughException = new Exception[1];
+
+                                AccessControlContext stack = AccessController.getContext();
+
+                                Class c = javaSecurityAccess.doIntersectionPrivilege(new PrivilegedAction<Class<?>>() {
+
+                                    @Override
+                                    public Class<?> run() {
+                                        try {
+                                            ReflectUtil.checkPackageAccess(respType);
                                 ClassLoader cl =
-                                    response.getClass().getClassLoader();
-                                Class<?> c = Class.forName(respType, true, cl);
+                                                respClass.getClassLoader();
+                                            return Class.forName(respType, true, cl);
+                                        } catch (Exception e) {
+                                            caughException[0] = e;
+                                        }
+                                        return null;
+                                    }
+                                }, stack, acc);
+
+                                if (caughException[0] != null) {
+                                    throw caughException[0];
+                                }
+
                                 subtype = c.isInstance(response);
                             } catch (Exception e) {
                                 subtype = false;
 
                                 if (tracing) {

@@ -2743,20 +2841,41 @@
      */
     protected ClassLoaderRepository getClassLoaderRepository() {
         return MBeanServerFactory.getClassLoaderRepository(server);
     }
 
-    private Class<?> loadClass(String className)
+    private Class<?> loadClass(final String className)
         throws ClassNotFoundException {
+        AccessControlContext stack = AccessController.getContext();
+        final ClassNotFoundException[] caughtException = new ClassNotFoundException[1];
+
+        Class c = javaSecurityAccess.doIntersectionPrivilege(new PrivilegedAction<Class<?>>() {
+
+            @Override
+            public Class<?> run() {
         try {
+                    ReflectUtil.checkPackageAccess(className);
             return Class.forName(className);
         } catch (ClassNotFoundException e) {
             final ClassLoaderRepository clr =
                 getClassLoaderRepository();
+                    try {
             if (clr == null) throw new ClassNotFoundException(className);
             return clr.loadClass(className);
+                    } catch (ClassNotFoundException ex) {
+                        caughtException[0] = ex;
+                    }
         }
+                return null;
+            }
+        }, stack, acc);
+
+        if (caughtException[0] != null) {
+            throw caughtException[0];
+        }
+
+        return c;
     }
 
 
     /*************************************/
     /* MBeanRegistration Interface       */