8190800: Support vectorization of Math.sqrt() on floats
authorrlupusoru
Wed, 22 Nov 2017 14:43:37 +0300
changeset 48089 22c9856fc2c2
parent 48088 2cd1c2b03782
child 48090 cce885f4baab
8190800: Support vectorization of Math.sqrt() on floats Reviewed-by: vlivanov, kvn
src/hotspot/cpu/x86/assembler_x86.cpp
src/hotspot/cpu/x86/assembler_x86.hpp
src/hotspot/cpu/x86/x86.ad
src/hotspot/share/adlc/formssel.cpp
src/hotspot/share/opto/classes.hpp
src/hotspot/share/opto/convertnode.cpp
src/hotspot/share/opto/convertnode.hpp
src/hotspot/share/opto/subnode.cpp
src/hotspot/share/opto/subnode.hpp
src/hotspot/share/opto/superword.cpp
src/hotspot/share/opto/vectornode.cpp
src/hotspot/share/opto/vectornode.hpp
src/hotspot/share/runtime/vmStructs.cpp
--- a/src/hotspot/cpu/x86/assembler_x86.cpp	Wed Nov 22 01:12:23 2017 -0800
+++ b/src/hotspot/cpu/x86/assembler_x86.cpp	Wed Nov 22 14:43:37 2017 +0300
@@ -5203,6 +5203,24 @@
   emit_operand(dst, src);
 }
 
