src/java.base/share/classes/sun/nio/ch/DatagramChannelImpl.java
changeset 59146 455612b3161a
parent 59000 612c58965775
child 59329 289000934908
--- a/src/java.base/share/classes/sun/nio/ch/DatagramChannelImpl.java	Wed Nov 20 09:12:07 2019 +0100
+++ b/src/java.base/share/classes/sun/nio/ch/DatagramChannelImpl.java	Wed Nov 20 08:35:53 2019 +0000
@@ -31,6 +31,7 @@
 import java.lang.invoke.MethodHandles;
 import java.lang.invoke.VarHandle;
 import java.lang.ref.Cleaner.Cleanable;
+import java.lang.reflect.Method;
 import java.net.DatagramSocket;
 import java.net.Inet4Address;
 import java.net.Inet6Address;
@@ -54,12 +55,18 @@
 import java.nio.channels.MembershipKey;
 import java.nio.channels.NotYetConnectedException;
 import java.nio.channels.SelectionKey;
+import java.nio.channels.spi.AbstractSelectableChannel;
 import java.nio.channels.spi.SelectorProvider;
+import java.security.AccessController;
+import java.security.PrivilegedExceptionAction;
 import java.util.Collections;
+import java.util.HashMap;
 import java.util.HashSet;
+import java.util.Map;
 import java.util.Objects;
 import java.util.Set;
 import java.util.concurrent.locks.ReentrantLock;
+import java.util.function.Consumer;
 
 import jdk.internal.ref.CleanerFactory;
 import sun.net.ResourceManager;
@@ -113,10 +120,13 @@
     private long readerThread;
     private long writerThread;
 
-    // Binding and remote address (when connected)
+    // Local and remote (connected) address
     private InetSocketAddress localAddress;
     private InetSocketAddress remoteAddress;
 
+    // Local address prior to connecting
+    private InetSocketAddress initialLocalAddress;
+
     // Socket adaptor, created lazily
     private static final VarHandle SOCKET;
     static {
@@ -1103,6 +1113,9 @@
                         bindInternal(null);
                     }
 
