8161608: StampedLock should use storeStoreFence when acquiring write lock
authordl
Tue, 26 Jul 2016 10:02:05 -0700
changeset 39780 18618975fbb6
parent 39779 4666307d3155
child 39781 8190c004acbd
8161608: StampedLock should use storeStoreFence when acquiring write lock Reviewed-by: martin, psandoz, plevart
jdk/src/java.base/share/classes/java/util/concurrent/locks/StampedLock.java
jdk/test/java/util/concurrent/tck/StampedLockTest.java
--- a/jdk/src/java.base/share/classes/java/util/concurrent/locks/StampedLock.java	Tue Jul 26 09:57:51 2016 -0700
+++ b/jdk/src/java.base/share/classes/java/util/concurrent/locks/StampedLock.java	Tue Jul 26 10:02:05 2016 -0700
@@ -256,8 +256,12 @@
      * method validate()) requires stricter ordering rules than apply
      * to normal volatile reads (of "state").  To force orderings of
      * reads before a validation and the validation itself in those
-     * cases where this is not already forced, we use
-     * VarHandle.acquireFence.
+     * cases where this is not already forced, we use acquireFence.
+     * Unlike in that paper, we allow writers to use plain writes.
+     * One would not expect reorderings of such writes with the lock
+     * acquisition CAS because there is a "control dependency", but it
+     * is theoretically possible, so we additionally add a
+     * storeStoreFence after lock acquisition CAS.
      *
      * The memory layout keeps lock state and queue pointers together
      * (normally on the same cache line). This usually works well for
@@ -355,6 +359,20 @@
         state = ORIGIN;
     }
 
+    private boolean casState(long expectedValue, long newValue) {
+        return STATE.compareAndSet(this, expectedValue, newValue);
+    }
+
+    private long tryWriteLock(long s) {
+        // assert (s & ABITS) == 0L;
+        long next;
+        if (casState(s, next = s | WBIT)) {
+            VarHandle.storeStoreFence();
+            return next;
+        }
+        return 0L;
+    }
+
     /**
      * Exclusively acquires the lock, blocking if necessary
      * until available.
@@ -363,10 +381,8 @@
      */
     @ReservedStackAccess
     public long writeLock() {
-        long s, next;  // bypass acquireWrite in fully unlocked case only
-        return ((((s = state) & ABITS) == 0L &&
-                 STATE.compareAndSet(this, s, next = s + WBIT)) ?
-                next : acquireWrite(false, 0L));
+        long next;
+        return ((next = tryWriteLock()) != 0L) ? next : acquireWrite(false, 0L);
     }
 
     /**
@@ -377,10 +393,8 @@
      */
     @ReservedStackAccess
     public long tryWriteLock() {
-        long s, next;
-        return ((((s = state) & ABITS) == 0L &&
-                 STATE.compareAndSet(this, s, next = s + WBIT)) ?
-                next : 0L);
+        long s;
+        return (((s = state) & ABITS) == 0L) ? tryWriteLock(s) : 0L;
     }
 
     /**
@@ -440,10 +454,13 @@
      */
     @ReservedStackAccess
     public long readLock() {
-        long s = state, next;  // bypass acquireRead on common uncontended case
-        return ((whead == wtail && (s & ABITS) < RFULL &&
-                 STATE.compareAndSet(this, s, next = s + RUNIT)) ?
-                next : acquireRead(false, 0L));
+        long s, next;
+        // bypass acquireRead on common uncontended case
+        return (whead == wtail
+                && ((s = state) & ABITS) < RFULL
+                && casState(s, next = s + RUNIT))
+            ? next
+            : acquireRead(false, 0L);
     }
 
     /**
@@ -457,7 +474,7 @@
         long s, m, next;
         while ((m = (s = state) & ABITS) != WBIT) {
             if (m < RFULL) {
-                if (STATE.compareAndSet(this, s, next = s + RUNIT))
+                if (casState(s, next = s + RUNIT))
                     return next;
             }
             else if ((next = tryIncReaderOverflow(s)) != 0L)
@@ -487,7 +504,7 @@
         if (!Thread.interrupted()) {
             if ((m = (s = state) & ABITS) != WBIT) {
                 if (m < RFULL) {
-                    if (STATE.compareAndSet(this, s, next = s + RUNIT))
+                    if (casState(s, next = s + RUNIT))
                         return next;
                 }
                 else if ((next = tryIncReaderOverflow(s)) != 0L)
@@ -514,10 +531,15 @@
      * before acquiring the lock
      */
     @ReservedStackAccess
