8174399: LambdaMetafactory should use types in implMethod.type()
authordlsmith
Mon, 13 Feb 2017 10:47:15 -0700
changeset 43789 43068ea5965e
parent 43788 22a618ec8268
child 43790 b9e56c7fba7e
child 43831 a6e823534165
8174399: LambdaMetafactory should use types in implMethod.type() Reviewed-by: psandoz
jdk/src/java.base/share/classes/java/lang/invoke/AbstractValidatingLambdaMetafactory.java
jdk/src/java.base/share/classes/java/lang/invoke/InnerClassLambdaMetafactory.java
jdk/test/java/lang/invoke/lambda/InheritedMethodTest.java
--- a/jdk/src/java.base/share/classes/java/lang/invoke/AbstractValidatingLambdaMetafactory.java	Mon Feb 13 17:18:48 2017 +0000
+++ b/jdk/src/java.base/share/classes/java/lang/invoke/AbstractValidatingLambdaMetafactory.java	Mon Feb 13 10:47:15 2017 -0700
@@ -26,6 +26,7 @@
 
 import sun.invoke.util.Wrapper;
 
+import static java.lang.invoke.MethodHandleInfo.*;
 import static sun.invoke.util.Wrapper.forPrimitiveType;
 import static sun.invoke.util.Wrapper.forWrapperType;
 import static sun.invoke.util.Wrapper.isWrapperType;
@@ -56,11 +57,11 @@
     final String samMethodName;               // Name of the SAM method "foo"
     final MethodType samMethodType;           // Type of the SAM method "(Object)Object"
     final MethodHandle implMethod;            // Raw method handle for the implementation method
+    final MethodType implMethodType;          // Type of the implMethod MethodHandle "(CC,int)String"
     final MethodHandleInfo implInfo;          // Info about the implementation method handle "MethodHandleInfo[5 CC.impl(int)String]"
     final int implKind;                       // Invocation kind for implementation "5"=invokevirtual
     final boolean implIsInstanceMethod;       // Is the implementation an instance method "true"
-    final Class<?> implDefiningClass;         // Type defining the implementation "class CC"
-    final MethodType implMethodType;          // Type of the implementation method "(int)String"
+    final Class<?> implClass;                 // Class for referencing the implementation method "class CC"
     final MethodType instantiatedMethodType;  // Instantiated erased functional interface method type "(Integer)Object"
     final boolean isSerializable;             // Should the returned instance be serializable
     final Class<?>[] markerInterfaces;        // Additional marker interfaces to be implemented
@@ -128,14 +129,34 @@
         this.samMethodType  = samMethodType;
 
         this.implMethod = implMethod;
+        this.implMethodType = implMethod.type();
         this.implInfo = caller.revealDirect(implMethod);
-        this.implKind = implInfo.getReferenceKind();
-        this.implIsInstanceMethod =
-                implKind == MethodHandleInfo.REF_invokeVirtual ||
-                implKind == MethodHandleInfo.REF_invokeSpecial ||
-                implKind == MethodHandleInfo.REF_invokeInterface;
-        this.implDefiningClass = implInfo.getDeclaringClass();
-        this.implMethodType = implInfo.getMethodType();
+        switch (implInfo.getReferenceKind()) {
+            case REF_invokeVirtual:
+            case REF_invokeInterface:
+                this.implClass = implMethodType.parameterType(0);
+                // reference kind reported by implInfo may not match implMethodType's first param
+                // Example: implMethodType is (Cloneable)String, implInfo is for Object.toString
+                this.implKind = implClass.isInterface() ? REF_invokeInterface : REF_invokeVirtual;
+                this.implIsInstanceMethod = true;
+                break;
+            case REF_invokeSpecial:
+                // JDK-8172817: should use referenced class here, but we don't know what it was
+                this.implClass = implInfo.getDeclaringClass();
+                this.implKind = REF_invokeSpecial;
+                this.implIsInstanceMethod = true;
+                break;
+            case REF_invokeStatic:
+            case REF_newInvokeSpecial:
+                // JDK-8172817: should use referenced class here for invokestatic, but we don't know what it was
+                this.implClass = implInfo.getDeclaringClass();
+                this.implKind = implInfo.getReferenceKind();
+                this.implIsInstanceMethod = false;
+                break;
+            default:
+                throw new LambdaConversionException(String.format("Unsupported MethodHandle kind: %s", implInfo));
+        }
+
         this.instantiatedMethodType = instantiatedMethodType;
         this.isSerializable = isSerializable;
         this.markerInterfaces = markerInterfaces;
