--- a/src/java.base/share/classes/sun/nio/ch/DatagramChannelImpl.java Wed Nov 20 09:12:07 2019 +0100
+++ b/src/java.base/share/classes/sun/nio/ch/DatagramChannelImpl.java Wed Nov 20 08:35:53 2019 +0000
@@ -31,6 +31,7 @@
import java.lang.invoke.MethodHandles;
import java.lang.invoke.VarHandle;
import java.lang.ref.Cleaner.Cleanable;
+import java.lang.reflect.Method;
import java.net.DatagramSocket;
import java.net.Inet4Address;
import java.net.Inet6Address;
@@ -54,12 +55,18 @@
import java.nio.channels.MembershipKey;
import java.nio.channels.NotYetConnectedException;
import java.nio.channels.SelectionKey;
+import java.nio.channels.spi.AbstractSelectableChannel;
import java.nio.channels.spi.SelectorProvider;
+import java.security.AccessController;
+import java.security.PrivilegedExceptionAction;
import java.util.Collections;
+import java.util.HashMap;
import java.util.HashSet;
+import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.concurrent.locks.ReentrantLock;
+import java.util.function.Consumer;
import jdk.internal.ref.CleanerFactory;
import sun.net.ResourceManager;
@@ -113,10 +120,13 @@
private long readerThread;
private long writerThread;
- // Binding and remote address (when connected)
+ // Local and remote (connected) address
private InetSocketAddress localAddress;
private InetSocketAddress remoteAddress;
+ // Local address prior to connecting
+ private InetSocketAddress initialLocalAddress;
+
// Socket adaptor, created lazily
private static final VarHandle SOCKET;
static {
@@ -1103,6 +1113,9 @@
bindInternal(null);
}
+ // capture local address before connect
+ initialLocalAddress = localAddress;
+
int n = Net.connect(family,
fd,
isa.getAddress(),
@@ -1160,21 +1173,19 @@
remoteAddress = null;
state = ST_UNCONNECTED;
- // check whether rebind is needed
- InetSocketAddress isa = Net.localAddress(fd);
- if (isa.getPort() == 0) {
- // On Linux, if bound to ephemeral port,
- // disconnect does not preserve that port.
- // In this case, try to rebind to the previous port.
- int port = localAddress.getPort();
- localAddress = isa; // in case Net.bind fails
- Net.bind(family, fd, isa.getAddress(), port);
- isa = Net.localAddress(fd); // refresh address
- assert isa.getPort() == port;
+ // refresh localAddress, should be same as it was prior to connect
+ localAddress = Net.localAddress(fd);
+ try {
+ if (!localAddress.equals(initialLocalAddress)) {
+ // Workaround connect(2) issues on Linux and macOS
+ repairSocket(initialLocalAddress);
+ assert (localAddress != null)
+ && localAddress.equals(Net.localAddress(fd))
+ && localAddress.equals(initialLocalAddress);
+ }
+ } finally {
+ initialLocalAddress = null;
}
-
- // refresh localAddress
- localAddress = isa;
}
} finally {
writeLock.unlock();
@@ -1186,6 +1197,134 @@
}
/**
+ * "Repair" the channel's socket after a disconnect that didn't restore the
+ * local address.
+ *
+ * On Linux, connect(2) dissolves the association but changes the local port
+ * to 0 when it was initially bound to an ephemeral port. The workaround here
+ * is to rebind to the original port.
+ *
+ * On macOS, connect(2) dissolves the association but rebinds the socket to
+ * the wildcard address when it was initially bound to a specific address.
+ * The workaround here is to re-create the socket.
+ */
+ private void repairSocket(InetSocketAddress target)
+ throws IOException
+ {
+ assert Thread.holdsLock(stateLock);
+
+ // Linux: try to bind the socket to the original address/port
+ if (localAddress.getPort() == 0) {
+ assert localAddress.getAddress().equals(target.getAddress());
+ Net.bind(family, fd, target.getAddress(), target.getPort());
+ localAddress = Net.localAddress(fd);
+ return;
+ }
+
+ // capture the value of all existing socket options
+ Map<SocketOption<?>, Object> map = new HashMap<>();
+ for (SocketOption<?> option : supportedOptions()) {
+ Object value = getOption(option);
+ if (value != null) {
+ map.put(option, value);
+ }
+ }
+
+ // macOS: re-create the socket.
+ FileDescriptor newfd = Net.socket(family, false);
+ try {
+ // copy the socket options that are protocol family agnostic
+ for (Map.Entry<SocketOption<?>, Object> e : map.entrySet()) {
+ SocketOption<?> option = e.getKey();
+ if (SocketOptionRegistry.findOption(option, Net.UNSPEC) != null) {
+ Object value = e.getValue();
+ try {
+ Net.setSocketOption(newfd, Net.UNSPEC, option, value);
+ } catch (IOException ignore) { }
+ }
+ }
+
+ // copy the blocking mode
+ if (!isBlocking()) {
+ IOUtil.configureBlocking(newfd, false);
+ }
+
+ // dup this channel's socket to the new socket. If this succeeds then
+ // fd will reference the new socket. If it fails then it will still
+ // reference the old socket.
+ nd.dup(newfd, fd);
+ } finally {
+ // release the file descriptor
+ nd.close(newfd);
+ }
+
+ // bind to the original local address
+ try {
+ Net.bind(family, fd, target.getAddress(), target.getPort());
+ } catch (IOException ioe) {
+ // bind failed, socket is left unbound
+ localAddress = null;
+ throw ioe;
+ }
+
+ // restore local address
+ localAddress = Net.localAddress(fd);
+
+ // restore all socket options (including those set in first pass)
+ for (Map.Entry<SocketOption<?>, Object> e : map.entrySet()) {
+ @SuppressWarnings("unchecked")
+ SocketOption<Object> option = (SocketOption<Object>) e.getKey();
+ Object value = e.getValue();
+ try {
+ setOption(option, value);
+ } catch (IOException ignore) { }
+ }
+
+ // restore multicast group membership
+ MembershipRegistry registry = this.registry;
+ if (registry != null) {
+ registry.forEach(k -> {
+ if (k instanceof MembershipKeyImpl.Type6) {
+ MembershipKeyImpl.Type6 key6 = (MembershipKeyImpl.Type6) k;
+ Net.join6(fd, key6.groupAddress(), key6.index(), key6.source());
+ } else {
+ MembershipKeyImpl.Type4 key4 = (MembershipKeyImpl.Type4) k;
+ Net.join4(fd, key4.groupAddress(), key4.interfaceAddress(), key4.source());
+ }
+ });
+ }
+
+ // reset registration in all Selectors that this channel is registered with
+ AbstractSelectableChannels.forEach(this, SelectionKeyImpl::reset);
+ }
+
+ /**
+ * Defines static methods to access AbstractSelectableChannel non-public members.
+ */
+ private static class AbstractSelectableChannels {
+ private static final Method FOREACH;
+ static {
+ try {
+ PrivilegedExceptionAction<Method> pae = () -> {
+ Method m = AbstractSelectableChannel.class.getDeclaredMethod("forEach", Consumer.class);
+ m.setAccessible(true);
+ return m;
+ };
+ FOREACH = AccessController.doPrivileged(pae);
+ } catch (Exception e) {
+ throw new InternalError(e);
+ }
+ }
+ static void forEach(AbstractSelectableChannel ch, Consumer<SelectionKeyImpl> action) {
+ try {
+ FOREACH.invoke(ch, action);
+ } catch (Exception e) {
+ throw new InternalError(e);
+ }
+ }
+ }
+
+ /**
* Joins channel's socket to the given group/interface and
* optional source address.
*/