test/jdk/java/util/concurrent/ExecutorService/Invoke.java
changeset 47341 ed1fd45b6eb5
parent 47216 71c04702a3d5
child 47727 53020d8cdf5b
--- a/test/jdk/java/util/concurrent/ExecutorService/Invoke.java	Fri Oct 13 18:07:47 2017 -0700
+++ b/test/jdk/java/util/concurrent/ExecutorService/Invoke.java	Fri Oct 13 18:12:54 2017 -0700
@@ -28,12 +28,18 @@
  * @author  Martin Buchholz
  */
 
-import java.util.Arrays;
+import static java.util.concurrent.TimeUnit.NANOSECONDS;
+import static java.util.concurrent.TimeUnit.SECONDS;
+
 import java.util.List;
+import java.util.stream.Collectors;
+import java.util.stream.IntStream;
 import java.util.concurrent.Callable;
+import java.util.concurrent.CyclicBarrier;
 import java.util.concurrent.ExecutorService;
 import java.util.concurrent.Executors;
 import java.util.concurrent.Future;
+import java.util.concurrent.ThreadLocalRandom;
 import java.util.concurrent.atomic.AtomicLong;
 
 public class Invoke {
@@ -61,36 +67,162 @@
         check(condition, "Assertion failure");
     }
 
