6625723: Excessive ThreadLocal storage used by ReentrantReadWriteLock
authormartin
Mon, 10 Mar 2008 23:23:47 -0700
changeset 62 ea448c54b34b
parent 61 5691b03db1ea
child 63 919733f875ac
6625723: Excessive ThreadLocal storage used by ReentrantReadWriteLock Reviewed-by: dholmes Contributed-by: Doug Lea <dl@cs.oswego.edu>
jdk/src/share/classes/java/util/concurrent/locks/ReentrantReadWriteLock.java
jdk/test/java/util/concurrent/locks/ReentrantReadWriteLock/Count.java
--- a/jdk/src/share/classes/java/util/concurrent/locks/ReentrantReadWriteLock.java	Mon Mar 10 23:23:47 2008 -0700
+++ b/jdk/src/share/classes/java/util/concurrent/locks/ReentrantReadWriteLock.java	Mon Mar 10 23:23:47 2008 -0700
@@ -222,7 +222,7 @@
     /** Inner class providing writelock */
     private final ReentrantReadWriteLock.WriteLock writerLock;
     /** Performs all synchronization mechanics */
-    private final Sync sync;
+    final Sync sync;
 
     /**
      * Creates a new {@code ReentrantReadWriteLock} with
@@ -239,7 +239,7 @@
      * @param fair {@code true} if this lock should use a fair ordering policy
      */
     public ReentrantReadWriteLock(boolean fair) {
-        sync = (fair)? new FairSync() : new NonfairSync();
+        sync = fair ? new FairSync() : new NonfairSync();
         readerLock = new ReadLock(this);
         writerLock = new WriteLock(this);
     }
@@ -256,8 +256,8 @@
 
         /*
          * Read vs write count extraction constants and functions.
-         * Lock state is logically divided into two shorts: The lower
-         * one representing the exclusive (writer) lock hold count,
+         * Lock state is logically divided into two unsigned shorts:
+         * The lower one representing the exclusive (writer) lock hold count,
          * and the upper the shared (reader) hold count.
          */
 
@@ -279,13 +279,6 @@
             int count;
             // Use id, not reference, to avoid garbage retention
             final long tid = Thread.currentThread().getId();
-            /** Decrement if positive; return previous value */
-            int tryDecrement() {
-                int c = count;
-                if (c > 0)
-                    count = c - 1;
-                return c;
-            }
         }
 
         /**
@@ -303,7 +296,7 @@
          * The number of read locks held by current thread.
          * Initialized only in constructor and readObject.
          */
-        transient ThreadLocalHoldCounter readHolds;
+        private transient ThreadLocalHoldCounter readHolds;
 
         /**
          * The hold count of the last thread to successfully acquire
@@ -312,7 +305,17 @@
          * acquire. This is non-volatile since it is just used
          * as a heuristic, and would be great for threads to cache.
          */
-        transient HoldCounter cachedHoldCounter;
+        private transient HoldCounter cachedHoldCounter;
+
+        /**
+         * firstReader is the first thread to have acquired the read lock.
+         * firstReaderHoldCount is firstReader's hold count.
+         * This allows tracking of read holds for uncontended read
+         * locks to be very cheap.
+         */
+        private final static long INVALID_THREAD_ID = -1;
+        private transient long firstReader = INVALID_THREAD_ID;
+        private transient int firstReaderHoldCount;
 
         Sync() {
             readHolds = new ThreadLocalHoldCounter();
@@ -390,12 +393,25 @@
         }
 
         protected final boolean tryReleaseShared(int unused) {
-            HoldCounter rh = cachedHoldCounter;
-            Thread current = Thread.currentThread();
-            if (rh == null || rh.tid != current.getId())
-                rh = readHolds.get();
-            if (rh.tryDecrement() <= 0)
-                throw new IllegalMonitorStateException();
+            long tid = Thread.currentThread().getId();
+            if (firstReader == tid) {
+                // assert firstReaderHoldCount > 0;
+                if (firstReaderHoldCount == 1)
+                    firstReader = INVALID_THREAD_ID;
+                else
+                    firstReaderHoldCount--;
+            } else {
+                HoldCounter rh = cachedHoldCounter;
+                if (rh == null || rh.tid != tid)
+                    rh = readHolds.get();
+                int count = rh.count;
+                if (count <= 1) {
+                    readHolds.remove();
+                    if (count <= 0)
+                        throw unmatchedUnlockException();
+                }
+                --rh.count;
+            }
             for (;;) {
                 int c = getState();
                 int nextc = c - SHARED_UNIT;
@@ -404,12 +420,16 @@
             }
         }
 
