src/java.base/share/classes/java/lang/invoke/DirectMethodHandle.java
changeset 49935 2ace90aec488
parent 48557 2e867226b914
child 50735 2f2af62dfac7
--- a/src/java.base/share/classes/java/lang/invoke/DirectMethodHandle.java	Mon Apr 30 18:10:24 2018 -0700
+++ b/src/java.base/share/classes/java/lang/invoke/DirectMethodHandle.java	Mon Apr 30 21:56:54 2018 -0400
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2008, 2013, Oracle and/or its affiliates. All rights reserved.
+ * Copyright (c) 2008, 2018, Oracle and/or its affiliates. All rights reserved.
  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
  *
  * This code is free software; you can redistribute it and/or modify it
@@ -72,26 +72,30 @@
     }
 
     // Factory methods:
-    static DirectMethodHandle make(byte refKind, Class<?> receiver, MemberName member) {
+    static DirectMethodHandle make(byte refKind, Class<?> refc, MemberName member, Class<?> callerClass) {
         MethodType mtype = member.getMethodOrFieldType();
         if (!member.isStatic()) {
-            if (!member.getDeclaringClass().isAssignableFrom(receiver) || member.isConstructor())
+            if (!member.getDeclaringClass().isAssignableFrom(refc) || member.isConstructor())
                 throw new InternalError(member.toString());
-            mtype = mtype.insertParameterTypes(0, receiver);
+            mtype = mtype.insertParameterTypes(0, refc);
         }
         if (!member.isField()) {
             switch (refKind) {
                 case REF_invokeSpecial: {
                     member = member.asSpecial();
-                    LambdaForm lform = preparedLambdaForm(member);
-                    return new Special(mtype, lform, member);
+                    LambdaForm lform = preparedLambdaForm(member, callerClass);
+                    Class<?> checkClass = refc;  // Class to use for receiver type check
+                    if (callerClass != null) {
+                        checkClass = callerClass;  // potentially strengthen to caller class
+                    }
+                    return new Special(mtype, lform, member, checkClass);
                 }
                 case REF_invokeInterface: {
-                    LambdaForm lform = preparedLambdaForm(member);
-                    return new Interface(mtype, lform, member, receiver);
+                    LambdaForm lform = preparedLambdaForm(member, callerClass);
+                    return new Interface(mtype, lform, member, refc);
                 }
                 default: {
-                    LambdaForm lform = preparedLambdaForm(member);
+                    LambdaForm lform = preparedLambdaForm(member, callerClass);
                     return new DirectMethodHandle(mtype, lform, member);
                 }
             }
@@ -108,11 +112,11 @@
             }
         }
     }
