hotspot/src/share/vm/opto/library_call.cpp
changeset 14132 3c1437abcefd
parent 14126 afb8a3a86f1f
child 14621 fd9265ab0f67
--- a/hotspot/src/share/vm/opto/library_call.cpp	Tue Oct 23 13:06:37 2012 -0700
+++ b/hotspot/src/share/vm/opto/library_call.cpp	Wed Oct 24 14:33:22 2012 -0700
@@ -44,18 +44,22 @@
  public:
  private:
   bool             _is_virtual;
+  bool             _is_predicted;
   vmIntrinsics::ID _intrinsic_id;
 
  public:
-  LibraryIntrinsic(ciMethod* m, bool is_virtual, vmIntrinsics::ID id)
+  LibraryIntrinsic(ciMethod* m, bool is_virtual, bool is_predicted, vmIntrinsics::ID id)
     : InlineCallGenerator(m),
       _is_virtual(is_virtual),
+      _is_predicted(is_predicted),
       _intrinsic_id(id)
   {
   }
   virtual bool is_intrinsic() const { return true; }
   virtual bool is_virtual()   const { return _is_virtual; }
+  virtual bool is_predicted()   const { return _is_predicted; }
   virtual JVMState* generate(JVMState* jvms);
+  virtual Node* generate_predicate(JVMState* jvms);
   vmIntrinsics::ID intrinsic_id() const { return _intrinsic_id; }
 };
 
@@ -83,6 +87,7 @@
   int               arg_size()  const    { return callee()->arg_size(); }
 
   bool try_to_inline();
+  Node* try_to_predicate();
 
   // Helper functions to inline natives
   void push_result(RegionNode* region, PhiNode* value);
@@ -148,6 +153,7 @@
   CallJavaNode* generate_method_call_virtual(vmIntrinsics::ID method_id) {
     return generate_method_call(method_id, true, false);
   }
+  Node * load_field_from_object(Node * fromObj, const char * fieldName, const char * fieldTypeString, bool is_exact, bool is_static);
 
   Node* make_string_method_node(int opcode, Node* str1_start, Node* cnt1, Node* str2_start, Node* cnt2);
   Node* make_string_method_node(int opcode, Node* str1, Node* str2);
@@ -253,6 +259,10 @@
   bool inline_reverseBytes(vmIntrinsics::ID id);
 
   bool inline_reference_get();
+  bool inline_aescrypt_Block(vmIntrinsics::ID id);
+  bool inline_cipherBlockChaining_AESCrypt(vmIntrinsics::ID id);
+  Node* inline_cipherBlockChaining_AESCrypt_predicate(bool decrypting);
+  Node* get_key_start_from_aescrypt_object(Node* aescrypt_object);
 };
 
 
@@ -306,6 +316,8 @@
     }
   }
 
+  bool is_predicted = false;
+
   switch (id) {
   case vmIntrinsics::_compareTo:
     if (!SpecialStringCompareTo)  return NULL;
@@ -413,6 +425,18 @@
     break;
 #endif
 
+  case vmIntrinsics::_aescrypt_encryptBlock:
+  case vmIntrinsics::_aescrypt_decryptBlock:
+    if (!UseAESIntrinsics) return NULL;
+    break;
+
+  case vmIntrinsics::_cipherBlockChaining_encryptAESCrypt:
+  case vmIntrinsics::_cipherBlockChaining_decryptAESCrypt:
+    if (!UseAESIntrinsics) return NULL;
+    // these two require the predicated logic
+    is_predicted = true;
+    break;
+
  default:
     assert(id <= vmIntrinsics::LAST_COMPILER_INLINE, "caller responsibility");
     assert(id != vmIntrinsics::_Object_init && id != vmIntrinsics::_invoke, "enum out of order?");
@@ -444,7 +468,7 @@
     if (!InlineUnsafeOps)  return NULL;
   }
 
-  return new LibraryIntrinsic(m, is_virtual, (vmIntrinsics::ID) id);
+  return new LibraryIntrinsic(m, is_virtual, is_predicted, (vmIntrinsics::ID) id);
 }
 
 //----------------------register_library_intrinsics-----------------------
@@ -496,6 +520,47 @@
   return NULL;
 }
 
