test/jdk/java/net/Socket/Timeouts.java
changeset 55102 59567035d279
parent 54289 6183f835b9b6
child 58679 9c3209ff7550
--- a/test/jdk/java/net/Socket/Timeouts.java	Wed May 29 22:17:48 2019 -0400
+++ b/test/jdk/java/net/Socket/Timeouts.java	Thu May 30 07:19:19 2019 +0100
@@ -23,9 +23,10 @@
 
 /*
  * @test
+ * @bug 8221481
  * @library /test/lib
  * @build jdk.test.lib.Utils
- * @run testng Timeouts
+ * @run testng/timeout=180 Timeouts
  * @summary Test Socket timeouts
  */
 
@@ -34,12 +35,17 @@
 import java.io.InputStream;
 import java.io.OutputStream;
 import java.net.ConnectException;
+import java.net.InetAddress;
+import java.net.InetSocketAddress;
 import java.net.ServerSocket;
 import java.net.Socket;
 import java.net.SocketAddress;
 import java.net.SocketException;
 import java.net.SocketTimeoutException;
 import java.util.concurrent.Executors;
+import java.util.concurrent.ExecutionException;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Future;
 import java.util.concurrent.ScheduledExecutorService;
 import java.util.concurrent.TimeUnit;
 
@@ -54,7 +60,7 @@
      * Test timed connect where connection is established
      */
     public void testTimedConnect1() throws IOException {
-        try (ServerSocket ss = new ServerSocket(0)) {
+        try (ServerSocket ss = boundServerSocket()) {
             try (Socket s = new Socket()) {
                 s.connect(ss.getLocalSocketAddress(), 2000);
             }
@@ -77,7 +83,7 @@
      * Test connect with a timeout of Integer.MAX_VALUE
      */
     public void testTimedConnect3() throws IOException {
-        try (ServerSocket ss = new ServerSocket(0)) {
+        try (ServerSocket ss = boundServerSocket()) {
             try (Socket s = new Socket()) {
                 s.connect(ss.getLocalSocketAddress(), Integer.MAX_VALUE);
             }
@@ -88,12 +94,10 @@
      * Test connect with a negative timeout.
      */
     public void testTimedConnect4() throws IOException {
-        try (ServerSocket ss = new ServerSocket(0)) {
+        try (ServerSocket ss = boundServerSocket()) {
             try (Socket s = new Socket()) {
-                try {
-                    s.connect(ss.getLocalSocketAddress(), -1);
-                    assertTrue(false);
-                } catch (IllegalArgumentException expected) { }
+                expectThrows(IllegalArgumentException.class,
+                             () -> s.connect(ss.getLocalSocketAddress(), -1));
             }
         }
     }
@@ -128,10 +132,10 @@
     public void testTimedRead3() throws IOException {
         withConnection((s1, s2) -> {
             s2.setSoTimeout(2000);
-            try {
-                s2.getInputStream().read();
-                assertTrue(false);
-            } catch (SocketTimeoutException expected) { }
+            long startMillis = millisTime();
+            expectThrows(SocketTimeoutException.class, () -> s2.getInputStream().read());
+            int timeout = s2.getSoTimeout();
+            checkDuration(startMillis, timeout-100, timeout+2000);
         });
     }
 
@@ -141,10 +145,7 @@
     public void testTimedRead4() throws IOException {
         withConnection((s1, s2) -> {
             s2.setSoTimeout(2000);
-            try {
-                s2.getInputStream().read();
-                assertTrue(false);
-            } catch (SocketTimeoutException e) { }
+            expectThrows(SocketTimeoutException.class, () -> s2.getInputStream().read());
             s1.getOutputStream().write(99);
             int b = s2.getInputStream().read();
             assertTrue(b == 99);
@@ -158,10 +159,7 @@
     public void testTimedRead5() throws IOException {
         withConnection((s1, s2) -> {
             s2.setSoTimeout(2000);
-            try {
-                s2.getInputStream().read();
-                assertTrue(false);
-            } catch (SocketTimeoutException e) { }
+            expectThrows(SocketTimeoutException.class, () -> s2.getInputStream().read());
             s2.setSoTimeout(30*3000);
             scheduleWrite(s1.getOutputStream(), 99, 2000);
             int b = s2.getInputStream().read();
@@ -175,10 +173,7 @@
     public void testTimedRead6() throws IOException {
         withConnection((s1, s2) -> {
             s2.setSoTimeout(2000);
-            try {
-                s2.getInputStream().read();
-                assertTrue(false);
-            } catch (SocketTimeoutException e) { }
+            expectThrows(SocketTimeoutException.class, () -> s2.getInputStream().read());
             s1.getOutputStream().write(99);
             s2.setSoTimeout(0);
             int b = s2.getInputStream().read();
@@ -193,10 +188,7 @@
     public void testTimedRead7() throws IOException {
         withConnection((s1, s2) -> {
             s2.setSoTimeout(2000);
-            try {
-                s2.getInputStream().read();
-                assertTrue(false);
-            } catch (SocketTimeoutException e) { }
+            expectThrows(SocketTimeoutException.class, () -> s2.getInputStream().read());
             scheduleWrite(s1.getOutputStream(), 99, 2000);
             s2.setSoTimeout(0);
             int b = s2.getInputStream().read();
@@ -211,10 +203,7 @@
         withConnection((s1, s2) -> {
             s2.setSoTimeout(30*1000);
             scheduleClose(s2, 2000);
-            try {
-                s2.getInputStream().read();
-                assertTrue(false);
-            } catch (SocketException expected) { }
+            expectThrows(SocketException.class, () -> s2.getInputStream().read());
         });
     }
 
@@ -280,7 +269,7 @@
     public void testTimedAccept1() throws IOException {
         Socket s1 = null;
         Socket s2 = null;
-        try (ServerSocket ss = new ServerSocket(0)) {
+        try (ServerSocket ss = boundServerSocket()) {
             s1 = new Socket();
             s1.connect(ss.getLocalSocketAddress());
             ss.setSoTimeout(30*1000);
@@ -295,7 +284,7 @@
      * Test timed accept where a connection is established after a short delay
      */
     public void testTimedAccept2() throws IOException {
-        try (ServerSocket ss = new ServerSocket(0)) {
+        try (ServerSocket ss = boundServerSocket()) {
             ss.setSoTimeout(30*1000);
             scheduleConnect(ss.getLocalSocketAddress(), 2000);
             Socket s = ss.accept();
@@ -307,13 +296,17 @@
      * Test timed accept where the accept times out
      */
     public void testTimedAccept3() throws IOException {
-        try (ServerSocket ss = new ServerSocket(0)) {
+        try (ServerSocket ss = boundServerSocket()) {
             ss.setSoTimeout(2000);
+            long startMillis = millisTime();
             try {
                 Socket s = ss.accept();
                 s.close();
-                assertTrue(false);
-            } catch (SocketTimeoutException expected) { }
+                fail();
+            } catch (SocketTimeoutException expected) {
+                int timeout = ss.getSoTimeout();
+                checkDuration(startMillis, timeout-100, timeout+2000);
+            }
         }
     }
 
@@ -322,12 +315,12 @@
      * previous accept timed out.
      */
     public void testTimedAccept4() throws IOException {
-        try (ServerSocket ss = new ServerSocket(0)) {
+        try (ServerSocket ss = boundServerSocket()) {
             ss.setSoTimeout(2000);
             try {
                 Socket s = ss.accept();
                 s.close();
-                assertTrue(false);
+                fail();
             } catch (SocketTimeoutException expected) { }
             try (Socket s1 = new Socket()) {
                 s1.connect(ss.getLocalSocketAddress());
@@ -342,12 +335,12 @@
      * accept timed out
      */
     public void testTimedAccept5() throws IOException {
-        try (ServerSocket ss = new ServerSocket(0)) {
+        try (ServerSocket ss = boundServerSocket()) {
             ss.setSoTimeout(2000);
             try {
                 Socket s = ss.accept();
                 s.close();
-                assertTrue(false);
+                fail();
             } catch (SocketTimeoutException expected) { }
             ss.setSoTimeout(0);
             try (Socket s1 = new Socket()) {
@@ -363,12 +356,12 @@
      * accept timed out and after a short delay
      */
     public void testTimedAccept6() throws IOException {
-        try (ServerSocket ss = new ServerSocket(0)) {
+        try (ServerSocket ss = boundServerSocket()) {
             ss.setSoTimeout(2000);
             try {
                 Socket s = ss.accept();
                 s.close();
-                assertTrue(false);
+                fail();
             } catch (SocketTimeoutException expected) { }
             ss.setSoTimeout(0);
             scheduleConnect(ss.getLocalSocketAddress(), 2000);
@@ -381,13 +374,134 @@
      * Test async close of a timed accept
      */
     public void testTimedAccept7() throws IOException {
-        try (ServerSocket ss = new ServerSocket(0)) {
+        try (ServerSocket ss = boundServerSocket()) {
             ss.setSoTimeout(30*1000);
-            scheduleClose(ss, 2000);
+            long delay = 2000;
+            scheduleClose(ss, delay);
+            long startMillis = millisTime();
             try {
                 ss.accept().close();
-                assertTrue(false);
-            } catch (SocketException expected) { }
+                fail();
+            } catch (SocketException expected) {
+                checkDuration(startMillis, delay-100, delay+2000);
+            }
+        }
+    }
+
+    /**
+     * Test timed accept with the thread interrupt status set.
+     */
+    public void testTimedAccept8() throws IOException {
+        try (ServerSocket ss = boundServerSocket()) {
+            ss.setSoTimeout(2000);
+            Thread.currentThread().interrupt();
+            long startMillis = millisTime();
+            try {
+                Socket s = ss.accept();
+                s.close();
+                fail();
+            } catch (SocketTimeoutException expected) {
+                // accept should have blocked for 2 seconds
+                int timeout = ss.getSoTimeout();
+                checkDuration(startMillis, timeout-100, timeout+2000);
+                assertTrue(Thread.currentThread().isInterrupted());
+            } finally {
+                Thread.interrupted(); // clear interrupt status
+            }
+        }
+    }
+
+    /**
+     * Test interrupt of thread blocked in timed accept.
+     */
+    public void testTimedAccept9() throws IOException {
+        try (ServerSocket ss = boundServerSocket()) {
+            ss.setSoTimeout(4000);
+            // interrupt thread after 1 second
+            Future<?> interrupter = scheduleInterrupt(Thread.currentThread(), 1000);
+            long startMillis = millisTime();
+            try {
+                Socket s = ss.accept();   // should block for 4 seconds
+                s.close();
+                fail();
+            } catch (SocketTimeoutException expected) {
+                // accept should have blocked for 4 seconds
+                int timeout = ss.getSoTimeout();
+                checkDuration(startMillis, timeout-100, timeout+2000);
+                assertTrue(Thread.currentThread().isInterrupted());
+            } finally {
+                interrupter.cancel(true);
+                Thread.interrupted(); // clear interrupt status
+            }
+        }
+    }
+
+    /**
+     * Test two threads blocked in timed accept where no connection is established.
+     */
+    public void testTimedAccept10() throws Exception {
+        ExecutorService pool = Executors.newFixedThreadPool(2);
+        try (ServerSocket ss = boundServerSocket()) {
+            ss.setSoTimeout(4000);
+
+            long startMillis = millisTime();
+
+            Future<Socket> result1 = pool.submit(ss::accept);
+            Future<Socket> result2 = pool.submit(ss::accept);
+
+            // both tasks should complete with SocketTimeoutException
+            Throwable e = expectThrows(ExecutionException.class, result1::get);
+            assertTrue(e.getCause() instanceof SocketTimeoutException);
+            e = expectThrows(ExecutionException.class, result2::get);
+            assertTrue(e.getCause() instanceof SocketTimeoutException);
+
+            // should get here in 4 seconds, not 8 seconds
+            int timeout = ss.getSoTimeout();
+            checkDuration(startMillis, timeout-100, timeout+2000);
+        } finally {
+            pool.shutdown();
+        }
+    }
+
+    /**
+     * Test two threads blocked in timed accept where one connection is established.
+     */
+    public void testTimedAccept11() throws Exception {
+        ExecutorService pool = Executors.newFixedThreadPool(2);
+        try (ServerSocket ss = boundServerSocket()) {
+            ss.setSoTimeout(4000);
+
+            long startMillis = millisTime();
+
+            Future<Socket> result1 = pool.submit(ss::accept);
+            Future<Socket> result2 = pool.submit(ss::accept);
+
+            // establish connection after 2 seconds
+            scheduleConnect(ss.getLocalSocketAddress(), 2000);
+
+            // one task should have accepted the connection, the other should
+            // have completed with SocketTimeoutException
+            Socket s1 = null;
+            try {
+                s1 = result1.get();
+                s1.close();
+            } catch (ExecutionException e) {
+                assertTrue(e.getCause() instanceof SocketTimeoutException);
+            }
+            Socket s2 = null;
+            try {
+                s2 = result2.get();
+                s2.close();
+            } catch (ExecutionException e) {
+                assertTrue(e.getCause() instanceof SocketTimeoutException);
+            }
+            assertTrue((s1 != null) ^ (s2 != null));
+
+            // should get here in 4 seconds, not 8 seconds
+            int timeout = ss.getSoTimeout();
+            checkDuration(startMillis, timeout-100, timeout+2000);
+        } finally {
+            pool.shutdown();
         }
     }
 
@@ -411,6 +525,19 @@
         }
     }
 
+    /**
+     * Returns a ServerSocket bound to a port on the loopback address
+     */
+    static ServerSocket boundServerSocket() throws IOException {
+        var loopback = InetAddress.getLoopbackAddress();
+        ServerSocket ss = new ServerSocket();
+        ss.bind(new InetSocketAddress(loopback, 0));
+        return ss;
+    }
+
+    /**
+     * An operation that accepts two arguments and may throw IOException
+     */
     interface ThrowingBiConsumer<T, U> {
         void accept(T t, U u) throws IOException;
     }
@@ -423,7 +550,7 @@
     {
         Socket s1 = null;
         Socket s2 = null;
-        try (ServerSocket ss = new ServerSocket(0)) {
+        try (ServerSocket ss = boundServerSocket()) {
             s1 = new Socket();
             s1.connect(ss.getLocalSocketAddress());
             s2 = ss.accept();
@@ -446,6 +573,13 @@
     }
 
     /**
+     * Schedule thread to be interrupted after a delay
+     */
+    static Future<?> scheduleInterrupt(Thread thread, long delay) {
+        return schedule(() -> thread.interrupt(), delay);
+    }
+
+    /**
      * Schedule a thread to connect to the given end point after a delay
      */
     static void scheduleConnect(SocketAddress remote, long delay) {
@@ -482,12 +616,36 @@
         scheduleWrite(out, new byte[] { (byte)b }, delay);
     }
 
-    static void schedule(Runnable task, long delay) {
+    static Future<?> schedule(Runnable task, long delay) {
         ScheduledExecutorService executor = Executors.newSingleThreadScheduledExecutor();
         try {
-            executor.schedule(task, delay, TimeUnit.MILLISECONDS);
+            return executor.schedule(task, delay, TimeUnit.MILLISECONDS);
         } finally {
             executor.shutdown();
         }
     }
+
+    /**
+     * Returns the current time in milliseconds.
+     */
+    private static long millisTime() {
+        long now = System.nanoTime();
+        return TimeUnit.MILLISECONDS.convert(now, TimeUnit.NANOSECONDS);
+    }
+
+    /**
+     * Check the duration of a task
+     * @param start start time, in milliseconds
+     * @param min minimum expected duration, in milliseconds
+     * @param max maximum expected duration, in milliseconds
+     * @return the duration (now - start), in milliseconds
+     */
+    private static long checkDuration(long start, long min, long max) {
+        long duration = millisTime() - start;
+        assertTrue(duration >= min,
+                "Duration " + duration + "ms, expected >= " + min + "ms");
+        assertTrue(duration <= max,
+                "Duration " + duration + "ms, expected <= " + max + "ms");
+        return duration;
+    }
 }