8206240: java.lang.Class.newInstance() is causing caller to leak
authormchung
Thu, 04 Oct 2018 13:02:58 -0700
changeset 52020 4c247dde38ed
parent 52019 d63efc278e93
child 52021 7b90af8664ca
8206240: java.lang.Class.newInstance() is causing caller to leak Reviewed-by: alanb
src/java.base/share/classes/java/lang/Class.java
src/java.base/share/classes/java/lang/reflect/AccessibleObject.java
src/java.base/share/classes/java/lang/reflect/Constructor.java
src/java.base/share/classes/java/lang/reflect/ReflectAccess.java
src/java.base/share/classes/jdk/internal/reflect/LangReflectAccess.java
src/java.base/share/classes/jdk/internal/reflect/ReflectionFactory.java
test/jdk/java/lang/StackWalker/ReflectionFrames.java
test/jdk/java/lang/reflect/callerCache/AccessTest.java
test/jdk/java/lang/reflect/callerCache/ReflectionCallerCacheTest.java
test/jdk/jdk/modules/open/Basic.java
--- a/src/java.base/share/classes/java/lang/Class.java	Thu Oct 04 10:19:01 2018 -0700
+++ b/src/java.base/share/classes/java/lang/Class.java	Thu Oct 04 13:02:58 2018 -0700
@@ -64,7 +64,6 @@
 import jdk.internal.loader.BootLoader;
 import jdk.internal.loader.BuiltinClassLoader;
 import jdk.internal.misc.Unsafe;
-import jdk.internal.misc.VM;
 import jdk.internal.module.Resources;
 import jdk.internal.reflect.CallerSensitive;
 import jdk.internal.reflect.ConstantPool;
@@ -540,11 +539,9 @@
             checkMemberAccess(sm, Member.PUBLIC, Reflection.getCallerClass(), false);
         }
 
-        // NOTE: the following code may not be strictly correct under
-        // the current Java memory model.
-
         // Constructor lookup