+Node* LibraryIntrinsic::generate_predicate(JVMState* jvms) {
+  LibraryCallKit kit(jvms, this);
+  Compile* C = kit.C;
+  int nodes = C->unique();
+#ifndef PRODUCT
+  assert(is_predicted(), "sanity");
+  if ((PrintIntrinsics || PrintInlining NOT_PRODUCT( || PrintOptoInlining) ) && Verbose) {
+    char buf[1000];
+    const char* str = vmIntrinsics::short_name_as_C_string(intrinsic_id(), buf, sizeof(buf));
+    tty->print_cr("Predicate for intrinsic %s", str);
+  }
+#endif
+
+  Node* slow_ctl = kit.try_to_predicate();
+  if (!kit.failing()) {
+    if (C->log()) {
+      C->log()->elem("predicate_intrinsic id='%s'%s nodes='%d'",
+                     vmIntrinsics::name_at(intrinsic_id()),
+                     (is_virtual() ? " virtual='1'" : ""),
+                     C->unique() - nodes);
+    }
+    return slow_ctl; // Could be NULL if the check folds.
+  }
+
+  // The intrinsic bailed out
+  if (PrintIntrinsics || PrintInlining NOT_PRODUCT( || PrintOptoInlining) ) {
+    if (jvms->has_method()) {
+      // Not a root compile.
+      const char* msg = "failed to generate predicate for intrinsic";
+      CompileTask::print_inlining(kit.callee(), jvms->depth() - 1, kit.bci(), msg);
+    } else {
+      // Root compile
+      tty->print("Did not generate predicate for intrinsic %s%s at bci:%d in",
+               vmIntrinsics::name_at(intrinsic_id()),
+               (is_virtual() ? " (virtual)" : ""), kit.bci());
+    }
+  }
+  C->gather_intrinsic_statistics(intrinsic_id(), is_virtual(), Compile::_intrinsic_failed);
+  return NULL;
+}
+
 bool LibraryCallKit::try_to_inline() {
   // Handle symbolic names for otherwise undistinguished boolean switches:
   const bool is_store       = true;
@@ -767,6 +832,14 @@
   case vmIntrinsics::_Reference_get:
     return inline_reference_get();
 
+  case vmIntrinsics::_aescrypt_encryptBlock:
+  case vmIntrinsics::_aescrypt_decryptBlock:
+    return inline_aescrypt_Block(intrinsic_id());
+
+  case vmIntrinsics::_cipherBlockChaining_encryptAESCrypt:
+  case vmIntrinsics::_cipherBlockChaining_decryptAESCrypt:
+    return inline_cipherBlockChaining_AESCrypt(intrinsic_id());
+
   default:
     // If you get here, it may be that someone has added a new intrinsic
     // to the list in vmSymbols.hpp without implementing it here.
@@ -780,6 +853,36 @@
   }
 }
 
+Node* LibraryCallKit::try_to_predicate() {
+  if (!jvms()->has_method()) {
+    // Root JVMState has a null method.
+    assert(map()->memory()->Opcode() == Op_Parm, "");
+    // Insert the memory aliasing node
+    set_all_memory(reset_memory());
+  }
+  assert(merged_memory(), "");
+
+  switch (intrinsic_id()) {
+  case vmIntrinsics::_cipherBlockChaining_encryptAESCrypt:
+    return inline_cipherBlockChaining_AESCrypt_predicate(false);
+  case vmIntrinsics::_cipherBlockChaining_decryptAESCrypt:
+    return inline_cipherBlockChaining_AESCrypt_predicate(true);
+
+  default:
+    // If you get here, it may be that someone has added a new intrinsic
+    // to the list in vmSymbols.hpp without implementing it here.
+#ifndef PRODUCT
+    if ((PrintMiscellaneous && (Verbose || WizardMode)) || PrintOpto) {
+      tty->print_cr("*** Warning: Unimplemented predicate for intrinsic %s(%d)",
+                    vmIntrinsics::name_at(intrinsic_id()), intrinsic_id());
+    }
+#endif
+    Node* slow_ctl = control();
+    set_control(top()); // No fast path instrinsic
+    return slow_ctl;
+  }
+}
+
 //------------------------------push_result------------------------------
 // Helper function for finishing intrinsics.
 void LibraryCallKit::push_result(RegionNode* region, PhiNode* value) {
@@ -5613,3 +5716,265 @@
   push(result);
   return true;
 }
