8135028: support for vectorizing double precision sqrt
authormcberg
Wed, 09 Sep 2015 10:34:17 -0700
changeset 32723 56534fb3d71a
parent 32582 56619bb8bcaa
child 32724 08ecb1e66e71
8135028: support for vectorizing double precision sqrt Reviewed-by: kvn, twisti
hotspot/src/cpu/x86/vm/assembler_x86.cpp
hotspot/src/cpu/x86/vm/assembler_x86.hpp
hotspot/src/cpu/x86/vm/x86.ad
hotspot/src/share/vm/adlc/formssel.cpp
hotspot/src/share/vm/opto/classes.hpp
hotspot/src/share/vm/opto/superword.cpp
hotspot/src/share/vm/opto/vectornode.cpp
hotspot/src/share/vm/opto/vectornode.hpp
hotspot/test/compiler/loopopts/superword/SumRedSqrt_Double.java
--- a/hotspot/src/cpu/x86/vm/assembler_x86.cpp	Fri Sep 04 12:47:57 2015 +0200
+++ b/hotspot/src/cpu/x86/vm/assembler_x86.cpp	Wed Sep 09 10:34:17 2015 -0700
@@ -3993,6 +3993,26 @@
   emit_vex_arith(0x5E, dst, nds, src, VEX_SIMD_NONE, vector_len);
 }
 
