8005419: Improve intrinsics code performance on x86 by using AVX2
authorkvn
Tue, 08 Jan 2013 11:30:51 -0800
changeset 15117 625397df6f4f
parent 15116 af423dcb739c
child 15118 1a1a6d1dfaab
8005419: Improve intrinsics code performance on x86 by using AVX2 Summary: use 256bit vpxor,vptest instructions in String.compareTo() and equals() intrinsics. Reviewed-by: twisti
hotspot/src/cpu/x86/vm/assembler_x86.cpp
hotspot/src/cpu/x86/vm/assembler_x86.hpp
hotspot/src/cpu/x86/vm/macroAssembler_x86.cpp
hotspot/src/cpu/x86/vm/macroAssembler_x86.hpp
hotspot/test/compiler/8005419/Test8005419.java
--- a/hotspot/src/cpu/x86/vm/assembler_x86.cpp	Mon Jan 07 14:08:28 2013 -0800
+++ b/hotspot/src/cpu/x86/vm/assembler_x86.cpp	Tue Jan 08 11:30:51 2013 -0800
@@ -2468,6 +2468,26 @@
   emit_int8((unsigned char)(0xC0 | encode));
 }
 
+void Assembler::vptest(XMMRegister dst, Address src) {
+  assert(VM_Version::supports_avx(), "");
+  InstructionMark im(this);
+  bool vector256 = true;
+  assert(dst != xnoreg, "sanity");
+  int dst_enc = dst->encoding();
+  // swap src<->dst for encoding
+  vex_prefix(src, dst_enc, dst_enc, VEX_SIMD_66, VEX_OPCODE_0F_38, false, vector256);
+  emit_int8(0x17);
+  emit_operand(dst, src);
+}
+
+void Assembler::vptest(XMMRegister dst, XMMRegister src) {
+  assert(VM_Version::supports_avx(), "");
+  bool vector256 = true;
+  int encode = vex_prefix_and_encode(dst, xnoreg, src, VEX_SIMD_66, vector256, VEX_OPCODE_0F_38);
+  emit_int8(0x17);
+  emit_int8((unsigned char)(0xC0 | encode));
+}
+
 void Assembler::punpcklbw(XMMRegister dst, Address src) {
   NOT_LP64(assert(VM_Version::supports_sse2(), ""));
   assert((UseAVX > 0), "SSE mode requires address alignment 16 bytes");
--- a/hotspot/src/cpu/x86/vm/assembler_x86.hpp	Mon Jan 07 14:08:28 2013 -0800
+++ b/hotspot/src/cpu/x86/vm/assembler_x86.hpp	Tue Jan 08 11:30:51 2013 -0800
@@ -1444,9 +1444,12 @@
   // Shift Right by bytes Logical DoubleQuadword Immediate
   void psrldq(XMMRegister dst, int shift);
 
-  // Logical Compare Double Quadword
+  // Logical Compare 128bit
   void ptest(XMMRegister dst, XMMRegister src);
   void ptest(XMMRegister dst, Address src);
+  // Logical Compare 256bit
+  void vptest(XMMRegister dst, XMMRegister src);
+  void vptest(XMMRegister dst, Address src);
 
   // Interleave Low Bytes
   void punpcklbw(XMMRegister dst, XMMRegister src);
--- a/hotspot/src/cpu/x86/vm/macroAssembler_x86.cpp	Mon Jan 07 14:08:28 2013 -0800
+++ b/hotspot/src/cpu/x86/vm/macroAssembler_x86.cpp	Tue Jan 08 11:30:51 2013 -0800
@@ -5675,42 +5675,114 @@
   testl(cnt2, cnt2);
   jcc(Assembler::zero, LENGTH_DIFF_LABEL);
 
-  // Load first characters
+  // Compare first characters
   load_unsigned_short(result, Address(str1, 0));
   load_unsigned_short(cnt1, Address(str2, 0));
-
-  // Compare first characters
   subl(result, cnt1);
   jcc(Assembler::notZero,  POP_LABEL);
-  decrementl(cnt2);
-  jcc(Assembler::zero, LENGTH_DIFF_LABEL);
-
-  {
-    // Check after comparing first character to see if strings are equivalent
-    Label LSkip2;
-    // Check if the strings start at same location
-    cmpptr(str1, str2);
-    jccb(Assembler::notEqual, LSkip2);
-
-    // Check if the length difference is zero (from stack)
-    cmpl(Address(rsp, 0), 0x0);
-    jcc(Assembler::equal,  LENGTH_DIFF_LABEL);
-
-    // Strings might not be equivalent
-    bind(LSkip2);
-  }
+  cmpl(cnt2, 1);
+  jcc(Assembler::equal, LENGTH_DIFF_LABEL);
+
+  // Check if the strings start at the same location.
+  cmpptr(str1, str2);
+  jcc(Assembler::equal, LENGTH_DIFF_LABEL);
 
   Address::ScaleFactor scale = Address::times_2;
   int stride = 8;
 
-  // Advance to next element
-  addptr(str1, 16/stride);
-  addptr(str2, 16/stride);
-
-  if (UseSSE42Intrinsics) {
+  if (UseAVX >= 2) {
+    Label COMPARE_WIDE_VECTORS, VECTOR_NOT_EQUAL, COMPARE_WIDE_TAIL, COMPARE_SMALL_STR;
+    Label COMPARE_WIDE_VECTORS_LOOP, COMPARE_16_CHARS, COMPARE_INDEX_CHAR;
+    Label COMPARE_TAIL_LONG;
+    int pcmpmask = 0x19;
+
+    // Setup to compare 16-chars (32-bytes) vectors,
+    // start from first character again because it has aligned address.
+    int stride2 = 16;
+    int adr_stride  = stride  << scale;
+    int adr_stride2 = stride2 << scale;
+
+    assert(result == rax && cnt2 == rdx && cnt1 == rcx, "pcmpestri");
+    // rax and rdx are used by pcmpestri as elements counters
+    movl(result, cnt2);
+    andl(cnt2, ~(stride2-1));   // cnt2 holds the vector count
+    jcc(Assembler::zero, COMPARE_TAIL_LONG);
+
+    // fast path : compare first 2 8-char vectors.
+    bind(COMPARE_16_CHARS);
+    movdqu(vec1, Address(str1, 0));
+    pcmpestri(vec1, Address(str2, 0), pcmpmask);
+    jccb(Assembler::below, COMPARE_INDEX_CHAR);
+
+    movdqu(vec1, Address(str1, adr_stride));
+    pcmpestri(vec1, Address(str2, adr_stride), pcmpmask);
+    jccb(Assembler::aboveEqual, COMPARE_WIDE_VECTORS);
+    addl(cnt1, stride);
+
+    // Compare the characters at index in cnt1
+    bind(COMPARE_INDEX_CHAR); //cnt1 has the offset of the mismatching character
+    load_unsigned_short(result, Address(str1, cnt1, scale));
+    load_unsigned_short(cnt2, Address(str2, cnt1, scale));
+    subl(result, cnt2);
+    jmp(POP_LABEL);
+
+    // Setup the registers to start vector comparison loop
+    bind(COMPARE_WIDE_VECTORS);
+    lea(str1, Address(str1, result, scale));
+    lea(str2, Address(str2, result, scale));
+    subl(result, stride2);
+    subl(cnt2, stride2);
+    jccb(Assembler::zero, COMPARE_WIDE_TAIL);
+    negptr(result);
+
+    //  In a loop, compare 16-chars (32-bytes) at once using (vpxor+vptest)
+    bind(COMPARE_WIDE_VECTORS_LOOP);
+    vmovdqu(vec1, Address(str1, result, scale));
+    vpxor(vec1, Address(str2, result, scale));
+    vptest(vec1, vec1);
+    jccb(Assembler::notZero, VECTOR_NOT_EQUAL);
+    addptr(result, stride2);
+    subl(cnt2, stride2);
+    jccb(Assembler::notZero, COMPARE_WIDE_VECTORS_LOOP);
+
+    // compare wide vectors tail
+    bind(COMPARE_WIDE_TAIL);
+    testptr(result, result);
+    jccb(Assembler::zero, LENGTH_DIFF_LABEL);
+
+    movl(result, stride2);
+    movl(cnt2, result);
+    negptr(result);
+    jmpb(COMPARE_WIDE_VECTORS_LOOP);
+
+    // Identifies the mismatching (higher or lower)16-bytes in the 32-byte vectors.
+    bind(VECTOR_NOT_EQUAL);
+    lea(str1, Address(str1, result, scale));
+    lea(str2, Address(str2, result, scale));
+    jmp(COMPARE_16_CHARS);
+
+    // Compare tail chars, length between 1 to 15 chars
+    bind(COMPARE_TAIL_LONG);
+    movl(cnt2, result);
+    cmpl(cnt2, stride);
+    jccb(Assembler::less, COMPARE_SMALL_STR);
+
+    movdqu(vec1, Address(str1, 0));
+    pcmpestri(vec1, Address(str2, 0), pcmpmask);
+    jcc(Assembler::below, COMPARE_INDEX_CHAR);
+    subptr(cnt2, stride);
+    jccb(Assembler::zero, LENGTH_DIFF_LABEL);
+    lea(str1, Address(str1, result, scale));
+    lea(str2, Address(str2, result, scale));
+    negptr(cnt2);
+    jmpb(WHILE_HEAD_LABEL);
+
+    bind(COMPARE_SMALL_STR);
+  } else if (UseSSE42Intrinsics) {
     Label COMPARE_WIDE_VECTORS, VECTOR_NOT_EQUAL, COMPARE_TAIL;
     int pcmpmask = 0x19;
-    // Setup to compare 16-byte vectors
+    // Setup to compare 8-char (16-byte) vectors,
+    // start from first character again because it has aligned address.
     movl(result, cnt2);
     andl(cnt2, ~(stride - 1));   // cnt2 holds the vector count
     jccb(Assembler::zero, COMPARE_TAIL);
@@ -5742,7 +5814,7 @@
     jccb(Assembler::notZero, COMPARE_WIDE_VECTORS);
 
     // compare wide vectors tail
-    testl(result, result);
+    testptr(result, result);
     jccb(Assembler::zero, LENGTH_DIFF_LABEL);
 
     movl(cnt2, stride);
@@ -5754,21 +5826,20 @@
 
     // Mismatched characters in the vectors
     bind(VECTOR_NOT_EQUAL);
-    addptr(result, cnt1);
-    movptr(cnt2, result);
-    load_unsigned_short(result, Address(str1, cnt2, scale));
-    load_unsigned_short(cnt1, Address(str2, cnt2, scale));
-    subl(result, cnt1);
+    addptr(cnt1, result);
+    load_unsigned_short(result, Address(str1, cnt1, scale));
+    load_unsigned_short(cnt2, Address(str2, cnt1, scale));
+    subl(result, cnt2);
     jmpb(POP_LABEL);
 
     bind(COMPARE_TAIL); // limit is zero
     movl(cnt2, result);
     // Fallthru to tail compare
   }
-
   // Shift str2 and str1 to the end of the arrays, negate min
-  lea(str1, Address(str1, cnt2, scale, 0));
-  lea(str2, Address(str2, cnt2, scale, 0));
+  lea(str1, Address(str1, cnt2, scale));
+  lea(str2, Address(str2, cnt2, scale));
+  decrementl(cnt2);  // first character was compared already
   negptr(cnt2);
 
   // Compare the rest of the elements
@@ -5833,7 +5904,44 @@
   shll(limit, 1);      // byte count != 0
   movl(result, limit); // copy
 
-  if (UseSSE42Intrinsics) {
+  if (UseAVX >= 2) {
+    // With AVX2, use 32-byte vector compare
+    Label COMPARE_WIDE_VECTORS, COMPARE_TAIL;
+
+    // Compare 32-byte vectors
+    andl(result, 0x0000001e);  //   tail count (in bytes)
+    andl(limit, 0xffffffe0);   // vector count (in bytes)
+    jccb(Assembler::zero, COMPARE_TAIL);
+
+    lea(ary1, Address(ary1, limit, Address::times_1));
+    lea(ary2, Address(ary2, limit, Address::times_1));
+    negptr(limit);
+
+    bind(COMPARE_WIDE_VECTORS);
+    vmovdqu(vec1, Address(ary1, limit, Address::times_1));
+    vmovdqu(vec2, Address(ary2, limit, Address::times_1));
+    vpxor(vec1, vec2);
+
+    vptest(vec1, vec1);
+    jccb(Assembler::notZero, FALSE_LABEL);
+    addptr(limit, 32);
+    jcc(Assembler::notZero, COMPARE_WIDE_VECTORS);
+
+    testl(result, result);
+    jccb(Assembler::zero, TRUE_LABEL);
+
+    vmovdqu(vec1, Address(ary1, result, Address::times_1, -32));
+    vmovdqu(vec2, Address(ary2, result, Address::times_1, -32));
+    vpxor(vec1, vec2);
+
+    vptest(vec1, vec1);
+    jccb(Assembler::notZero, FALSE_LABEL);
+    jmpb(TRUE_LABEL);
+
+    bind(COMPARE_TAIL); // limit is zero
+    movl(limit, result);
+    // Fallthru to tail compare
+  } else if (UseSSE42Intrinsics) {
     // With SSE4.2, use double quad vector compare
     Label COMPARE_WIDE_VECTORS, COMPARE_TAIL;
 
--- a/hotspot/src/cpu/x86/vm/macroAssembler_x86.hpp	Mon Jan 07 14:08:28 2013 -0800
+++ b/hotspot/src/cpu/x86/vm/macroAssembler_x86.hpp	Tue Jan 08 11:30:51 2013 -0800
@@ -1011,6 +1011,10 @@
       Assembler::vxorpd(dst, nds, src, vector256);
   }
 
+  // Simple version for AVX2 256bit vectors
+  void vpxor(XMMRegister dst, XMMRegister src) { Assembler::vpxor(dst, dst, src, true); }
+  void vpxor(XMMRegister dst, Address src) { Assembler::vpxor(dst, dst, src, true); }
+
   // Move packed integer values from low 128 bit to hign 128 bit in 256 bit vector.
   void vinserti128h(XMMRegister dst, XMMRegister nds, XMMRegister src) {
     if (UseAVX > 1) // vinserti128h is available only in AVX2
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/hotspot/test/compiler/8005419/Test8005419.java	Tue Jan 08 11:30:51 2013 -0800
@@ -0,0 +1,120 @@
+/*
+ * Copyright (c) 2012, Oracle and/or its affiliates. All rights reserved.
+ * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
+ *
+ * This code is free software; you can redistribute it and/or modify it
+ * under the terms of the GNU General Public License version 2 only, as
+ * published by the Free Software Foundation.
+ *
+ * This code is distributed in the hope that it will be useful, but WITHOUT
+ * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
+ * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
+ * version 2 for more details (a copy is included in the LICENSE file that
+ * accompanied this code).
+ *
+ * You should have received a copy of the GNU General Public License version
+ * 2 along with this work; if not, write to the Free Software Foundation,
+ * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
+ *
+ * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
+ * or visit www.oracle.com if you need additional information or have any
+ * questions.
+ */
+
+/*
+ * @test
+ * @bug 8005419
+ * @summary Improve intrinsics code performance on x86 by using AVX2
+ * @run main/othervm -Xbatch -Xmx64m Test8005419
+ *
+ */
+
+public class Test8005419 {
+    public static int SIZE = 64;
+
+    public static void main(String[] args) {
+        char[] a = new char[SIZE];
+        char[] b = new char[SIZE];
+
+        for (int i = 16; i < SIZE; i++) {
+          a[i] = (char)i;
+          b[i] = (char)i;
+        }
+        String s1 = new String(a);
+        String s2 = new String(b);
+
+        // Warm up
+        boolean failed = false;
+        int result = 0;
+        for (int i = 0; i < 10000; i++) {
+          result += test(s1, s2);
+        }
+        for (int i = 0; i < 10000; i++) {
+          result += test(s1, s2);
+        }
+        for (int i = 0; i < 10000; i++) {
+          result += test(s1, s2);
+        }
+        if (result != 0) failed = true;
+
+        System.out.println("Start testing");
+        // Compare same string
+        result = test(s1, s1);
+        if (result != 0) {
+          failed = true;
+          System.out.println("Failed same: result = " + result + ", expected 0");
+        }
+        // Compare equal strings
+        for (int i = 1; i <= SIZE; i++) {
+          s1 = new String(a, 0, i);
+          s2 = new String(b, 0, i);
+          result = test(s1, s2);
+          if (result != 0) {
+            failed = true;
+            System.out.println("Failed equals s1[" + i + "], s2[" + i + "]: result = " + result + ", expected 0");
+          }
+        }
+        // Compare equal strings but different sizes
+        for (int i = 1; i <= SIZE; i++) {
+          s1 = new String(a, 0, i);
+          for (int j = 1; j <= SIZE; j++) {
+            s2 = new String(b, 0, j);
+            result = test(s1, s2);
+            if (result != (i-j)) {
+              failed = true;
+              System.out.println("Failed diff size s1[" + i + "], s2[" + j + "]: result = " + result + ", expected " + (i-j));
+            }
+          }
+        }
+        // Compare strings with one char different and different sizes
+        for (int i = 1; i <= SIZE; i++) {
+          s1 = new String(a, 0, i);
+          for (int j = 0; j < i; j++) {
+            b[j] -= 3; // change char
+            s2 = new String(b, 0, i);
+            result = test(s1, s2);
+            int chdiff = a[j] - b[j];
+            if (result != chdiff) {
+              failed = true;
+              System.out.println("Failed diff char s1[" + i + "], s2[" + i + "]: result = " + result + ", expected " + chdiff);
+            }
+            result = test(s2, s1);
+            chdiff = b[j] - a[j];
+            if (result != chdiff) {
+              failed = true;
+              System.out.println("Failed diff char s2[" + i + "], s1[" + i + "]: result = " + result + ", expected " + chdiff);
+            }
+            b[j] += 3; // restore
+          }
+        }
+        if (failed) {
+          System.out.println("FAILED");
+          System.exit(97);
+        }
+        System.out.println("PASSED");
+    }
+
+    private static int test(String str1, String str2) {
+        return str1.compareTo(str2);
+    }
+}