8054492: Casting can result in redundant null checks in generated code
authorkvn
Fri, 31 Oct 2014 16:51:57 -0700
changeset 27450 603dbcf4f547
parent 27449 7ed11cfd9be8
child 27451 7e2e2b955d15
8054492: Casting can result in redundant null checks in generated code Summary: add C2 intrinsic for Class.cast() method and force inline it too. Reviewed-by: jrose, roland, drchase, iignatyev
hotspot/src/share/vm/classfile/vmSymbols.hpp
hotspot/src/share/vm/oops/method.cpp
hotspot/src/share/vm/opto/library_call.cpp
hotspot/src/share/vm/prims/whitebox.cpp
hotspot/test/TEST.groups
hotspot/test/compiler/intrinsics/classcast/NullCheckDroppingsTest.java
hotspot/test/testlibrary/whitebox/sun/hotspot/code/NMethod.java
--- a/hotspot/src/share/vm/classfile/vmSymbols.hpp	Fri Oct 31 12:01:27 2014 -1000
+++ b/hotspot/src/share/vm/classfile/vmSymbols.hpp	Fri Oct 31 16:51:57 2014 -0700
@@ -455,6 +455,7 @@
   template(object_void_signature,                     "(Ljava/lang/Object;)V")                    \
   template(object_int_signature,                      "(Ljava/lang/Object;)I")                    \
   template(object_boolean_signature,                  "(Ljava/lang/Object;)Z")                    \
+  template(object_object_signature,                   "(Ljava/lang/Object;)Ljava/lang/Object;")   \
   template(string_void_signature,                     "(Ljava/lang/String;)V")                    \
   template(string_int_signature,                      "(Ljava/lang/String;)I")                    \
   template(throwable_void_signature,                  "(Ljava/lang/Throwable;)V")                 \
@@ -746,6 +747,8 @@
    do_name(     isPrimitive_name,                                "isPrimitive")                                         \
   do_intrinsic(_getSuperclass,            java_lang_Class,        getSuperclass_name, void_class_signature,      F_RN)  \
    do_name(     getSuperclass_name,                              "getSuperclass")                                       \
+  do_intrinsic(_Class_cast,               java_lang_Class,        Class_cast_name, object_object_signature,      F_R)   \
+   do_name(     Class_cast_name,                                 "cast")                                                \
                                                                                                                         \
   do_intrinsic(_getClassAccessFlags,      sun_reflect_Reflection, getClassAccessFlags_name, class_int_signature, F_SN)  \
    do_name(     getClassAccessFlags_name,                        "getClassAccessFlags")                                 \
--- a/hotspot/src/share/vm/oops/method.cpp	Fri Oct 31 12:01:27 2014 -1000
+++ b/hotspot/src/share/vm/oops/method.cpp	Fri Oct 31 16:51:57 2014 -0700
@@ -1295,6 +1295,10 @@
   vmIntrinsics::ID id = vmIntrinsics::find_id(klass_id, name_id, sig_id, flags);
   if (id != vmIntrinsics::_none) {
     set_intrinsic_id(id);
+    if (id == vmIntrinsics::_Class_cast) {
+      // Even if the intrinsic is rejected, we want to inline this simple method.
+      set_force_inline(true);
+    }
     return;
   }
 
--- a/hotspot/src/share/vm/opto/library_call.cpp	Fri Oct 31 12:01:27 2014 -1000
+++ b/hotspot/src/share/vm/opto/library_call.cpp	Fri Oct 31 16:51:57 2014 -0700
@@ -268,6 +268,7 @@
   bool inline_fp_conversions(vmIntrinsics::ID id);
   bool inline_number_methods(vmIntrinsics::ID id);
   bool inline_reference_get();
+  bool inline_Class_cast();
   bool inline_aescrypt_Block(vmIntrinsics::ID id);
   bool inline_cipherBlockChaining_AESCrypt(vmIntrinsics::ID id);
   Node* inline_cipherBlockChaining_AESCrypt_predicate(bool decrypting);
@@ -869,6 +870,8 @@
 
   case vmIntrinsics::_Reference_get:            return inline_reference_get();
 