+void Assembler::vsqrtpd(XMMRegister dst, XMMRegister src, int vector_len) {
+  assert(VM_Version::supports_avx(), "");
+  if (VM_Version::supports_evex()) {
+    emit_vex_arith_q(0x51, dst, xnoreg, src, VEX_SIMD_66, vector_len);
+  } else {
+    emit_vex_arith(0x51, dst, xnoreg, src, VEX_SIMD_66, vector_len);
+  }
+}
+
+void Assembler::vsqrtpd(XMMRegister dst, Address src, int vector_len) {
+  assert(VM_Version::supports_avx(), "");
+  if (VM_Version::supports_evex()) {
+    tuple_type = EVEX_FV;
+    input_size_in_bits = EVEX_64bit;
+    emit_vex_arith_q(0x51, dst, xnoreg, src, VEX_SIMD_66, vector_len);
+  } else {
+    emit_vex_arith(0x51, dst, xnoreg, src, VEX_SIMD_66, vector_len);
+  }
+}
+
 void Assembler::andpd(XMMRegister dst, XMMRegister src) {
   NOT_LP64(assert(VM_Version::supports_sse2(), ""));
   if (VM_Version::supports_evex() && VM_Version::supports_avx512dq()) {
--- a/hotspot/src/cpu/x86/vm/assembler_x86.hpp	Fri Sep 04 12:47:57 2015 +0200
+++ b/hotspot/src/cpu/x86/vm/assembler_x86.hpp	Wed Sep 09 10:34:17 2015 -0700
@@ -1920,6 +1920,10 @@
   void vdivpd(XMMRegister dst, XMMRegister nds, Address src, int vector_len);
   void vdivps(XMMRegister dst, XMMRegister nds, Address src, int vector_len);
 
+  // Sqrt Packed Floating-Point Values - Double precision only
+  void vsqrtpd(XMMRegister dst, XMMRegister src, int vector_len);
+  void vsqrtpd(XMMRegister dst, Address src, int vector_len);
+
   // Bitwise Logical AND of Packed Floating-Point Values
   void andpd(XMMRegister dst, XMMRegister src);
   void andps(XMMRegister dst, XMMRegister src);
--- a/hotspot/src/cpu/x86/vm/x86.ad	Fri Sep 04 12:47:57 2015 +0200
+++ b/hotspot/src/cpu/x86/vm/x86.ad	Wed Sep 09 10:34:17 2015 -0700
@@ -1691,6 +1691,10 @@
       if (UseSSE < 1) // requires at least SSE
         return false;
     break;
+    case Op_SqrtVD:
+      if (UseAVX < 1) // enabled for AVX only
+        return false;
+    break;
     case Op_CompareAndSwapL:
 #ifdef _LP64
     case Op_CompareAndSwapP:
@@ -7474,6 +7478,75 @@
   ins_pipe( pipe_slow );
 %}
 
+// --------------------------------- Sqrt --------------------------------------
+
+// Floating point vector sqrt - double precision only
+instruct vsqrt2D_reg(vecX dst, vecX src) %{
+  predicate(UseAVX > 0 && n->as_Vector()->length() == 2);
+  match(Set dst (SqrtVD src));
+  format %{ "vsqrtpd  $dst,$src\t! sqrt packed2D" %}
+  ins_encode %{
+    int vector_len = 0;
+    __ vsqrtpd($dst$$XMMRegister, $src$$XMMRegister, vector_len);
+  %}
+  ins_pipe( pipe_slow );
+%}
+
+instruct vsqrt2D_mem(vecX dst, memory mem) %{
+  predicate(UseAVX > 0 && n->as_Vector()->length() == 2);
+  match(Set dst (SqrtVD (LoadVector mem)));
+  format %{ "vsqrtpd  $dst,$mem\t! sqrt packed2D" %}
+  ins_encode %{
+    int vector_len = 0;
+    __ vsqrtpd($dst$$XMMRegister, $mem$$Address, vector_len);
+  %}
+  ins_pipe( pipe_slow );
+%}
+
+instruct vsqrt4D_reg(vecY dst, vecY src) %{
+  predicate(UseAVX > 0 && n->as_Vector()->length() == 4);
+  match(Set dst (SqrtVD src));
+  format %{ "vsqrtpd  $dst,$src\t! sqrt packed4D" %}
+  ins_encode %{
+    int vector_len = 1;
+    __ vsqrtpd($dst$$XMMRegister, $src$$XMMRegister, vector_len);
+  %}
+  ins_pipe( pipe_slow );
+%}
+
+instruct vsqrt4D_mem(vecY dst, memory mem) %{
+  predicate(UseAVX > 0 && n->as_Vector()->length() == 4);
+  match(Set dst (SqrtVD (LoadVector mem)));
+  format %{ "vsqrtpd  $dst,$mem\t! sqrt packed4D" %}
+  ins_encode %{
+    int vector_len = 1;
+    __ vsqrtpd($dst$$XMMRegister, $mem$$Address, vector_len);
+  %}
+  ins_pipe( pipe_slow );
+%}
+
+instruct vsqrt8D_reg(vecZ dst, vecZ src) %{
+  predicate(UseAVX > 2 && n->as_Vector()->length() == 8);
+  match(Set dst (SqrtVD src));
+  format %{ "vsqrtpd  $dst,$src\t! sqrt packed8D" %}
+  ins_encode %{
+    int vector_len = 2;
+    __ vsqrtpd($dst$$XMMRegister, $src$$XMMRegister, vector_len);
+  %}
+  ins_pipe( pipe_slow );
+%}
+
+instruct vsqrt8D_mem(vecZ dst, memory mem) %{
+  predicate(UseAVX > 2 && n->as_Vector()->length() == 8);
+  match(Set dst (SqrtVD (LoadVector mem)));
+  format %{ "vsqrtpd  $dst,$mem\t! sqrt packed8D" %}
+  ins_encode %{
+    int vector_len = 2;
+    __ vsqrtpd($dst$$XMMRegister, $mem$$Address, vector_len);
+  %}
+  ins_pipe( pipe_slow );
+%}
+
 // ------------------------------ LeftShift -----------------------------------
 
 // Shorts/Chars vector left shift
--- a/hotspot/src/share/vm/adlc/formssel.cpp	Fri Sep 04 12:47:57 2015 +0200
+++ b/hotspot/src/share/vm/adlc/formssel.cpp	Wed Sep 09 10:34:17 2015 -0700
@@ -4143,6 +4143,7 @@
     "SubVB","SubVS","SubVI","SubVL","SubVF","SubVD",
     "MulVS","MulVI","MulVL","MulVF","MulVD",
     "DivVF","DivVD",
+    "SqrtVD",
     "AndV" ,"XorV" ,"OrV",
     "AddReductionVI", "AddReductionVL",
     "AddReductionVF", "AddReductionVD",
--- a/hotspot/src/share/vm/opto/classes.hpp	Fri Sep 04 12:47:57 2015 +0200
+++ b/hotspot/src/share/vm/opto/classes.hpp	Wed Sep 09 10:34:17 2015 -0700
@@ -290,6 +290,7 @@
 macro(MulReductionVD)
 macro(DivVF)
 macro(DivVD)
+macro(SqrtVD)
 macro(LShiftCntV)
 macro(RShiftCntV)
 macro(LShiftVB)
--- a/hotspot/src/share/vm/opto/superword.cpp	Fri Sep 04 12:47:57 2015 +0200
+++ b/hotspot/src/share/vm/opto/superword.cpp	Wed Sep 09 10:34:17 2015 -0700
@@ -1858,6 +1858,11 @@
           vn = VectorNode::make(opc, in1, in2, vlen, velt_basic_type(n));
           vlen_in_bytes = vn->as_Vector()->length_in_bytes();
         }