-        if (cachedConstructor == null) {
+        Constructor<T> tmpConstructor = cachedConstructor;
+        if (tmpConstructor == null) {
             if (this == Class.class) {
                 throw new IllegalAccessException(
                     "Can not call newInstance() on the Class for java.lang.Class"
@@ -555,9 +552,7 @@
                 final Constructor<T> c = getReflectionFactory().copyConstructor(
                     getConstructor0(empty, Member.DECLARED));
                 // Disable accessibility checks on the constructor
-                // since we have to do the security check here anyway
-                // (the stack depth is wrong for the Constructor's
-                // security check to work)
+                // access check is done with the true caller
                 java.security.AccessController.doPrivileged(
                     new java.security.PrivilegedAction<>() {
                         public Void run() {
@@ -565,32 +560,24 @@
                                 return null;
                             }
                         });
-                cachedConstructor = c;
+                cachedConstructor = tmpConstructor = c;
             } catch (NoSuchMethodException e) {
                 throw (InstantiationException)
                     new InstantiationException(getName()).initCause(e);
             }
         }
-        Constructor<T> tmpConstructor = cachedConstructor;
-        // Security check (same as in java.lang.reflect.Constructor)
-        Class<?> caller = Reflection.getCallerClass();
-        if (newInstanceCallerCache != caller) {
-            int modifiers = tmpConstructor.getModifiers();
-            Reflection.ensureMemberAccess(caller, this, this, modifiers);
-            newInstanceCallerCache = caller;
-        }
-        // Run constructor
+
         try {
-            return tmpConstructor.newInstance((Object[])null);
+            Class<?> caller = Reflection.getCallerClass();
+            return getReflectionFactory().newInstance(tmpConstructor, null, caller);
         } catch (InvocationTargetException e) {
             Unsafe.getUnsafe().throwException(e.getTargetException());
             // Not reached
             return null;
         }
     }
+
     private transient volatile Constructor<T> cachedConstructor;
-    private transient volatile Class<?>       newInstanceCallerCache;
-
 
     /**
      * Determines if the specified {@code Object} is assignment-compatible
--- a/src/java.base/share/classes/java/lang/reflect/AccessibleObject.java	Thu Oct 04 10:19:01 2018 -0700
+++ b/src/java.base/share/classes/java/lang/reflect/AccessibleObject.java	Thu Oct 04 13:02:58 2018 -0700
@@ -27,6 +27,7 @@
 
 import java.lang.annotation.Annotation;
 import java.lang.invoke.MethodHandle;
+import java.lang.ref.WeakReference;
 import java.security.AccessController;
 
 import jdk.internal.misc.VM;
@@ -567,21 +568,68 @@
     // Shared access checking logic.
 
     // For non-public members or members in package-private classes,
-    // it is necessary to perform somewhat expensive security checks.
-    // If the security check succeeds for a given class, it will
+    // it is necessary to perform somewhat expensive access checks.
+    // If the access check succeeds for a given class, it will
     // always succeed (it is not affected by the granting or revoking
     // of permissions); we speed up the check in the common case by
     // remembering the last Class for which the check succeeded.
     //
-    // The simple security check for Constructor is to see if
+    // The simple access check for Constructor is to see if
     // the caller has already been seen, verified, and cached.
-    // (See also Class.newInstance(), which uses a similar method.)
     //
-    // A more complicated security check cache is needed for Method and Field
-    // The cache can be either null (empty cache), a 2-array of {caller,targetClass},
+    // A more complicated access check cache is needed for Method and Field
+    // The cache can be either null (empty cache), {caller,targetClass} pair,
     // or a caller (with targetClass implicitly equal to memberClass).
-    // In the 2-array case, the targetClass is always different from the memberClass.
-    volatile Object securityCheckCache;
+    // In the {caller,targetClass} case, the targetClass is always different
+    // from the memberClass.
+    volatile Object accessCheckCache;
+
+    private static class Cache {
+        final WeakReference<Class<?>> callerRef;
+        final WeakReference<Class<?>> targetRef;
+
+        Cache(Class<?> caller, Class<?> target) {
+            this.callerRef = new WeakReference<>(caller);
+            this.targetRef = new WeakReference<>(target);
+        }
+
+        boolean isCacheFor(Class<?> caller, Class<?> refc) {
+            return callerRef.get() == caller && targetRef.get() == refc;
+        }
+
+        static Object protectedMemberCallerCache(Class<?> caller, Class<?> refc) {
+            return new Cache(caller, refc);
+        }
+    }
+
+    /*
+     * Returns true if the previous access check was verified for the
+     * given caller accessing a protected member with an instance of
+     * the given targetClass where the target class is different than
+     * the declaring member class.
+     */
+    private boolean isAccessChecked(Class<?> caller, Class<?> targetClass) {
+        Object cache = accessCheckCache;  // read volatile
+        if (cache instanceof Cache) {
+            return ((Cache) cache).isCacheFor(caller, targetClass);
+        }
+        return false;
+    }
+
+    /*
+     * Returns true if the previous access check was verified for the
+     * given caller accessing a static member or an instance member of
+     * the target class that is the same as the declaring member class.
+     */
+    private boolean isAccessChecked(Class<?> caller) {
+        Object cache = accessCheckCache;  // read volatile
+        if (cache instanceof WeakReference) {
+            @SuppressWarnings("unchecked")
+            WeakReference<Class<?>> ref = (WeakReference<Class<?>>) cache;
+            return ref.get() == caller;
+        }
+        return false;
+    }
 
     final void checkAccess(Class<?> caller, Class<?> memberClass,
                            Class<?> targetClass, int modifiers)
@@ -603,21 +651,13 @@
         if (caller == memberClass) {  // quick check
             return true;             // ACCESS IS OK
         }
-        Object cache = securityCheckCache;  // read volatile
         if (targetClass != null // instance member or constructor
             && Modifier.isProtected(modifiers)
             && targetClass != memberClass) {
-            // Must match a 2-list of { caller, targetClass }.
-            if (cache instanceof Class[]) {
-                Class<?>[] cache2 = (Class<?>[]) cache;
-                if (cache2[1] == targetClass &&
-                    cache2[0] == caller) {
-                    return true;     // ACCESS IS OK
-                }
-                // (Test cache[1] first since range check for [1]
-                // subsumes range check for [0].)
+            if (isAccessChecked(caller, targetClass)) {
+                return true;         // ACCESS IS OK
             }
-        } else if (cache == caller) {
+        } else if (isAccessChecked(caller)) {
             // Non-protected case (or targetClass == memberClass or static member).
             return true;             // ACCESS IS OK
         }
@@ -642,14 +682,9 @@
         Object cache = (targetClass != null
                         && Modifier.isProtected(modifiers)
                         && targetClass != memberClass)
-                        ? new Class<?>[] { caller, targetClass }
-                        : caller;
-
-        // Note:  The two cache elements are not volatile,
-        // but they are effectively final.  The Java memory model
-        // guarantees that the initializing stores for the cache
-        // elements will occur before the volatile write.
-        securityCheckCache = cache;         // write volatile
+                        ? Cache.protectedMemberCallerCache(caller, targetClass)
+                        : new WeakReference<>(caller);
+        accessCheckCache = cache;         // write volatile
         return true;
     }
 
--- a/src/java.base/share/classes/java/lang/reflect/Constructor.java	Thu Oct 04 10:19:01 2018 -0700
+++ b/src/java.base/share/classes/java/lang/reflect/Constructor.java	Thu Oct 04 13:02:58 2018 -0700
@@ -476,18 +476,27 @@
         throws InstantiationException, IllegalAccessException,
                IllegalArgumentException, InvocationTargetException
     {
-        if (!override) {
-            Class<?> caller = Reflection.getCallerClass();
+        Class<?> caller = override ? null : Reflection.getCallerClass();
+        return newInstanceWithCaller(initargs, !override, caller);
+    }
+
+    /* package-private */
+    T newInstanceWithCaller(Object[] args, boolean checkAccess, Class<?> caller)
+        throws InstantiationException, IllegalAccessException,
+               InvocationTargetException
+    {
+        if (checkAccess)
             checkAccess(caller, clazz, clazz, modifiers);
-        }
+
         if ((clazz.getModifiers() & Modifier.ENUM) != 0)
             throw new IllegalArgumentException("Cannot reflectively create enum objects");
+
         ConstructorAccessor ca = constructorAccessor;   // read volatile
         if (ca == null) {
             ca = acquireConstructorAccessor();
         }
         @SuppressWarnings("unchecked")
-        T inst = (T) ca.newInstance(initargs);
+        T inst = (T) ca.newInstance(args);
         return inst;
     }
 
--- a/src/java.base/share/classes/java/lang/reflect/ReflectAccess.java	Thu Oct 04 10:19:01 2018 -0700
+++ b/src/java.base/share/classes/java/lang/reflect/ReflectAccess.java	Thu Oct 04 13:02:58 2018 -0700
@@ -159,4 +159,10 @@
     public <T extends AccessibleObject> T getRoot(T obj) {
         return (T) obj.getRoot();
     }
+
+    public <T> T newInstance(Constructor<T> ctor, Object[] args, Class<?> caller)
+        throws IllegalAccessException, InstantiationException, InvocationTargetException
+    {
+        return ctor.newInstanceWithCaller(args, true, caller);
+    }
 }