+
+
+Node * LibraryCallKit::load_field_from_object(Node * fromObj, const char * fieldName, const char * fieldTypeString,
+                                              bool is_exact=true, bool is_static=false) {
+
+  const TypeInstPtr* tinst = _gvn.type(fromObj)->isa_instptr();
+  assert(tinst != NULL, "obj is null");
+  assert(tinst->klass()->is_loaded(), "obj is not loaded");
+  assert(!is_exact || tinst->klass_is_exact(), "klass not exact");
+
+  ciField* field = tinst->klass()->as_instance_klass()->get_field_by_name(ciSymbol::make(fieldName),
+                                                                          ciSymbol::make(fieldTypeString),
+                                                                          is_static);
+  if (field == NULL) return (Node *) NULL;
+  assert (field != NULL, "undefined field");
+
+  // Next code  copied from Parse::do_get_xxx():
+
+  // Compute address and memory type.
+  int offset  = field->offset_in_bytes();
+  bool is_vol = field->is_volatile();
+  ciType* field_klass = field->type();
+  assert(field_klass->is_loaded(), "should be loaded");
+  const TypePtr* adr_type = C->alias_type(field)->adr_type();
+  Node *adr = basic_plus_adr(fromObj, fromObj, offset);
+  BasicType bt = field->layout_type();
+
+  // Build the resultant type of the load
+  const Type *type = TypeOopPtr::make_from_klass(field_klass->as_klass());
+
+  // Build the load.
+  Node* loadedField = make_load(NULL, adr, type, bt, adr_type, is_vol);
+  return loadedField;
+}
+
+
+//------------------------------inline_aescrypt_Block-----------------------
+bool LibraryCallKit::inline_aescrypt_Block(vmIntrinsics::ID id) {
+  address stubAddr;
+  const char *stubName;
+  assert(UseAES, "need AES instruction support");
+
+  switch(id) {
+  case vmIntrinsics::_aescrypt_encryptBlock:
+    stubAddr = StubRoutines::aescrypt_encryptBlock();
+    stubName = "aescrypt_encryptBlock";
+    break;
+  case vmIntrinsics::_aescrypt_decryptBlock:
+    stubAddr = StubRoutines::aescrypt_decryptBlock();
+    stubName = "aescrypt_decryptBlock";
+    break;
+  }
+  if (stubAddr == NULL) return false;
+
+  // Restore the stack and pop off the arguments.
+  int nargs = 5;  // this + 2 oop/offset combos
+  assert(callee()->signature()->size() == nargs-1, "encryptBlock has 4 arguments");
+
+  Node *aescrypt_object  = argument(0);
+  Node *src         = argument(1);
+  Node *src_offset  = argument(2);
+  Node *dest        = argument(3);
+  Node *dest_offset = argument(4);
+
+  // (1) src and dest are arrays.
+  const Type* src_type = src->Value(&_gvn);
+  const Type* dest_type = dest->Value(&_gvn);
+  const TypeAryPtr* top_src = src_type->isa_aryptr();
+  const TypeAryPtr* top_dest = dest_type->isa_aryptr();
+  assert (top_src  != NULL && top_src->klass()  != NULL &&  top_dest != NULL && top_dest->klass() != NULL, "args are strange");
+
+  // for the quick and dirty code we will skip all the checks.
+  // we are just trying to get the call to be generated.
+  Node* src_start  = src;
+  Node* dest_start = dest;
+  if (src_offset != NULL || dest_offset != NULL) {
+    assert(src_offset != NULL && dest_offset != NULL, "");
+    src_start  = array_element_address(src,  src_offset,  T_BYTE);
+    dest_start = array_element_address(dest, dest_offset, T_BYTE);
+  }
+
+  // now need to get the start of its expanded key array
+  // this requires a newer class file that has this array as littleEndian ints, otherwise we revert to java
+  Node* k_start = get_key_start_from_aescrypt_object(aescrypt_object);
+  if (k_start == NULL) return false;
+
+  // Call the stub.
+  make_runtime_call(RC_LEAF|RC_NO_FP, OptoRuntime::aescrypt_block_Type(),
+                    stubAddr, stubName, TypePtr::BOTTOM,
+                    src_start, dest_start, k_start);
+
+  return true;
+}
+
+//------------------------------inline_cipherBlockChaining_AESCrypt-----------------------
+bool LibraryCallKit::inline_cipherBlockChaining_AESCrypt(vmIntrinsics::ID id) {
+  address stubAddr;
+  const char *stubName;
+
+  assert(UseAES, "need AES instruction support");
+
+  switch(id) {
+  case vmIntrinsics::_cipherBlockChaining_encryptAESCrypt:
+    stubAddr = StubRoutines::cipherBlockChaining_encryptAESCrypt();
+    stubName = "cipherBlockChaining_encryptAESCrypt";
+    break;
+  case vmIntrinsics::_cipherBlockChaining_decryptAESCrypt:
+    stubAddr = StubRoutines::cipherBlockChaining_decryptAESCrypt();
+    stubName = "cipherBlockChaining_decryptAESCrypt";
+    break;
+  }
+  if (stubAddr == NULL) return false;
+
+
+  // Restore the stack and pop off the arguments.
+  int nargs = 6;  // this + oop/offset + len + oop/offset
+  assert(callee()->signature()->size() == nargs-1, "wrong number of arguments");
+  Node *cipherBlockChaining_object  = argument(0);
+  Node *src         = argument(1);
+  Node *src_offset  = argument(2);
+  Node *len         = argument(3);
+  Node *dest        = argument(4);
+  Node *dest_offset = argument(5);
+
+  // (1) src and dest are arrays.
+  const Type* src_type = src->Value(&_gvn);
+  const Type* dest_type = dest->Value(&_gvn);
+  const TypeAryPtr* top_src = src_type->isa_aryptr();
+  const TypeAryPtr* top_dest = dest_type->isa_aryptr();
+  assert (top_src  != NULL && top_src->klass()  != NULL
+          &&  top_dest != NULL && top_dest->klass() != NULL, "args are strange");
+
+  // checks are the responsibility of the caller
+  Node* src_start  = src;
+  Node* dest_start = dest;
+  if (src_offset != NULL || dest_offset != NULL) {
+    assert(src_offset != NULL && dest_offset != NULL, "");
+    src_start  = array_element_address(src,  src_offset,  T_BYTE);
+    dest_start = array_element_address(dest, dest_offset, T_BYTE);
+  }
+
+  // if we are in this set of code, we "know" the embeddedCipher is an AESCrypt object
+  // (because of the predicated logic executed earlier).
+  // so we cast it here safely.
+  // this requires a newer class file that has this array as littleEndian ints, otherwise we revert to java
+
+  Node* embeddedCipherObj = load_field_from_object(cipherBlockChaining_object, "embeddedCipher", "Lcom/sun/crypto/provider/SymmetricCipher;", /*is_exact*/ false);
+  if (embeddedCipherObj == NULL) return false;
+
+  // cast it to what we know it will be at runtime
+  const TypeInstPtr* tinst = _gvn.type(cipherBlockChaining_object)->isa_instptr();
+  assert(tinst != NULL, "CBC obj is null");
+  assert(tinst->klass()->is_loaded(), "CBC obj is not loaded");
+  ciKlass* klass_AESCrypt = tinst->klass()->as_instance_klass()->find_klass(ciSymbol::make("com/sun/crypto/provider/AESCrypt"));
+  if (!klass_AESCrypt->is_loaded()) return false;
+
+  ciInstanceKlass* instklass_AESCrypt = klass_AESCrypt->as_instance_klass();
+  const TypeKlassPtr* aklass = TypeKlassPtr::make(instklass_AESCrypt);
+  const TypeOopPtr* xtype = aklass->as_instance_type();
+  Node* aescrypt_object = new(C) CheckCastPPNode(control(), embeddedCipherObj, xtype);
+  aescrypt_object = _gvn.transform(aescrypt_object);
+
+  // we need to get the start of the aescrypt_object's expanded key array
+  Node* k_start = get_key_start_from_aescrypt_object(aescrypt_object);
+  if (k_start == NULL) return false;
+
+  // similarly, get the start address of the r vector
+  Node* objRvec = load_field_from_object(cipherBlockChaining_object, "r", "[B", /*is_exact*/ false);
+  if (objRvec == NULL) return false;
+  Node* r_start = array_element_address(objRvec, intcon(0), T_BYTE);
+
+  // Call the stub, passing src_start, dest_start, k_start, r_start and src_len
+  make_runtime_call(RC_LEAF|RC_NO_FP,
+                    OptoRuntime::cipherBlockChaining_aescrypt_Type(),
+                    stubAddr, stubName, TypePtr::BOTTOM,
+                    src_start, dest_start, k_start, r_start, len);
+
+  // return is void so no result needs to be pushed
+
+  return true;
+}
+
+//------------------------------get_key_start_from_aescrypt_object-----------------------
+Node * LibraryCallKit::get_key_start_from_aescrypt_object(Node *aescrypt_object) {
+  Node* objAESCryptKey = load_field_from_object(aescrypt_object, "K", "[I", /*is_exact*/ false);
+  assert (objAESCryptKey != NULL, "wrong version of com.sun.crypto.provider.AESCrypt");
+  if (objAESCryptKey == NULL) return (Node *) NULL;
+
+  // now have the array, need to get the start address of the K array
+  Node* k_start = array_element_address(objAESCryptKey, intcon(0), T_INT);
+  return k_start;
+}
+
+//----------------------------inline_cipherBlockChaining_AESCrypt_predicate----------------------------
+// Return node representing slow path of predicate check.
+// the pseudo code we want to emulate with this predicate is:
+// for encryption:
+//    if (embeddedCipherObj instanceof AESCrypt) do_intrinsic, else do_javapath
+// for decryption:
+//    if ((embeddedCipherObj instanceof AESCrypt) && (cipher!=plain)) do_intrinsic, else do_javapath
+//    note cipher==plain is more conservative than the original java code but that's OK
+//
+Node* LibraryCallKit::inline_cipherBlockChaining_AESCrypt_predicate(bool decrypting) {
+  // First, check receiver for NULL since it is virtual method.
+  int nargs = arg_size();
+  Node* objCBC = argument(0);
+  _sp += nargs;
+  objCBC = do_null_check(objCBC, T_OBJECT);
+  _sp -= nargs;
+
+  if (stopped()) return NULL; // Always NULL
+
+  // Load embeddedCipher field of CipherBlockChaining object.
+  Node* embeddedCipherObj = load_field_from_object(objCBC, "embeddedCipher", "Lcom/sun/crypto/provider/SymmetricCipher;", /*is_exact*/ false);
+
+  // get AESCrypt klass for instanceOf check
+  // AESCrypt might not be loaded yet if some other SymmetricCipher got us to this compile point
+  // will have same classloader as CipherBlockChaining object
+  const TypeInstPtr* tinst = _gvn.type(objCBC)->isa_instptr();
+  assert(tinst != NULL, "CBCobj is null");
+  assert(tinst->klass()->is_loaded(), "CBCobj is not loaded");
+
+  // we want to do an instanceof comparison against the AESCrypt class
+  ciKlass* klass_AESCrypt = tinst->klass()->as_instance_klass()->find_klass(ciSymbol::make("com/sun/crypto/provider/AESCrypt"));
+  if (!klass_AESCrypt->is_loaded()) {
+    // if AESCrypt is not even loaded, we never take the intrinsic fast path
+    Node* ctrl = control();
+    set_control(top()); // no regular fast path
+    return ctrl;
+  }
+  ciInstanceKlass* instklass_AESCrypt = klass_AESCrypt->as_instance_klass();
+
+  _sp += nargs;          // gen_instanceof might do an uncommon trap
+  Node* instof = gen_instanceof(embeddedCipherObj, makecon(TypeKlassPtr::make(instklass_AESCrypt)));
+  _sp -= nargs;
+  Node* cmp_instof  = _gvn.transform(new (C) CmpINode(instof, intcon(1)));
+  Node* bool_instof  = _gvn.transform(new (C) BoolNode(cmp_instof, BoolTest::ne));
+
+  Node* instof_false = generate_guard(bool_instof, NULL, PROB_MIN);
+
+  // for encryption, we are done
+  if (!decrypting)
+    return instof_false;  // even if it is NULL
+
+  // for decryption, we need to add a further check to avoid
+  // taking the intrinsic path when cipher and plain are the same
+  // see the original java code for why.
+  RegionNode* region = new(C) RegionNode(3);
+  region->init_req(1, instof_false);
+  Node* src = argument(1);
+  Node *dest = argument(4);
+  Node* cmp_src_dest = _gvn.transform(new (C) CmpPNode(src, dest));
+  Node* bool_src_dest = _gvn.transform(new (C) BoolNode(cmp_src_dest, BoolTest::eq));
+  Node* src_dest_conjoint = generate_guard(bool_src_dest, NULL, PROB_MIN);
+  region->init_req(2, src_dest_conjoint);
+
+  record_for_igvn(region);
+  return _gvn.transform(region);
+
+}
+
+