-    public long readLockInterruptibly() throws InterruptedException {
-        long next;
-        if (!Thread.interrupted() &&
-            (next = acquireRead(true, 0L)) != INTERRUPTED)
+        public long readLockInterruptibly() throws InterruptedException {
+        long s, next;
+        if (!Thread.interrupted()
+            // bypass acquireRead on common uncontended case
+            && ((whead == wtail
+                 && ((s = state) & ABITS) < RFULL
+                 && casState(s, next = s + RUNIT))
+                ||
+                (next = acquireRead(true, 0L)) != INTERRUPTED))
             return next;
         throw new InterruptedException();
     }
@@ -598,7 +620,7 @@
                && (stamp & RBITS) > 0L
                && ((m = s & RBITS) > 0L)) {
             if (m < RFULL) {
-                if (STATE.compareAndSet(this, s, s - RUNIT)) {
+                if (casState(s, s - RUNIT)) {
                     if (m == RUNIT && (h = whead) != null && h.status != 0)
                         release(h);
                     return;
@@ -620,7 +642,7 @@
      */
     @ReservedStackAccess
     public void unlock(long stamp) {
-        if ((stamp & WBIT) != 0)
+        if ((stamp & WBIT) != 0L)
             unlockWrite(stamp);
         else
             unlockRead(stamp);
@@ -644,7 +666,7 @@
             if ((m = s & ABITS) == 0L) {
                 if (a != 0L)
                     break;
-                if (STATE.compareAndSet(this, s, next = s + WBIT))
+                if ((next = tryWriteLock(s)) != 0L)
                     return next;
             }
             else if (m == WBIT) {
@@ -653,8 +675,10 @@
                 return stamp;
             }
             else if (m == RUNIT && a != 0L) {
-                if (STATE.compareAndSet(this, s, next = s - RUNIT + WBIT))
+                if (casState(s, next = s - RUNIT + WBIT)) {
+                    VarHandle.storeStoreFence();
                     return next;
+                }
             }
             else
                 break;
@@ -688,7 +712,7 @@
             else if (a == 0L) {
                 // optimistic read stamp
                 if ((s & ABITS) < RFULL) {
-                    if (STATE.compareAndSet(this, s, next = s + RUNIT))
+                    if (casState(s, next = s + RUNIT))
                         return next;
                 }
                 else if ((next = tryIncReaderOverflow(s)) != 0L)
@@ -730,7 +754,7 @@
             else if ((m = s & ABITS) == 0L) // invalid read stamp
                 break;
             else if (m < RFULL) {
-                if (STATE.compareAndSet(this, s, next = s - RUNIT)) {
+                if (casState(s, next = s - RUNIT)) {
                     if (m == RUNIT && (h = whead) != null && h.status != 0)
                         release(h);
                     return next & SBITS;
@@ -771,7 +795,7 @@
         long s, m; WNode h;
         while ((m = (s = state) & ABITS) != 0L && m < WBIT) {
             if (m < RFULL) {
-                if (STATE.compareAndSet(this, s, s - RUNIT)) {
+                if (casState(s, s - RUNIT)) {
                     if (m == RUNIT && (h = whead) != null && h.status != 0)
                         release(h);
                     return true;
@@ -940,7 +964,7 @@
         long s, m; WNode h;
         while ((m = (s = state) & RBITS) > 0L) {
             if (m < RFULL) {
-                if (STATE.compareAndSet(this, s, s - RUNIT)) {
+                if (casState(s, s - RUNIT)) {
                     if (m == RUNIT && (h = whead) != null && h.status != 0)
                         release(h);
                     return;
@@ -971,7 +995,7 @@
     private long tryIncReaderOverflow(long s) {
         // assert (s & ABITS) >= RFULL;
         if ((s & ABITS) == RFULL) {
-            if (STATE.compareAndSet(this, s, s | RBITS)) {
+            if (casState(s, s | RBITS)) {
                 ++readerOverflow;
                 STATE.setVolatile(this, s);
                 return s;
@@ -993,7 +1017,7 @@
     private long tryDecReaderOverflow(long s) {
         // assert (s & ABITS) >= RFULL;
         if ((s & ABITS) == RFULL) {
-            if (STATE.compareAndSet(this, s, s | RBITS)) {
+            if (casState(s, s | RBITS)) {
                 int r; long next;
                 if ((r = readerOverflow) > 0) {
                     readerOverflow = r - 1;
@@ -1047,7 +1071,7 @@
         for (int spins = -1;;) { // spin while enqueuing
             long m, s, ns;
             if ((m = (s = state) & ABITS) == 0L) {
-                if (STATE.compareAndSet(this, s, ns = s + WBIT))
+                if ((ns = tryWriteLock(s)) != 0L)
                     return ns;
             }
             else if (spins < 0)
@@ -1082,7 +1106,7 @@
                 for (int k = spins; k > 0; --k) { // spin at head
                     long s, ns;
                     if (((s = state) & ABITS) == 0L) {
-                        if (STATE.compareAndSet(this, s, ns = s + WBIT)) {
+                        if ((ns = tryWriteLock(s)) != 0L) {
                             whead = node;
                             node.prev = null;
                             if (wasInterrupted)
@@ -1158,7 +1182,7 @@
             if ((h = whead) == (p = wtail)) {
                 for (long m, s, ns;;) {
                     if ((m = (s = state) & ABITS) < RFULL ?
-                        STATE.compareAndSet(this, s, ns = s + RUNIT) :
+                        casState(s, ns = s + RUNIT) :
                         (m < WBIT && (ns = tryIncReaderOverflow(s)) != 0L)) {
                         if (wasInterrupted)
                             Thread.currentThread().interrupt();
@@ -1208,7 +1232,7 @@
                         long m, s, ns;
                         do {
                             if ((m = (s = state) & ABITS) < RFULL ?
-                                STATE.compareAndSet(this, s, ns = s + RUNIT) :
+                                casState(s, ns = s + RUNIT) :
                                 (m < WBIT &&
                                  (ns = tryIncReaderOverflow(s)) != 0L)) {
                                 if (wasInterrupted)
@@ -1260,7 +1284,7 @@
                 for (int k = spins;;) { // spin at head
                     long m, s, ns;
                     if ((m = (s = state) & ABITS) < RFULL ?
-                        STATE.compareAndSet(this, s, ns = s + RUNIT) :
+                        casState(s, ns = s + RUNIT) :
                         (m < WBIT && (ns = tryIncReaderOverflow(s)) != 0L)) {
                         WNode c; Thread w;
                         whead = node;
--- a/jdk/test/java/util/concurrent/tck/StampedLockTest.java	Tue Jul 26 09:57:51 2016 -0700
+++ b/jdk/test/java/util/concurrent/tck/StampedLockTest.java	Tue Jul 26 10:02:05 2016 -0700
@@ -32,11 +32,18 @@
  * http://creativecommons.org/publicdomain/zero/1.0/
  */
 
+import static java.util.concurrent.TimeUnit.DAYS;
 import static java.util.concurrent.TimeUnit.MILLISECONDS;
 
+import java.util.ArrayList;
+import java.util.List;
 import java.util.concurrent.CountDownLatch;
+import java.util.concurrent.TimeUnit;
 import java.util.concurrent.locks.Lock;
 import java.util.concurrent.locks.StampedLock;
+import java.util.function.BiConsumer;
+import java.util.function.Consumer;
+import java.util.function.Function;
 
 import junit.framework.Test;
 import junit.framework.TestSuite;
@@ -1078,4 +1085,121 @@
         assertThrows(IllegalMonitorStateException.class, actions);
     }
 
+    static long writeLockInterruptiblyUninterrupted(StampedLock sl) {
+        try { return sl.writeLockInterruptibly(); }
+        catch (InterruptedException ex) { throw new AssertionError(ex); }
+    }
+
+    static long tryWriteLockUninterrupted(StampedLock sl, long time, TimeUnit unit) {
+        try { return sl.tryWriteLock(time, unit); }
+        catch (InterruptedException ex) { throw new AssertionError(ex); }
+    }
+
+    static long readLockInterruptiblyUninterrupted(StampedLock sl) {
+        try { return sl.readLockInterruptibly(); }
+        catch (InterruptedException ex) { throw new AssertionError(ex); }
+    }
+
+    static long tryReadLockUninterrupted(StampedLock sl, long time, TimeUnit unit) {
+        try { return sl.tryReadLock(time, unit); }
+        catch (InterruptedException ex) { throw new AssertionError(ex); }
+    }
+
+    /**
+     * Invalid write stamps result in IllegalMonitorStateException
+     */
+    public void testInvalidWriteStampsThrowIllegalMonitorStateException() {
+        List<Function<StampedLock, Long>> writeLockers = new ArrayList<>();
+        writeLockers.add((sl) -> sl.writeLock());
+        writeLockers.add((sl) -> writeLockInterruptiblyUninterrupted(sl));
+        writeLockers.add((sl) -> tryWriteLockUninterrupted(sl, Long.MIN_VALUE, DAYS));
+        writeLockers.add((sl) -> tryWriteLockUninterrupted(sl, 0, DAYS));
+
+        List<BiConsumer<StampedLock, Long>> writeUnlockers = new ArrayList<>();
+        writeUnlockers.add((sl, stamp) -> sl.unlockWrite(stamp));
+        writeUnlockers.add((sl, stamp) -> assertTrue(sl.tryUnlockWrite()));
+        writeUnlockers.add((sl, stamp) -> sl.asWriteLock().unlock());
+        writeUnlockers.add((sl, stamp) -> sl.unlock(stamp));
+
+        List<Consumer<StampedLock>> mutaters = new ArrayList<>();
+        mutaters.add((sl) -> {});
+        mutaters.add((sl) -> sl.readLock());
+        for (Function<StampedLock, Long> writeLocker : writeLockers)
+            mutaters.add((sl) -> writeLocker.apply(sl));
+
+        for (Function<StampedLock, Long> writeLocker : writeLockers)
+        for (BiConsumer<StampedLock, Long> writeUnlocker : writeUnlockers)
+        for (Consumer<StampedLock> mutater : mutaters) {
+            final StampedLock sl = new StampedLock();
+            final long stamp = writeLocker.apply(sl);
+            assertTrue(stamp != 0L);
+            assertThrows(IllegalMonitorStateException.class,
+                         () -> sl.unlockRead(stamp));
+            writeUnlocker.accept(sl, stamp);
+            mutater.accept(sl);
+            assertThrows(IllegalMonitorStateException.class,
+                         () -> sl.unlock(stamp),
+                         () -> sl.unlockRead(stamp),
+                         () -> sl.unlockWrite(stamp));
+        }
+    }
+
+    /**
+     * Invalid read stamps result in IllegalMonitorStateException
+     */
+    public void testInvalidReadStampsThrowIllegalMonitorStateException() {
+        List<Function<StampedLock, Long>> readLockers = new ArrayList<>();
+        readLockers.add((sl) -> sl.readLock());
+        readLockers.add((sl) -> readLockInterruptiblyUninterrupted(sl));
+        readLockers.add((sl) -> tryReadLockUninterrupted(sl, Long.MIN_VALUE, DAYS));
+        readLockers.add((sl) -> tryReadLockUninterrupted(sl, 0, DAYS));
+
+        List<BiConsumer<StampedLock, Long>> readUnlockers = new ArrayList<>();
+        readUnlockers.add((sl, stamp) -> sl.unlockRead(stamp));
+        readUnlockers.add((sl, stamp) -> assertTrue(sl.tryUnlockRead()));
+        readUnlockers.add((sl, stamp) -> sl.asReadLock().unlock());
+        readUnlockers.add((sl, stamp) -> sl.unlock(stamp));
+
+        List<Function<StampedLock, Long>> writeLockers = new ArrayList<>();
+        writeLockers.add((sl) -> sl.writeLock());
+        writeLockers.add((sl) -> writeLockInterruptiblyUninterrupted(sl));
+        writeLockers.add((sl) -> tryWriteLockUninterrupted(sl, Long.MIN_VALUE, DAYS));
+        writeLockers.add((sl) -> tryWriteLockUninterrupted(sl, 0, DAYS));
+
+        List<BiConsumer<StampedLock, Long>> writeUnlockers = new ArrayList<>();
+        writeUnlockers.add((sl, stamp) -> sl.unlockWrite(stamp));
+        writeUnlockers.add((sl, stamp) -> assertTrue(sl.tryUnlockWrite()));
+        writeUnlockers.add((sl, stamp) -> sl.asWriteLock().unlock());
+        writeUnlockers.add((sl, stamp) -> sl.unlock(stamp));
+
+
+        for (Function<StampedLock, Long> readLocker : readLockers)
+        for (BiConsumer<StampedLock, Long> readUnlocker : readUnlockers)
+        for (Function<StampedLock, Long> writeLocker : writeLockers)
+        for (BiConsumer<StampedLock, Long> writeUnlocker : writeUnlockers) {
+            final StampedLock sl = new StampedLock();
+            final long stamp = readLocker.apply(sl);
+            assertTrue(stamp != 0L);
+            assertThrows(IllegalMonitorStateException.class,
+                         () -> sl.unlockWrite(stamp));
+            readUnlocker.accept(sl, stamp);
+            assertThrows(IllegalMonitorStateException.class,
+                         () -> sl.unlock(stamp),
+                         () -> sl.unlockRead(stamp),
+                         () -> sl.unlockWrite(stamp));
+            final long writeStamp = writeLocker.apply(sl);
+            assertTrue(writeStamp != 0L);
+            assertTrue(writeStamp != stamp);
+            assertThrows(IllegalMonitorStateException.class,
+                         () -> sl.unlock(stamp),
+                         () -> sl.unlockRead(stamp),
+                         () -> sl.unlockWrite(stamp));
+            writeUnlocker.accept(sl, writeStamp);
+            assertThrows(IllegalMonitorStateException.class,
+                         () -> sl.unlock(stamp),
+                         () -> sl.unlockRead(stamp),
+                         () -> sl.unlockWrite(stamp));
+        }
+    }
+
 }