# HG changeset patch # User rlupusoru # Date 1511351017 -10800 # Node ID 22c9856fc2c2df133481c614b19fe980e7d5ea0b # Parent 2cd1c2b037825cf720a8867e187d9fc07812c323 8190800: Support vectorization of Math.sqrt() on floats Reviewed-by: vlivanov, kvn diff -r 2cd1c2b03782 -r 22c9856fc2c2 src/hotspot/cpu/x86/assembler_x86.cpp --- a/src/hotspot/cpu/x86/assembler_x86.cpp Wed Nov 22 01:12:23 2017 -0800 +++ b/src/hotspot/cpu/x86/assembler_x86.cpp Wed Nov 22 14:43:37 2017 +0300 @@ -5203,6 +5203,24 @@ emit_operand(dst, src); } +void Assembler::vsqrtps(XMMRegister dst, XMMRegister src, int vector_len) { + assert(VM_Version::supports_avx(), ""); + InstructionAttr attributes(vector_len, /* vex_w */ false, /* legacy_mode */ false, /* no_mask_reg */ false, /* uses_vl */ true); + int encode = vex_prefix_and_encode(dst->encoding(), 0, src->encoding(), VEX_SIMD_NONE, VEX_OPCODE_0F, &attributes); + emit_int8(0x51); + emit_int8((unsigned char)(0xC0 | encode)); +} + +void Assembler::vsqrtps(XMMRegister dst, Address src, int vector_len) { + assert(VM_Version::supports_avx(), ""); + InstructionMark im(this); + InstructionAttr attributes(vector_len, /* vex_w */ false, /* legacy_mode */ false, /* no_mask_reg */ false, /* uses_vl */ true); + attributes.set_address_attributes(/* tuple_type */ EVEX_FV, /* input_size_in_bits */ EVEX_64bit); + vex_prefix(src, 0, dst->encoding(), VEX_SIMD_NONE, VEX_OPCODE_0F, &attributes); + emit_int8(0x51); + emit_operand(dst, src); +} + void Assembler::andpd(XMMRegister dst, XMMRegister src) { NOT_LP64(assert(VM_Version::supports_sse2(), "")); InstructionAttr attributes(AVX_128bit, /* rex_w */ !_legacy_mode_dq, /* legacy_mode */ _legacy_mode_dq, /* no_mask_reg */ false, /* uses_vl */ true); diff -r 2cd1c2b03782 -r 22c9856fc2c2 src/hotspot/cpu/x86/assembler_x86.hpp --- a/src/hotspot/cpu/x86/assembler_x86.hpp Wed Nov 22 01:12:23 2017 -0800 +++ b/src/hotspot/cpu/x86/assembler_x86.hpp Wed Nov 22 14:43:37 2017 +0300 @@ -1919,9 +1919,11 @@ void vdivpd(XMMRegister dst, XMMRegister nds, Address src, int vector_len); void vdivps(XMMRegister dst, XMMRegister nds, Address src, int vector_len); - // Sqrt Packed Floating-Point Values - Double precision only + // Sqrt Packed Floating-Point Values void vsqrtpd(XMMRegister dst, XMMRegister src, int vector_len); void vsqrtpd(XMMRegister dst, Address src, int vector_len); + void vsqrtps(XMMRegister dst, XMMRegister src, int vector_len); + void vsqrtps(XMMRegister dst, Address src, int vector_len); // Bitwise Logical AND of Packed Floating-Point Values void andpd(XMMRegister dst, XMMRegister src); diff -r 2cd1c2b03782 -r 22c9856fc2c2 src/hotspot/cpu/x86/x86.ad --- a/src/hotspot/cpu/x86/x86.ad Wed Nov 22 01:12:23 2017 -0800 +++ b/src/hotspot/cpu/x86/x86.ad Wed Nov 22 14:43:37 2017 +0300 @@ -1252,6 +1252,7 @@ ret_value = false; break; case Op_SqrtVD: + case Op_SqrtVF: if (UseAVX < 1) // enabled for AVX only ret_value = false; break; @@ -2580,7 +2581,7 @@ instruct sqrtF_reg(regF dst, regF src) %{ predicate(UseSSE>=1); - match(Set dst (ConvD2F (SqrtD (ConvF2D src)))); + match(Set dst (SqrtF src)); format %{ "sqrtss $dst, $src" %} ins_cost(150); @@ -2592,7 +2593,7 @@ instruct sqrtF_mem(regF dst, memory src) %{ predicate(UseSSE>=1); - match(Set dst (ConvD2F (SqrtD (ConvF2D (LoadF src))))); + match(Set dst (SqrtF (LoadF src))); format %{ "sqrtss $dst, $src" %} ins_cost(150); @@ -2604,7 +2605,8 @@ instruct sqrtF_imm(regF dst, immF con) %{ predicate(UseSSE>=1); - match(Set dst (ConvD2F (SqrtD (ConvF2D con)))); + match(Set dst (SqrtF con)); + format %{ "sqrtss $dst, [$constantaddress]\t# load from constant table: float=$con" %} ins_cost(150); ins_encode %{ @@ -8388,7 +8390,7 @@ // --------------------------------- Sqrt -------------------------------------- -// Floating point vector sqrt - double precision only +// Floating point vector sqrt instruct vsqrt2D_reg(vecX dst, vecX src) %{ predicate(UseAVX > 0 && n->as_Vector()->length() == 2); match(Set dst (SqrtVD src)); @@ -8455,6 +8457,94 @@ ins_pipe( pipe_slow ); %} +instruct vsqrt2F_reg(vecD dst, vecD src) %{ + predicate(UseAVX > 0 && n->as_Vector()->length() == 2); + match(Set dst (SqrtVF src)); + format %{ "vsqrtps $dst,$src\t! sqrt packed2F" %} + ins_encode %{ + int vector_len = 0; + __ vsqrtps($dst$$XMMRegister, $src$$XMMRegister, vector_len); + %} + ins_pipe( pipe_slow ); +%} + +instruct vsqrt2F_mem(vecD dst, memory mem) %{ + predicate(UseAVX > 0 && n->as_Vector()->length() == 2); + match(Set dst (SqrtVF (LoadVector mem))); + format %{ "vsqrtps $dst,$mem\t! sqrt packed2F" %} + ins_encode %{ + int vector_len = 0; + __ vsqrtps($dst$$XMMRegister, $mem$$Address, vector_len); + %} + ins_pipe( pipe_slow ); +%} + +instruct vsqrt4F_reg(vecX dst, vecX src) %{ + predicate(UseAVX > 0 && n->as_Vector()->length() == 4); + match(Set dst (SqrtVF src)); + format %{ "vsqrtps $dst,$src\t! sqrt packed4F" %} + ins_encode %{ + int vector_len = 0; + __ vsqrtps($dst$$XMMRegister, $src$$XMMRegister, vector_len); + %} + ins_pipe( pipe_slow ); +%} + +instruct vsqrt4F_mem(vecX dst, memory mem) %{ + predicate(UseAVX > 0 && n->as_Vector()->length() == 4); + match(Set dst (SqrtVF (LoadVector mem))); + format %{ "vsqrtps $dst,$mem\t! sqrt packed4F" %} + ins_encode %{ + int vector_len = 0; + __ vsqrtps($dst$$XMMRegister, $mem$$Address, vector_len); + %} + ins_pipe( pipe_slow ); +%} + +instruct vsqrt8F_reg(vecY dst, vecY src) %{ + predicate(UseAVX > 0 && n->as_Vector()->length() == 8); + match(Set dst (SqrtVF src)); + format %{ "vsqrtps $dst,$src\t! sqrt packed8F" %} + ins_encode %{ + int vector_len = 1; + __ vsqrtps($dst$$XMMRegister, $src$$XMMRegister, vector_len); + %} + ins_pipe( pipe_slow ); +%} + +instruct vsqrt8F_mem(vecY dst, memory mem) %{ + predicate(UseAVX > 0 && n->as_Vector()->length() == 8); + match(Set dst (SqrtVF (LoadVector mem))); + format %{ "vsqrtps $dst,$mem\t! sqrt packed8F" %} + ins_encode %{ + int vector_len = 1; + __ vsqrtps($dst$$XMMRegister, $mem$$Address, vector_len); + %} + ins_pipe( pipe_slow ); +%} + +instruct vsqrt16F_reg(vecZ dst, vecZ src) %{ + predicate(UseAVX > 2 && n->as_Vector()->length() == 16); + match(Set dst (SqrtVF src)); + format %{ "vsqrtps $dst,$src\t! sqrt packed16F" %} + ins_encode %{ + int vector_len = 2; + __ vsqrtps($dst$$XMMRegister, $src$$XMMRegister, vector_len); + %} + ins_pipe( pipe_slow ); +%} + +instruct vsqrt16F_mem(vecZ dst, memory mem) %{ + predicate(UseAVX > 2 && n->as_Vector()->length() == 16); + match(Set dst (SqrtVF (LoadVector mem))); + format %{ "vsqrtps $dst,$mem\t! sqrt packed16F" %} + ins_encode %{ + int vector_len = 2; + __ vsqrtps($dst$$XMMRegister, $mem$$Address, vector_len); + %} + ins_pipe( pipe_slow ); +%} + // ------------------------------ LeftShift ----------------------------------- // Shorts/Chars vector left shift diff -r 2cd1c2b03782 -r 22c9856fc2c2 src/hotspot/share/adlc/formssel.cpp --- a/src/hotspot/share/adlc/formssel.cpp Wed Nov 22 01:12:23 2017 -0800 +++ b/src/hotspot/share/adlc/formssel.cpp Wed Nov 22 14:43:37 2017 +0300 @@ -4034,6 +4034,7 @@ strcmp(opType,"ModF")==0 || strcmp(opType,"ModI")==0 || strcmp(opType,"SqrtD")==0 || + strcmp(opType,"SqrtF")==0 || strcmp(opType,"TanD")==0 || strcmp(opType,"ConvD2F")==0 || strcmp(opType,"ConvD2I")==0 || @@ -4167,7 +4168,7 @@ "DivVF","DivVD", "AbsVF","AbsVD", "NegVF","NegVD", - "SqrtVD", + "SqrtVD","SqrtVF", "AndV" ,"XorV" ,"OrV", "AddReductionVI", "AddReductionVL", "AddReductionVF", "AddReductionVD", diff -r 2cd1c2b03782 -r 22c9856fc2c2 src/hotspot/share/opto/classes.hpp --- a/src/hotspot/share/opto/classes.hpp Wed Nov 22 01:12:23 2017 -0800 +++ b/src/hotspot/share/opto/classes.hpp Wed Nov 22 14:43:37 2017 +0300 @@ -252,6 +252,7 @@ macro(SafePointScalarObject) macro(SCMemProj) macro(SqrtD) +macro(SqrtF) macro(Start) macro(StartOSR) macro(StoreB) @@ -320,6 +321,7 @@ macro(NegVF) macro(NegVD) macro(SqrtVD) +macro(SqrtVF) macro(LShiftCntV) macro(RShiftCntV) macro(LShiftVB) diff -r 2cd1c2b03782 -r 22c9856fc2c2 src/hotspot/share/opto/convertnode.cpp --- a/src/hotspot/share/opto/convertnode.cpp Wed Nov 22 01:12:23 2017 -0800 +++ b/src/hotspot/share/opto/convertnode.cpp Wed Nov 22 14:43:37 2017 +0300 @@ -73,6 +73,21 @@ return TypeF::make( (float)td->getd() ); } +//------------------------------Ideal------------------------------------------ +// If we see pattern ConvF2D SomeDoubleOp ConvD2F, do operation as float. +Node *ConvD2FNode::Ideal(PhaseGVN *phase, bool can_reshape) { + if ( in(1)->Opcode() == Op_SqrtD ) { + Node* sqrtd = in(1); + if ( sqrtd->in(1)->Opcode() == Op_ConvF2D ) { + if ( Matcher::match_rule_supported(Op_SqrtF) ) { + Node* convf2d = sqrtd->in(1); + return new SqrtFNode(phase->C, sqrtd->in(0), convf2d->in(1)); + } + } + } + return NULL; +} + //------------------------------Identity--------------------------------------- // Float's can be converted to doubles with no loss of bits. Hence // converting a float to a double and back to a float is a NOP. diff -r 2cd1c2b03782 -r 22c9856fc2c2 src/hotspot/share/opto/convertnode.hpp --- a/src/hotspot/share/opto/convertnode.hpp Wed Nov 22 01:12:23 2017 -0800 +++ b/src/hotspot/share/opto/convertnode.hpp Wed Nov 22 14:43:37 2017 +0300 @@ -51,6 +51,7 @@ virtual const Type *bottom_type() const { return Type::FLOAT; } virtual const Type* Value(PhaseGVN* phase) const; virtual Node* Identity(PhaseGVN* phase); + virtual Node *Ideal(PhaseGVN *phase, bool can_reshape); virtual uint ideal_reg() const { return Op_RegF; } }; diff -r 2cd1c2b03782 -r 22c9856fc2c2 src/hotspot/share/opto/subnode.cpp --- a/src/hotspot/share/opto/subnode.cpp Wed Nov 22 01:12:23 2017 -0800 +++ b/src/hotspot/share/opto/subnode.cpp Wed Nov 22 14:43:37 2017 +0300 @@ -1595,3 +1595,12 @@ if( d < 0.0 ) return Type::DOUBLE; return TypeD::make( sqrt( d ) ); } + +const Type* SqrtFNode::Value(PhaseGVN* phase) const { + const Type *t1 = phase->type( in(1) ); + if( t1 == Type::TOP ) return Type::TOP; + if( t1->base() != Type::FloatCon ) return Type::FLOAT; + float f = t1->getf(); + if( f < 0.0f ) return Type::FLOAT; + return TypeF::make( (float)sqrt( (double)f ) ); +} diff -r 2cd1c2b03782 -r 22c9856fc2c2 src/hotspot/share/opto/subnode.hpp --- a/src/hotspot/share/opto/subnode.hpp Wed Nov 22 01:12:23 2017 -0800 +++ b/src/hotspot/share/opto/subnode.hpp Wed Nov 22 14:43:37 2017 +0300 @@ -442,6 +442,20 @@ virtual const Type* Value(PhaseGVN* phase) const; }; +//------------------------------SqrtFNode-------------------------------------- +// square root a float +class SqrtFNode : public Node { +public: + SqrtFNode(Compile* C, Node *c, Node *in1) : Node(c, in1) { + init_flags(Flag_is_expensive); + C->add_expensive_node(this); + } + virtual int Opcode() const; + const Type *bottom_type() const { return Type::FLOAT; } + virtual uint ideal_reg() const { return Op_RegF; } + virtual const Type* Value(PhaseGVN* phase) const; +}; + //-------------------------------ReverseBytesINode-------------------------------- // reverse bytes of an integer class ReverseBytesINode : public Node { diff -r 2cd1c2b03782 -r 22c9856fc2c2 src/hotspot/share/opto/superword.cpp --- a/src/hotspot/share/opto/superword.cpp Wed Nov 22 01:12:23 2017 -0800 +++ b/src/hotspot/share/opto/superword.cpp Wed Nov 22 14:43:37 2017 +0300 @@ -2307,7 +2307,7 @@ vn = VectorNode::make(opc, in1, in2, vlen, velt_basic_type(n)); vlen_in_bytes = vn->as_Vector()->length_in_bytes(); } - } else if (opc == Op_SqrtD || opc == Op_AbsF || opc == Op_AbsD || opc == Op_NegF || opc == Op_NegD) { + } else if (opc == Op_SqrtF || opc == Op_SqrtD || opc == Op_AbsF || opc == Op_AbsD || opc == Op_NegF || opc == Op_NegD) { // Promote operand to vector (Sqrt/Abs/Neg are 2 address instructions) Node* in = vector_opd(p, 1); vn = VectorNode::make(opc, in, NULL, vlen, velt_basic_type(n)); diff -r 2cd1c2b03782 -r 22c9856fc2c2 src/hotspot/share/opto/vectornode.cpp --- a/src/hotspot/share/opto/vectornode.cpp Wed Nov 22 01:12:23 2017 -0800 +++ b/src/hotspot/share/opto/vectornode.cpp Wed Nov 22 14:43:37 2017 +0300 @@ -113,6 +113,9 @@ case Op_NegD: assert(bt == T_DOUBLE, "must be"); return Op_NegVD; + case Op_SqrtF: + assert(bt == T_FLOAT, "must be"); + return Op_SqrtVF; case Op_SqrtD: assert(bt == T_DOUBLE, "must be"); return Op_SqrtVD; @@ -316,7 +319,7 @@ case Op_NegVF: return new NegVFNode(n1, vt); case Op_NegVD: return new NegVDNode(n1, vt); - // Currently only supports double precision sqrt + case Op_SqrtVF: return new SqrtVFNode(n1, vt); case Op_SqrtVD: return new SqrtVDNode(n1, vt); case Op_LShiftVB: return new LShiftVBNode(n1, n2, vt); diff -r 2cd1c2b03782 -r 22c9856fc2c2 src/hotspot/share/opto/vectornode.hpp --- a/src/hotspot/share/opto/vectornode.hpp Wed Nov 22 01:12:23 2017 -0800 +++ b/src/hotspot/share/opto/vectornode.hpp Wed Nov 22 14:43:37 2017 +0300 @@ -373,6 +373,14 @@ virtual int Opcode() const; }; +//------------------------------SqrtVFNode-------------------------------------- +// Vector Sqrt float +class SqrtVFNode : public VectorNode { + public: + SqrtVFNode(Node* in, const TypeVect* vt) : VectorNode(in,vt) {} + virtual int Opcode() const; +}; + //------------------------------SqrtVDNode-------------------------------------- // Vector Sqrt double class SqrtVDNode : public VectorNode { diff -r 2cd1c2b03782 -r 22c9856fc2c2 src/hotspot/share/runtime/vmStructs.cpp --- a/src/hotspot/share/runtime/vmStructs.cpp Wed Nov 22 01:12:23 2017 -0800 +++ b/src/hotspot/share/runtime/vmStructs.cpp Wed Nov 22 14:43:37 2017 +0300 @@ -1958,6 +1958,7 @@ declare_c2_type(NegFNode, NegNode) \ declare_c2_type(NegDNode, NegNode) \ declare_c2_type(AtanDNode, Node) \ + declare_c2_type(SqrtFNode, Node) \ declare_c2_type(SqrtDNode, Node) \ declare_c2_type(ReverseBytesINode, Node) \ declare_c2_type(ReverseBytesLNode, Node) \