8173470: [C2] Mask shift operands in ideal graph.
authorgoetz
Fri, 27 Jan 2017 11:19:52 +0100
changeset 46277 73607b4788cb
parent 46274 534d019edb92
child 46278 6db7a35e08ae
8173470: [C2] Mask shift operands in ideal graph. Reviewed-by: lucy, kvn
hotspot/src/cpu/s390/vm/s390.ad
hotspot/src/share/vm/opto/mulnode.cpp
--- a/hotspot/src/cpu/s390/vm/s390.ad	Fri Feb 17 14:47:46 2017 -0500
+++ b/hotspot/src/cpu/s390/vm/s390.ad	Fri Jan 27 11:19:52 2017 +0100
@@ -6770,6 +6770,7 @@
   format %{ "SLL     $dst,$src,$nbits\t# use RISC-like SLLG also for int" %}
   ins_encode %{
     int Nbit = $nbits$$constant;
+    assert((Nbit & (BitsPerJavaInteger - 1)) == Nbit, "Check shift mask in ideal graph");
     __ z_sllg($dst$$Register, $src$$Register, Nbit & (BitsPerJavaInteger - 1), Z_R0);
   %}
   ins_pipe(pipe_class_dummy);
@@ -6843,6 +6844,7 @@
   format %{ "SRA     $dst,$src" %}
   ins_encode %{
     int Nbit = $src$$constant;
+    assert((Nbit & (BitsPerJavaInteger - 1)) == Nbit, "Check shift mask in ideal graph");
     __ z_sra($dst$$Register, Nbit & (BitsPerJavaInteger - 1), Z_R0);
   %}
   ins_pipe(pipe_class_dummy);
@@ -6895,6 +6897,7 @@
   format %{ "SRL     $dst,$src" %}
   ins_encode %{
     int Nbit = $src$$constant;
+    assert((Nbit & (BitsPerJavaInteger - 1)) == Nbit, "Check shift mask in ideal graph");
     __ z_srl($dst$$Register, Nbit & (BitsPerJavaInteger - 1), Z_R0);
   %}
   ins_pipe(pipe_class_dummy);
--- a/hotspot/src/share/vm/opto/mulnode.cpp	Fri Feb 17 14:47:46 2017 -0500
+++ b/hotspot/src/share/vm/opto/mulnode.cpp	Fri Jan 27 11:19:52 2017 +0100
@@ -630,23 +630,42 @@
 }
 
 //=============================================================================