+void Assembler::vsqrtps(XMMRegister dst, XMMRegister src, int vector_len) {
+  assert(VM_Version::supports_avx(), "");
+  InstructionAttr attributes(vector_len, /* vex_w */ false, /* legacy_mode */ false, /* no_mask_reg */ false, /* uses_vl */ true);
+  int encode = vex_prefix_and_encode(dst->encoding(), 0, src->encoding(), VEX_SIMD_NONE, VEX_OPCODE_0F, &attributes);
+  emit_int8(0x51);
+  emit_int8((unsigned char)(0xC0 | encode));
+}
+
+void Assembler::vsqrtps(XMMRegister dst, Address src, int vector_len) {
+  assert(VM_Version::supports_avx(), "");
+  InstructionMark im(this);
+  InstructionAttr attributes(vector_len, /* vex_w */ false, /* legacy_mode */ false, /* no_mask_reg */ false, /* uses_vl */ true);
+  attributes.set_address_attributes(/* tuple_type */ EVEX_FV, /* input_size_in_bits */ EVEX_64bit);
+  vex_prefix(src, 0, dst->encoding(), VEX_SIMD_NONE, VEX_OPCODE_0F, &attributes);
+  emit_int8(0x51);
+  emit_operand(dst, src);
+}
+
 void Assembler::andpd(XMMRegister dst, XMMRegister src) {
   NOT_LP64(assert(VM_Version::supports_sse2(), ""));
   InstructionAttr attributes(AVX_128bit, /* rex_w */ !_legacy_mode_dq, /* legacy_mode */ _legacy_mode_dq, /* no_mask_reg */ false, /* uses_vl */ true);
--- a/src/hotspot/cpu/x86/assembler_x86.hpp	Wed Nov 22 01:12:23 2017 -0800
+++ b/src/hotspot/cpu/x86/assembler_x86.hpp	Wed Nov 22 14:43:37 2017 +0300
@@ -1919,9 +1919,11 @@
   void vdivpd(XMMRegister dst, XMMRegister nds, Address src, int vector_len);
   void vdivps(XMMRegister dst, XMMRegister nds, Address src, int vector_len);
 
-  // Sqrt Packed Floating-Point Values - Double precision only
+  // Sqrt Packed Floating-Point Values
   void vsqrtpd(XMMRegister dst, XMMRegister src, int vector_len);
   void vsqrtpd(XMMRegister dst, Address src, int vector_len);
+  void vsqrtps(XMMRegister dst, XMMRegister src, int vector_len);
+  void vsqrtps(XMMRegister dst, Address src, int vector_len);
 
   // Bitwise Logical AND of Packed Floating-Point Values
   void andpd(XMMRegister dst, XMMRegister src);
--- a/src/hotspot/cpu/x86/x86.ad	Wed Nov 22 01:12:23 2017 -0800
+++ b/src/hotspot/cpu/x86/x86.ad	Wed Nov 22 14:43:37 2017 +0300
@@ -1252,6 +1252,7 @@
         ret_value = false;
       break;
     case Op_SqrtVD:
+    case Op_SqrtVF:
       if (UseAVX < 1) // enabled for AVX only
         ret_value = false;
       break;
@@ -2580,7 +2581,7 @@
 
 instruct sqrtF_reg(regF dst, regF src) %{
   predicate(UseSSE>=1);
-  match(Set dst (ConvD2F (SqrtD (ConvF2D src))));
+  match(Set dst (SqrtF src));
 
   format %{ "sqrtss  $dst, $src" %}
   ins_cost(150);
@@ -2592,7 +2593,7 @@
 
 instruct sqrtF_mem(regF dst, memory src) %{
   predicate(UseSSE>=1);
-  match(Set dst (ConvD2F (SqrtD (ConvF2D (LoadF src)))));
+  match(Set dst (SqrtF (LoadF src)));
 
   format %{ "sqrtss  $dst, $src" %}
   ins_cost(150);
@@ -2604,7 +2605,8 @@
 
 instruct sqrtF_imm(regF dst, immF con) %{
   predicate(UseSSE>=1);
-  match(Set dst (ConvD2F (SqrtD (ConvF2D con))));
+  match(Set dst (SqrtF con));
+
   format %{ "sqrtss  $dst, [$constantaddress]\t# load from constant table: float=$con" %}
   ins_cost(150);
   ins_encode %{
@@ -8388,7 +8390,7 @@
 
 // --------------------------------- Sqrt --------------------------------------
 
-// Floating point vector sqrt - double precision only
+// Floating point vector sqrt
 instruct vsqrt2D_reg(vecX dst, vecX src) %{
   predicate(UseAVX > 0 && n->as_Vector()->length() == 2);
   match(Set dst (SqrtVD src));
@@ -8455,6 +8457,94 @@
   ins_pipe( pipe_slow );
 %}
 
+instruct vsqrt2F_reg(vecD dst, vecD src) %{
+  predicate(UseAVX > 0 && n->as_Vector()->length() == 2);
+  match(Set dst (SqrtVF src));
+  format %{ "vsqrtps  $dst,$src\t! sqrt packed2F" %}
+  ins_encode %{
+    int vector_len = 0;
+    __ vsqrtps($dst$$XMMRegister, $src$$XMMRegister, vector_len);
+  %}
+  ins_pipe( pipe_slow );
+%}
+
+instruct vsqrt2F_mem(vecD dst, memory mem) %{
+  predicate(UseAVX > 0 && n->as_Vector()->length() == 2);
+  match(Set dst (SqrtVF (LoadVector mem)));
+  format %{ "vsqrtps  $dst,$mem\t! sqrt packed2F" %}
+  ins_encode %{
+    int vector_len = 0;
+    __ vsqrtps($dst$$XMMRegister, $mem$$Address, vector_len);
+  %}
+  ins_pipe( pipe_slow );
+%}
+
+instruct vsqrt4F_reg(vecX dst, vecX src) %{
+  predicate(UseAVX > 0 && n->as_Vector()->length() == 4);
+  match(Set dst (SqrtVF src));
+  format %{ "vsqrtps  $dst,$src\t! sqrt packed4F" %}
+  ins_encode %{
+    int vector_len = 0;
+    __ vsqrtps($dst$$XMMRegister, $src$$XMMRegister, vector_len);
+  %}
+  ins_pipe( pipe_slow );
+%}
+
+instruct vsqrt4F_mem(vecX dst, memory mem) %{
+  predicate(UseAVX > 0 && n->as_Vector()->length() == 4);
+  match(Set dst (SqrtVF (LoadVector mem)));
+  format %{ "vsqrtps  $dst,$mem\t! sqrt packed4F" %}
+  ins_encode %{
+    int vector_len = 0;
+    __ vsqrtps($dst$$XMMRegister, $mem$$Address, vector_len);
+  %}
+  ins_pipe( pipe_slow );
+%}
+
+instruct vsqrt8F_reg(vecY dst, vecY src) %{
+  predicate(UseAVX > 0 && n->as_Vector()->length() == 8);
+  match(Set dst (SqrtVF src));
+  format %{ "vsqrtps  $dst,$src\t! sqrt packed8F" %}
+  ins_encode %{
+    int vector_len = 1;
+    __ vsqrtps($dst$$XMMRegister, $src$$XMMRegister, vector_len);
+  %}
+  ins_pipe( pipe_slow );
+%}
+
+instruct vsqrt8F_mem(vecY dst, memory mem) %{
+  predicate(UseAVX > 0 && n->as_Vector()->length() == 8);
+  match(Set dst (SqrtVF (LoadVector mem)));
+  format %{ "vsqrtps  $dst,$mem\t! sqrt packed8F" %}
+  ins_encode %{
+    int vector_len = 1;
+    __ vsqrtps($dst$$XMMRegister, $mem$$Address, vector_len);
+  %}
+  ins_pipe( pipe_slow );
+%}
+
+instruct vsqrt16F_reg(vecZ dst, vecZ src) %{
+  predicate(UseAVX > 2 && n->as_Vector()->length() == 16);
+  match(Set dst (SqrtVF src));
+  format %{ "vsqrtps  $dst,$src\t! sqrt packed16F" %}
+  ins_encode %{
+    int vector_len = 2;
+    __ vsqrtps($dst$$XMMRegister, $src$$XMMRegister, vector_len);
+  %}
+  ins_pipe( pipe_slow );
+%}
+
+instruct vsqrt16F_mem(vecZ dst, memory mem) %{
+  predicate(UseAVX > 2 && n->as_Vector()->length() == 16);
+  match(Set dst (SqrtVF (LoadVector mem)));
+  format %{ "vsqrtps  $dst,$mem\t! sqrt packed16F" %}
+  ins_encode %{
+    int vector_len = 2;
+    __ vsqrtps($dst$$XMMRegister, $mem$$Address, vector_len);
+  %}
+  ins_pipe( pipe_slow );
+%}
+
 // ------------------------------ LeftShift -----------------------------------
 
 // Shorts/Chars vector left shift
--- a/src/hotspot/share/adlc/formssel.cpp	Wed Nov 22 01:12:23 2017 -0800
+++ b/src/hotspot/share/adlc/formssel.cpp	Wed Nov 22 14:43:37 2017 +0300
@@ -4034,6 +4034,7 @@
         strcmp(opType,"ModF")==0 ||
         strcmp(opType,"ModI")==0 ||
         strcmp(opType,"SqrtD")==0 ||
+        strcmp(opType,"SqrtF")==0 ||
         strcmp(opType,"TanD")==0 ||
         strcmp(opType,"ConvD2F")==0 ||
         strcmp(opType,"ConvD2I")==0 ||
@@ -4167,7 +4168,7 @@
     "DivVF","DivVD",
     "AbsVF","AbsVD",
     "NegVF","NegVD",
-    "SqrtVD",
+    "SqrtVD","SqrtVF",
     "AndV" ,"XorV" ,"OrV",
     "AddReductionVI", "AddReductionVL",
     "AddReductionVF", "AddReductionVD",
--- a/src/hotspot/share/opto/classes.hpp	Wed Nov 22 01:12:23 2017 -0800
+++ b/src/hotspot/share/opto/classes.hpp	Wed Nov 22 14:43:37 2017 +0300
@@ -252,6 +252,7 @@
 macro(SafePointScalarObject)
 macro(SCMemProj)
 macro(SqrtD)
+macro(SqrtF)
 macro(Start)
 macro(StartOSR)
 macro(StoreB)
@@ -320,6 +321,7 @@
 macro(NegVF)
 macro(NegVD)
 macro(SqrtVD)
+macro(SqrtVF)
 macro(LShiftCntV)
 macro(RShiftCntV)
 macro(LShiftVB)
--- a/src/hotspot/share/opto/convertnode.cpp	Wed Nov 22 01:12:23 2017 -0800
+++ b/src/hotspot/share/opto/convertnode.cpp	Wed Nov 22 14:43:37 2017 +0300
@@ -73,6 +73,21 @@
   return TypeF::make( (float)td->getd() );
 }
 
+//------------------------------Ideal------------------------------------------
+// If we see pattern ConvF2D SomeDoubleOp ConvD2F, do operation as float.
+Node *ConvD2FNode::Ideal(PhaseGVN *phase, bool can_reshape) {
+  if ( in(1)->Opcode() == Op_SqrtD ) {
+    Node* sqrtd = in(1);
+    if ( sqrtd->in(1)->Opcode() == Op_ConvF2D ) {
+      if ( Matcher::match_rule_supported(Op_SqrtF) ) {
+        Node* convf2d = sqrtd->in(1);
+        return new SqrtFNode(phase->C, sqrtd->in(0), convf2d->in(1));
+      }
+    }
+  }
+  return NULL;
+}
+
 //------------------------------Identity---------------------------------------
 // Float's can be converted to doubles with no loss of bits.  Hence
 // converting a float to a double and back to a float is a NOP.
--- a/src/hotspot/share/opto/convertnode.hpp	Wed Nov 22 01:12:23 2017 -0800
+++ b/src/hotspot/share/opto/convertnode.hpp	Wed Nov 22 14:43:37 2017 +0300
@@ -51,6 +51,7 @@
   virtual const Type *bottom_type() const { return Type::FLOAT; }
   virtual const Type* Value(PhaseGVN* phase) const;
   virtual Node* Identity(PhaseGVN* phase);
+  virtual Node *Ideal(PhaseGVN *phase, bool can_reshape);
   virtual uint  ideal_reg() const { return Op_RegF; }
 };
 
--- a/src/hotspot/share/opto/subnode.cpp	Wed Nov 22 01:12:23 2017 -0800
+++ b/src/hotspot/share/opto/subnode.cpp	Wed Nov 22 14:43:37 2017 +0300
@@ -1595,3 +1595,12 @@
   if( d < 0.0 ) return Type::DOUBLE;
   return TypeD::make( sqrt( d ) );
 }
+
+const Type* SqrtFNode::Value(PhaseGVN* phase) const {
+  const Type *t1 = phase->type( in(1) );
+  if( t1 == Type::TOP ) return Type::TOP;
+  if( t1->base() != Type::FloatCon ) return Type::FLOAT;
+  float f = t1->getf();
+  if( f < 0.0f ) return Type::FLOAT;
+  return TypeF::make( (float)sqrt( (double)f ) );
+}
--- a/src/hotspot/share/opto/subnode.hpp	Wed Nov 22 01:12:23 2017 -0800
+++ b/src/hotspot/share/opto/subnode.hpp	Wed Nov 22 14:43:37 2017 +0300
@@ -442,6 +442,20 @@
   virtual const Type* Value(PhaseGVN* phase) const;
 };
 
+//------------------------------SqrtFNode--------------------------------------
+// square root a float
+class SqrtFNode : public Node {
+public:
+  SqrtFNode(Compile* C, Node *c, Node *in1) : Node(c, in1) {
+    init_flags(Flag_is_expensive);
+    C->add_expensive_node(this);
+  }
+  virtual int Opcode() const;
+  const Type *bottom_type() const { return Type::FLOAT; }
+  virtual uint ideal_reg() const { return Op_RegF; }
+  virtual const Type* Value(PhaseGVN* phase) const;
+};
+
 //-------------------------------ReverseBytesINode--------------------------------
 // reverse bytes of an integer
 class ReverseBytesINode : public Node {
--- a/src/hotspot/share/opto/superword.cpp	Wed Nov 22 01:12:23 2017 -0800
+++ b/src/hotspot/share/opto/superword.cpp	Wed Nov 22 14:43:37 2017 +0300
@@ -2307,7 +2307,7 @@
           vn = VectorNode::make(opc, in1, in2, vlen, velt_basic_type(n));
           vlen_in_bytes = vn->as_Vector()->length_in_bytes();
         }
-      } else if (opc == Op_SqrtD || opc == Op_AbsF || opc == Op_AbsD || opc == Op_NegF || opc == Op_NegD) {
+      } else if (opc == Op_SqrtF || opc == Op_SqrtD || opc == Op_AbsF || opc == Op_AbsD || opc == Op_NegF || opc == Op_NegD) {
         // Promote operand to vector (Sqrt/Abs/Neg are 2 address instructions)
         Node* in = vector_opd(p, 1);
         vn = VectorNode::make(opc, in, NULL, vlen, velt_basic_type(n));
--- a/src/hotspot/share/opto/vectornode.cpp	Wed Nov 22 01:12:23 2017 -0800
+++ b/src/hotspot/share/opto/vectornode.cpp	Wed Nov 22 14:43:37 2017 +0300
@@ -113,6 +113,9 @@
   case Op_NegD:
     assert(bt == T_DOUBLE, "must be");
     return Op_NegVD;
+  case Op_SqrtF:
+    assert(bt == T_FLOAT, "must be");
+    return Op_SqrtVF;
   case Op_SqrtD:
     assert(bt == T_DOUBLE, "must be");
     return Op_SqrtVD;
@@ -316,7 +319,7 @@
   case Op_NegVF: return new NegVFNode(n1, vt);
   case Op_NegVD: return new NegVDNode(n1, vt);
 
-  // Currently only supports double precision sqrt
+  case Op_SqrtVF: return new SqrtVFNode(n1, vt);
   case Op_SqrtVD: return new SqrtVDNode(n1, vt);
 
   case Op_LShiftVB: return new LShiftVBNode(n1, n2, vt);
--- a/src/hotspot/share/opto/vectornode.hpp	Wed Nov 22 01:12:23 2017 -0800
+++ b/src/hotspot/share/opto/vectornode.hpp	Wed Nov 22 14:43:37 2017 +0300
@@ -373,6 +373,14 @@
   virtual int Opcode() const;
 };
 
+//------------------------------SqrtVFNode--------------------------------------
+// Vector Sqrt float
+class SqrtVFNode : public VectorNode {
+ public:
+  SqrtVFNode(Node* in, const TypeVect* vt) : VectorNode(in,vt) {}
+  virtual int Opcode() const;
+};
+
 //------------------------------SqrtVDNode--------------------------------------
 // Vector Sqrt double
 class SqrtVDNode : public VectorNode {
--- a/src/hotspot/share/runtime/vmStructs.cpp	Wed Nov 22 01:12:23 2017 -0800
+++ b/src/hotspot/share/runtime/vmStructs.cpp	Wed Nov 22 14:43:37 2017 +0300
@@ -1958,6 +1958,7 @@
   declare_c2_type(NegFNode, NegNode)                                      \
   declare_c2_type(NegDNode, NegNode)                                      \
   declare_c2_type(AtanDNode, Node)                                        \
+  declare_c2_type(SqrtFNode, Node)                                        \
   declare_c2_type(SqrtDNode, Node)                                        \
   declare_c2_type(ReverseBytesINode, Node)                                \
   declare_c2_type(ReverseBytesLNode, Node)                                \