test/jdk/java/util/concurrent/tck/MapTest.java
changeset 58892 35bac2745d04
parent 58138 1e4270f875ee
--- a/test/jdk/java/util/concurrent/tck/MapTest.java	Fri Nov 01 16:16:05 2019 +0100
+++ b/test/jdk/java/util/concurrent/tck/MapTest.java	Fri Nov 01 09:04:04 2019 -0700
@@ -32,13 +32,17 @@
  * http://creativecommons.org/publicdomain/zero/1.0/
  */
 
-import junit.framework.Test;
-
 import java.util.ArrayList;
 import java.util.Iterator;
 import java.util.List;
 import java.util.Map;
+import java.util.concurrent.CompletableFuture;
 import java.util.concurrent.ThreadLocalRandom;
+import java.util.concurrent.atomic.AtomicBoolean;
+import java.util.concurrent.atomic.AtomicLong;
+import java.util.function.BiFunction;
+
+import junit.framework.Test;
 
 /**
  * Contains tests applicable to all Map implementations.
@@ -227,6 +231,71 @@
         assertTrue(clone.isEmpty());
     }
 
+    /**
+     * Concurrent access by compute methods behaves as expected
+     */
+    public void testConcurrentAccess() throws Throwable {
+        final Map map = impl.emptyMap();
+        final long testDurationMillis = expensiveTests ? 1000 : 2;
+        final int nTasks = impl.isConcurrent()
+            ? ThreadLocalRandom.current().nextInt(1, 10)
+            : 1;
+        final AtomicBoolean done = new AtomicBoolean(false);
+        final boolean remappingFunctionCalledAtMostOnce
+            = impl.remappingFunctionCalledAtMostOnce();
+        final List<CompletableFuture> futures = new ArrayList<>();
+        final AtomicLong expectedSum = new AtomicLong(0);
+        final Action[] tasks = {
+            // repeatedly increment values using compute()
+            () -> {
+                long[] invocations = new long[2];
+                ThreadLocalRandom rnd = ThreadLocalRandom.current();
+                BiFunction<Object, Object, Object> incValue = (k, v) -> {
+                    invocations[1]++;
+                    int vi = (v == null) ? 1 : impl.valueToInt(v) + 1;
+                    return impl.makeValue(vi);
+                };
+                while (!done.getAcquire()) {
+                    invocations[0]++;
+                    Object key = impl.makeKey(3 * rnd.nextInt(10));
+                    map.compute(key, incValue);
+                }
+                if (remappingFunctionCalledAtMostOnce)
+                    assertEquals(invocations[0], invocations[1]);
+                expectedSum.getAndAdd(invocations[0]);
+            },
+            // repeatedly increment values using computeIfPresent()
+            () -> {
+                long[] invocations = new long[2];
+                ThreadLocalRandom rnd = ThreadLocalRandom.current();
+                BiFunction<Object, Object, Object> incValue = (k, v) -> {
+                    invocations[1]++;
+                    int vi = impl.valueToInt(v) + 1;
+                    return impl.makeValue(vi);
+                };
+                while (!done.getAcquire()) {
+                    Object key = impl.makeKey(3 * rnd.nextInt(10));
+                    if (map.computeIfPresent(key, incValue) != null)
+                        invocations[0]++;
+                }
+                if (remappingFunctionCalledAtMostOnce)
+                    assertEquals(invocations[0], invocations[1]);
+                expectedSum.getAndAdd(invocations[0]);
+            },
+        };
+        for (int i = nTasks; i--> 0; ) {
+            Action task = chooseRandomly(tasks);
+            futures.add(CompletableFuture.runAsync(checkedRunnable(task)));
+        }
+        Thread.sleep(testDurationMillis);
+        done.setRelease(true);
+        for (var future : futures)
+            checkTimedGet(future, null);
+
+        long sum = map.values().stream().mapToLong(x -> (int) x).sum();
+        assertEquals(expectedSum.get(), sum);
+    }
+
 //     public void testFailsIntentionallyForDebugging() {
 //         fail(impl.klazz().getSimpleName());
 //     }