+
+static int getShiftCon(PhaseGVN *phase, Node *shiftNode, int retVal) {
+  const Type *t = phase->type(shiftNode->in(2));
+  if (t == Type::TOP) return retVal;       // Right input is dead.
+  const TypeInt *t2 = t->isa_int();
+  if (!t2 || !t2->is_con()) return retVal; // Right input is a constant.
+
+  return t2->get_con();
+}
+
+static int maskShiftAmount(PhaseGVN *phase, Node *shiftNode, int nBits) {
+  int       shift = getShiftCon(phase, shiftNode, 0);
+  int maskedShift = shift & (nBits - 1);
+
+  if (maskedShift == 0) return 0;         // Let Identity() handle 0 shift count.
+
+  if (shift != maskedShift) {
+    shiftNode->set_req(2, phase->intcon(maskedShift)); // Replace shift count with masked value.
+  }
+
+  return maskedShift;
+}
+
 //------------------------------Identity---------------------------------------
 Node* LShiftINode::Identity(PhaseGVN* phase) {
-  const TypeInt *ti = phase->type( in(2) )->isa_int();  // shift count is an int
-  return ( ti && ti->is_con() && ( ti->get_con() & ( BitsPerInt - 1 ) ) == 0 ) ? in(1) : this;
+  return ((getShiftCon(phase, this, -1) & (BitsPerJavaInteger - 1)) == 0) ? in(1) : this;
 }
 
 //------------------------------Ideal------------------------------------------
 // If the right input is a constant, and the left input is an add of a
 // constant, flatten the tree: (X+con1)<<con0 ==> X<<con0 + con1<<con0
 Node *LShiftINode::Ideal(PhaseGVN *phase, bool can_reshape) {
-  const Type *t  = phase->type( in(2) );
-  if( t == Type::TOP ) return NULL;       // Right input is dead
-  const TypeInt *t2 = t->isa_int();
-  if( !t2 || !t2->is_con() ) return NULL; // Right input is a constant
-  const int con = t2->get_con() & ( BitsPerInt - 1 );  // masked shift count
-
-  if ( con == 0 )  return NULL; // let Identity() handle 0 shift count
+  int con = maskShiftAmount(phase, this, BitsPerJavaInteger);
+  if (con == 0) {
+    return NULL;
+  }
 
   // Left input is an add of a constant?
   Node *add1 = in(1);
@@ -744,21 +763,17 @@
 //=============================================================================
 //------------------------------Identity---------------------------------------
 Node* LShiftLNode::Identity(PhaseGVN* phase) {
-  const TypeInt *ti = phase->type( in(2) )->isa_int(); // shift count is an int
-  return ( ti && ti->is_con() && ( ti->get_con() & ( BitsPerLong - 1 ) ) == 0 ) ? in(1) : this;
+  return ((getShiftCon(phase, this, -1) & (BitsPerJavaLong - 1)) == 0) ? in(1) : this;
 }
 
 //------------------------------Ideal------------------------------------------
 // If the right input is a constant, and the left input is an add of a
 // constant, flatten the tree: (X+con1)<<con0 ==> X<<con0 + con1<<con0
 Node *LShiftLNode::Ideal(PhaseGVN *phase, bool can_reshape) {
-  const Type *t  = phase->type( in(2) );
-  if( t == Type::TOP ) return NULL;       // Right input is dead
-  const TypeInt *t2 = t->isa_int();
-  if( !t2 || !t2->is_con() ) return NULL; // Right input is a constant
-  const int con = t2->get_con() & ( BitsPerLong - 1 );  // masked shift count
-
-  if ( con == 0 ) return NULL;  // let Identity() handle 0 shift count
+  int con = maskShiftAmount(phase, this, BitsPerJavaLong);
+  if (con == 0) {
+    return NULL;
+  }
 
   // Left input is an add of a constant?
   Node *add1 = in(1);
@@ -853,26 +868,24 @@
 //=============================================================================
 //------------------------------Identity---------------------------------------
 Node* RShiftINode::Identity(PhaseGVN* phase) {
-  const TypeInt *t2 = phase->type(in(2))->isa_int();
-  if( !t2 ) return this;
-  if ( t2->is_con() && ( t2->get_con() & ( BitsPerInt - 1 ) ) == 0 )
-    return in(1);
+  int shift = getShiftCon(phase, this, -1);
+  if (shift == -1) return this;
+  if ((shift & (BitsPerJavaInteger - 1)) == 0) return in(1);
 
   // Check for useless sign-masking
-  if( in(1)->Opcode() == Op_LShiftI &&
+  if (in(1)->Opcode() == Op_LShiftI &&
       in(1)->req() == 3 &&
-      in(1)->in(2) == in(2) &&
-      t2->is_con() ) {
-    uint shift = t2->get_con();
+      in(1)->in(2) == in(2)) {
     shift &= BitsPerJavaInteger-1; // semantics of Java shifts
     // Compute masks for which this shifting doesn't change
-    int lo = (-1 << (BitsPerJavaInteger - shift-1)); // FFFF8000
+    int lo = (-1 << (BitsPerJavaInteger - ((uint)shift)-1)); // FFFF8000
     int hi = ~lo;               // 00007FFF
     const TypeInt *t11 = phase->type(in(1)->in(1))->isa_int();
-    if( !t11 ) return this;
+    if (!t11) return this;
     // Does actual value fit inside of mask?
-    if( lo <= t11->_lo && t11->_hi <= hi )
+    if (lo <= t11->_lo && t11->_hi <= hi) {
       return in(1)->in(1);      // Then shifting is a nop
+    }
   }
 
   return this;
@@ -881,15 +894,13 @@
 //------------------------------Ideal------------------------------------------
 Node *RShiftINode::Ideal(PhaseGVN *phase, bool can_reshape) {
   // Inputs may be TOP if they are dead.
-  const TypeInt *t1 = phase->type( in(1) )->isa_int();
-  if( !t1 ) return NULL;        // Left input is an integer
-  const TypeInt *t2 = phase->type( in(2) )->isa_int();
-  if( !t2 || !t2->is_con() ) return NULL; // Right input is a constant
+  const TypeInt *t1 = phase->type(in(1))->isa_int();
+  if (!t1) return NULL;        // Left input is an integer
   const TypeInt *t3;  // type of in(1).in(2)
-  int shift = t2->get_con();
-  shift &= BitsPerJavaInteger-1;  // semantics of Java shifts
-
-  if ( shift == 0 ) return NULL;  // let Identity() handle 0 shift count
+  int shift = maskShiftAmount(phase, this, BitsPerJavaInteger);
+  if (shift == 0) {
+    return NULL;
+  }
 
   // Check for (x & 0xFF000000) >> 24, whose mask can be made smaller.
   // Such expressions arise normally from shift chains like (byte)(x >> 24).
@@ -1003,8 +1014,8 @@
 //=============================================================================
 //------------------------------Identity---------------------------------------
 Node* RShiftLNode::Identity(PhaseGVN* phase) {
-  const TypeInt *ti = phase->type( in(2) )->isa_int(); // shift count is an int
-  return ( ti && ti->is_con() && ( ti->get_con() & ( BitsPerLong - 1 ) ) == 0 ) ? in(1) : this;
+  const TypeInt *ti = phase->type(in(2))->isa_int(); // Shift count is an int.
+  return (ti && ti->is_con() && (ti->get_con() & (BitsPerJavaLong - 1)) == 0) ? in(1) : this;
 }
 
 //------------------------------Value------------------------------------------
@@ -1061,25 +1072,25 @@
 //=============================================================================
 //------------------------------Identity---------------------------------------
 Node* URShiftINode::Identity(PhaseGVN* phase) {
-  const TypeInt *ti = phase->type( in(2) )->isa_int();
-  if ( ti && ti->is_con() && ( ti->get_con() & ( BitsPerInt - 1 ) ) == 0 ) return in(1);
+  int shift = getShiftCon(phase, this, -1);
+  if ((shift & (BitsPerJavaInteger - 1)) == 0) return in(1);
 
   // Check for "((x << LogBytesPerWord) + (wordSize-1)) >> LogBytesPerWord" which is just "x".
   // Happens during new-array length computation.
   // Safe if 'x' is in the range [0..(max_int>>LogBytesPerWord)]
   Node *add = in(1);
-  if( add->Opcode() == Op_AddI ) {
-    const TypeInt *t2  = phase->type(add->in(2))->isa_int();
-    if( t2 && t2->is_con(wordSize - 1) &&
-        add->in(1)->Opcode() == Op_LShiftI ) {
-      // Check that shift_counts are LogBytesPerWord
+  if (add->Opcode() == Op_AddI) {
+    const TypeInt *t2 = phase->type(add->in(2))->isa_int();
+    if (t2 && t2->is_con(wordSize - 1) &&
+        add->in(1)->Opcode() == Op_LShiftI) {
+      // Check that shift_counts are LogBytesPerWord.
       Node          *lshift_count   = add->in(1)->in(2);
       const TypeInt *t_lshift_count = phase->type(lshift_count)->isa_int();
-      if( t_lshift_count && t_lshift_count->is_con(LogBytesPerWord) &&
-          t_lshift_count == phase->type(in(2)) ) {
+      if (t_lshift_count && t_lshift_count->is_con(LogBytesPerWord) &&
+          t_lshift_count == phase->type(in(2))) {
         Node          *x   = add->in(1)->in(1);
         const TypeInt *t_x = phase->type(x)->isa_int();
-        if( t_x != NULL && 0 <= t_x->_lo && t_x->_hi <= (max_jint>>LogBytesPerWord) ) {
+        if (t_x != NULL && 0 <= t_x->_lo && t_x->_hi <= (max_jint>>LogBytesPerWord)) {
           return x;
         }
       }
@@ -1091,10 +1102,11 @@
 
 //------------------------------Ideal------------------------------------------
 Node *URShiftINode::Ideal(PhaseGVN *phase, bool can_reshape) {
-  const TypeInt *t2 = phase->type( in(2) )->isa_int();
-  if( !t2 || !t2->is_con() ) return NULL; // Right input is a constant
-  const int con = t2->get_con() & 31; // Shift count is always masked
-  if ( con == 0 ) return NULL;  // let Identity() handle a 0 shift count
+  int con = maskShiftAmount(phase, this, BitsPerJavaInteger);
+  if (con == 0) {
+    return NULL;
+  }
+
   // We'll be wanting the right-shift amount as a mask of that many bits
   const int mask = right_n_bits(BitsPerJavaInteger - con);
 
@@ -1117,7 +1129,8 @@
   // If Q is "X << z" the rounding is useless.  Look for patterns like
   // ((X<<Z) + Y) >>> Z  and replace with (X + Y>>>Z) & Z-mask.
   Node *add = in(1);
-  if( in1_op == Op_AddI ) {
+  const TypeInt *t2 = phase->type(in(2))->isa_int();
+  if (in1_op == Op_AddI) {
     Node *lshl = add->in(1);
     if( lshl->Opcode() == Op_LShiftI &&
         phase->type(lshl->in(2)) == t2 ) {
@@ -1231,17 +1244,16 @@
 //=============================================================================
 //------------------------------Identity---------------------------------------
 Node* URShiftLNode::Identity(PhaseGVN* phase) {
-  const TypeInt *ti = phase->type( in(2) )->isa_int(); // shift count is an int
-  return ( ti && ti->is_con() && ( ti->get_con() & ( BitsPerLong - 1 ) ) == 0 ) ? in(1) : this;
+  return ((getShiftCon(phase, this, -1) & (BitsPerJavaLong - 1)) == 0) ? in(1) : this;
 }
 
 //------------------------------Ideal------------------------------------------
 Node *URShiftLNode::Ideal(PhaseGVN *phase, bool can_reshape) {
-  const TypeInt *t2 = phase->type( in(2) )->isa_int();
-  if( !t2 || !t2->is_con() ) return NULL; // Right input is a constant
-  const int con = t2->get_con() & ( BitsPerLong - 1 ); // Shift count is always masked
-  if ( con == 0 ) return NULL;  // let Identity() handle a 0 shift count
-                              // note: mask computation below does not work for 0 shift count
+  int con = maskShiftAmount(phase, this, BitsPerJavaLong);
+  if (con == 0) {
+    return NULL;
+  }
+
   // We'll be wanting the right-shift amount as a mask of that many bits
   const jlong mask = jlong(max_julong >> con);
 
@@ -1250,7 +1262,8 @@
   // If Q is "X << z" the rounding is useless.  Look for patterns like
   // ((X<<Z) + Y) >>> Z  and replace with (X + Y>>>Z) & Z-mask.
   Node *add = in(1);
-  if( add->Opcode() == Op_AddL ) {
+  const TypeInt *t2 = phase->type(in(2))->isa_int();
+  if (add->Opcode() == Op_AddL) {
     Node *lshl = add->in(1);
     if( lshl->Opcode() == Op_LShiftL &&
         phase->type(lshl->in(2)) == t2 ) {