8188853: java/util/concurrent/ExecutorService/Invoke.java Assertion failure
Reviewed-by: martin, psandoz, dholmes
--- a/test/jdk/java/util/concurrent/ExecutorService/Invoke.java Fri Oct 13 18:07:47 2017 -0700
+++ b/test/jdk/java/util/concurrent/ExecutorService/Invoke.java Fri Oct 13 18:12:54 2017 -0700
@@ -28,12 +28,18 @@
* @author Martin Buchholz
*/
-import java.util.Arrays;
+import static java.util.concurrent.TimeUnit.NANOSECONDS;
+import static java.util.concurrent.TimeUnit.SECONDS;
+
import java.util.List;
+import java.util.stream.Collectors;
+import java.util.stream.IntStream;
import java.util.concurrent.Callable;
+import java.util.concurrent.CyclicBarrier;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
+import java.util.concurrent.ThreadLocalRandom;
import java.util.concurrent.atomic.AtomicLong;
public class Invoke {
@@ -61,36 +67,162 @@
check(condition, "Assertion failure");
}
+ static long secondsElapsedSince(long startTime) {
+ return NANOSECONDS.toSeconds(System.nanoTime() - startTime);
+ }
+
+ static void awaitInterrupt(long timeoutSeconds) {
+ long startTime = System.nanoTime();
+ try {
+ Thread.sleep(SECONDS.toMillis(timeoutSeconds));
+ fail("timed out waiting for interrupt");
+ } catch (InterruptedException expected) {
+ check(secondsElapsedSince(startTime) < timeoutSeconds);
+ }
+ }
+
public static void main(String[] args) {
try {
- final AtomicLong count = new AtomicLong(0);
- ExecutorService fixed = Executors.newFixedThreadPool(5);
- class Inc implements Callable<Long> {
- public Long call() throws Exception {
- Thread.sleep(200); // Catch IE from possible cancel
- return count.incrementAndGet();
- }
+ testInvokeAll();
+ testInvokeAny();
+ testInvokeAny_cancellationInterrupt();
+ } catch (Throwable t) { unexpected(t); }
+
+ if (failed > 0)
+ throw new Error(
+ String.format("Passed = %d, failed = %d", passed, failed));
+ }
+
+ static final long timeoutSeconds = 10L;
+
+ static void testInvokeAll() throws Throwable {
+ final ThreadLocalRandom rnd = ThreadLocalRandom.current();
+ final int nThreads = rnd.nextInt(2, 7);
+ final boolean timed = rnd.nextBoolean();
+ final ExecutorService pool = Executors.newFixedThreadPool(nThreads);
+ final AtomicLong count = new AtomicLong(0);
+ class Task implements Callable<Long> {
+ public Long call() throws Exception {
+ return count.incrementAndGet();
}
- List<Inc> tasks = Arrays.asList(new Inc(), new Inc(), new Inc());
- List<Future<Long>> futures = fixed.invokeAll(tasks);
+ }
+
+ try {
+ final List<Task> tasks =
+ IntStream.range(0, nThreads)
+ .mapToObj(i -> new Task())
+ .collect(Collectors.toList());
+
+ List<Future<Long>> futures;
+ if (timed) {
+ long startTime = System.nanoTime();
+ futures = pool.invokeAll(tasks, timeoutSeconds, SECONDS);
+ check(secondsElapsedSince(startTime) < timeoutSeconds);
+ }
+ else
+ futures = pool.invokeAll(tasks);
check(futures.size() == tasks.size());
check(count.get() == tasks.size());
long gauss = 0;
for (Future<Long> future : futures) gauss += future.get();
- check(gauss == ((tasks.size()+1)*tasks.size())/2);
+ check(gauss == (tasks.size()+1)*tasks.size()/2);
+
+ pool.shutdown();
+ check(pool.awaitTermination(10L, SECONDS));
+ } finally {
+ pool.shutdownNow();
+ }
+ }
- ExecutorService single = Executors.newSingleThreadExecutor();
- long save = count.get();
- check(single.invokeAny(tasks) == save + 1);
- check(count.get() == save + 1);
+ static void testInvokeAny() throws Throwable {
+ final ThreadLocalRandom rnd = ThreadLocalRandom.current();
+ final boolean timed = rnd.nextBoolean();
+ final ExecutorService pool = Executors.newSingleThreadExecutor();
+ final AtomicLong count = new AtomicLong(0);
+ class Task implements Callable<Long> {
+ public Long call() throws Exception {
+ long x = count.incrementAndGet();
+ check(x <= 2);
+ if (x == 2)
+ // wait for main thread to interrupt us
+ awaitInterrupt(timeoutSeconds);
+ return x;
+ }
+ }
+
+ try {
+ final List<Task> tasks =
+ IntStream.range(0, rnd.nextInt(1, 7))
+ .mapToObj(i -> new Task())
+ .collect(Collectors.toList());
+
+ long val;
+ if (timed) {
+ long startTime = System.nanoTime();
+ val = pool.invokeAny(tasks, timeoutSeconds, SECONDS);
+ check(secondsElapsedSince(startTime) < timeoutSeconds);
+ }
+ else
+ val = pool.invokeAny(tasks);
+ check(val == 1);
+
+ // inherent race between main thread interrupt and
+ // start of second task
+ check(count.get() == 1 || count.get() == 2);
- fixed.shutdown();
- single.shutdown();
+ pool.shutdown();
+ check(pool.awaitTermination(timeoutSeconds, SECONDS));
+ } finally {
+ pool.shutdownNow();
+ }
+ }
- } catch (Throwable t) { unexpected(t); }
+ /**
+ * Every remaining running task is sent an interrupt for cancellation.
+ */
+ static void testInvokeAny_cancellationInterrupt() throws Throwable {
+ final ThreadLocalRandom rnd = ThreadLocalRandom.current();
+ final int nThreads = rnd.nextInt(2, 7);
+ final boolean timed = rnd.nextBoolean();
+ final ExecutorService pool = Executors.newFixedThreadPool(nThreads);
+ final AtomicLong count = new AtomicLong(0);
+ final AtomicLong interruptedCount = new AtomicLong(0);
+ final CyclicBarrier allStarted = new CyclicBarrier(nThreads);
+ class Task implements Callable<Long> {
+ public Long call() throws Exception {
+ allStarted.await();
+ long x = count.incrementAndGet();
+ if (x > 1)
+ // main thread will interrupt us
+ awaitInterrupt(timeoutSeconds);
+ return x;
+ }
+ }
- System.out.printf("%nPassed = %d, failed = %d%n%n", passed, failed);
- if (failed > 0) throw new Error("Some tests failed");
+ try {
+ final List<Task> tasks =
+ IntStream.range(0, nThreads)
+ .mapToObj(i -> new Task())
+ .collect(Collectors.toList());
+
+ long val;
+ if (timed) {
+ long startTime = System.nanoTime();
+ val = pool.invokeAny(tasks, timeoutSeconds, SECONDS);
+ check(secondsElapsedSince(startTime) < timeoutSeconds);
+ }
+ else
+ val = pool.invokeAny(tasks);
+ check(val == 1);
+
+ pool.shutdown();
+ check(pool.awaitTermination(timeoutSeconds, SECONDS));
+
+ // Check after shutdown to avoid race
+ check(count.get() == nThreads);
+ } finally {
+ pool.shutdownNow();
+ }
}
}