8009411: (reflect) Class.getMethods should not include static methods from interfaces
authorjfranck
Tue, 22 Oct 2013 10:34:12 +0200
changeset 21314 1a616b8bdb31
parent 21313 a07fdfb76d28
child 21315 6afab0f25c10
8009411: (reflect) Class.getMethods should not include static methods from interfaces Summary: Update getMethods() and getMethod() to filter out interface statics Reviewed-by: darcy Contributed-by: joel.franck@oracle.com, andreas.lundblad@oracle.com, amy.lu@oracle.com, peter.levart@gmail.com
jdk/src/share/classes/java/lang/Class.java
jdk/test/java/lang/reflect/DefaultStaticTest/DefaultStaticInvokeTest.java
jdk/test/java/lang/reflect/DefaultStaticTest/DefaultStaticTestData.java
jdk/test/java/lang/reflect/Method/InterfaceStatic/StaticInterfaceMethodInWayOfDefault.java
--- a/jdk/src/share/classes/java/lang/Class.java	Tue Oct 22 03:49:50 2013 +0000
+++ b/jdk/src/share/classes/java/lang/Class.java	Tue Oct 22 10:34:12 2013 +0200
@@ -1571,6 +1571,10 @@
      * <p> If this {@code Class} object represents a primitive type or void,
      * then the returned array has length 0.
      *
+     * <p> Static methods declared in superinterfaces of the class or interface
+     * represented by this {@code Class} object are not considered members of
+     * the class or interface.
+     *
      * <p> The elements in the returned array are not sorted and are not in any
      * particular order.
      *
@@ -1729,6 +1733,10 @@
      * <p> If this {@code Class} object represents an array type, then this
      * method does not find the {@code clone()} method.
      *
+     * <p> Static methods declared in superinterfaces of the class or interface
+     * represented by this {@code Class} object are not considered members of
+     * the class or interface.
+     *
      * @param name the name of the method
      * @param parameterTypes the list of parameters
      * @return the {@code Method} object that matches the specified
@@ -1752,7 +1760,7 @@
     public Method getMethod(String name, Class<?>... parameterTypes)
         throws NoSuchMethodException, SecurityException {
         checkMemberAccess(Member.PUBLIC, Reflection.getCallerClass(), true);
-        Method method = getMethod0(name, parameterTypes);
+        Method method = getMethod0(name, parameterTypes, true);
         if (method == null) {
             throw new NoSuchMethodException(getName() + "." + name + argumentTypesToString(parameterTypes));
         }
@@ -2727,6 +2735,14 @@
             }
         }
 
+        void addAllNonStatic(Method[] methods) {
+            for (Method candidate : methods) {
+                if (!Modifier.isStatic(candidate.getModifiers())) {
+                    add(candidate);
+                }
+            }
+        }
+
         int length() {
             return length;
         }
@@ -2797,7 +2813,7 @@
         MethodArray inheritedMethods = new MethodArray();
         Class<?>[] interfaces = getInterfaces();
         for (int i = 0; i < interfaces.length; i++) {
-            inheritedMethods.addAll(interfaces[i].privateGetPublicMethods());
+            inheritedMethods.addAllNonStatic(interfaces[i].privateGetPublicMethods());
         }
         if (!isInterface()) {
             Class<?> c = getSuperclass();
@@ -2900,7 +2916,7 @@
     }
 
 