+      } else if (opc == Op_SqrtD) {
+        // Promote operand to vector (Sqrt is a 2 address instruction)
+        Node* in = vector_opd(p, 1);
+        vn = VectorNode::make(opc, in, NULL, vlen, velt_basic_type(n));
+        vlen_in_bytes = vn->as_Vector()->length_in_bytes();
       } else {
         ShouldNotReachHere();
       }
--- a/hotspot/src/share/vm/opto/vectornode.cpp	Fri Sep 04 12:47:57 2015 +0200
+++ b/hotspot/src/share/vm/opto/vectornode.cpp	Wed Sep 09 10:34:17 2015 -0700
@@ -92,6 +92,9 @@
   case Op_DivD:
     assert(bt == T_DOUBLE, "must be");
     return Op_DivVD;
+  case Op_SqrtD:
+    assert(bt == T_DOUBLE, "must be");
+    return Op_SqrtVD;
   case Op_LShiftI:
     switch (bt) {
     case T_BOOLEAN:
@@ -277,6 +280,9 @@
   case Op_DivVF: return new DivVFNode(n1, n2, vt);
   case Op_DivVD: return new DivVDNode(n1, n2, vt);
 
+  // Currently only supports double precision sqrt
+  case Op_SqrtVD: return new SqrtVDNode(n1, vt);
+
   case Op_LShiftVB: return new LShiftVBNode(n1, n2, vt);
   case Op_LShiftVS: return new LShiftVSNode(n1, n2, vt);
   case Op_LShiftVI: return new LShiftVINode(n1, n2, vt);
--- a/hotspot/src/share/vm/opto/vectornode.hpp	Fri Sep 04 12:47:57 2015 +0200
+++ b/hotspot/src/share/vm/opto/vectornode.hpp	Wed Sep 09 10:34:17 2015 -0700
@@ -309,6 +309,14 @@
   virtual int Opcode() const;
 };
 
