8221168: java/util/concurrent/CountDownLatch/Basic.java fails
authordl
Sat, 14 Sep 2019 11:24:14 -0700
changeset 58136 f689a48dba4b
parent 58135 2081ff900d65
child 58137 6a556bcd94fc
8221168: java/util/concurrent/CountDownLatch/Basic.java fails Reviewed-by: martin, alanb
test/jdk/java/util/concurrent/CountDownLatch/Basic.java
--- a/test/jdk/java/util/concurrent/CountDownLatch/Basic.java	Sat Sep 14 11:20:57 2019 -0700
+++ b/test/jdk/java/util/concurrent/CountDownLatch/Basic.java	Sat Sep 14 11:24:14 2019 -0700
@@ -23,70 +23,57 @@
 
 /*
  * @test
- * @bug 6332435
+ * @bug 6332435 8221168
  * @summary Basic tests for CountDownLatch
  * @library /test/lib
  * @author Seetharam Avadhanam, Martin Buchholz
  */
 
 import java.util.concurrent.CountDownLatch;
+import java.util.concurrent.ThreadLocalRandom;
 import java.util.concurrent.TimeUnit;
 import java.util.concurrent.atomic.AtomicInteger;
+import java.util.function.Supplier;
 import jdk.test.lib.Utils;
 
 public class Basic {
     static final long LONG_DELAY_MS = Utils.adjustTimeout(10_000);
 
-    interface AwaiterFactory {
-        Awaiter getAwaiter();
-    }
-
     abstract static class Awaiter extends Thread {
-        private volatile Throwable result = null;
-        protected void result(Throwable result) { this.result = result; }
-        public Throwable result() { return this.result; }
-    }
-
-    private void toTheStartingGate(CountDownLatch gate) {
-        try {
-            gate.await();
-        }
-        catch (Throwable t) { fail(t); }
+        volatile Throwable exception;
+        volatile boolean interrupted;
+        abstract void realRun() throws Exception;
+        public final void run() {
+            try { realRun(); }
+            catch (Throwable ex) { exception = ex; }
+            interrupted = Thread.interrupted();
+        };
     }
 
-    private Awaiter awaiter(final CountDownLatch latch,
-                            final CountDownLatch gate) {
-        return new Awaiter() { public void run() {
-            System.out.println("without millis: " + latch.toString());
-            gate.countDown();
-
-            try {
+    static Awaiter awaiter(CountDownLatch latch,
+                           CountDownLatch gate) {
+        return new Awaiter() {
+            public void realRun() throws InterruptedException {
+                gate.countDown();
                 latch.await();
-                System.out.println("without millis - ComingOut");
-            }
-            catch (Throwable result) { result(result); }}};
+            }};
     }
 
-    private Awaiter awaiter(final CountDownLatch latch,
-                            final CountDownLatch gate,
-                            final long millis) {
-        return new Awaiter() { public void run() {
-            System.out.println("with millis: "+latch.toString());
-            gate.countDown();
-
-            try {
-                latch.await(millis, TimeUnit.MILLISECONDS);
-                System.out.println("with millis - ComingOut");
-            }
-            catch (Throwable result) { result(result); }}};
+    static Awaiter awaiter(CountDownLatch latch,
+                           CountDownLatch gate,
+                           long timeoutMillis) {
+        return new Awaiter() {
+            public void realRun() throws InterruptedException {
+                gate.countDown();
+                latch.await(timeoutMillis, TimeUnit.MILLISECONDS);
+            }};
     }
 
-    AwaiterFactory awaiterFactory(CountDownLatch latch, CountDownLatch gate) {
-        return () -> awaiter(latch, gate);
-    }
-
-    AwaiterFactory timedAwaiterFactory(CountDownLatch latch, CountDownLatch gate) {
-        return () -> awaiter(latch, gate, LONG_DELAY_MS);
+    static Supplier<Awaiter> randomAwaiterSupplier(
+            CountDownLatch latch, CountDownLatch gate) {
+        return () -> (ThreadLocalRandom.current().nextBoolean())
+            ? awaiter(latch, gate)
+            : awaiter(latch, gate, LONG_DELAY_MS);
     }
 
     //----------------------------------------------------------------
@@ -94,28 +81,24 @@
     //----------------------------------------------------------------
     public static void normalUse() throws Throwable {
         int count = 0;
-        Basic test = new Basic();
         CountDownLatch latch = new CountDownLatch(3);
         Awaiter[] a = new Awaiter[12];
 
         for (int i = 0; i < 3; i++) {
             CountDownLatch gate = new CountDownLatch(4);
-            AwaiterFactory factory1 = test.awaiterFactory(latch, gate);
-            AwaiterFactory factory2 = test.timedAwaiterFactory(latch, gate);
-            a[count] = factory1.getAwaiter(); a[count++].start();
-            a[count] = factory1.getAwaiter(); a[count++].start();
-            a[count] = factory2.getAwaiter(); a[count++].start();
-            a[count] = factory2.getAwaiter(); a[count++].start();
-            test.toTheStartingGate(gate);
-            System.out.println("Main Thread: " + latch.toString());
+            Supplier<Awaiter> s = randomAwaiterSupplier(latch, gate);
+            a[count] = s.get(); a[count++].start();
+            a[count] = s.get(); a[count++].start();
+            a[count] = s.get(); a[count++].start();
+            a[count] = s.get(); a[count++].start();
+            gate.await();
             latch.countDown();
             checkCount(latch, 2-i);
         }
-        for (int i = 0; i < 12; i++)
-            a[i].join();
-
-        for (int i = 0; i < 12; i++)
-            checkResult(a[i], null);
+        for (Awaiter awaiter : a)
+            awaiter.join();
+        for (Awaiter awaiter : a)
+            checkException(awaiter, null);
     }
 
     //----------------------------------------------------------------
@@ -123,38 +106,38 @@
     //----------------------------------------------------------------
     public static void threadInterrupted() throws Throwable {
         int count = 0;
-        Basic test = new Basic();
         CountDownLatch latch = new CountDownLatch(3);
         Awaiter[] a = new Awaiter[12];
 
         for (int i = 0; i < 3; i++) {
             CountDownLatch gate = new CountDownLatch(4);
-            AwaiterFactory factory1 = test.awaiterFactory(latch, gate);
-            AwaiterFactory factory2 = test.timedAwaiterFactory(latch, gate);
-            a[count] = factory1.getAwaiter(); a[count++].start();
-            a[count] = factory1.getAwaiter(); a[count++].start();
-            a[count] = factory2.getAwaiter(); a[count++].start();
-            a[count] = factory2.getAwaiter(); a[count++].start();
+            Supplier<Awaiter> s = randomAwaiterSupplier(latch, gate);
+            a[count] = s.get(); a[count++].start();
+            a[count] = s.get(); a[count++].start();
+            a[count] = s.get(); a[count++].start();
+            a[count] = s.get(); a[count++].start();
+            gate.await();
             a[count-1].interrupt();
-            test.toTheStartingGate(gate);
-            System.out.println("Main Thread: " + latch.toString());
             latch.countDown();
             checkCount(latch, 2-i);
         }
-        for (int i = 0; i < 12; i++)
-            a[i].join();
-
-        for (int i = 0; i < 12; i++)
-            checkResult(a[i],
-                        (i % 4) == 3 ? InterruptedException.class : null);
+        for (Awaiter awaiter : a)
+            awaiter.join();
+        for (int i = 0; i < a.length; i++) {
+            Awaiter awaiter = a[i];
+            Throwable ex = awaiter.exception;
+            if ((i % 4) == 3 && !awaiter.interrupted)
+                checkException(awaiter, InterruptedException.class);
+            else
+                checkException(awaiter, null);
+        }
     }
 
     //----------------------------------------------------------------
     // One thread timed out
     //----------------------------------------------------------------
     public static void timeOut() throws Throwable {
-        int count =0;
-        Basic test = new Basic();
+        int count = 0;
         CountDownLatch latch = new CountDownLatch(3);
         Awaiter[] a = new Awaiter[12];
 
@@ -162,54 +145,56 @@
 
         for (int i = 0; i < 3; i++) {
             CountDownLatch gate = new CountDownLatch(4);
-            AwaiterFactory factory1 = test.awaiterFactory(latch, gate);
-            AwaiterFactory factory2 = test.timedAwaiterFactory(latch, gate);
-            a[count] = test.awaiter(latch, gate, timeout[i]); a[count++].start();
-            a[count] = factory1.getAwaiter(); a[count++].start();
-            a[count] = factory2.getAwaiter(); a[count++].start();
-            a[count] = factory2.getAwaiter(); a[count++].start();
-            test.toTheStartingGate(gate);
-            System.out.println("Main Thread: " + latch.toString());
+            Supplier<Awaiter> s = randomAwaiterSupplier(latch, gate);
+            a[count] = awaiter(latch, gate, timeout[i]); a[count++].start();
+            a[count] = s.get(); a[count++].start();
+            a[count] = s.get(); a[count++].start();
+            a[count] = s.get(); a[count++].start();
+            gate.await();
             latch.countDown();
             checkCount(latch, 2-i);
         }
-        for (int i = 0; i < 12; i++)
-            a[i].join();
-
-        for (int i = 0; i < 12; i++)
-            checkResult(a[i], null);
+        for (Awaiter awaiter : a)
+            awaiter.join();
+        for (Awaiter awaiter : a)
+            checkException(awaiter, null);
     }
 
     public static void main(String[] args) throws Throwable {
-        normalUse();
-        threadInterrupted();
-        timeOut();
+        try {
+            normalUse();
+        } catch (Throwable ex) { fail(ex); }
+        try {
+            threadInterrupted();
+        } catch (Throwable ex) { fail(ex); }
+        try {
+            timeOut();
+        } catch (Throwable ex) { fail(ex); }
+
         if (failures.get() > 0L)
             throw new AssertionError(failures.get() + " failures");
     }
 
-    private static final AtomicInteger failures = new AtomicInteger(0);
+    static final AtomicInteger failures = new AtomicInteger(0);
 
-    private static void fail(String msg) {
+    static void fail(String msg) {
         fail(new AssertionError(msg));
     }
 
-    private static void fail(Throwable t) {
+    static void fail(Throwable t) {
         t.printStackTrace();
         failures.getAndIncrement();
     }
 
-    private static void checkCount(CountDownLatch b, int expected) {
+    static void checkCount(CountDownLatch b, int expected) {
         if (b.getCount() != expected)
             fail("Count = " + b.getCount() +
                  ", expected = " + expected);
     }
 
-    private static void checkResult(Awaiter a, Class c) {
-        Throwable t = a.result();
-        if (! ((t == null && c == null) || c.isInstance(t))) {
-            System.out.println("Mismatch: " + t + ", " + c.getName());
-            failures.getAndIncrement();
-        }
+    static void checkException(Awaiter awaiter, Class<? extends Throwable> c) {
+        Throwable ex = awaiter.exception;
+        if (! ((ex == null && c == null) || c.isInstance(ex)))
+            fail("Expected: " + c + ", got: " + ex);
     }
 }