+        private IllegalMonitorStateException unmatchedUnlockException() {
+            return new IllegalMonitorStateException(
+                "attempt to unlock read lock, not locked by current thread");
+        }
+
         protected final int tryAcquireShared(int unused) {
             /*
              * Walkthrough:
              * 1. If write lock held by another thread, fail.
-             * 2. If count saturated, throw error.
-             * 3. Otherwise, this thread is eligible for
+             * 2. Otherwise, this thread is eligible for
              *    lock wrt state, so ask if it should block
              *    because of queue policy. If not, try
              *    to grant by CASing state and updating count.
@@ -417,23 +437,33 @@
              *    acquires, which is postponed to full version
              *    to avoid having to check hold count in
              *    the more typical non-reentrant case.
-             * 4. If step 3 fails either because thread
-             *    apparently not eligible or CAS fails,
-             *    chain to version with full retry loop.
+             * 3. If step 2 fails either because thread
+             *    apparently not eligible or CAS fails or count
+             *    saturated, chain to version with full retry loop.
              */
             Thread current = Thread.currentThread();
             int c = getState();
             if (exclusiveCount(c) != 0 &&
                 getExclusiveOwnerThread() != current)
                 return -1;
-            if (sharedCount(c) == MAX_COUNT)
-                throw new Error("Maximum lock count exceeded");
+            int r = sharedCount(c);
             if (!readerShouldBlock() &&
+                r < MAX_COUNT &&
                 compareAndSetState(c, c + SHARED_UNIT)) {
-                HoldCounter rh = cachedHoldCounter;
-                if (rh == null || rh.tid != current.getId())
-                    cachedHoldCounter = rh = readHolds.get();
-                rh.count++;
+                long tid = current.getId();
+                if (r == 0) {
+                    firstReader = tid;
+                    firstReaderHoldCount = 1;
+                } else if (firstReader == tid) {
+                    firstReaderHoldCount++;
+                } else {
+                    HoldCounter rh = cachedHoldCounter;
+                    if (rh == null || rh.tid != tid)
+                        cachedHoldCounter = rh = readHolds.get();
+                    else if (rh.count == 0)
+                        readHolds.set(rh);
+                    rh.count++;
+                }
                 return 1;
             }
             return fullTryAcquireShared(current);
@@ -450,20 +480,56 @@
              * complicating tryAcquireShared with interactions between
              * retries and lazily reading hold counts.
              */
-            HoldCounter rh = cachedHoldCounter;
-            if (rh == null || rh.tid != current.getId())
-                rh = readHolds.get();
+            HoldCounter rh = null;
             for (;;) {
                 int c = getState();
-                int w = exclusiveCount(c);
-                if ((w != 0 && getExclusiveOwnerThread() != current) ||
-                    ((rh.count | w) == 0 && readerShouldBlock()))
-                    return -1;
+                if (exclusiveCount(c) != 0) {
+                    if (getExclusiveOwnerThread() != current)
+                        //if (removeNeeded) readHolds.remove();
+                        return -1;
+                    // else we hold the exclusive lock; blocking here
+                    // would cause deadlock.
+                } else if (readerShouldBlock()) {
+                    // Make sure we're not acquiring read lock reentrantly
+                    long tid = current.getId();
+                    if (firstReader == tid) {
+                        // assert firstReaderHoldCount > 0;
+                    } else {
+                        if (rh == null) {
+                            rh = cachedHoldCounter;
+                            if (rh == null || rh.tid != tid) {
+                                rh = readHolds.get();
+                                if (rh.count == 0)
+                                    readHolds.remove();
+                            }
+                        }
+                        if (rh.count == 0)
+                            return -1;
+                    }
+                }
                 if (sharedCount(c) == MAX_COUNT)
                     throw new Error("Maximum lock count exceeded");
                 if (compareAndSetState(c, c + SHARED_UNIT)) {
-                    cachedHoldCounter = rh; // cache for release
-                    rh.count++;
+                    long tid = current.getId();
+                    if (sharedCount(c) == 0) {
+                        firstReader = tid;
+                        firstReaderHoldCount = 1;
+                    } else if (firstReader == tid) {
+                        firstReaderHoldCount++;
+                    } else {
+                        if (rh == null) {
+                            rh = cachedHoldCounter;
+                            if (rh != null && rh.tid == tid) {
+                                if (rh.count == 0)
+                                    readHolds.set(rh);
+                            } else {
+                                rh = readHolds.get();
+                            }
+                        } else if (rh.count == 0)
+                            readHolds.set(rh);
+                        cachedHoldCounter = rh; // cache for release
+                        rh.count++;
+                    }
                     return 1;
                 }
             }
@@ -472,14 +538,14 @@
         /**
          * Performs tryLock for write, enabling barging in both modes.
          * This is identical in effect to tryAcquire except for lack
-         * of calls to writerShouldBlock
+         * of calls to writerShouldBlock.
          */
         final boolean tryWriteLock() {
             Thread current = Thread.currentThread();
             int c = getState();
             if (c != 0) {
                 int w = exclusiveCount(c);
-                if (w == 0 ||current != getExclusiveOwnerThread())
+                if (w == 0 || current != getExclusiveOwnerThread())
                     return false;
                 if (w == MAX_COUNT)
                     throw new Error("Maximum lock count exceeded");
@@ -493,7 +559,7 @@
         /**
          * Performs tryLock for read, enabling barging in both modes.
          * This is identical in effect to tryAcquireShared except for
-         * lack of calls to readerShouldBlock
+         * lack of calls to readerShouldBlock.
          */
         final boolean tryReadLock() {
             Thread current = Thread.currentThread();
@@ -502,13 +568,24 @@
                 if (exclusiveCount(c) != 0 &&
                     getExclusiveOwnerThread() != current)
                     return false;
-                if (sharedCount(c) == MAX_COUNT)
+                int r = sharedCount(c);
+                if (r == MAX_COUNT)
                     throw new Error("Maximum lock count exceeded");
                 if (compareAndSetState(c, c + SHARED_UNIT)) {
-                    HoldCounter rh = cachedHoldCounter;
-                    if (rh == null || rh.tid != current.getId())
-                        cachedHoldCounter = rh = readHolds.get();
-                    rh.count++;
+                    long tid = current.getId();
+                    if (r == 0) {
+                        firstReader = tid;
+                        firstReaderHoldCount = 1;
+                    } else if (firstReader == tid) {
+                        firstReaderHoldCount++;
+                    } else {
+                        HoldCounter rh = cachedHoldCounter;
+                        if (rh == null || rh.tid != tid)
+                            cachedHoldCounter = rh = readHolds.get();
+                        else if (rh.count == 0)
+                            readHolds.set(rh);
+                        rh.count++;
+                    }
                     return true;
                 }
             }
@@ -546,7 +623,20 @@
         }
 
         final int getReadHoldCount() {
-            return getReadLockCount() == 0? 0 : readHolds.get().count;
+            if (getReadLockCount() == 0)
+                return 0;
+
+            long tid = Thread.currentThread().getId();
+            if (firstReader == tid)
+                return firstReaderHoldCount;
+
+            HoldCounter rh = cachedHoldCounter;
+            if (rh != null && rh.tid == tid)
+                return rh.count;
+
+            int count = readHolds.get().count;
+            if (count == 0) readHolds.remove();
+            return count;
         }
 
         /**
@@ -557,6 +647,7 @@
             throws java.io.IOException, ClassNotFoundException {
             s.defaultReadObject();
             readHolds = new ThreadLocalHoldCounter();
+            firstReader = INVALID_THREAD_ID;
             setState(0); // reset to unlocked state
         }
 
--- a/jdk/test/java/util/concurrent/locks/ReentrantReadWriteLock/Count.java	Mon Mar 10 23:23:47 2008 -0700
+++ b/jdk/test/java/util/concurrent/locks/ReentrantReadWriteLock/Count.java	Mon Mar 10 23:23:47 2008 -0700
@@ -1,5 +1,5 @@
 /*
- * Copyright 2005-2006 Sun Microsystems, Inc.  All Rights Reserved.
+ * Copyright 2005-2008 Sun Microsystems, Inc.  All Rights Reserved.
  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
  *
  * This code is free software; you can redistribute it and/or modify it
@@ -23,21 +23,89 @@
 
 /*
  * @test
- * @bug 6207928 6328220 6378321
+ * @bug 6207928 6328220 6378321 6625723
  * @summary Recursive lock invariant sanity checks
  * @author Martin Buchholz
  */
 
+import java.io.*;
 import java.util.*;
 import java.util.concurrent.*;
 import java.util.concurrent.locks.*;
 
 // I am the Cownt, and I lahve to cownt.
 public class Count {
-    private static void realMain(String[] args) throws Throwable {
-        final ReentrantLock rl = new ReentrantLock();
-        final ReentrantReadWriteLock rwl = new ReentrantReadWriteLock();
+    final Random rnd = new Random();
+
+    void lock(Lock lock) {
+        try {
+            switch (rnd.nextInt(4)) {
+            case 0: lock.lock(); break;
+            case 1: lock.lockInterruptibly(); break;
+            case 2: check(lock.tryLock()); break;
+            case 3: check(lock.tryLock(45, TimeUnit.MINUTES)); break;
+            }
+        } catch (Throwable t) { unexpected(t); }
+    }
+
+    void test(String[] args) throws Throwable {
+        for (boolean fair : new boolean[] { true, false })
+            for (boolean serialClone : new boolean[] { true, false }) {
+                testReentrantLocks(fair, serialClone);
+                testConcurrentReadLocks(fair, serialClone);
+            }
+    }
+
+    void testConcurrentReadLocks(final boolean fair,
+                                 final boolean serialClone) throws Throwable {
+        final int nThreads = 10;
+        final CyclicBarrier barrier = new CyclicBarrier(nThreads);
+        final ExecutorService es = Executors.newFixedThreadPool(nThreads);
+        final ReentrantReadWriteLock rwl = serialClone ?
+            serialClone(new ReentrantReadWriteLock(fair)) :
+            new ReentrantReadWriteLock(fair);
+        for (int i = 0; i < nThreads; i++) {
+            es.submit(new Runnable() { public void run() {
+                try {
+                    int n = 5;
+                    for (int i = 0; i < n; i++) {
+                        barrier.await();
+                        equal(rwl.getReadHoldCount(), i);
+                        equal(rwl.getWriteHoldCount(), 0);
+                        check(! rwl.isWriteLocked());
+                        equal(rwl.getReadLockCount(), nThreads * i);
+                        barrier.await();
+                        lock(rwl.readLock());
+                    }
+                    for (int i = 0; i < n; i++) {
+                        rwl.readLock().unlock();
+                        barrier.await();
+                        equal(rwl.getReadHoldCount(), n-i-1);
+                        equal(rwl.getReadLockCount(), nThreads*(n-i-1));
+                        equal(rwl.getWriteHoldCount(), 0);
+                        check(! rwl.isWriteLocked());
+                        barrier.await();
+                    }
+                    THROWS(IllegalMonitorStateException.class,
+                           new F(){void f(){rwl.readLock().unlock();}},
+                           new F(){void f(){rwl.writeLock().unlock();}});
+                    barrier.await();
+                } catch (Throwable t) { unexpected(t); }}});}
+        es.shutdown();
+        check(es.awaitTermination(10, TimeUnit.SECONDS));
+    }
+
+    void testReentrantLocks(final boolean fair,
+                            final boolean serialClone) throws Throwable {
+        final ReentrantLock rl = serialClone ?
+            serialClone(new ReentrantLock(fair)) :
+            new ReentrantLock(fair);
+        final ReentrantReadWriteLock rwl = serialClone ?
+            serialClone(new ReentrantReadWriteLock(fair)) :
+            new ReentrantReadWriteLock(fair);
         final int depth = 10;
+        equal(rl.isFair(), fair);
+        equal(rwl.isFair(), fair);
         check(! rl.isLocked());
         check(! rwl.isWriteLocked());
         check(! rl.isHeldByCurrentThread());
@@ -50,28 +118,11 @@
             equal(rwl.getReadHoldCount(), i);
             equal(rwl.getWriteHoldCount(), i);
             equal(rwl.writeLock().getHoldCount(), i);
-            switch (i%4) {
-            case 0:
-                rl.lock();
-                rwl.writeLock().lock();
-                rwl.readLock().lock();
-                break;
-            case 1:
-                rl.lockInterruptibly();
-                rwl.writeLock().lockInterruptibly();
-                rwl.readLock().lockInterruptibly();
-                break;
-            case 2:
-                check(rl.tryLock());
-                check(rwl.writeLock().tryLock());
-                check(rwl.readLock().tryLock());
-                break;
-            case 3:
-                check(rl.tryLock(454, TimeUnit.MILLISECONDS));
-                check(rwl.writeLock().tryLock(454, TimeUnit.NANOSECONDS));
-                check(rwl.readLock().tryLock(454, TimeUnit.HOURS));
-                break;
-            }
+            equal(rl.isLocked(), i > 0);
+            equal(rwl.isWriteLocked(), i > 0);
+            lock(rl);
+            lock(rwl.writeLock());
+            lock(rwl.readLock());
         }
 
         for (int i = depth; i > 0; i--) {
@@ -95,20 +146,48 @@
             rwl.writeLock().unlock();
             rl.unlock();
         }
+        THROWS(IllegalMonitorStateException.class,
+               new F(){void f(){rl.unlock();}},
+               new F(){void f(){rwl.readLock().unlock();}},
+               new F(){void f(){rwl.writeLock().unlock();}});
     }
 
     //--------------------- Infrastructure ---------------------------
-    static volatile int passed = 0, failed = 0;
-    static void pass() {passed++;}
-    static void fail() {failed++; Thread.dumpStack();}
-    static void fail(String msg) {System.out.println(msg); fail();}
-    static void unexpected(Throwable t) {failed++; t.printStackTrace();}
-    static void check(boolean cond) {if (cond) pass(); else fail();}
-    static void equal(Object x, Object y) {
+    volatile int passed = 0, failed = 0;
+    void pass() {passed++;}
+    void fail() {failed++; Thread.dumpStack();}
+    void fail(String msg) {System.err.println(msg); fail();}
+    void unexpected(Throwable t) {failed++; t.printStackTrace();}
+    void check(boolean cond) {if (cond) pass(); else fail();}
+    void equal(Object x, Object y) {
         if (x == null ? y == null : x.equals(y)) pass();
         else fail(x + " not equal to " + y);}
     public static void main(String[] args) throws Throwable {
-        try {realMain(args);} catch (Throwable t) {unexpected(t);}
+        new Count().instanceMain(args);}
+    void instanceMain(String[] args) throws Throwable {
+        try {test(args);} catch (Throwable t) {unexpected(t);}
         System.out.printf("%nPassed = %d, failed = %d%n%n", passed, failed);
         if (failed > 0) throw new AssertionError("Some tests failed");}
+    abstract class F {abstract void f() throws Throwable;}
+    void THROWS(Class<? extends Throwable> k, F... fs) {
+        for (F f : fs)
+            try {f.f(); fail("Expected " + k.getName() + " not thrown");}
+            catch (Throwable t) {
+                if (k.isAssignableFrom(t.getClass())) pass();
+                else unexpected(t);}}
+
+    static byte[] serializedForm(Object obj) {
+        try {
+            ByteArrayOutputStream baos = new ByteArrayOutputStream();
+            new ObjectOutputStream(baos).writeObject(obj);
+            return baos.toByteArray();
+        } catch (IOException e) { throw new RuntimeException(e); }}
+    static Object readObject(byte[] bytes)
+        throws IOException, ClassNotFoundException {
+        InputStream is = new ByteArrayInputStream(bytes);
+        return new ObjectInputStream(is).readObject();}
+    @SuppressWarnings("unchecked")
+    static <T> T serialClone(T obj) {
+        try { return (T) readObject(serializedForm(obj)); }
+        catch (Exception e) { throw new RuntimeException(e); }}
 }