--- a/jdk/src/java.base/share/classes/java/util/concurrent/locks/StampedLock.java Tue Jul 26 09:57:51 2016 -0700
+++ b/jdk/src/java.base/share/classes/java/util/concurrent/locks/StampedLock.java Tue Jul 26 10:02:05 2016 -0700
@@ -256,8 +256,12 @@
* method validate()) requires stricter ordering rules than apply
* to normal volatile reads (of "state"). To force orderings of
* reads before a validation and the validation itself in those
- * cases where this is not already forced, we use
- * VarHandle.acquireFence.
+ * cases where this is not already forced, we use acquireFence.
+ * Unlike in that paper, we allow writers to use plain writes.
+ * One would not expect reorderings of such writes with the lock
+ * acquisition CAS because there is a "control dependency", but it
+ * is theoretically possible, so we additionally add a
+ * storeStoreFence after lock acquisition CAS.
*
* The memory layout keeps lock state and queue pointers together
* (normally on the same cache line). This usually works well for
@@ -355,6 +359,20 @@
state = ORIGIN;
}
+ private boolean casState(long expectedValue, long newValue) {
+ return STATE.compareAndSet(this, expectedValue, newValue);
+ }
+
+ private long tryWriteLock(long s) {
+ // assert (s & ABITS) == 0L;
+ long next;
+ if (casState(s, next = s | WBIT)) {
+ VarHandle.storeStoreFence();
+ return next;
+ }
+ return 0L;
+ }
+
/**
* Exclusively acquires the lock, blocking if necessary
* until available.
@@ -363,10 +381,8 @@
*/
@ReservedStackAccess
public long writeLock() {
- long s, next; // bypass acquireWrite in fully unlocked case only
- return ((((s = state) & ABITS) == 0L &&
- STATE.compareAndSet(this, s, next = s + WBIT)) ?
- next : acquireWrite(false, 0L));
+ long next;
+ return ((next = tryWriteLock()) != 0L) ? next : acquireWrite(false, 0L);
}
/**
@@ -377,10 +393,8 @@
*/
@ReservedStackAccess
public long tryWriteLock() {
- long s, next;
- return ((((s = state) & ABITS) == 0L &&
- STATE.compareAndSet(this, s, next = s + WBIT)) ?
- next : 0L);
+ long s;
+ return (((s = state) & ABITS) == 0L) ? tryWriteLock(s) : 0L;
}
/**
@@ -440,10 +454,13 @@
*/
@ReservedStackAccess
public long readLock() {
- long s = state, next; // bypass acquireRead on common uncontended case
- return ((whead == wtail && (s & ABITS) < RFULL &&
- STATE.compareAndSet(this, s, next = s + RUNIT)) ?
- next : acquireRead(false, 0L));
+ long s, next;
+ // bypass acquireRead on common uncontended case
+ return (whead == wtail
+ && ((s = state) & ABITS) < RFULL
+ && casState(s, next = s + RUNIT))
+ ? next
+ : acquireRead(false, 0L);
}
/**
@@ -457,7 +474,7 @@
long s, m, next;
while ((m = (s = state) & ABITS) != WBIT) {
if (m < RFULL) {
- if (STATE.compareAndSet(this, s, next = s + RUNIT))
+ if (casState(s, next = s + RUNIT))
return next;
}
else if ((next = tryIncReaderOverflow(s)) != 0L)
@@ -487,7 +504,7 @@
if (!Thread.interrupted()) {
if ((m = (s = state) & ABITS) != WBIT) {
if (m < RFULL) {
- if (STATE.compareAndSet(this, s, next = s + RUNIT))
+ if (casState(s, next = s + RUNIT))
return next;
}
else if ((next = tryIncReaderOverflow(s)) != 0L)
@@ -514,10 +531,15 @@
* before acquiring the lock
*/
@ReservedStackAccess
- public long readLockInterruptibly() throws InterruptedException {
- long next;
- if (!Thread.interrupted() &&
- (next = acquireRead(true, 0L)) != INTERRUPTED)
+ public long readLockInterruptibly() throws InterruptedException {
+ long s, next;
+ if (!Thread.interrupted()
+ // bypass acquireRead on common uncontended case
+ && ((whead == wtail
+ && ((s = state) & ABITS) < RFULL
+ && casState(s, next = s + RUNIT))
+ ||
+ (next = acquireRead(true, 0L)) != INTERRUPTED))
return next;
throw new InterruptedException();
}
@@ -598,7 +620,7 @@
&& (stamp & RBITS) > 0L
&& ((m = s & RBITS) > 0L)) {
if (m < RFULL) {
- if (STATE.compareAndSet(this, s, s - RUNIT)) {
+ if (casState(s, s - RUNIT)) {
if (m == RUNIT && (h = whead) != null && h.status != 0)
release(h);
return;
@@ -620,7 +642,7 @@
*/
@ReservedStackAccess
public void unlock(long stamp) {
- if ((stamp & WBIT) != 0)
+ if ((stamp & WBIT) != 0L)
unlockWrite(stamp);
else
unlockRead(stamp);
@@ -644,7 +666,7 @@
if ((m = s & ABITS) == 0L) {
if (a != 0L)
break;
- if (STATE.compareAndSet(this, s, next = s + WBIT))
+ if ((next = tryWriteLock(s)) != 0L)
return next;
}
else if (m == WBIT) {
@@ -653,8 +675,10 @@
return stamp;
}
else if (m == RUNIT && a != 0L) {
- if (STATE.compareAndSet(this, s, next = s - RUNIT + WBIT))
+ if (casState(s, next = s - RUNIT + WBIT)) {
+ VarHandle.storeStoreFence();
return next;
+ }
}
else
break;
@@ -688,7 +712,7 @@
else if (a == 0L) {
// optimistic read stamp
if ((s & ABITS) < RFULL) {
- if (STATE.compareAndSet(this, s, next = s + RUNIT))
+ if (casState(s, next = s + RUNIT))
return next;
}
else if ((next = tryIncReaderOverflow(s)) != 0L)
@@ -730,7 +754,7 @@
else if ((m = s & ABITS) == 0L) // invalid read stamp
break;
else if (m < RFULL) {
- if (STATE.compareAndSet(this, s, next = s - RUNIT)) {
+ if (casState(s, next = s - RUNIT)) {
if (m == RUNIT && (h = whead) != null && h.status != 0)
release(h);
return next & SBITS;
@@ -771,7 +795,7 @@
long s, m; WNode h;
while ((m = (s = state) & ABITS) != 0L && m < WBIT) {
if (m < RFULL) {
- if (STATE.compareAndSet(this, s, s - RUNIT)) {
+ if (casState(s, s - RUNIT)) {
if (m == RUNIT && (h = whead) != null && h.status != 0)
release(h);
return true;
@@ -940,7 +964,7 @@
long s, m; WNode h;
while ((m = (s = state) & RBITS) > 0L) {
if (m < RFULL) {
- if (STATE.compareAndSet(this, s, s - RUNIT)) {
+ if (casState(s, s - RUNIT)) {
if (m == RUNIT && (h = whead) != null && h.status != 0)
release(h);
return;
@@ -971,7 +995,7 @@
private long tryIncReaderOverflow(long s) {
// assert (s & ABITS) >= RFULL;
if ((s & ABITS) == RFULL) {
- if (STATE.compareAndSet(this, s, s | RBITS)) {
+ if (casState(s, s | RBITS)) {
++readerOverflow;
STATE.setVolatile(this, s);
return s;
@@ -993,7 +1017,7 @@
private long tryDecReaderOverflow(long s) {
// assert (s & ABITS) >= RFULL;
if ((s & ABITS) == RFULL) {
- if (STATE.compareAndSet(this, s, s | RBITS)) {
+ if (casState(s, s | RBITS)) {
int r; long next;
if ((r = readerOverflow) > 0) {
readerOverflow = r - 1;
@@ -1047,7 +1071,7 @@
for (int spins = -1;;) { // spin while enqueuing
long m, s, ns;
if ((m = (s = state) & ABITS) == 0L) {
- if (STATE.compareAndSet(this, s, ns = s + WBIT))
+ if ((ns = tryWriteLock(s)) != 0L)
return ns;
}
else if (spins < 0)
@@ -1082,7 +1106,7 @@
for (int k = spins; k > 0; --k) { // spin at head
long s, ns;
if (((s = state) & ABITS) == 0L) {
- if (STATE.compareAndSet(this, s, ns = s + WBIT)) {
+ if ((ns = tryWriteLock(s)) != 0L) {
whead = node;
node.prev = null;
if (wasInterrupted)
@@ -1158,7 +1182,7 @@
if ((h = whead) == (p = wtail)) {
for (long m, s, ns;;) {
if ((m = (s = state) & ABITS) < RFULL ?
- STATE.compareAndSet(this, s, ns = s + RUNIT) :
+ casState(s, ns = s + RUNIT) :
(m < WBIT && (ns = tryIncReaderOverflow(s)) != 0L)) {
if (wasInterrupted)
Thread.currentThread().interrupt();
@@ -1208,7 +1232,7 @@
long m, s, ns;
do {
if ((m = (s = state) & ABITS) < RFULL ?
- STATE.compareAndSet(this, s, ns = s + RUNIT) :
+ casState(s, ns = s + RUNIT) :
(m < WBIT &&
(ns = tryIncReaderOverflow(s)) != 0L)) {
if (wasInterrupted)
@@ -1260,7 +1284,7 @@
for (int k = spins;;) { // spin at head
long m, s, ns;
if ((m = (s = state) & ABITS) < RFULL ?
- STATE.compareAndSet(this, s, ns = s + RUNIT) :
+ casState(s, ns = s + RUNIT) :
(m < WBIT && (ns = tryIncReaderOverflow(s)) != 0L)) {
WNode c; Thread w;
whead = node;
--- a/jdk/test/java/util/concurrent/tck/StampedLockTest.java Tue Jul 26 09:57:51 2016 -0700
+++ b/jdk/test/java/util/concurrent/tck/StampedLockTest.java Tue Jul 26 10:02:05 2016 -0700
@@ -32,11 +32,18 @@
* http://creativecommons.org/publicdomain/zero/1.0/
*/
+import static java.util.concurrent.TimeUnit.DAYS;
import static java.util.concurrent.TimeUnit.MILLISECONDS;
+import java.util.ArrayList;
+import java.util.List;
import java.util.concurrent.CountDownLatch;
+import java.util.concurrent.TimeUnit;
import java.util.concurrent.locks.Lock;
import java.util.concurrent.locks.StampedLock;
+import java.util.function.BiConsumer;
+import java.util.function.Consumer;
+import java.util.function.Function;
import junit.framework.Test;
import junit.framework.TestSuite;
@@ -1078,4 +1085,121 @@
assertThrows(IllegalMonitorStateException.class, actions);
}
+ static long writeLockInterruptiblyUninterrupted(StampedLock sl) {
+ try { return sl.writeLockInterruptibly(); }
+ catch (InterruptedException ex) { throw new AssertionError(ex); }
+ }
+
+ static long tryWriteLockUninterrupted(StampedLock sl, long time, TimeUnit unit) {
+ try { return sl.tryWriteLock(time, unit); }
+ catch (InterruptedException ex) { throw new AssertionError(ex); }
+ }
+
+ static long readLockInterruptiblyUninterrupted(StampedLock sl) {
+ try { return sl.readLockInterruptibly(); }
+ catch (InterruptedException ex) { throw new AssertionError(ex); }
+ }
+
+ static long tryReadLockUninterrupted(StampedLock sl, long time, TimeUnit unit) {
+ try { return sl.tryReadLock(time, unit); }
+ catch (InterruptedException ex) { throw new AssertionError(ex); }
+ }
+
+ /**
+ * Invalid write stamps result in IllegalMonitorStateException
+ */
+ public void testInvalidWriteStampsThrowIllegalMonitorStateException() {
+ List<Function<StampedLock, Long>> writeLockers = new ArrayList<>();
+ writeLockers.add((sl) -> sl.writeLock());
+ writeLockers.add((sl) -> writeLockInterruptiblyUninterrupted(sl));
+ writeLockers.add((sl) -> tryWriteLockUninterrupted(sl, Long.MIN_VALUE, DAYS));
+ writeLockers.add((sl) -> tryWriteLockUninterrupted(sl, 0, DAYS));
+
+ List<BiConsumer<StampedLock, Long>> writeUnlockers = new ArrayList<>();
+ writeUnlockers.add((sl, stamp) -> sl.unlockWrite(stamp));
+ writeUnlockers.add((sl, stamp) -> assertTrue(sl.tryUnlockWrite()));
+ writeUnlockers.add((sl, stamp) -> sl.asWriteLock().unlock());
+ writeUnlockers.add((sl, stamp) -> sl.unlock(stamp));
+
+ List<Consumer<StampedLock>> mutaters = new ArrayList<>();
+ mutaters.add((sl) -> {});
+ mutaters.add((sl) -> sl.readLock());
+ for (Function<StampedLock, Long> writeLocker : writeLockers)
+ mutaters.add((sl) -> writeLocker.apply(sl));
+
+ for (Function<StampedLock, Long> writeLocker : writeLockers)
+ for (BiConsumer<StampedLock, Long> writeUnlocker : writeUnlockers)
+ for (Consumer<StampedLock> mutater : mutaters) {
+ final StampedLock sl = new StampedLock();
+ final long stamp = writeLocker.apply(sl);
+ assertTrue(stamp != 0L);
+ assertThrows(IllegalMonitorStateException.class,
+ () -> sl.unlockRead(stamp));
+ writeUnlocker.accept(sl, stamp);
+ mutater.accept(sl);
+ assertThrows(IllegalMonitorStateException.class,
+ () -> sl.unlock(stamp),
+ () -> sl.unlockRead(stamp),
+ () -> sl.unlockWrite(stamp));
+ }
+ }
+
+ /**
+ * Invalid read stamps result in IllegalMonitorStateException
+ */
+ public void testInvalidReadStampsThrowIllegalMonitorStateException() {
+ List<Function<StampedLock, Long>> readLockers = new ArrayList<>();
+ readLockers.add((sl) -> sl.readLock());
+ readLockers.add((sl) -> readLockInterruptiblyUninterrupted(sl));
+ readLockers.add((sl) -> tryReadLockUninterrupted(sl, Long.MIN_VALUE, DAYS));
+ readLockers.add((sl) -> tryReadLockUninterrupted(sl, 0, DAYS));
+
+ List<BiConsumer<StampedLock, Long>> readUnlockers = new ArrayList<>();
+ readUnlockers.add((sl, stamp) -> sl.unlockRead(stamp));
+ readUnlockers.add((sl, stamp) -> assertTrue(sl.tryUnlockRead()));
+ readUnlockers.add((sl, stamp) -> sl.asReadLock().unlock());
+ readUnlockers.add((sl, stamp) -> sl.unlock(stamp));
+
+ List<Function<StampedLock, Long>> writeLockers = new ArrayList<>();
+ writeLockers.add((sl) -> sl.writeLock());
+ writeLockers.add((sl) -> writeLockInterruptiblyUninterrupted(sl));
+ writeLockers.add((sl) -> tryWriteLockUninterrupted(sl, Long.MIN_VALUE, DAYS));
+ writeLockers.add((sl) -> tryWriteLockUninterrupted(sl, 0, DAYS));
+
+ List<BiConsumer<StampedLock, Long>> writeUnlockers = new ArrayList<>();
+ writeUnlockers.add((sl, stamp) -> sl.unlockWrite(stamp));
+ writeUnlockers.add((sl, stamp) -> assertTrue(sl.tryUnlockWrite()));
+ writeUnlockers.add((sl, stamp) -> sl.asWriteLock().unlock());
+ writeUnlockers.add((sl, stamp) -> sl.unlock(stamp));
+
+
+ for (Function<StampedLock, Long> readLocker : readLockers)
+ for (BiConsumer<StampedLock, Long> readUnlocker : readUnlockers)
+ for (Function<StampedLock, Long> writeLocker : writeLockers)
+ for (BiConsumer<StampedLock, Long> writeUnlocker : writeUnlockers) {
+ final StampedLock sl = new StampedLock();
+ final long stamp = readLocker.apply(sl);
+ assertTrue(stamp != 0L);
+ assertThrows(IllegalMonitorStateException.class,
+ () -> sl.unlockWrite(stamp));
+ readUnlocker.accept(sl, stamp);
+ assertThrows(IllegalMonitorStateException.class,
+ () -> sl.unlock(stamp),
+ () -> sl.unlockRead(stamp),
+ () -> sl.unlockWrite(stamp));
+ final long writeStamp = writeLocker.apply(sl);
+ assertTrue(writeStamp != 0L);
+ assertTrue(writeStamp != stamp);
+ assertThrows(IllegalMonitorStateException.class,
+ () -> sl.unlock(stamp),
+ () -> sl.unlockRead(stamp),
+ () -> sl.unlockWrite(stamp));
+ writeUnlocker.accept(sl, writeStamp);
+ assertThrows(IllegalMonitorStateException.class,
+ () -> sl.unlock(stamp),
+ () -> sl.unlockRead(stamp),
+ () -> sl.unlockWrite(stamp));
+ }
+ }
+
}