-    static DirectMethodHandle make(Class<?> receiver, MemberName member) {
+    static DirectMethodHandle make(Class<?> refc, MemberName member) {
         byte refKind = member.getReferenceKind();
         if (refKind == REF_invokeSpecial)
             refKind =  REF_invokeVirtual;
-        return make(refKind, receiver, member);
+        return make(refKind, refc, member, null /* no callerClass context */);
     }
     static DirectMethodHandle make(MemberName member) {
         if (member.isConstructor())
@@ -161,7 +165,7 @@
      * Cache and share this structure among all methods with
      * the same basicType and refKind.
      */
-    private static LambdaForm preparedLambdaForm(MemberName m) {
+    private static LambdaForm preparedLambdaForm(MemberName m, Class<?> callerClass) {
         assert(m.isInvocable()) : m;  // call preparedFieldLambdaForm instead
         MethodType mtype = m.getInvocationType().basicType();
         assert(!m.isMethodHandleInvoke()) : m;
@@ -179,6 +183,9 @@
             preparedLambdaForm(mtype, which);
             which = LF_INVSTATIC_INIT;
         }
+        if (which == LF_INVSPECIAL && callerClass != null && callerClass.isInterface()) {
+            which = LF_INVSPECIAL_IFC;
+        }
         LambdaForm lform = preparedLambdaForm(mtype, which);
         maybeCompile(lform, m);
         assert(lform.methodType().dropParameterTypes(0, 1)
@@ -187,6 +194,10 @@
         return lform;
     }
 
+    private static LambdaForm preparedLambdaForm(MemberName m) {
+        return preparedLambdaForm(m, null);
+    }
+
     private static LambdaForm preparedLambdaForm(MethodType mtype, int which) {
         LambdaForm lform = mtype.form().cachedLambdaForm(which);
         if (lform != null)  return lform;
@@ -197,13 +208,16 @@
     static LambdaForm makePreparedLambdaForm(MethodType mtype, int which) {
         boolean needsInit = (which == LF_INVSTATIC_INIT);
         boolean doesAlloc = (which == LF_NEWINVSPECIAL);
-        boolean needsReceiverCheck = (which == LF_INVINTERFACE);
+        boolean needsReceiverCheck = (which == LF_INVINTERFACE ||
+                                      which == LF_INVSPECIAL_IFC);
+
         String linkerName;
         LambdaForm.Kind kind;
         switch (which) {
         case LF_INVVIRTUAL:    linkerName = "linkToVirtual";   kind = DIRECT_INVOKE_VIRTUAL;     break;
         case LF_INVSTATIC:     linkerName = "linkToStatic";    kind = DIRECT_INVOKE_STATIC;      break;
         case LF_INVSTATIC_INIT:linkerName = "linkToStatic";    kind = DIRECT_INVOKE_STATIC_INIT; break;
+        case LF_INVSPECIAL_IFC:linkerName = "linkToSpecial";   kind = DIRECT_INVOKE_SPECIAL_IFC; break;
         case LF_INVSPECIAL:    linkerName = "linkToSpecial";   kind = DIRECT_INVOKE_SPECIAL;     break;
         case LF_INVINTERFACE:  linkerName = "linkToInterface"; kind = DIRECT_INVOKE_INTERFACE;   break;
         case LF_NEWINVSPECIAL: linkerName = "linkToSpecial";   kind = DIRECT_NEW_INVOKE_SPECIAL; break;
@@ -376,8 +390,10 @@
 
     /** This subclass represents invokespecial instructions. */
     static class Special extends DirectMethodHandle {
-        private Special(MethodType mtype, LambdaForm form, MemberName member) {
+        private final Class<?> caller;
+        private Special(MethodType mtype, LambdaForm form, MemberName member, Class<?> caller) {
             super(mtype, form, member);
+            this.caller = caller;
         }
         @Override
         boolean isInvokeSpecial() {
@@ -385,7 +401,15 @@
         }
         @Override
         MethodHandle copyWith(MethodType mt, LambdaForm lf) {
-            return new Special(mt, lf, member);
+            return new Special(mt, lf, member, caller);
+        }
+        Object checkReceiver(Object recv) {
+            if (!caller.isInstance(recv)) {
+                String msg = String.format("Receiver class %s is not a subclass of caller class %s",
+                                           recv.getClass().getName(), caller.getName());
+                throw new IncompatibleClassChangeError(msg);
+            }
+            return recv;
         }
     }
 
@@ -401,17 +425,23 @@
         MethodHandle copyWith(MethodType mt, LambdaForm lf) {
             return new Interface(mt, lf, member, refc);
         }
-
+        @Override
         Object checkReceiver(Object recv) {
             if (!refc.isInstance(recv)) {
-                String msg = String.format("Class %s does not implement the requested interface %s",
-                        recv.getClass().getName(), refc.getName());
+                String msg = String.format("Receiver class %s does not implement the requested interface %s",
+                                           recv.getClass().getName(), refc.getName());
                 throw new IncompatibleClassChangeError(msg);
             }
             return recv;
         }
     }
 
+    /** Used for interface receiver type checks, by Interface and Special modes. */
+    Object checkReceiver(Object recv) {
+        throw new InternalError("Should only be invoked on a subclass");
+    }
+
+
     /** This subclass handles constructor references. */
     static class Constructor extends DirectMethodHandle {
         final MemberName initMethod;
@@ -823,10 +853,10 @@
                             MemberName.getFactory()
                                     .resolveOrFail(REF_getField, member, DirectMethodHandle.class, NoSuchMethodException.class));
                 case NF_checkReceiver:
-                    member = new MemberName(Interface.class, "checkReceiver", OBJ_OBJ_TYPE, REF_invokeVirtual);
+                    member = new MemberName(DirectMethodHandle.class, "checkReceiver", OBJ_OBJ_TYPE, REF_invokeVirtual);
                     return new NamedFunction(
                         MemberName.getFactory()
-                            .resolveOrFail(REF_invokeVirtual, member, Interface.class, NoSuchMethodException.class));
+                            .resolveOrFail(REF_invokeVirtual, member, DirectMethodHandle.class, NoSuchMethodException.class));
                 default:
                     throw newInternalError("Unknown function: " + func);
             }