@@ -183,24 +204,12 @@
      * @throws LambdaConversionException if there are improper conversions
      */
     void validateMetafactoryArgs() throws LambdaConversionException {
-        switch (implKind) {
-            case MethodHandleInfo.REF_invokeInterface:
-            case MethodHandleInfo.REF_invokeVirtual:
-            case MethodHandleInfo.REF_invokeStatic:
-            case MethodHandleInfo.REF_newInvokeSpecial:
-            case MethodHandleInfo.REF_invokeSpecial:
-                break;
-            default:
-                throw new LambdaConversionException(String.format("Unsupported MethodHandle kind: %s", implInfo));
-        }
-
-        // Check arity: optional-receiver + captured + SAM == impl
+        // Check arity: captured + SAM == impl
         final int implArity = implMethodType.parameterCount();
-        final int receiverArity = implIsInstanceMethod ? 1 : 0;
         final int capturedArity = invokedType.parameterCount();
         final int samArity = samMethodType.parameterCount();
         final int instantiatedArity = instantiatedMethodType.parameterCount();
-        if (implArity + receiverArity != capturedArity + samArity) {
+        if (implArity != capturedArity + samArity) {
             throw new LambdaConversionException(
                     String.format("Incorrect number of parameters for %s method %s; %d captured parameters, %d functional interface method parameters, %d implementation parameters",
                                   implIsInstanceMethod ? "instance" : "static", implInfo,
@@ -221,8 +230,8 @@
         }
 
         // If instance: first captured arg (receiver) must be subtype of class where impl method is defined
-        final int capturedStart;
-        final int samStart;
+        final int capturedStart; // index of first non-receiver capture parameter in implMethodType
+        final int samStart; // index of first non-receiver sam parameter in implMethodType
         if (implIsInstanceMethod) {
             final Class<?> receiverClass;
 
@@ -235,45 +244,36 @@
             } else {
                 // receiver is a captured variable
                 capturedStart = 1;
-                samStart = 0;
+                samStart = capturedArity;
                 receiverClass = invokedType.parameterType(0);
             }
 
             // check receiver type
-            if (!implDefiningClass.isAssignableFrom(receiverClass)) {
+            if (!implClass.isAssignableFrom(receiverClass)) {
                 throw new LambdaConversionException(
                         String.format("Invalid receiver type %s; not a subtype of implementation type %s",
-                                      receiverClass, implDefiningClass));
-            }
-
-           Class<?> implReceiverClass = implMethod.type().parameterType(0);
-           if (implReceiverClass != implDefiningClass && !implReceiverClass.isAssignableFrom(receiverClass)) {
-               throw new LambdaConversionException(
-                       String.format("Invalid receiver type %s; not a subtype of implementation receiver type %s",
-                                     receiverClass, implReceiverClass));
+                                      receiverClass, implClass));
             }
         } else {
             // no receiver
             capturedStart = 0;
-            samStart = 0;
+            samStart = capturedArity;
         }
 
         // Check for exact match on non-receiver captured arguments
-        final int implFromCaptured = capturedArity - capturedStart;
-        for (int i=0; i<implFromCaptured; i++) {
+        for (int i=capturedStart; i<capturedArity; i++) {
             Class<?> implParamType = implMethodType.parameterType(i);
-            Class<?> capturedParamType = invokedType.parameterType(i + capturedStart);
+            Class<?> capturedParamType = invokedType.parameterType(i);
             if (!capturedParamType.equals(implParamType)) {
                 throw new LambdaConversionException(
                         String.format("Type mismatch in captured lambda parameter %d: expecting %s, found %s",
                                       i, capturedParamType, implParamType));
             }
         }
-        // Check for adaptation match on SAM arguments
-        final int samOffset = samStart - implFromCaptured;
-        for (int i=implFromCaptured; i<implArity; i++) {
+        // Check for adaptation match on non-receiver SAM arguments
+        for (int i=samStart; i<implArity; i++) {
             Class<?> implParamType = implMethodType.parameterType(i);
-            Class<?> instantiatedParamType = instantiatedMethodType.parameterType(i + samOffset);
+            Class<?> instantiatedParamType = instantiatedMethodType.parameterType(i - capturedArity);
             if (!isAdaptableTo(instantiatedParamType, implParamType, true)) {
                 throw new LambdaConversionException(
                         String.format("Type mismatch for lambda argument %d: %s is not convertible to %s",
@@ -283,10 +283,7 @@
 
         // Adaptation match: return type
         Class<?> expectedType = instantiatedMethodType.returnType();
-        Class<?> actualReturnType =
-                (implKind == MethodHandleInfo.REF_newInvokeSpecial)
-                  ? implDefiningClass
-                  : implMethodType.returnType();
+        Class<?> actualReturnType = implMethodType.returnType();
         if (!isAdaptableToAsReturn(actualReturnType, expectedType)) {
             throw new LambdaConversionException(
                     String.format("Type mismatch for lambda return: %s is not convertible to %s",
--- a/jdk/src/java.base/share/classes/java/lang/invoke/InnerClassLambdaMetafactory.java	Mon Feb 13 17:18:48 2017 +0000
+++ b/jdk/src/java.base/share/classes/java/lang/invoke/InnerClassLambdaMetafactory.java	Mon Feb 13 10:47:15 2017 -0700
@@ -96,7 +96,6 @@
     private final String implMethodClassName;        // Name of type containing implementation "CC"
     private final String implMethodName;             // Name of implementation method "impl"
     private final String implMethodDesc;             // Type descriptor for implementation methods "(I)Ljava/lang/String;"
-    private final Class<?> implMethodReturnClass;    // class for implementation method return type "Ljava/lang/String;"
     private final MethodType constructorType;        // Generated class constructor type "(CC)void"
     private final ClassWriter cw;                    // ASM class writer
     private final String[] argNames;                 // Generated names for the constructor arguments
@@ -153,12 +152,9 @@
         super(caller, invokedType, samMethodName, samMethodType,
               implMethod, instantiatedMethodType,
               isSerializable, markerInterfaces, additionalBridges);
-        implMethodClassName = implDefiningClass.getName().replace('.', '/');
+        implMethodClassName = implClass.getName().replace('.', '/');
         implMethodName = implInfo.getName();
-        implMethodDesc = implMethodType.toMethodDescriptorString();
-        implMethodReturnClass = (implKind == MethodHandleInfo.REF_newInvokeSpecial)
-                ? implDefiningClass
-                : implMethodType.returnType();
+        implMethodDesc = implInfo.getMethodType().toMethodDescriptorString();
         constructorType = invokedType.changeReturnType(Void.TYPE);
         lambdaClassName = targetClass.getName().replace('.', '/') + "$$Lambda$" + counter.incrementAndGet();
         cw = new ClassWriter(ClassWriter.COMPUTE_MAXS);
@@ -467,13 +463,14 @@
             // Invoke the method we want to forward to
             visitMethodInsn(invocationOpcode(), implMethodClassName,
                             implMethodName, implMethodDesc,
-                            implDefiningClass.isInterface());
+                            implClass.isInterface());
 
             // Convert the return value (if any) and return it
             // Note: if adapting from non-void to void, the 'return'
             // instruction will pop the unneeded result
+            Class<?> implReturnClass = implMethodType.returnType();
             Class<?> samReturnClass = methodType.returnType();
-            convertType(implMethodReturnClass, samReturnClass, samReturnClass);
+            convertType(implReturnClass, samReturnClass, samReturnClass);
             visitInsn(getReturnOpcode(samReturnClass));
             // Maxs computed by ClassWriter.COMPUTE_MAXS,these arguments ignored
             visitMaxs(-1, -1);
@@ -482,23 +479,13 @@
 
         private void convertArgumentTypes(MethodType samType) {
             int lvIndex = 0;
-            boolean samIncludesReceiver = implIsInstanceMethod &&
-                                                   invokedType.parameterCount() == 0;
-            int samReceiverLength = samIncludesReceiver ? 1 : 0;
-            if (samIncludesReceiver) {
-                // push receiver
-                Class<?> rcvrType = samType.parameterType(0);
-                visitVarInsn(getLoadOpcode(rcvrType), lvIndex + 1);
-                lvIndex += getParameterSize(rcvrType);
-                convertType(rcvrType, implDefiningClass, instantiatedMethodType.parameterType(0));
-            }
             int samParametersLength = samType.parameterCount();
-            int argOffset = implMethodType.parameterCount() - samParametersLength;
-            for (int i = samReceiverLength; i < samParametersLength; i++) {
+            int captureArity = invokedType.parameterCount();
+            for (int i = 0; i < samParametersLength; i++) {
                 Class<?> argType = samType.parameterType(i);
                 visitVarInsn(getLoadOpcode(argType), lvIndex + 1);
                 lvIndex += getParameterSize(argType);
-                convertType(argType, implMethodType.parameterType(argOffset + i), instantiatedMethodType.parameterType(i));
+                convertType(argType, implMethodType.parameterType(captureArity + i), instantiatedMethodType.parameterType(i));
             }
         }
 
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/jdk/test/java/lang/invoke/lambda/InheritedMethodTest.java	Mon Feb 13 10:47:15 2017 -0700
@@ -0,0 +1,85 @@
+/*
+ * Copyright (c) 2017, Oracle and/or its affiliates. All rights reserved.
+ * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
+ *
+ * This code is free software; you can redistribute it and/or modify it
+ * under the terms of the GNU General Public License version 2 only, as
+ * published by the Free Software Foundation.
+ *
+ * This code is distributed in the hope that it will be useful, but WITHOUT
+ * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
+ * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
+ * version 2 for more details (a copy is included in the LICENSE file that
+ * accompanied this code).
+ *
+ * You should have received a copy of the GNU General Public License version
+ * 2 along with this work; if not, write to the Free Software Foundation,
+ * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
+ *
+ * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
+ * or visit www.oracle.com if you need additional information or have any
+ * questions.
+ */
+
+/*
+ * @test
+ * @bug 8174399
+ * @summary LambdaMetafactory should be able to handle inherited methods as 'implMethod'
+ */
+import java.lang.ReflectiveOperationException;
+import java.lang.invoke.*;
+
+public class InheritedMethodTest {
+
+    public static MethodType mt(Class<?> ret, Class<?>... params) { return MethodType.methodType(ret, params); }
+
+    public interface StringFactory {
+        String get();
+    }
+
+    public interface I {
+        String iString();
+    }
+
+    public interface J extends I {}
+
+    public static abstract class C implements I {}
+
+    public static class D extends C implements J {
+        public String toString() { return "a"; }
+        public String iString() { return "b"; }
+    }
+
+    private static final MethodHandles.Lookup lookup = MethodHandles.lookup();
+
+    public static void main(String... args) throws Throwable {
+        test(lookup.findVirtual(C.class, "toString", mt(String.class)), "a");
+        test(lookup.findVirtual(C.class, "iString", mt(String.class)), "b");
+        test(lookup.findVirtual(J.class, "toString", mt(String.class)), "a");
+        test(lookup.findVirtual(J.class, "iString", mt(String.class)), "b");
+        test(lookup.findVirtual(I.class, "toString", mt(String.class)), "a");
+        test(lookup.findVirtual(I.class, "iString", mt(String.class)), "b");
+    }
+
+    static void test(MethodHandle implMethod, String expected) throws Throwable {
+        testMetafactory(implMethod, expected);
+        testAltMetafactory(implMethod, expected);
+    }
+
+    static void testMetafactory(MethodHandle implMethod, String expected) throws Throwable {
+        CallSite cs = LambdaMetafactory.metafactory(lookup, "get", mt(StringFactory.class, D.class), mt(String.class),
+                                                    implMethod, mt(String.class));
+        StringFactory factory = (StringFactory) cs.dynamicInvoker().invokeExact(new D());
+        String actual = factory.get();
+        if (!expected.equals(actual)) throw new AssertionError("Unexpected result: " + actual);
+    }
+
+    static void testAltMetafactory(MethodHandle implMethod, String expected) throws Throwable {
+        CallSite cs = LambdaMetafactory.altMetafactory(lookup, "get", mt(StringFactory.class, D.class), mt(String.class),
+                                                       implMethod, mt(String.class), LambdaMetafactory.FLAG_SERIALIZABLE);
+        StringFactory factory = (StringFactory) cs.dynamicInvoker().invokeExact(new D());
+        String actual = factory.get();
+        if (!expected.equals(actual)) throw new AssertionError("Unexpected result: " + actual);
+    }
+
+}