-    private Method getMethod0(String name, Class<?>[] parameterTypes) {
+    private Method getMethod0(String name, Class<?>[] parameterTypes, boolean includeStaticMethods) {
         // Note: the intent is that the search algorithm this routine
         // uses be equivalent to the ordering imposed by
         // privateGetPublicMethods(). It fetches only the declared
@@ -2913,25 +2929,23 @@
         if ((res = searchMethods(privateGetDeclaredMethods(true),
                                  name,
                                  parameterTypes)) != null) {
-            return res;
+            if (includeStaticMethods || !Modifier.isStatic(res.getModifiers()))
+                return res;
         }
         // Search superclass's methods
         if (!isInterface()) {
             Class<? super T> c = getSuperclass();
             if (c != null) {
-                if ((res = c.getMethod0(name, parameterTypes)) != null) {
+                if ((res = c.getMethod0(name, parameterTypes, true)) != null) {
                     return res;
                 }
             }
         }
         // Search superinterfaces' methods
         Class<?>[] interfaces = getInterfaces();
-        for (int i = 0; i < interfaces.length; i++) {
-            Class<?> c = interfaces[i];
-            if ((res = c.getMethod0(name, parameterTypes)) != null) {
+        for (Class<?> c : interfaces)
+            if ((res = c.getMethod0(name, parameterTypes, false)) != null)
                 return res;
-            }
-        }
         // Not found
         return null;
     }
--- a/jdk/test/java/lang/reflect/DefaultStaticTest/DefaultStaticInvokeTest.java	Tue Oct 22 03:49:50 2013 +0000
+++ b/jdk/test/java/lang/reflect/DefaultStaticTest/DefaultStaticInvokeTest.java	Tue Oct 22 10:34:12 2013 +0200
@@ -44,30 +44,86 @@
 import static org.testng.Assert.assertTrue;
 import static org.testng.Assert.assertFalse;
 import static org.testng.Assert.assertNotNull;
+import static org.testng.Assert.fail;
 import org.testng.annotations.Test;
 
 import static helper.Mod.*;
 import static helper.Declared.*;
 import helper.Mod;
 
+
 public class DefaultStaticInvokeTest {
 
+    // getMethods(): Make sure getMethods returns the expected methods.
     @Test(dataProvider = "testCasesAll",
             dataProviderClass = DefaultStaticTestData.class)
     public void testGetMethods(String testTarget, Object param)
             throws Exception {
-        // test the methods retrieved by getMethods()
         testMethods(ALL_METHODS, testTarget, param);
     }
 
+
+    // getDeclaredMethods(): Make sure getDeclaredMethods returns the expected methods.
     @Test(dataProvider = "testCasesAll",
             dataProviderClass = DefaultStaticTestData.class)
     public void testGetDeclaredMethods(String testTarget, Object param)
             throws Exception {
-        // test the methods retrieved by getDeclaredMethods()
         testMethods(DECLARED_ONLY, testTarget, param);
     }
 
+
+    // getMethod(): Make sure that getMethod finds all methods it should find.
+    @Test(dataProvider = "testCasesAll",
+            dataProviderClass = DefaultStaticTestData.class)
+    public void testGetMethod(String testTarget, Object param)
+            throws Exception {
+
+        Class<?> typeUnderTest = Class.forName(testTarget);
+
+        MethodDesc[] descs = typeUnderTest.getAnnotationsByType(MethodDesc.class);
+
+        for (MethodDesc desc : descs) {
+            assertTrue(isFoundByGetMethod(typeUnderTest,
+                                          desc.name(),
+                                          argTypes(param)));
+        }
+    }
+
+
+    // getMethod(): Make sure that getMethod does *not* find certain methods.
+    @Test(dataProvider = "testCasesAll",
+            dataProviderClass = DefaultStaticTestData.class)
+    public void testGetMethodSuperInterfaces(String testTarget, Object param)
+            throws Exception {
+
+        // Make sure static methods in superinterfaces are not found (unless the type under
+        // test declares a static method with the same signature).
+
+        Class<?> typeUnderTest = Class.forName(testTarget);
+
+        for (Class<?> interfaze : typeUnderTest.getInterfaces()) {
+
+            for (MethodDesc desc : interfaze.getAnnotationsByType(MethodDesc.class)) {
+
+                boolean isStatic = desc.mod() == STATIC;
+
+                boolean declaredInThisType = isMethodDeclared(typeUnderTest,
+                                                              desc.name());
+
+                boolean expectedToBeFound = !isStatic || declaredInThisType;
+
+                if (expectedToBeFound)
+                    continue; // already tested in testGetMethod()
+
+                assertFalse(isFoundByGetMethod(typeUnderTest,
+                                               desc.name(),
+                                               argTypes(param)));
+            }
+        }
+    }
+
+
+    // Method.invoke(): Make sure Method.invoke returns the expected value.
     @Test(dataProvider = "testCasesAll",
             dataProviderClass = DefaultStaticTestData.class)
     public void testMethodInvoke(String testTarget, Object param)
@@ -78,11 +134,13 @@
         // test the method retrieved by Class.getMethod(String, Object[])
         for (MethodDesc toTest : expectedMethods) {
             String name = toTest.name();
-            Method m = getTestMethod(typeUnderTest, name, param);
+            Method m = typeUnderTest.getMethod(name, argTypes(param));
             testThisMethod(toTest, m, typeUnderTest, param);
         }
     }
 
+
+    // MethodHandle.invoke(): Make sure MethodHandle.invoke returns the expected value.
     @Test(dataProvider = "testCasesAll",
             dataProviderClass = DefaultStaticTestData.class)
     public void testMethodHandleInvoke(String testTarget, Object param)
@@ -116,6 +174,7 @@
 
     }
 
+    // Lookup.findStatic / .findVirtual: Make sure IllegalAccessException is thrown as expected.
     @Test(dataProvider = "testClasses",
             dataProviderClass = DefaultStaticTestData.class)
     public void testIAE(String testTarget, Object param)
@@ -128,7 +187,7 @@
             String mName = toTest.name();
             Mod mod = toTest.mod();
             if (mod != STATIC && typeUnderTest.isInterface()) {
-                return;
+                continue;
             }
             Exception caught = null;
             try {
@@ -136,10 +195,12 @@
             } catch (Exception e) {
                 caught = e;
             }
-            assertTrue(caught != null);
+            assertNotNull(caught);
             assertEquals(caught.getClass(), IllegalAccessException.class);
         }
     }