--- a/src/java.base/share/classes/jdk/internal/reflect/LangReflectAccess.java	Thu Oct 04 10:19:01 2018 -0700
+++ b/src/java.base/share/classes/jdk/internal/reflect/LangReflectAccess.java	Thu Oct 04 13:02:58 2018 -0700
@@ -118,4 +118,8 @@
 
     /** Gets the root of the given AccessibleObject object; null if arg is the root */
     public <T extends AccessibleObject> T getRoot(T obj);
+
+    /** Returns a new instance created by the given constructor with access check */
+    public <T> T newInstance(Constructor<T> ctor, Object[] args, Class<?> caller)
+        throws IllegalAccessException, InstantiationException, InvocationTargetException;
 }
--- a/src/java.base/share/classes/jdk/internal/reflect/ReflectionFactory.java	Thu Oct 04 10:19:01 2018 -0700
+++ b/src/java.base/share/classes/jdk/internal/reflect/ReflectionFactory.java	Thu Oct 04 13:02:58 2018 -0700
@@ -398,6 +398,12 @@
         return langReflectAccess().getExecutableSharedParameterTypes(ex);
     }
 
+    public <T> T newInstance(Constructor<T> ctor, Object[] args, Class<?> caller)
+        throws IllegalAccessException, InstantiationException, InvocationTargetException
+    {
+        return langReflectAccess().newInstance(ctor, args, caller);
+    }
+
     //--------------------------------------------------------------------------
     //
     // Routines used by serialization