+  case vmIntrinsics::_Class_cast:               return inline_Class_cast();
+
   case vmIntrinsics::_aescrypt_encryptBlock:
   case vmIntrinsics::_aescrypt_decryptBlock:    return inline_aescrypt_Block(intrinsic_id());
 
@@ -3546,6 +3549,89 @@
   return true;
 }
 
+//-------------------------inline_Class_cast-------------------
+bool LibraryCallKit::inline_Class_cast() {
+  Node* mirror = argument(0); // Class
+  Node* obj    = argument(1);
+  const TypeInstPtr* mirror_con = _gvn.type(mirror)->isa_instptr();
+  if (mirror_con == NULL) {
+    return false;  // dead path (mirror->is_top()).
+  }
+  if (obj == NULL || obj->is_top()) {
+    return false;  // dead path
+  }
+  const TypeOopPtr* tp = _gvn.type(obj)->isa_oopptr();
+
+  // First, see if Class.cast() can be folded statically.
+  // java_mirror_type() returns non-null for compile-time Class constants.
+  ciType* tm = mirror_con->java_mirror_type();
+  if (tm != NULL && tm->is_klass() &&
+      tp != NULL && tp->klass() != NULL) {
+    if (!tp->klass()->is_loaded()) {
+      // Don't use intrinsic when class is not loaded.
+      return false;
+    } else {
+      int static_res = C->static_subtype_check(tm->as_klass(), tp->klass());
+      if (static_res == Compile::SSC_always_true) {
+        // isInstance() is true - fold the code.
+        set_result(obj);
+        return true;
+      } else if (static_res == Compile::SSC_always_false) {
+        // Don't use intrinsic, have to throw ClassCastException.
+        // If the reference is null, the non-intrinsic bytecode will
+        // be optimized appropriately.
+        return false;
+      }
+    }
+  }
+
+  // Bailout intrinsic and do normal inlining if exception path is frequent.
+  if (too_many_traps(Deoptimization::Reason_intrinsic)) {
+    return false;
+  }
+
+  // Generate dynamic checks.
+  // Class.cast() is java implementation of _checkcast bytecode.
+  // Do checkcast (Parse::do_checkcast()) optimizations here.
+
+  mirror = null_check(mirror);
+  // If mirror is dead, only null-path is taken.
+  if (stopped()) {
+    return true;
+  }
+
+  // Not-subtype or the mirror's klass ptr is NULL (in case it is a primitive).
+  enum { _bad_type_path = 1, _prim_path = 2, PATH_LIMIT };
+  RegionNode* region = new RegionNode(PATH_LIMIT);
+  record_for_igvn(region);
+
+  // Now load the mirror's klass metaobject, and null-check it.
+  // If kls is null, we have a primitive mirror and
+  // nothing is an instance of a primitive type.
+  Node* kls = load_klass_from_mirror(mirror, false, region, _prim_path);
+
+  Node* res = top();
+  if (!stopped()) {
+    Node* bad_type_ctrl = top();
+    // Do checkcast optimizations.
+    res = gen_checkcast(obj, kls, &bad_type_ctrl);
+    region->init_req(_bad_type_path, bad_type_ctrl);
+  }
+  if (region->in(_prim_path) != top() ||
+      region->in(_bad_type_path) != top()) {
+    // Let Interpreter throw ClassCastException.
+    PreserveJVMState pjvms(this);
+    set_control(_gvn.transform(region));
+    uncommon_trap(Deoptimization::Reason_intrinsic,
+                  Deoptimization::Action_maybe_recompile);
+  }
+  if (!stopped()) {
+    set_result(res);
+  }
+  return true;
+}
+
+
 //--------------------------inline_native_subtype_check------------------------
 // This intrinsic takes the JNI calls out of the heart of
 // UnsafeFieldAccessorImpl.set, which improves Field.set, readObject, etc.
--- a/hotspot/src/share/vm/prims/whitebox.cpp	Fri Oct 31 12:01:27 2014 -1000
+++ b/hotspot/src/share/vm/prims/whitebox.cpp	Fri Oct 31 16:51:57 2014 -0700
@@ -803,20 +803,24 @@
   ThreadToNativeFromVM ttn(thread);
   jclass clazz = env->FindClass(vmSymbols::java_lang_Object()->as_C_string());
   CHECK_JNI_EXCEPTION_(env, NULL);