+    static long secondsElapsedSince(long startTime) {
+        return NANOSECONDS.toSeconds(System.nanoTime() - startTime);
+    }
+
+    static void awaitInterrupt(long timeoutSeconds) {
+        long startTime = System.nanoTime();
+        try {
+            Thread.sleep(SECONDS.toMillis(timeoutSeconds));
+            fail("timed out waiting for interrupt");
+        } catch (InterruptedException expected) {
+            check(secondsElapsedSince(startTime) < timeoutSeconds);
+        }
+    }
+
     public static void main(String[] args) {
         try {
-            final AtomicLong count = new AtomicLong(0);
-            ExecutorService fixed = Executors.newFixedThreadPool(5);
-            class Inc implements Callable<Long> {
-                public Long call() throws Exception {
-                    Thread.sleep(200); // Catch IE from possible cancel
-                    return count.incrementAndGet();
-                }
+            testInvokeAll();
+            testInvokeAny();
+            testInvokeAny_cancellationInterrupt();
+        } catch (Throwable t) {  unexpected(t); }
+
+        if (failed > 0)
+            throw new Error(
+                    String.format("Passed = %d, failed = %d", passed, failed));
+    }
+
+    static final long timeoutSeconds = 10L;
+
+    static void testInvokeAll() throws Throwable {
+        final ThreadLocalRandom rnd = ThreadLocalRandom.current();
+        final int nThreads = rnd.nextInt(2, 7);
+        final boolean timed = rnd.nextBoolean();
+        final ExecutorService pool = Executors.newFixedThreadPool(nThreads);
+        final AtomicLong count = new AtomicLong(0);
+        class Task implements Callable<Long> {
+            public Long call() throws Exception {
+                return count.incrementAndGet();
             }
-            List<Inc> tasks = Arrays.asList(new Inc(), new Inc(), new Inc());
-            List<Future<Long>> futures = fixed.invokeAll(tasks);
+        }
+
+        try {
+            final List<Task> tasks =
+                IntStream.range(0, nThreads)
+                .mapToObj(i -> new Task())
+                .collect(Collectors.toList());
+
+            List<Future<Long>> futures;
+            if (timed) {
+                long startTime = System.nanoTime();
+                futures = pool.invokeAll(tasks, timeoutSeconds, SECONDS);
+                check(secondsElapsedSince(startTime) < timeoutSeconds);
+            }
+            else
+                futures = pool.invokeAll(tasks);
             check(futures.size() == tasks.size());
             check(count.get() == tasks.size());
 
             long gauss = 0;
             for (Future<Long> future : futures) gauss += future.get();
-            check(gauss == ((tasks.size()+1)*tasks.size())/2);
+            check(gauss == (tasks.size()+1)*tasks.size()/2);
+
+            pool.shutdown();
+            check(pool.awaitTermination(10L, SECONDS));
+        } finally {
+            pool.shutdownNow();
+        }
+    }
 
-            ExecutorService single = Executors.newSingleThreadExecutor();
-            long save = count.get();
-            check(single.invokeAny(tasks) == save + 1);
-            check(count.get() == save + 1);
+    static void testInvokeAny() throws Throwable {
+        final ThreadLocalRandom rnd = ThreadLocalRandom.current();
+        final boolean timed = rnd.nextBoolean();
+        final ExecutorService pool = Executors.newSingleThreadExecutor();
+        final AtomicLong count = new AtomicLong(0);
+        class Task implements Callable<Long> {
+            public Long call() throws Exception {
+                long x = count.incrementAndGet();
+                check(x <= 2);
+                if (x == 2)
+                    // wait for main thread to interrupt us
+                    awaitInterrupt(timeoutSeconds);
+                return x;
+            }
+        }
+
+        try {
+            final List<Task> tasks =
+                IntStream.range(0, rnd.nextInt(1, 7))
+                .mapToObj(i -> new Task())
+                .collect(Collectors.toList());
+
+            long val;
+            if (timed) {
+                long startTime = System.nanoTime();
+                val = pool.invokeAny(tasks, timeoutSeconds, SECONDS);
+                check(secondsElapsedSince(startTime) < timeoutSeconds);
+            }
+            else
+                val = pool.invokeAny(tasks);
+            check(val == 1);
+
+            // inherent race between main thread interrupt and
+            // start of second task
+            check(count.get() == 1 || count.get() == 2);
 
-            fixed.shutdown();
-            single.shutdown();
+            pool.shutdown();
+            check(pool.awaitTermination(timeoutSeconds, SECONDS));
+        } finally {
+            pool.shutdownNow();
+        }
+    }
 
-        } catch (Throwable t) { unexpected(t); }
+    /**
+     * Every remaining running task is sent an interrupt for cancellation.
+     */
+    static void testInvokeAny_cancellationInterrupt() throws Throwable {
+        final ThreadLocalRandom rnd = ThreadLocalRandom.current();
+        final int nThreads = rnd.nextInt(2, 7);
+        final boolean timed = rnd.nextBoolean();
+        final ExecutorService pool = Executors.newFixedThreadPool(nThreads);
+        final AtomicLong count = new AtomicLong(0);
+        final AtomicLong interruptedCount = new AtomicLong(0);
+        final CyclicBarrier allStarted = new CyclicBarrier(nThreads);
+        class Task implements Callable<Long> {
+            public Long call() throws Exception {
+                allStarted.await();
+                long x = count.incrementAndGet();
+                if (x > 1)
+                    // main thread will interrupt us
+                    awaitInterrupt(timeoutSeconds);
+                return x;
+            }
+        }
 
-        System.out.printf("%nPassed = %d, failed = %d%n%n", passed, failed);
-        if (failed > 0) throw new Error("Some tests failed");
+        try {
+            final List<Task> tasks =
+                IntStream.range(0, nThreads)
+                .mapToObj(i -> new Task())
+                .collect(Collectors.toList());
+
+            long val;
+            if (timed) {
+                long startTime = System.nanoTime();
+                val = pool.invokeAny(tasks, timeoutSeconds, SECONDS);
+                check(secondsElapsedSince(startTime) < timeoutSeconds);
+            }
+            else
+                val = pool.invokeAny(tasks);
+            check(val == 1);
+
+            pool.shutdown();
+            check(pool.awaitTermination(timeoutSeconds, SECONDS));
+
+            // Check after shutdown to avoid race
+            check(count.get() == nThreads);
+        } finally {
+            pool.shutdownNow();
+        }
     }
 }