+                    // capture local address before connect
+                    initialLocalAddress = localAddress;
+
                     int n = Net.connect(family,
                                         fd,
                                         isa.getAddress(),
@@ -1160,21 +1173,19 @@
                     remoteAddress = null;
                     state = ST_UNCONNECTED;
 
-                    // check whether rebind is needed
-                    InetSocketAddress isa = Net.localAddress(fd);
-                    if (isa.getPort() == 0) {
-                        // On Linux, if bound to ephemeral port,
-                        // disconnect does not preserve that port.
-                        // In this case, try to rebind to the previous port.
-                        int port = localAddress.getPort();
-                        localAddress = isa; // in case Net.bind fails
-                        Net.bind(family, fd, isa.getAddress(), port);
-                        isa = Net.localAddress(fd); // refresh address
-                        assert isa.getPort() == port;
+                    // refresh localAddress, should be same as it was prior to connect
+                    localAddress = Net.localAddress(fd);
+                    try {
+                        if (!localAddress.equals(initialLocalAddress)) {
+                            // Workaround connect(2) issues on Linux and macOS
+                            repairSocket(initialLocalAddress);
+                            assert (localAddress != null)
+                                    && localAddress.equals(Net.localAddress(fd))
+                                    && localAddress.equals(initialLocalAddress);
+                        }
+                    } finally {
+                        initialLocalAddress = null;
                     }
-
-                    // refresh localAddress
-                    localAddress = isa;
                 }
             } finally {
                 writeLock.unlock();
@@ -1186,6 +1197,134 @@
     }
 
     /**
+     * "Repair" the channel's socket after a disconnect that didn't restore the
+     * local address.
+     *
+     * On Linux, connect(2) dissolves the association but changes the local port
+     * to 0 when it was initially bound to an ephemeral port. The workaround here
+     * is to rebind to the original port.
+     *
+     * On macOS, connect(2) dissolves the association but rebinds the socket to
+     * the wildcard address when it was initially bound to a specific address.
+     * The workaround here is to re-create the socket.
+     */
+    private void repairSocket(InetSocketAddress target)
+        throws IOException
+    {
+        assert Thread.holdsLock(stateLock);
+
+        // Linux: try to bind the socket to the original address/port
+        if (localAddress.getPort() == 0) {
+            assert localAddress.getAddress().equals(target.getAddress());
+            Net.bind(family, fd, target.getAddress(), target.getPort());
+            localAddress = Net.localAddress(fd);
+            return;
+        }
+
+        // capture the value of all existing socket options
+        Map<SocketOption<?>, Object> map = new HashMap<>();
+        for (SocketOption<?> option : supportedOptions()) {
+            Object value = getOption(option);
+            if (value != null) {
+                map.put(option, value);
+            }
+        }
+
+        // macOS: re-create the socket.
+        FileDescriptor newfd = Net.socket(family, false);
+        try {
+            // copy the socket options that are protocol family agnostic
+            for (Map.Entry<SocketOption<?>, Object> e : map.entrySet()) {
+                SocketOption<?> option = e.getKey();
+                if (SocketOptionRegistry.findOption(option, Net.UNSPEC) != null) {
+                    Object value = e.getValue();
+                    try {
+                        Net.setSocketOption(newfd, Net.UNSPEC, option, value);
+                    } catch (IOException ignore) { }
+                }
+            }
+
+            // copy the blocking mode
+            if (!isBlocking()) {
+                IOUtil.configureBlocking(newfd, false);
+            }
+
+            // dup this channel's socket to the new socket. If this succeeds then
+            // fd will reference the new socket. If it fails then it will still
+            // reference the old socket.
+            nd.dup(newfd, fd);
+        } finally {
+            // release the file descriptor
+            nd.close(newfd);
+        }
+
+        // bind to the original local address
+        try {
+            Net.bind(family, fd, target.getAddress(), target.getPort());
+        } catch (IOException ioe) {
+            // bind failed, socket is left unbound
+            localAddress = null;
+            throw ioe;
+        }
+
+        // restore local address
+        localAddress = Net.localAddress(fd);
+
+        // restore all socket options (including those set in first pass)
+        for (Map.Entry<SocketOption<?>, Object> e : map.entrySet()) {
+            @SuppressWarnings("unchecked")
+            SocketOption<Object> option = (SocketOption<Object>) e.getKey();
+            Object value = e.getValue();
+            try {
+                setOption(option, value);
+            } catch (IOException ignore) { }
+        }
+
+        // restore multicast group membership
+        MembershipRegistry registry = this.registry;
+        if (registry != null) {
+            registry.forEach(k -> {
+                if (k instanceof MembershipKeyImpl.Type6) {
+                    MembershipKeyImpl.Type6 key6 = (MembershipKeyImpl.Type6) k;
+                    Net.join6(fd, key6.groupAddress(), key6.index(), key6.source());
+                } else {
+                    MembershipKeyImpl.Type4 key4 = (MembershipKeyImpl.Type4) k;
+                    Net.join4(fd, key4.groupAddress(), key4.interfaceAddress(), key4.source());
+                }
+            });
+        }
+
+        // reset registration in all Selectors that this channel is registered with
+        AbstractSelectableChannels.forEach(this, SelectionKeyImpl::reset);
+    }
+
+    /**
+     * Defines static methods to access AbstractSelectableChannel non-public members.
+     */
+    private static class AbstractSelectableChannels {
+        private static final Method FOREACH;
+        static {
+            try {
+                PrivilegedExceptionAction<Method> pae = () -> {
+                    Method m = AbstractSelectableChannel.class.getDeclaredMethod("forEach", Consumer.class);
+                    m.setAccessible(true);
+                    return m;
+                };
+                FOREACH = AccessController.doPrivileged(pae);
+            } catch (Exception e) {
+                throw new InternalError(e);
+            }
+        }
+        static void forEach(AbstractSelectableChannel ch, Consumer<SelectionKeyImpl> action) {
+            try {
+                FOREACH.invoke(ch, action);
+            } catch (Exception e) {
+                throw new InternalError(e);
+            }
+        }
+    }
+
+    /**
      * Joins channel's socket to the given group/interface and
      * optional source address.
      */