-  result = env->NewObjectArray(2, clazz, NULL);
+  result = env->NewObjectArray(3, clazz, NULL);
   if (result == NULL) {
     return result;
   }
 
-  jobject obj = integerBox(thread, env, code->comp_level());
+  jobject level = integerBox(thread, env, code->comp_level());
   CHECK_JNI_EXCEPTION_(env, NULL);
-  env->SetObjectArrayElement(result, 0, obj);
+  env->SetObjectArrayElement(result, 0, level);
 
   jbyteArray insts = env->NewByteArray(insts_size);
   CHECK_JNI_EXCEPTION_(env, NULL);
   env->SetByteArrayRegion(insts, 0, insts_size, (jbyte*) code->insts_begin());
   env->SetObjectArrayElement(result, 1, insts);
 
+  jobject id = integerBox(thread, env, code->compile_id());
+  CHECK_JNI_EXCEPTION_(env, NULL);
+  env->SetObjectArrayElement(result, 2, id);
+
   return result;
 WB_END
 
--- a/hotspot/test/TEST.groups	Fri Oct 31 12:01:27 2014 -1000
+++ b/hotspot/test/TEST.groups	Fri Oct 31 16:51:57 2014 -0700
@@ -479,6 +479,7 @@
   compiler/intrinsics/mathexact/SubExactILoopDependentTest.java \
   compiler/intrinsics/stringequals/TestStringEqualsBadLength.java \
   compiler/intrinsics/unsafe/UnsafeGetAddressTest.java \