+
+
     private static final String[] OBJECT_METHOD_NAMES = {
         "equals",
         "hashCode",
@@ -192,15 +253,15 @@
                 myMethods.put(mName, m);
             }
         }
-        assertEquals(expectedMethods.length, myMethods.size());
+
+        assertEquals(myMethods.size(), expectedMethods.length);
 
         for (MethodDesc toTest : expectedMethods) {
 
             String name = toTest.name();
-            Method candidate = myMethods.get(name);
+            Method candidate = myMethods.remove(name);
 
             assertNotNull(candidate);
-            myMethods.remove(name);
 
             testThisMethod(toTest, candidate, typeUnderTest, param);
 
@@ -210,6 +271,7 @@
         assertTrue(myMethods.isEmpty());
     }
 
+
     private void testThisMethod(MethodDesc toTest, Method method,
             Class<?> typeUnderTest, Object param) throws Exception {
         // Test modifiers, and invoke
@@ -256,37 +318,52 @@
                 assertFalse(method.isDefault());
                 break;
             default:
-                assertFalse(true); //this should never happen
+                fail(); //this should never happen
                 break;
         }
 
     }
 
+
+    private boolean isMethodDeclared(Class<?> type, String name) {
+        MethodDesc[] methDescs = type.getAnnotationsByType(MethodDesc.class);
+        for (MethodDesc desc : methDescs) {
+            if (desc.declared() == YES && desc.name().equals(name))
+                return true;
+        }
+        return false;
+    }
+
+
+    private boolean isFoundByGetMethod(Class<?> c, String method, Class<?>... argTypes) {
+        try {
+            c.getMethod(method, argTypes);
+            return true;
+        } catch (NoSuchMethodException notFound) {
+            return false;
+        }
+    }
+
+
+    private Class<?>[] argTypes(Object param) {
+        return param == null ? new Class[0] : new Class[] { Object.class };
+    }
+
+
     private Object tryInvoke(Method m, Class<?> receiverType, Object param)
             throws Exception {
         Object receiver = receiverType == null ? null : receiverType.newInstance();
-        Object result = null;
-        if (param == null) {
-            result = m.invoke(receiver);
-        } else {
-            result = m.invoke(receiver, param);
-        }
-        return result;
+        Object[] args = param == null ? new Object[0] : new Object[] { param };
+        return m.invoke(receiver, args);
     }
 
