8160751: Optimize ConcurrentHashMap.keySet().removeAll
authordl
Tue, 26 Jul 2016 09:57:51 -0700
changeset 39779 4666307d3155
parent 39778 5cda06c52cdd
child 39780 18618975fbb6
8160751: Optimize ConcurrentHashMap.keySet().removeAll 8161372: ConcurrentHashMap.computeIfAbsent(k,f) locks bin when k present Reviewed-by: martin, psandoz, plevart
jdk/src/java.base/share/classes/java/util/concurrent/ConcurrentHashMap.java
jdk/test/java/util/concurrent/tck/ConcurrentHashMapTest.java
--- a/jdk/src/java.base/share/classes/java/util/concurrent/ConcurrentHashMap.java	Tue Jul 26 09:53:38 2016 -0700
+++ b/jdk/src/java.base/share/classes/java/util/concurrent/ConcurrentHashMap.java	Tue Jul 26 09:57:51 2016 -0700
@@ -1023,7 +1023,7 @@
         int hash = spread(key.hashCode());
         int binCount = 0;
         for (Node<K,V>[] tab = table;;) {
-            Node<K,V> f; int n, i, fh;
+            Node<K,V> f; int n, i, fh; K fk; V fv;
             if (tab == null || (n = tab.length) == 0)
                 tab = initTable();
             else if ((f = tabAt(tab, i = (n - 1) & hash)) == null) {
@@ -1032,6 +1032,10 @@
             }
             else if ((fh = f.hash) == MOVED)
                 tab = helpTransfer(tab, f);
+            else if (onlyIfAbsent && fh == hash &&  // check first node
+                     ((fk = f.key) == key || fk != null && key.equals(fk)) &&
+                     (fv = f.val) != null)
+                return fv;
             else {
                 V oldVal = null;
                 synchronized (f) {
@@ -1702,7 +1706,7 @@
         V val = null;
         int binCount = 0;
         for (Node<K,V>[] tab = table;;) {
-            Node<K,V> f; int n, i, fh;
+            Node<K,V> f; int n, i, fh; K fk; V fv;
             if (tab == null || (n = tab.length) == 0)
                 tab = initTable();
             else if ((f = tabAt(tab, i = (n - 1) & h)) == null) {
@@ -1724,6 +1728,10 @@
             }
             else if ((fh = f.hash) == MOVED)
                 tab = helpTransfer(tab, f);
+            else if (fh == h &&                  // check first node
+                     ((fk = f.key) == key || fk != null && key.equals(fk)) &&
+                     (fv = f.val) != null)
+                return fv;
             else {
                 boolean added = false;
                 synchronized (f) {
@@ -4553,14 +4561,21 @@
             return true;
         }
 
-        public final boolean removeAll(Collection<?> c) {
+        public boolean removeAll(Collection<?> c) {
             if (c == null) throw new NullPointerException();
             boolean modified = false;
-            for (Iterator<E> it = iterator(); it.hasNext();) {
-                if (c.contains(it.next())) {
-                    it.remove();
-                    modified = true;
+            // Use (c instanceof Set) as a hint that lookup in c is as
+            // efficient as this view
+            if (c instanceof Set<?> && c.size() > map.table.length) {
+                for (Iterator<?> it = iterator(); it.hasNext(); ) {
+                    if (c.contains(it.next())) {
+                        it.remove();
+                        modified = true;
+                    }
                 }
+            } else {
+                for (Object e : c)
+                    modified |= remove(e);
             }
             return modified;
         }
@@ -4747,6 +4762,18 @@
             throw new UnsupportedOperationException();
         }
 
+        @Override public boolean removeAll(Collection<?> c) {
+            if (c == null) throw new NullPointerException();
+            boolean modified = false;
+            for (Iterator<V> it = iterator(); it.hasNext();) {
+                if (c.contains(it.next())) {
+                    it.remove();
+                    modified = true;
+                }
+            }
+            return modified;
+        }
+
         public boolean removeIf(Predicate<? super V> filter) {
             return map.removeValueIf(filter);
         }
--- a/jdk/test/java/util/concurrent/tck/ConcurrentHashMapTest.java	Tue Jul 26 09:53:38 2016 -0700
+++ b/jdk/test/java/util/concurrent/tck/ConcurrentHashMapTest.java	Tue Jul 26 09:57:51 2016 -0700
@@ -43,6 +43,8 @@
 import java.util.Random;
 import java.util.Set;
 import java.util.concurrent.ConcurrentHashMap;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Executors;
 
 import junit.framework.Test;
 import junit.framework.TestSuite;
@@ -830,4 +832,47 @@
         }
     }
 
+    /**
+     * Tests performance of removeAll when the other collection is much smaller.
+     * ant -Djsr166.tckTestClass=ConcurrentHashMapTest -Djsr166.methodFilter=testRemoveAll_performance -Djsr166.expensiveTests=true tck
+     */
+    public void testRemoveAll_performance() {
+        final int mapSize = expensiveTests ? 1_000_000 : 100;
+        final int iterations = expensiveTests ? 500 : 2;
+        final ConcurrentHashMap<Integer, Integer> map = new ConcurrentHashMap<>();
+        for (int i = 0; i < mapSize; i++)
+            map.put(i, i);
+        Set<Integer> keySet = map.keySet();
+        Collection<Integer> removeMe = Arrays.asList(new Integer[] { -99, -86 });
+        for (int i = 0; i < iterations; i++)
+            assertFalse(keySet.removeAll(removeMe));
+        assertEquals(mapSize, map.size());
+    }
+
+    /**
+     * Tests performance of computeIfAbsent when the element is present.
+     * See JDK-8161372
+     * ant -Djsr166.tckTestClass=ConcurrentHashMapTest -Djsr166.methodFilter=testcomputeIfAbsent_performance -Djsr166.expensiveTests=true tck
+     */
+    public void testcomputeIfAbsent_performance() {
+        final int mapSize = 20;
+        final int iterations = expensiveTests ? (1 << 23) : mapSize * 2;
+        final int threads = expensiveTests ? 10 : 2;
+        final ConcurrentHashMap<Integer, Integer> map = new ConcurrentHashMap<>();
+        for (int i = 0; i < mapSize; i++)
+            map.put(i, i);
+        final ExecutorService pool = Executors.newFixedThreadPool(2);
+        try (PoolCleaner cleaner = cleaner(pool)) {
+            Runnable r = new CheckedRunnable() {
+                public void realRun() {
+                    int result = 0;
+                    for (int i = 0; i < iterations; i++)
+                        result += map.computeIfAbsent(i % mapSize, (k) -> k + k);
+                    if (result == -42) throw new Error();
+                }};
+            for (int i = 0; i < threads; i++)
+                pool.execute(r);
+        }
+    }
+
 }