+//------------------------------SqrtVDNode--------------------------------------
+// Vector Sqrt double
+class SqrtVDNode : public VectorNode {
+ public:
+  SqrtVDNode(Node* in, const TypeVect* vt) : VectorNode(in,vt) {}
+  virtual int Opcode() const;
+};
+
 //------------------------------LShiftVBNode-----------------------------------
 // Vector left shift bytes
 class LShiftVBNode : public VectorNode {
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/hotspot/test/compiler/loopopts/superword/SumRedSqrt_Double.java	Wed Sep 09 10:34:17 2015 -0700
@@ -0,0 +1,95 @@
+/*
+ * Copyright (c) 2014, Oracle and/or its affiliates. All rights reserved.
+ * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
+ *
+ * This code is free software; you can redistribute it and/or modify it
+ * under the terms of the GNU General Public License version 2 only, as
+ * published by the Free Software Foundation.
+ *
+ * This code is distributed in the hope that it will be useful, but WITHOUT
+ * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
+ * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
+ * version 2 for more details (a copy is included in the LICENSE file that
+ * accompanied this code).
+ *
+ * You should have received a copy of the GNU General Public License version
+ * 2 along with this work; if not, write to the Free Software Foundation,
+ * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
+ *
+ * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
+ * or visit www.oracle.com if you need additional information or have any
+ * questions.
+ *
+ */
+
+/**
+* @test
+* @summary Add C2 x86 Superword support for scalar sum reduction optimizations : double sqrt test
+*
+* @run main/othervm -XX:+IgnoreUnrecognizedVMOptions -XX:+SuperWordReductions -XX:LoopUnrollLimit=250 -XX:LoopMaxUnroll=2 -XX:CompileThresholdScaling=0.1 SumRedSqrt_Double
+* @run main/othervm -XX:+IgnoreUnrecognizedVMOptions -XX:-SuperWordReductions -XX:LoopUnrollLimit=250 -XX:LoopMaxUnroll=2 -XX:CompileThresholdScaling=0.1 SumRedSqrt_Double
+*
+* @run main/othervm -XX:+IgnoreUnrecognizedVMOptions -XX:+SuperWordReductions -XX:LoopUnrollLimit=250 -XX:LoopMaxUnroll=4 -XX:CompileThresholdScaling=0.1 SumRedSqrt_Double
+* @run main/othervm -XX:+IgnoreUnrecognizedVMOptions -XX:-SuperWordReductions -XX:LoopUnrollLimit=250 -XX:LoopMaxUnroll=4 -XX:CompileThresholdScaling=0.1 SumRedSqrt_Double
+*
+* @run main/othervm -XX:+IgnoreUnrecognizedVMOptions -XX:+SuperWordReductions -XX:LoopUnrollLimit=250 -XX:LoopMaxUnroll=8 -XX:CompileThresholdScaling=0.1 SumRedSqrt_Double
+* @run main/othervm -XX:+IgnoreUnrecognizedVMOptions -XX:-SuperWordReductions -XX:LoopUnrollLimit=250 -XX:LoopMaxUnroll=8 -XX:CompileThresholdScaling=0.1 SumRedSqrt_Double
+*
+* @run main/othervm -XX:+IgnoreUnrecognizedVMOptions -XX:+SuperWordReductions -XX:LoopUnrollLimit=250 -XX:LoopMaxUnroll=16 -XX:CompileThresholdScaling=0.1 SumRedSqrt_Double
+* @run main/othervm -XX:+IgnoreUnrecognizedVMOptions -XX:-SuperWordReductions -XX:LoopUnrollLimit=250 -XX:LoopMaxUnroll=16 -XX:CompileThresholdScaling=0.1 SumRedSqrt_Double
+*/
+
+public class SumRedSqrt_Double
+{
+  public static void main(String[] args) throws Exception {
+    double[] a = new double[256*1024];
+    double[] b = new double[256*1024];
+    double[] c = new double[256*1024];
+    double[] d = new double[256*1024];
+    sumReductionInit(a,b,c);
+    double total = 0;
+    double valid = 2.06157643776E14;
+    for(int j = 0; j < 2000; j++) {
+      total = sumReductionImplement(a,b,c,d,total);
+    }
+    if(total == valid) {
+      System.out.println("Success");
+    } else {
+      System.out.println("Invalid sum of elements variable in total: " + total);
+      System.out.println("Expected value = " + valid);
+      throw new Exception("Failed");
+    }
+  }
+
+  public static void sumReductionInit(
+    double[] a,
+    double[] b,
+    double[] c)
+  {
+    for(int j = 0; j < 1; j++)
+    {
+      for(int i = 0; i < a.length; i++)
+      {
+        a[i] = i * 1 + j;
+        b[i] = i * 1 - j;
+        c[i] = i + j;
+      }
+    }
+  }
+
+  public static double sumReductionImplement(
+    double[] a,
+    double[] b,
+    double[] c,
+    double[] d,
+    double total)
+  {
+    for(int i = 0; i < a.length; i++)
+    {
+      d[i]= Math.sqrt(a[i] * b[i]) + Math.sqrt(a[i] * c[i]) + Math.sqrt(b[i] * c[i]);
+      total += d[i];
+    }
+    return total;
+  }
+
+}