+  compiler/intrinsics/classcast/NullCheckDroppingsTest.java \
   compiler/jsr292/ConcurrentClassLoadingTest.java \
   compiler/jsr292/CreatesInterfaceDotEqualsCallInfo.java \
   compiler/loopopts/TestLogSum.java \
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/hotspot/test/compiler/intrinsics/classcast/NullCheckDroppingsTest.java	Fri Oct 31 16:51:57 2014 -0700
@@ -0,0 +1,346 @@
+/*
+ * Copyright (c) 2014, Oracle and/or its affiliates. All rights reserved.
+ * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
+ *
+ * This code is free software; you can redistribute it and/or modify it
+ * under the terms of the GNU General Public License version 2 only, as
+ * published by the Free Software Foundation.
+ *
+ * This code is distributed in the hope that it will be useful, but WITHOUT
+ * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
+ * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
+ * version 2 for more details (a copy is included in the LICENSE file that
+ * accompanied this code).
+ *
+ * You should have received a copy of the GNU General Public License version
+ * 2 along with this work; if not, write to the Free Software Foundation,
+ * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
+ *
+ * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
+ * or visit www.oracle.com if you need additional information or have any
+ * questions.
+ */
+
+/*
+ * @test NullCheckDroppingsTest
+ * @bug 8054492
+ * @summary "Casting can result in redundant null checks in generated code"
+ * @library /testlibrary /testlibrary/whitebox /testlibrary/com/oracle/java/testlibrary
+ * @build NullCheckDroppingsTest
+ * @run main ClassFileInstaller sun.hotspot.WhiteBox
+ *                              sun.hotspot.WhiteBox$WhiteBoxPermission
+ * @run main ClassFileInstaller com.oracle.java.testlibrary.Platform
+ * @run main/othervm -Xbootclasspath/a:. -XX:+IgnoreUnrecognizedVMOptions -XX:+UnlockDiagnosticVMOptions -XX:+WhiteBoxAPI
+ *                   -Xmixed -XX:-BackgroundCompilation -XX:-TieredCompilation -XX:CompileThreshold=1000
+ *                   -XX:CompileCommand=exclude,NullCheckDroppingsTest::runTest NullCheckDroppingsTest
+ */
+
+import sun.hotspot.WhiteBox;
+import sun.hotspot.code.NMethod;
+import com.oracle.java.testlibrary.Platform;
+
+import java.lang.reflect.Method;
+import java.lang.invoke.MethodHandle;
+import java.lang.invoke.MethodHandles;
+import java.lang.invoke.MethodType;
+import java.util.function.BiFunction;
+
+public class NullCheckDroppingsTest {
+
+    private static final WhiteBox WHITE_BOX = WhiteBox.getWhiteBox();
+
+    static final BiFunction<Class, Object, Object> fCast = (c, o) -> c.cast(o);
+
+    static final MethodHandle SET_SSINK;
+    static final MethodHandle MH_CAST;
+
+    static {
+        try {
+            SET_SSINK = MethodHandles.lookup().findSetter(NullCheckDroppingsTest.class, "ssink", String.class);
+            MH_CAST = MethodHandles.lookup().findVirtual(Class.class,
+                                                         "cast",
+                                                         MethodType.methodType(Object.class, Object.class));
+        }
+        catch (Exception e) {
+            throw new Error(e);
+        }
+    }
+
+    static volatile String svalue = "A";
+    static volatile String snull = null;
+    static volatile Integer iobj = new Integer(0);
+    static volatile int[] arr = new int[2];
+    static volatile Class objClass = String.class;
+    static volatile Class nullClass = null;
+
+    String  ssink;
+    Integer isink;
+    int[]   asink;
+
+    public static void main(String[] args) throws Exception {
+
+        // Only test C2 in Server VM
+        if (!Platform.isServer()) {
+            return;
+        }
+        // Make sure background compilation is disabled
+        if (WHITE_BOX.getBooleanVMFlag("BackgroundCompilation")) {
+            throw new AssertionError("Background compilation enabled");
+        }
+        // Make sure Tiered compilation is disabled
+        if (WHITE_BOX.getBooleanVMFlag("TieredCompilation")) {
+            throw new AssertionError("Tiered compilation enabled");
+        }
+
+        Method methodClassCast = NullCheckDroppingsTest.class.getDeclaredMethod("testClassCast", String.class);
+        Method methodMHCast    = NullCheckDroppingsTest.class.getDeclaredMethod("testMHCast",    String.class);
+        Method methodMHSetter  = NullCheckDroppingsTest.class.getDeclaredMethod("testMHSetter",  String.class);
+        Method methodFunction  = NullCheckDroppingsTest.class.getDeclaredMethod("testFunction",  String.class);
+
+        NullCheckDroppingsTest t = new NullCheckDroppingsTest();
+        t.runTest(methodClassCast, false);
+        t.runTest(methodMHCast,    false);
+        t.runTest(methodMHSetter,  false);
+        t.runTest(methodFunction,  false);
+
+        // Edge cases
+        Method methodClassCastNull = NullCheckDroppingsTest.class.getDeclaredMethod("testClassCastNull", String.class);
+        Method methodNullClassCast = NullCheckDroppingsTest.class.getDeclaredMethod("testNullClassCast", String.class);
+        Method methodClassCastObj  = NullCheckDroppingsTest.class.getDeclaredMethod("testClassCastObj",  Object.class);
+        Method methodObjClassCast  = NullCheckDroppingsTest.class.getDeclaredMethod("testObjClassCast",  String.class);
+        Method methodVarClassCast  = NullCheckDroppingsTest.class.getDeclaredMethod("testVarClassCast",  String.class);
+        Method methodClassCastInt  = NullCheckDroppingsTest.class.getDeclaredMethod("testClassCastInt",  Object.class);
+        Method methodIntClassCast  = NullCheckDroppingsTest.class.getDeclaredMethod("testIntClassCast",  Object.class);
+        Method methodClassCastint  = NullCheckDroppingsTest.class.getDeclaredMethod("testClassCastint",  Object.class);
+        Method methodintClassCast  = NullCheckDroppingsTest.class.getDeclaredMethod("testintClassCast",  Object.class);
+        Method methodClassCastPrim = NullCheckDroppingsTest.class.getDeclaredMethod("testClassCastPrim", Object.class);
+        Method methodPrimClassCast = NullCheckDroppingsTest.class.getDeclaredMethod("testPrimClassCast", Object.class);
+
+        t.runTest(methodClassCastNull, false);
+        t.runTest(methodNullClassCast, false);
+        t.runTest(methodClassCastObj,  false);
+        t.runTest(methodObjClassCast,  true);
+        t.runTest(methodVarClassCast,  true);
+        t.runTest(methodClassCastInt,  false);
+        t.runTest(methodIntClassCast,  true);
+        t.runTest(methodClassCastint,  false);
+        t.runTest(methodintClassCast,  false);
+        t.runTest(methodClassCastPrim, false);
+        t.runTest(methodPrimClassCast, true);
+    }
+
+    void testClassCast(String s) {
+        try {
+            ssink = String.class.cast(s);
+        } catch (Throwable t) {
+            throw new Error(t);
+        }
+    }
+
+    void testClassCastNull(String s) {
+        try {
+            ssink = String.class.cast(null);
+        } catch (Throwable t) {
+            throw new Error(t);
+        }
+    }
+
+    void testNullClassCast(String s) {
+        try {
+            ssink = (String)nullClass.cast(s);
+            throw new AssertionError("NullPointerException is not thrown");
+        } catch (NullPointerException t) {
+            // Ignore NullPointerException
+        } catch (Throwable t) {
+            throw new Error(t);
+        }
+    }
+
+    void testClassCastObj(Object s) {
+        try {
+            ssink = String.class.cast(s);
+        } catch (Throwable t) {
+            throw new Error(t);
+        }
+    }
+
+    void testObjClassCast(String s) {
+        try {
+            ssink = (String)objClass.cast(s);
+        } catch (Throwable t) {
+            throw new Error(t);
+        }
+    }
+
+    void testVarClassCast(String s) {
+        Class cl = (s == null) ? null : String.class;
+        try {
+            ssink = (String)cl.cast(svalue);
+            if (s == null) {
+                throw new AssertionError("NullPointerException is not thrown");
+            }
+        } catch (NullPointerException t) {
+            // Ignore NullPointerException
+        } catch (Throwable t) {
+            throw new Error(t);
+        }
+    }
+
+    void testClassCastInt(Object s) {
+        try {
+            ssink = String.class.cast(iobj);
+            throw new AssertionError("ClassCastException is not thrown");
+        } catch (ClassCastException t) {
+            // Ignore ClassCastException: Cannot cast java.lang.Integer to java.lang.String
+        } catch (Throwable t) {
+            throw new Error(t);
+        }
+    }
+
+    void testIntClassCast(Object s) {
+        try {
+            isink = Integer.class.cast(s);
+            if (s != null) {
+                throw new AssertionError("ClassCastException is not thrown");
+            }
+        } catch (ClassCastException t) {
+            // Ignore ClassCastException: Cannot cast java.lang.String to java.lang.Integer
+        } catch (Throwable t) {
+            throw new Error(t);
+        }
+    }
+
+    void testClassCastint(Object s) {
+        try {
+            ssink = String.class.cast(45);
+            throw new AssertionError("ClassCastException is not thrown");
+        } catch (ClassCastException t) {
+            // Ignore ClassCastException: Cannot cast java.lang.Integer to java.lang.String
+        } catch (Throwable t) {
+            throw new Error(t);
+        }
+    }
+
+    void testintClassCast(Object s) {
+        try {
+            isink = int.class.cast(s);
+            if (s != null) {
+                throw new AssertionError("ClassCastException is not thrown");
+            }
+        } catch (ClassCastException t) {
+            // Ignore ClassCastException: Cannot cast java.lang.String to java.lang.Integer
+        } catch (Throwable t) {
+            throw new Error(t);
+        }
+    }
+
+    void testClassCastPrim(Object s) {
+        try {
+            ssink = String.class.cast(arr);
+            throw new AssertionError("ClassCastException is not thrown");
+        } catch (ClassCastException t) {
+            // Ignore ClassCastException: Cannot cast [I to java.lang.String
+        } catch (Throwable t) {
+            throw new Error(t);
+        }
+    }
+
+    void testPrimClassCast(Object s) {
+        try {
+            asink = int[].class.cast(s);
+            if (s != null) {
+                throw new AssertionError("ClassCastException is not thrown");
+            }
+        } catch (ClassCastException t) {
+            // Ignore ClassCastException: Cannot cast java.lang.String to [I
+        } catch (Throwable t) {
+            throw new Error(t);
+        }
+    }
+
+    void testMHCast(String s) {
+        try {
+            ssink = (String) (Object) MH_CAST.invokeExact(String.class, (Object) s);
+        } catch (Throwable t) {
+            throw new Error(t);
+        }
+    }
+
+    void testMHSetter(String s) {
+        try {
+            SET_SSINK.invokeExact(this, s);
+        } catch (Throwable t) {
+            throw new Error(t);
+        }
+    }
+
+    void testFunction(String s) {
+        try {
+            ssink = (String) fCast.apply(String.class, s);
+        } catch (Throwable t) {
+            throw new Error(t);
+        }
+    }
+
+    void runTest(Method method, boolean deopt) {
+        if (method == null) {
+            throw new AssertionError("method was not found");
+        }
+        // Ensure method is compiled
+        WHITE_BOX.testSetDontInlineMethod(method, true);
+        for (int i = 0; i < 3000; i++) {
+            try {
+                method.invoke(this, svalue);
+            } catch (Exception e) {
+                throw new Error("Unexpected exception: ", e);
+            }
+        }
+        NMethod nm = getNMethod(method);
+
+        // Passing null should cause a de-optimization
+        // if method is compiled with a null-check.
+        try {
+            method.invoke(this, snull);
+        } catch (Exception e) {
+            throw new Error("Unexpected exception: ", e);
+        }
+        checkDeoptimization(method, nm, deopt);
+    }
+
+    static NMethod getNMethod(Method test) {
+        // Because background compilation is disabled, method should now be compiled
+        if (!WHITE_BOX.isMethodCompiled(test)) {
+            throw new AssertionError(test + " not compiled");
+        }
+
+        NMethod nm = NMethod.get(test, false); // not OSR nmethod
+        if (nm == null) {
+            throw new AssertionError(test + " missing nmethod?");
+        }
+        if (nm.comp_level != 4) {
+            throw new AssertionError(test + " compiled by not C2: " + nm);
+        }
+        return nm;
+    }
+
+    static void checkDeoptimization(Method method, NMethod nmOrig, boolean deopt) {
+        // Check deoptimization event (intrinsic Class.cast() works).
+        if (WHITE_BOX.isMethodCompiled(method) == deopt) {
+            throw new AssertionError(method + " was" + (deopt ? " not" : "") + " deoptimized");
+        }
+        if (deopt) {
+            return;
+        }
+        // Ensure no recompilation when no deoptimization is expected.
+        NMethod nm = NMethod.get(method, false); // not OSR nmethod
+        if (nm == null) {
+            throw new AssertionError(method + " missing nmethod?");
+        }
+        if (nm.comp_level != 4) {
+            throw new AssertionError(method + " compiled by not C2: " + nm);
+        }
+        if (nm.compile_id != nmOrig.compile_id) {
+            throw new AssertionError(method + " was recompiled: old nmethod=" + nmOrig + ", new nmethod=" + nm);
+        }
+    }
+}
--- a/hotspot/test/testlibrary/whitebox/sun/hotspot/code/NMethod.java	Fri Oct 31 12:01:27 2014 -1000
+++ b/hotspot/test/testlibrary/whitebox/sun/hotspot/code/NMethod.java	Fri Oct 31 16:51:57 2014 -0700
@@ -34,18 +34,21 @@
     return obj == null ? null : new NMethod(obj);
   }
   private NMethod(Object[] obj) {
-    assert obj.length == 2;
+    assert obj.length == 3;
     comp_level = (Integer) obj[0];
     insts = (byte[]) obj[1];
+    compile_id = (Integer) obj[2];
   }
   public byte[] insts;
   public int comp_level;
+  public int compile_id;
 
   @Override
   public String toString() {
     return "NMethod{" +
         "insts=" + insts +
         ", comp_level=" + comp_level +
+        ", compile_id=" + compile_id +
         '}';
   }
 }