--- a/test/jdk/java/util/concurrent/tck/MapTest.java Fri Nov 01 16:16:05 2019 +0100
+++ b/test/jdk/java/util/concurrent/tck/MapTest.java Fri Nov 01 09:04:04 2019 -0700
@@ -32,13 +32,17 @@
* http://creativecommons.org/publicdomain/zero/1.0/
*/
-import junit.framework.Test;
-
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
+import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ThreadLocalRandom;
+import java.util.concurrent.atomic.AtomicBoolean;
+import java.util.concurrent.atomic.AtomicLong;
+import java.util.function.BiFunction;
+
+import junit.framework.Test;
/**
* Contains tests applicable to all Map implementations.
@@ -227,6 +231,71 @@
assertTrue(clone.isEmpty());
}
+ /**
+ * Concurrent access by compute methods behaves as expected
+ */
+ public void testConcurrentAccess() throws Throwable {
+ final Map map = impl.emptyMap();
+ final long testDurationMillis = expensiveTests ? 1000 : 2;
+ final int nTasks = impl.isConcurrent()
+ ? ThreadLocalRandom.current().nextInt(1, 10)
+ : 1;
+ final AtomicBoolean done = new AtomicBoolean(false);
+ final boolean remappingFunctionCalledAtMostOnce
+ = impl.remappingFunctionCalledAtMostOnce();
+ final List<CompletableFuture> futures = new ArrayList<>();
+ final AtomicLong expectedSum = new AtomicLong(0);
+ final Action[] tasks = {
+ // repeatedly increment values using compute()
+ () -> {
+ long[] invocations = new long[2];
+ ThreadLocalRandom rnd = ThreadLocalRandom.current();
+ BiFunction<Object, Object, Object> incValue = (k, v) -> {
+ invocations[1]++;
+ int vi = (v == null) ? 1 : impl.valueToInt(v) + 1;
+ return impl.makeValue(vi);
+ };
+ while (!done.getAcquire()) {
+ invocations[0]++;
+ Object key = impl.makeKey(3 * rnd.nextInt(10));
+ map.compute(key, incValue);
+ }
+ if (remappingFunctionCalledAtMostOnce)
+ assertEquals(invocations[0], invocations[1]);
+ expectedSum.getAndAdd(invocations[0]);
+ },
+ // repeatedly increment values using computeIfPresent()
+ () -> {
+ long[] invocations = new long[2];
+ ThreadLocalRandom rnd = ThreadLocalRandom.current();
+ BiFunction<Object, Object, Object> incValue = (k, v) -> {
+ invocations[1]++;
+ int vi = impl.valueToInt(v) + 1;
+ return impl.makeValue(vi);
+ };
+ while (!done.getAcquire()) {
+ Object key = impl.makeKey(3 * rnd.nextInt(10));
+ if (map.computeIfPresent(key, incValue) != null)
+ invocations[0]++;
+ }
+ if (remappingFunctionCalledAtMostOnce)
+ assertEquals(invocations[0], invocations[1]);
+ expectedSum.getAndAdd(invocations[0]);
+ },
+ };
+ for (int i = nTasks; i--> 0; ) {
+ Action task = chooseRandomly(tasks);
+ futures.add(CompletableFuture.runAsync(checkedRunnable(task)));
+ }
+ Thread.sleep(testDurationMillis);
+ done.setRelease(true);
+ for (var future : futures)
+ checkTimedGet(future, null);
+
+ long sum = map.values().stream().mapToLong(x -> (int) x).sum();
+ assertEquals(expectedSum.get(), sum);
+ }
+
// public void testFailsIntentionallyForDebugging() {
// fail(impl.klazz().getSimpleName());
// }