-    private Method getTestMethod(Class clazz, String methodName, Object param)
-            throws NoSuchMethodException {
-        Class[] paramsType = (param != null)
-                ? new Class[]{Object.class}
-                : new Class[]{};
-        return clazz.getMethod(methodName, paramsType);
-    }
 
     private MethodHandle getTestMH(Class clazz, String methodName, Object param)
             throws Exception {
         return getTestMH(clazz, methodName, param, false);
     }
 
+
     private MethodHandle getTestMH(Class clazz, String methodName,
             Object param, boolean isNegativeTest)
             throws Exception {
--- a/jdk/test/java/lang/reflect/DefaultStaticTest/DefaultStaticTestData.java	Tue Oct 22 03:49:50 2013 +0000
+++ b/jdk/test/java/lang/reflect/DefaultStaticTest/DefaultStaticTestData.java	Tue Oct 22 10:34:12 2013 +0200
@@ -172,7 +172,7 @@
 
 @MethodDesc(name = "defaultMethod", retval = "TestIF8.TestClass8", mod = DEFAULT, declared = NO)
 class TestClass8<T> implements TestIF8<T> {
-};
+}
 
 @MethodDesc(name = "defaultMethod", retval = "TestIF9.defaultMethod", mod = DEFAULT, declared = YES)
 interface TestIF9 extends TestIF1 {
@@ -218,7 +218,6 @@
 }
 
 @MethodDesc(name = "defaultMethod", retval = "TestIF12.defaultMethod", mod = DEFAULT, declared = YES)