--- a/test/jdk/java/lang/StackWalker/ReflectionFrames.java	Thu Oct 04 10:19:01 2018 -0700
+++ b/test/jdk/java/lang/StackWalker/ReflectionFrames.java	Thu Oct 04 13:02:58 2018 -0700
@@ -48,6 +48,16 @@
 
 public class ReflectionFrames {
     final static boolean verbose = false;
+    final static Class<?> REFLECT_ACCESS = findClass("java.lang.reflect.ReflectAccess");
+    final static Class<?> REFLECTION_FACTORY = findClass("jdk.internal.reflect.ReflectionFactory");
+
+    private static Class<?> findClass(String cn) {
+        try {
+            return Class.forName(cn);
+        } catch (ClassNotFoundException e) {
+            throw new AssertionError(e);
+        }
+    }
 
     /**
      * This test invokes new StackInspector() directly from
@@ -327,6 +337,8 @@
                      List.of(StackInspector.class.getName()
                                  +"::<init>",
                              Constructor.class.getName()
+                                 +"::newInstanceWithCaller",
+                             Constructor.class.getName()
                                  +"::newInstance",
                              StackInspector.Caller.class.getName()
                                  +"::create",
@@ -355,6 +367,8 @@
                      List.of(StackInspector.class.getName()
                                  +"::<init>",
                              Constructor.class.getName()
+                                 +"::newInstanceWithCaller",
+                             Constructor.class.getName()
                                  +"::newInstance",
                              StackInspector.Caller.class.getName()
                                  +"::create",
@@ -387,6 +401,8 @@
                      List.of(StackInspector.class.getName()
                                  +"::<init>",
                              Constructor.class.getName()
+                                 +"::newInstanceWithCaller",
+                             Constructor.class.getName()
                                  +"::newInstance",
                              StackInspector.Caller.class.getName()
                                  +"::create",
@@ -436,15 +452,19 @@
         assertEquals(obj.collectedFrames,
                      List.of(StackInspector.class.getName()
                                  +"::<init>",
+                             REFLECT_ACCESS.getName()
+                                 +"::newInstance",
+                             REFLECTION_FACTORY.getName()
+                                 +"::newInstance",
                              Class.class.getName()
                                  +"::newInstance",
                              StackInspector.Caller.class.getName()
                                  +"::create",
                              ReflectionFrames.class.getName()
                                  +"::testNewInstance"));
-        // Because Class.newInstance is not filtered, then the
-        // caller is Class.class
-        assertEquals(obj.cls, Class.class);
+        // Because implementation frames are not filtered, then the
+        // caller is ReflectAccess.class
+        assertEquals(obj.cls, REFLECT_ACCESS);
         assertEquals(obj.filtered, 0);
 
         // Calls the StackInspector.reflect method through reflection
@@ -464,6 +484,10 @@
         assertEquals(obj.collectedFrames,
                      List.of(StackInspector.class.getName()
                                  +"::<init>",
+                             REFLECT_ACCESS.getName()
+                                 +"::newInstance",
+                             REFLECTION_FACTORY.getName()
+                                 +"::newInstance",
                              Class.class.getName()
                                  +"::newInstance",
                              StackInspector.Caller.class.getName()
@@ -473,9 +497,9 @@
                              ReflectionFrames.class.getName()
                                  +"::testNewInstance"));
 
-        // Because Class.newInstance is not filtered, then the
-        // caller is Class.class
-        assertEquals(obj.cls, Class.class);
+        // Because implementation frames are not filtered, then the
+        // caller is ReflectAccess.class
+        assertEquals(obj.cls, REFLECT_ACCESS);
         assertEquals(obj.filtered, 0);
 
         // Calls the StackInspector.handle method through reflection
@@ -495,6 +519,10 @@
         assertEquals(obj.collectedFrames,
                      List.of(StackInspector.class.getName()
                                  +"::<init>",
+                             REFLECT_ACCESS.getName()
+                                 +"::newInstance",
+                             REFLECTION_FACTORY.getName()
+                                 +"::newInstance",
                              Class.class.getName()
                                  +"::newInstance",
                              StackInspector.Caller.class.getName()
@@ -504,9 +532,9 @@
                              ReflectionFrames.class.getName()
                                  +"::testNewInstance"));
 
-        // Because Class.newInstance is not filtered, then the
-        // caller is Class.class
-        assertEquals(obj.cls, Class.class);
+        // Because implementation frames are not filtered, then the
+        // caller is ReflectAccess.class
+        assertEquals(obj.cls, REFLECT_ACCESS);
         assertEquals(obj.filtered, 0);
 
         // Sets a non-default walker configured to show
@@ -529,6 +557,8 @@
                      List.of(StackInspector.class.getName()
                                  +"::<init>",
                              Constructor.class.getName()
+                                 +"::newInstanceWithCaller",
+                             REFLECT_ACCESS.getName()
                                  +"::newInstance",
                              Class.class.getName()
                                  +"::newInstance",
@@ -538,9 +568,9 @@
                                  +"::invoke",
                              ReflectionFrames.class.getName()
                                  +"::testNewInstance"));
-        // Because Class.newInstance is not filtered, then the
-        // caller is Class.class
-        assertEquals(obj.cls, Class.class);
+        // Because implementation frames are not filtered, then the
+        // caller is ReflectAccess.class
+        assertEquals(obj.cls, REFLECT_ACCESS);
         assertNotEquals(obj.filtered, 0);
 
         // Calls the StackInspector.reflect method through reflection
@@ -557,10 +587,13 @@
         obj = ((StackInspector)StackInspector.Caller.class
                              .getMethod("reflect", How.class)
                              .invoke(null, How.CLASS));
+        System.out.println(obj.collectedFrames);
         assertEquals(obj.collectedFrames,
                      List.of(StackInspector.class.getName()
                                  +"::<init>",
                              Constructor.class.getName()
+                                 +"::newInstanceWithCaller",
+                             REFLECT_ACCESS.getName()
                                  +"::newInstance",
                              Class.class.getName()
                                  +"::newInstance",
@@ -575,9 +608,9 @@
                              ReflectionFrames.class.getName()
                                  +"::testNewInstance"));
 
-        // Because Class.newInstance is not filtered, then the
-        // caller is Class.class
-        assertEquals(obj.cls, Class.class);
+        // Because implementation frames are not filtered, then the
+        // caller is ReflectAccess.class
+        assertEquals(obj.cls, REFLECT_ACCESS);
         assertNotEquals(obj.filtered, 0);
 
         // Calls the StackInspector.handle method through reflection
@@ -598,6 +631,8 @@
                      List.of(StackInspector.class.getName()
                                  +"::<init>",
                              Constructor.class.getName()
+                                 +"::newInstanceWithCaller",
+                             REFLECT_ACCESS.getName()
                                  +"::newInstance",
                              Class.class.getName()
                                  +"::newInstance",
@@ -611,9 +646,9 @@
                              ReflectionFrames.class.getName()
                                  +"::testNewInstance"));
 
-        // Because Class.newInstance is not filtered, then the
-        // caller is Class.class
-        assertEquals(obj.cls, Class.class);
+        // Because implementation frames are not filtered, then the
+        // caller is ReflectAccess.class
+        assertEquals(obj.cls, REFLECT_ACCESS);
         assertNotEquals(obj.filtered, 0);
     }
 
--- a/test/jdk/java/lang/reflect/callerCache/AccessTest.java	Thu Oct 04 10:19:01 2018 -0700
+++ b/test/jdk/java/lang/reflect/callerCache/AccessTest.java	Thu Oct 04 13:02:58 2018 -0700
@@ -149,4 +149,11 @@
             super("privateStaticFinalField");
         }
     }
+
+    public static class NewInstance implements Callable<Object> {
+        public Object call() throws Exception {
+            return Members.class.newInstance();
+        }
+    }
+
 }
--- a/test/jdk/java/lang/reflect/callerCache/ReflectionCallerCacheTest.java	Thu Oct 04 10:19:01 2018 -0700
+++ b/test/jdk/java/lang/reflect/callerCache/ReflectionCallerCacheTest.java	Thu Oct 04 13:02:58 2018 -0700
@@ -77,7 +77,8 @@
             { "AccessTest$PublicFinalField"},
             { "AccessTest$PrivateFinalField"},
             { "AccessTest$PublicStaticFinalField"},
-            { "AccessTest$PrivateStaticFinalField"}
+            { "AccessTest$PrivateStaticFinalField"},
+            { "AccessTest$NewInstance"}
         };
     }
 
--- a/test/jdk/jdk/modules/open/Basic.java	Thu Oct 04 10:19:01 2018 -0700
+++ b/test/jdk/jdk/modules/open/Basic.java	Thu Oct 04 13:02:58 2018 -0700
@@ -98,6 +98,9 @@
         ctor.setAccessible(true);
         ctor.newInstance();
 
+        // Class::newInstance
+        clazz.newInstance();
+
         // method handles
         findNoArgConstructorAndInvoke(clazz, MethodHandles.publicLookup());
         findNoArgConstructorAndInvoke(clazz, MethodHandles.lookup());
@@ -122,6 +125,12 @@
         ctor.setAccessible(true);
         ctor.newInstance();
 
+        // Class::newInstance
+        try {
+            clazz.newInstance();
+            assertTrue(false);
+        } catch (IllegalAccessException expected) { }
+
         // method handles
         try {
             findNoArgConstructorAndInvoke(clazz, MethodHandles.publicLookup());
@@ -150,6 +159,9 @@
         ctor.setAccessible(true);
         ctor.newInstance();
 
+        // Class::newInstance
+        clazz.newInstance();
+
         // method handles
         findNoArgConstructorAndInvoke(clazz, MethodHandles.publicLookup());
         findNoArgConstructorAndInvoke(clazz, MethodHandles.lookup());
@@ -174,6 +186,12 @@
         ctor.setAccessible(true);
         ctor.newInstance();
 
+        // Class::newInstance
+        try {
+            clazz.newInstance();
+            assertTrue(false);
+        } catch (IllegalAccessException expected) { }
+
         // method handles
         try {
             findNoArgConstructorAndInvoke(clazz, MethodHandles.publicLookup());
@@ -200,6 +218,7 @@
         // core reflection
         Class<?> clazz = q.PublicType.class;
         clazz.getConstructor().newInstance();
+        clazz.newInstance();
 
         // method handles
         findNoArgConstructorAndInvoke(clazz, MethodHandles.publicLookup());
@@ -226,6 +245,12 @@
         ctor.setAccessible(true);
         ctor.newInstance();
 
+        // Class::newInstance
+        try {
+            clazz.newInstance();
+            assertTrue(false);
+        } catch (IllegalAccessException expected) { }
+
         // method handles
         try {
             findNoArgConstructorAndInvoke(clazz, MethodHandles.publicLookup());
@@ -256,6 +281,12 @@
             assertTrue(false);
         } catch (InaccessibleObjectException expected) { }
 
+        // Class::newInstance
+        try {
+            clazz.newInstance();
+            assertTrue(false);
+        } catch (IllegalAccessException expected) { }
+
         // method handles
         try {
             findNoArgConstructorAndInvoke(clazz, MethodHandles.publicLookup());
@@ -288,6 +319,12 @@
             assertTrue(false);
         } catch (InaccessibleObjectException expected) { }
 
+        // Class::newInstance
+        try {
+            clazz.newInstance();
+            assertTrue(false);
+        } catch (IllegalAccessException expected) { }
+
         // method handles
         try {
             findNoArgConstructorAndInvoke(clazz, MethodHandles.publicLookup());