test/jdk/java/net/MulticastSocket/UnreferencedMulticastSockets.java
changeset 58423 54de0c861d32
parent 55047 3131927311ee
--- a/test/jdk/java/net/MulticastSocket/UnreferencedMulticastSockets.java	Tue Oct 01 12:27:14 2019 +0200
+++ b/test/jdk/java/net/MulticastSocket/UnreferencedMulticastSockets.java	Tue Oct 01 12:10:33 2019 +0100
@@ -50,6 +50,7 @@
 import java.util.ArrayDeque;
 import java.util.List;
 import java.util.Optional;
+import java.util.concurrent.Phaser;
 import java.util.concurrent.TimeUnit;
 
 import jdk.test.lib.net.IPSupport;
@@ -72,11 +73,14 @@
     static class Server implements Runnable {
 
         MulticastSocket ss;
-
+        final int port;
+        final Phaser phaser = new Phaser(2);
         Server() throws IOException {
+            InetAddress loopback = InetAddress.getLoopbackAddress();
             InetSocketAddress serverAddress =
-                new InetSocketAddress(InetAddress.getLoopbackAddress(), 0);
+                new InetSocketAddress(loopback, 0);
             ss = new MulticastSocket(serverAddress);
+            port = ss.getLocalPort();
             System.out.printf("  DatagramServer addr: %s: %d%n",
                     this.getHost(), this.getPort());
             pendingSockets.add(new NamedWeak(ss, pendingQueue, "serverMulticastSocket"));
@@ -89,7 +93,7 @@
         }
 
         int getPort() {
-            return ss.getLocalPort();
+            return port;
         }
 
         // Receive a byte and send back a byte
@@ -98,12 +102,18 @@
                 byte[] buffer = new byte[50];
                 DatagramPacket p = new DatagramPacket(buffer, buffer.length);
                 ss.receive(p);
+                System.out.printf("Server: ping received from: %s%n", p.getSocketAddress());
+                phaser.arriveAndAwaitAdvance(); // await the client...
                 buffer[0] += 1;
+                System.out.printf("Server: sending echo to: %s%n", p.getSocketAddress());
                 ss.send(p);         // send back +1
 
+                System.out.printf("Server: awaiting client%n");
+                phaser.arriveAndAwaitAdvance(); // await the client...
                 // do NOT close but 'forget' the socket reference
+                System.out.printf("Server: forgetting socket...%n");
                 ss = null;
-            } catch (Exception ioe) {
+            } catch (Throwable ioe) {
                 ioe.printStackTrace();
             }
         }
@@ -112,8 +122,11 @@
     public static void main(String args[]) throws Exception {
         IPSupport.throwSkippedExceptionIfNonOperational();
 
+        InetSocketAddress clientAddress =
+                new InetSocketAddress(InetAddress.getLoopbackAddress(), 0);
+
         // Create and close a MulticastSocket to warm up the FD count for side effects.
-        try (MulticastSocket s = new MulticastSocket(0)) {
+        try (MulticastSocket s = new MulticastSocket(clientAddress)) {
             // no-op; close immediately
             s.getLocalPort();   // no-op
         }
@@ -126,8 +139,33 @@
         Thread thr = new Thread(svr);
         thr.start();
 
-        MulticastSocket client = new MulticastSocket(0);
-        System.out.printf("  client bound port: %d%n", client.getLocalPort());
+        // It is possible under some circumstances that the client
+        // might get bound to the same port than the server: this
+        // would make the test fail - so if this happen we try to
+        // bind to a specific port by incrementing the server port.
+        MulticastSocket client = null;
+        int serverPort = svr.getPort();
+        int maxtries = 20;
+        for (int i = 0; i < maxtries; i++) {
+            try {
+                System.out.printf("Trying to bind client to: %s%n", clientAddress);
+                client = new MulticastSocket(clientAddress);
+                if (client.getLocalPort() != svr.getPort()) break;
+                client.close();
+            } catch (IOException x) {
+                System.out.printf("Couldn't create client after %d attempts: %s%n", i, x);
+                if (i == maxtries) throw x;
+            }
+            if (i == maxtries) {
+                String msg = String.format("Couldn't create client after %d attempts", i);
+                System.out.println(msg);
+                throw new AssertionError(msg);
+            }
+            clientAddress = new InetSocketAddress(clientAddress.getAddress(), serverPort + i);
+        }
+
+        System.out.printf("  client bound port: %s:%d%n",
+                client.getLocalAddress(), client.getLocalPort());
         client.connect(svr.getHost(), svr.getPort());
         pendingSockets.add(new NamedWeak(client, pendingQueue, "clientMulticastSocket"));
         extractRefs(client, "clientMulticastSocket");
@@ -136,14 +174,17 @@
         msg[0] = 1;
         DatagramPacket p = new DatagramPacket(msg, msg.length, svr.getHost(), svr.getPort());
         client.send(p);
+        System.out.printf("  ping sent to: %s:%d%n", svr.getHost(), svr.getPort());
+        svr.phaser.arriveAndAwaitAdvance(); // wait until the server has received its packet
 
         p = new DatagramPacket(msg, msg.length);
         client.receive(p);
 
-        System.out.printf("echo received from: %s%n", p.getSocketAddress());
+        System.out.printf("  echo received from: %s%n", p.getSocketAddress());
         if (msg[0] != 2) {
             throw new AssertionError("incorrect data received: expected: 2, actual: " + msg[0]);
         }
+        svr.phaser.arriveAndAwaitAdvance(); // let the server null out its socket
 
         // Do NOT close the MulticastSocket; forget it