-@MethodDesc(name = "staticMethod", retval = "TestIF2.staticMethod", mod = STATIC, declared = NO)
 interface TestIF12 extends TestIF2 {
 
     default String defaultMethod() {
@@ -299,7 +298,7 @@
 
 @MethodDesc(name = "defaultMethod", retval = "TestIF16.defaultMethod", mod = DEFAULT, declared = NO)
 class TestClass16 implements TestIF16 {
-};
+}
 
 @MethodDesc(name = "defaultMethod", retval = "TestIF17.defaultMethod", mod = DEFAULT, declared = YES)
 @MethodDesc(name = "staticMethod", retval = "TestIF17.staticMethod", mod = STATIC, declared = YES)
@@ -318,6 +317,12 @@
 class TestClass17 implements TestIF17 {
 }
 
+
+@MethodDesc(name = "defaultMethod", retval = "TestIF17.defaultMethod", mod = DEFAULT, declared = NO)
+class TestClass18 extends TestClass17 {
+}
+
+
 @Retention(RetentionPolicy.RUNTIME)
 @Repeatable(MethodDescs.class)
 @interface MethodDesc {
@@ -332,6 +337,41 @@
     MethodDesc[] value();
 }
 
+//Diamond Case for static method
+@MethodDesc(name = "staticMethod", retval = "TestIF2A.staticMethod", mod = STATIC, declared = YES)
+interface TestIF2A extends TestIF2 {
+    static String staticMethod() {
+        return "TestIF2A.staticMethod";
+    }
+}
+
+@MethodDesc(name = "method", retval = "", mod = ABSTRACT, declared = YES)
+interface TestIF2B extends TestIF2 {
+    String method();
+}
+
+@MethodDesc(name = "method", retval = "", mod = ABSTRACT, declared = YES)
+interface TestIF18 extends TestIF10, TestIF2A {
+    String method();
+}
+
+@MethodDesc(name = "method", retval = "", mod = ABSTRACT, declared = NO)
+@MethodDesc(name = "defaultMethod", retval = "TestIF12.defaultMethod", mod = DEFAULT, declared = NO)
+interface TestIF19 extends TestIF12, TestIF2B {
+}
+
+@MethodDesc(name = "staticMethod", retval = "TestIF20.staticMethod", mod = STATIC, declared = YES)
+@MethodDesc(name = "defaultMethod", retval = "TestIF12.defaultMethod", mod = DEFAULT, declared = NO)
+interface TestIF20 extends TestIF12, TestIF2A {
+    static String staticMethod() {
+        return "TestIF20.staticMethod";
+    }
+}
+
+@MethodDesc(name = "method", retval = "", mod = ABSTRACT, declared = NO)
+interface TestIF21 extends TestIF2A, TestIF2B {
+}
+
 public class DefaultStaticTestData {
 
     /**
@@ -343,22 +383,23 @@
     static Object[][] testClasses() {
         return new Object[][]{
             {"TestClass1", null},
-            //{"TestClass2", null}, @ignore due to JDK-8009411
+            {"TestClass2", null},
             {"TestClass3", null},
-            //{"TestClass4", null}, @ignore due to JDK-8009411
-            //{"TestClass5", null}, @ignore due to JDK-8009411
-            //{"TestClass6", null}, @ignore due to JDK-8009411
+            {"TestClass4", null},
+            {"TestClass5", null},
+            {"TestClass6", null},
             {"TestClass7", "TestIF7.TestClass7"},
             {"TestClass8", "TestIF8.TestClass8"},
             {"TestClass9", null},
             {"TestClass91", null},
-            //{"TestClass11", null}, @ignore due to JDK-8009411
-            //{"TestClass12", null}, @ignore due to JDK-8009411
+            {"TestClass11", null},
+            {"TestClass12", null},
             {"TestClass13", null},
             {"TestClass14", null},
             {"TestClass15", null},
-            {"TestClass16", null}
-        //{"TestClass17", null} @ignore due to JDK-8009411
+            {"TestClass16", null},
+            {"TestClass17", null},
+            {"TestClass18", null},
         };
     }
 
@@ -372,6 +413,8 @@
         return new Object[][]{
             {"TestIF1", null},
             {"TestIF2", null},
+            {"TestIF2A", null},
+            {"TestIF2B", null},
             {"TestIF3", null},
             {"TestIF4", null},
             {"TestIF5", null},
@@ -388,7 +431,12 @@
             {"TestIF1D", null},
             {"TestIF15", null},
             {"TestIF16", null},
-            {"TestIF17", null},};
+            {"TestIF17", null},
+            {"TestIF18", null},
+            {"TestIF19", null},
+            {"TestIF20", null},
+            {"TestIF21", null},
+        };
     }
 
     @DataProvider
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/jdk/test/java/lang/reflect/Method/InterfaceStatic/StaticInterfaceMethodInWayOfDefault.java	Tue Oct 22 10:34:12 2013 +0200
@@ -0,0 +1,183 @@
+/*
+ * Copyright (c) 2013, Oracle and/or its affiliates. All rights reserved.
+ * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
+ *
+ * This code is free software; you can redistribute it and/or modify it
+ * under the terms of the GNU General Public License version 2 only, as
+ * published by the Free Software Foundation.
+ *
+ * This code is distributed in the hope that it will be useful, but WITHOUT
+ * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
+ * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
+ * version 2 for more details (a copy is included in the LICENSE file that
+ * accompanied this code).
+ *
+ * You should have received a copy of the GNU General Public License version
+ * 2 along with this work; if not, write to the Free Software Foundation,
+ * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
+ *
+ * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
+ * or visit www.oracle.com if you need additional information or have any
+ * questions.
+ */
+
+/*
+ * @test
+ * @bug 8009411
+ * @summary Test that a static method on an interface doesn't hide a default
+ *          method with the same name and signature in a separate compilation
+ *          scenario.
+ */
+
+import java.io.IOException;
+import java.io.InputStream;
+import java.lang.reflect.InvocationTargetException;
+import java.lang.reflect.Method;
+import java.util.concurrent.Callable;
+
+import sun.misc.IOUtils;
+
+public class StaticInterfaceMethodInWayOfDefault {
+    public interface A_v1 {
+    }
+
+    public interface A_v2 {
+        default void m() {
+            System.err.println("A.m() called");
+        }
+    }
+
+    public interface B  extends A_v1 {
+        static void m() {
+            System.err.println("B.m() called");
+        }
+    }
+
+    public interface C_v1 extends B {
+        default void m() {
+            System.err.println("C.m() called");
+        }
+    }
+
+    public interface C_v2 extends B {
+    }
+
+    public static class TestTask implements Callable<String> {
+        @Override
+        public String call() {
+            try {
+                Method m = C_v1.class.getMethod("m", (Class<?>[])null);
+                return  m.getDeclaringClass().getSimpleName();
+            } catch (NoSuchMethodException e) {
+                System.err.println("Couldn't find method");
+                return "ERROR";
+            }
+        }
+    }
+
+    public static void main(String[] args) throws Exception {
+        int errors = 0;
+        Callable<String> v1Task = new TestTask();
+
+        ClassLoader v2Loader = new V2ClassLoader(
+            StaticInterfaceMethodInWayOfDefault.class.getClassLoader());
+        Callable<String> v2Task = (Callable<String>) Class.forName(
+            TestTask.class.getName(),
+            true,
+            v2Loader).newInstance();
+
+        System.err.println("Running using _v1 classes:");
+        String res = v1Task.call();
+        if(!res.equals("C_v1")) {
+            System.err.println("Got wrong method, expecting C_v1, got: " + res);
+            errors++;
+        }
+
+        System.err.println("Running using _v2 classes:");
+        res = v2Task.call();
+        if(!res.equals("A_v1")) {
+            System.err.println("Got wrong method, expecting A_v1, got: " + res);
+            errors++;
+        }
+
+        if (errors != 0)
+            throw new RuntimeException("Errors found, check log for details");
+    }
+
+    /**
+     * A ClassLoader implementation that loads alternative implementations of
+     * classes. If class name ends with "_v1" it locates instead a class with
+     * name ending with "_v2" and loads that class instead.
+     */
+    static class V2ClassLoader extends ClassLoader {
+        V2ClassLoader(ClassLoader parent) {
+            super(parent);
+        }
+
+        @Override
+        protected Class<?> loadClass(String name, boolean resolve)
+            throws ClassNotFoundException {
+            if (name.indexOf('.') < 0) { // root package is our class
+                synchronized (getClassLoadingLock(name)) {
+                    // First, check if the class has already been loaded
+                    Class<?> c = findLoadedClass(name);
+                    if (c == null) {
+                        c = findClass(name);
+                    }
+                    if (resolve) {
+                        resolveClass(c);
+                    }
+                    return c;
+                }
+            }
+            else { // not our class
+                return super.loadClass(name, resolve);
+            }
+        }
+
+        @Override
+        protected Class<?> findClass(String name)
+            throws ClassNotFoundException {
+            // special class name -> replace it with alternative name
+            if (name.endsWith("_v1")) {
+                String altName = name.substring(0, name.length() - 3) + "_v2";
+                String altPath = altName.replace('.', '/').concat(".class");
+                try (InputStream is = getResourceAsStream(altPath)) {
+                    if (is != null) {
+                        byte[] bytes = IOUtils.readFully(is, -1, true);
+                        // patch class bytes to contain original name
+                        for (int i = 0; i < bytes.length - 2; i++) {
+                            if (bytes[i] == '_' &&
+                                bytes[i + 1] == 'v' &&
+                                bytes[i + 2] == '2') {
+                                bytes[i + 2] = '1';
+                            }
+                        }
+                        return defineClass(name, bytes, 0, bytes.length);
+                    }
+                    else {
+                        throw new ClassNotFoundException(name);
+                    }
+                }
+                catch (IOException e) {
+                    throw new ClassNotFoundException(name, e);
+                }
+            }
+            else { // not special class name -> just load the class
+                String path = name.replace('.', '/').concat(".class");
+                try (InputStream is = getResourceAsStream(path)) {
+                    if (is != null) {
+                        byte[] bytes = IOUtils.readFully(is, -1, true);
+                        return defineClass(name, bytes, 0, bytes.length);
+                    }
+                    else {
+                        throw new ClassNotFoundException(name);
+                    }
+                }
+                catch (IOException e) {
+                    throw new ClassNotFoundException(name, e);
+                }
+            }
+        }
+    }
+}