Eliminate need to cache traffic class niosocketimpl-branch
authoralanb
Sun, 21 Apr 2019 07:05:04 +0100
branchniosocketimpl-branch
changeset 57336 766140c67efa
parent 57322 4744fdcf458c
child 57338 8684e6479b20
Eliminate need to cache traffic class
src/java.base/share/classes/java/net/Socket.java
src/java.base/share/classes/sun/nio/ch/NioSocketImpl.java
src/java.base/share/classes/sun/nio/ch/SocketChannelImpl.java
test/jdk/java/net/SocketOption/OptionsTest.java
--- a/src/java.base/share/classes/java/net/Socket.java	Sat Apr 13 07:23:18 2019 +0100
+++ b/src/java.base/share/classes/java/net/Socket.java	Sun Apr 21 07:05:04 2019 +0100
@@ -1469,8 +1469,9 @@
         try {
             getImpl().setOption(SocketOptions.IP_TOS, tc);
         } catch (SocketException se) {
-            // may not be supported to change when socket is connected
-            if (!isConnected())
+            // not supported if socket already connected
+            // Solaris returns error in such cases
+            if(!isConnected())
                 throw se;
         }
     }
@@ -1858,14 +1859,7 @@
      * @since 9
      */
     public <T> Socket setOption(SocketOption<T> name, T value) throws IOException {
-        try {
-            getImpl().setOption(name, value);
-        } catch (SocketException se) {
-            // may not be supported to change when socket is connected
-            if (name != StandardSocketOptions.IP_TOS || !isConnected()) {
-                throw se;
-            }
-        }
+        getImpl().setOption(name, value);
         return this;
     }
 
--- a/src/java.base/share/classes/sun/nio/ch/NioSocketImpl.java	Sat Apr 13 07:23:18 2019 +0100
+++ b/src/java.base/share/classes/sun/nio/ch/NioSocketImpl.java	Sun Apr 21 07:05:04 2019 +0100
@@ -115,9 +115,6 @@
     // used when SO_REUSEADDR is emulated, protected by stateLock
     private boolean isReuseAddress;
 
-    // cached value of IPV6_TCLASS or IP_TOS socket option, protected by stateLock
-    private int trafficClass;
-
     // read or accept timeout in millis
     private volatile int timeout;
 
@@ -198,7 +195,13 @@
     private void configureNonBlocking(FileDescriptor fd) throws IOException {
         if (!nonBlocking) {
             assert readLock.isHeldByCurrentThread() || writeLock.isHeldByCurrentThread();
-            IOUtil.configureBlocking(fd, false);
+            stateLock.lock();
+            try {
+                ensureOpen();
+                IOUtil.configureBlocking(fd, false);
+            } finally {
+                stateLock.unlock();
+            }
             nonBlocking = true;
         }
     }
@@ -980,9 +983,7 @@
             ensureOpen();
             if (opt == StandardSocketOptions.IP_TOS) {
                 // maps to IP_TOS or IPV6_TCLASS
-                int i = (int) value;
-                Net.setSocketOption(fd, family(), opt, i);
-                trafficClass = i;
+                Net.setSocketOption(fd, family(), opt, value);
             } else if (opt == StandardSocketOptions.SO_REUSEADDR) {
                 boolean b = (boolean) value;
                 if (Net.useExclusiveBind()) {
@@ -1007,7 +1008,7 @@
         try {
             ensureOpen();
             if (opt == StandardSocketOptions.IP_TOS) {
-                return (T) Integer.valueOf(trafficClass);
+                return (T) Net.getSocketOption(fd, family(), opt);
             } else if (opt == StandardSocketOptions.SO_REUSEADDR) {
                 if (Net.useExclusiveBind()) {
                     return (T) Boolean.valueOf(isReuseAddress);
@@ -1063,7 +1064,6 @@
                 case IP_TOS: {
                     int i = intValue(value, "IP_TOS");
                     Net.setSocketOption(fd, family(), StandardSocketOptions.IP_TOS, i);
-                    trafficClass = i;
                     break;
                 }
                 case TCP_NODELAY: {
@@ -1159,7 +1159,7 @@
                 case SO_RCVBUF:
                     return Net.getSocketOption(fd, StandardSocketOptions.SO_RCVBUF);
                 case IP_TOS:
-                    return trafficClass;
+                    return Net.getSocketOption(fd, family(), StandardSocketOptions.IP_TOS);
                 case SO_KEEPALIVE:
                     return Net.getSocketOption(fd, StandardSocketOptions.SO_KEEPALIVE);
                 case SO_REUSEPORT:
--- a/src/java.base/share/classes/sun/nio/ch/SocketChannelImpl.java	Sat Apr 13 07:23:18 2019 +0100
+++ b/src/java.base/share/classes/sun/nio/ch/SocketChannelImpl.java	Sun Apr 21 07:05:04 2019 +0100
@@ -792,15 +792,13 @@
                     boolean connected = false;
                     try {
                         beginFinishConnect(blocking);
-                        boolean polled;
                         if (blocking) {
                             do {
-                                polled = Net.pollConnect(fd, -1);
-                            } while (!polled && isOpen());
+                                connected = Net.pollConnect(fd, -1);
+                            } while (!connected && isOpen());
                         } else {
-                            polled = Net.pollConnect(fd, 0);
+                            connected = Net.pollConnect(fd, 0);
                         }
-                        connected = polled && isOpen();
                     } finally {
                         endFinishConnect(blocking, connected);
                     }
--- a/test/jdk/java/net/SocketOption/OptionsTest.java	Sat Apr 13 07:23:18 2019 +0100
+++ b/test/jdk/java/net/SocketOption/OptionsTest.java	Sun Apr 21 07:05:04 2019 +0100
@@ -36,23 +36,24 @@
 
 public class OptionsTest {
 
-    static class Test {
-        Test(SocketOption<?> option, Object testValue) {
+    static class Test<T> {
+        final SocketOption<T> option;
+        final T value;
+        Test(SocketOption<T> option, T value) {
             this.option = option;
-            this.testValue = testValue;
+            this.value = value;
         }
-        static Test create (SocketOption<?> option, Object testValue) {
-            return new Test(option, testValue);
+        static <T> Test<T> create(SocketOption<T> option, T value) {
+            return new Test<T>(option, value);
         }
-        Object option;
-        Object testValue;
+
     }
 
     // The tests set the option using the new API, read back the set value
     // which could be diferent, and then use the legacy get API to check
     // these values are the same
 
-    static Test[] socketTests = new Test[] {
+    static Test<?>[] socketTests = new Test<?>[] {
         Test.create(StandardSocketOptions.SO_KEEPALIVE, Boolean.TRUE),
         Test.create(StandardSocketOptions.SO_SNDBUF, Integer.valueOf(10 * 100)),
         Test.create(StandardSocketOptions.SO_RCVBUF, Integer.valueOf(8 * 100)),
@@ -62,14 +63,14 @@
         Test.create(StandardSocketOptions.IP_TOS, Integer.valueOf(100))
     };
 
-    static Test[] serverSocketTests = new Test[] {
+    static Test<?>[] serverSocketTests = new Test<?>[] {
         Test.create(StandardSocketOptions.SO_RCVBUF, Integer.valueOf(8 * 100)),
         Test.create(StandardSocketOptions.SO_REUSEADDR, Boolean.FALSE),
         Test.create(StandardSocketOptions.SO_REUSEPORT, Boolean.FALSE),
         Test.create(StandardSocketOptions.IP_TOS, Integer.valueOf(100))
     };
 
-    static Test[] dgSocketTests = new Test[] {
+    static Test<?>[] datagramSocketTests = new Test<?>[] {
         Test.create(StandardSocketOptions.SO_SNDBUF, Integer.valueOf(10 * 100)),
         Test.create(StandardSocketOptions.SO_RCVBUF, Integer.valueOf(8 * 100)),
         Test.create(StandardSocketOptions.SO_REUSEADDR, Boolean.FALSE),
@@ -77,7 +78,7 @@
         Test.create(StandardSocketOptions.IP_TOS, Integer.valueOf(100))
     };
 
-    static Test[] mcSocketTests = new Test[] {
+    static Test<?>[] multicastSocketTests = new Test<?>[] {
         Test.create(StandardSocketOptions.IP_MULTICAST_IF, getNetworkInterface()),
         Test.create(StandardSocketOptions.IP_MULTICAST_TTL, Integer.valueOf(10)),
         Test.create(StandardSocketOptions.IP_MULTICAST_LOOP, Boolean.TRUE)
@@ -87,7 +88,7 @@
         try {
             Enumeration<NetworkInterface> nifs = NetworkInterface.getNetworkInterfaces();
             while (nifs.hasMoreElements()) {
-                NetworkInterface ni = (NetworkInterface)nifs.nextElement();
+                NetworkInterface ni = nifs.nextElement();
                 if (ni.supportsMulticast()) {
                     return ni;
                 }
@@ -97,99 +98,108 @@
         return null;
     }
 
+    static boolean okayToTest(Socket s, SocketOption<?> option) {
+        if (option == StandardSocketOptions.SO_REUSEPORT) {
+            // skip SO_REUSEPORT if option is not supported
+            return s.supportedOptions().contains(StandardSocketOptions.SO_REUSEPORT);
+        }
+        if (option == StandardSocketOptions.IP_TOS && s.isConnected()) {
+            // skip IP_TOS if connected
+            return false;
+        }
+        return true;
+    }
+
+    static <T> void testEqual(SocketOption<T> option, T value1, T value2) {
+        if (!value1.equals(value2)) {
+            throw new RuntimeException("Test of " + option.name() + " failed: "
+                    + value1 + " != " + value2);
+        }
+    }
+
+    static <T> void test(Socket s, Test<T> test) throws Exception {
+        SocketOption<T> option = test.option;
+        s.setOption(option, test.value);
+        T value1 = s.getOption(test.option);
+        T value2 = (T) legacyGetOption(Socket.class, s, test.option);
+        testEqual(option, value1, value2);
+    }
+
+    static <T> void test(ServerSocket ss, Test<T> test) throws Exception {
+        SocketOption<T> option = test.option;
+        ss.setOption(option, test.value);
+        T value1 = ss.getOption(test.option);
+        T value2 = (T) legacyGetOption(ServerSocket.class, ss, test.option);
+        testEqual(option, value1, value2);
+    }
+
+    static <T> void test(DatagramSocket ds, Test<T> test) throws Exception {
+        SocketOption<T> option = test.option;
+        ds.setOption(option, test.value);
+        T value1 = ds.getOption(test.option);
+        T value2 = (T) legacyGetOption(ds.getClass(), ds, test.option);
+        testEqual(option, value1, value2);
+    }
+
+    @SuppressWarnings("try")
     static void doSocketTests() throws Exception {
-        try (
-            ServerSocket srv = new ServerSocket(0);
-            Socket c = new Socket(InetAddress.getLoopbackAddress(), srv.getLocalPort());
-            Socket s = srv.accept();
-        ) {
-            Set<SocketOption<?>> options = c.supportedOptions();
-            boolean reuseport = options.contains(StandardSocketOptions.SO_REUSEPORT);
-            for (int i=0; i<socketTests.length; i++) {
-                Test test = socketTests[i];
-                if (!(test.option == StandardSocketOptions.SO_REUSEPORT && !reuseport)) {
-                    c.setOption((SocketOption)test.option, test.testValue);
-                    Object getval = c.getOption((SocketOption)test.option);
-                    Object legacyget = legacyGetOption(Socket.class, c,test.option);
-                    if (!getval.equals(legacyget)) {
-                        Formatter f = new Formatter();
-                        f.format("S Err %d: %s/%s", i, getval, legacyget);
-                        throw new RuntimeException(f.toString());
-                    }
+        // unconnected socket
+        try (Socket s = new Socket()) {
+            for (Test<?> test : socketTests) {
+                if (okayToTest(s, test.option)) {
+                    test(s, test);
                 }
             }
         }
-    }
 
-    static void doDgSocketTests() throws Exception {
-        try (
-            DatagramSocket c = new DatagramSocket(0);
-        ) {
-            Set<SocketOption<?>> options = c.supportedOptions();
-            boolean reuseport = options.contains(StandardSocketOptions.SO_REUSEPORT);
-            for (int i=0; i<dgSocketTests.length; i++) {
-                Test test = dgSocketTests[i];
-                if (!(test.option == StandardSocketOptions.SO_REUSEPORT && !reuseport)) {
-                    c.setOption((SocketOption)test.option, test.testValue);
-                    Object getval = c.getOption((SocketOption)test.option);
-                    Object legacyget = legacyGetOption(DatagramSocket.class, c,test.option);
-                    if (!getval.equals(legacyget)) {
-                        Formatter f = new Formatter();
-                        f.format("DG Err %d: %s/%s", i, getval, legacyget);
-                        throw new RuntimeException(f.toString());
+        // connected socket
+        try (ServerSocket ss = new ServerSocket(0)) {
+            try (Socket s1 = new Socket()) {
+                s1.connect(ss.getLocalSocketAddress());
+                try (Socket s2 = ss.accept()) {
+                    for (Test<?> test : socketTests) {
+                        if (okayToTest(s1, test.option)) {
+                            test(s1, test);
+                        }
                     }
                 }
             }
         }
     }
 
-    static void doMcSocketTests() throws Exception {
-        try (
-            MulticastSocket c = new MulticastSocket(0);
-        ) {
-            for (int i=0; i<mcSocketTests.length; i++) {
-                Test test = mcSocketTests[i];
-                c.setOption((SocketOption)test.option, test.testValue);
-                Object getval = c.getOption((SocketOption)test.option);
-                Object legacyget = legacyGetOption(MulticastSocket.class, c,test.option);
-                if (!getval.equals(legacyget)) {
-                    Formatter f = new Formatter();
-                    f.format("MC Err %d: %s/%s", i, getval, legacyget);
-                    throw new RuntimeException(f.toString());
+    static void doServerSocketTests() throws Exception {
+        try (ServerSocket ss = new ServerSocket(0)) {
+            Set<SocketOption<?>> options = ss.supportedOptions();
+            boolean reuseport = options.contains(StandardSocketOptions.SO_REUSEPORT);
+            for (Test<?> test : serverSocketTests) {
+                if (!(test.option == StandardSocketOptions.SO_REUSEPORT && !reuseport)) {
+                    test(ss, test);
                 }
             }
         }
     }
 
-    static void doServerSocketTests() throws Exception {
-        try (
-            ServerSocket c = new ServerSocket(0);
-        ) {
-            Set<SocketOption<?>> options = c.supportedOptions();
+    static void doDatagramSocketTests() throws Exception {
+        try (DatagramSocket ds = new DatagramSocket(0)) {
+            Set<SocketOption<?>> options = ds.supportedOptions();
             boolean reuseport = options.contains(StandardSocketOptions.SO_REUSEPORT);
-            for (int i=0; i<serverSocketTests.length; i++) {
-                Test test = serverSocketTests[i];
+            for (Test<?> test : datagramSocketTests) {
                 if (!(test.option == StandardSocketOptions.SO_REUSEPORT && !reuseport)) {
-                    c.setOption((SocketOption)test.option, test.testValue);
-                    Object getval = c.getOption((SocketOption)test.option);
-                    Object legacyget = legacyGetOption(
-                        ServerSocket.class, c, test.option
-                    );
-                    if (!getval.equals(legacyget)) {
-                        Formatter f = new Formatter();
-                        f.format("SS Err %d: %s/%s", i, getval, legacyget);
-                        throw new RuntimeException(f.toString());
-                    }
+                    test(ds, test);
                 }
             }
         }
     }
 
-    static Object legacyGetOption(
-        Class<?> type, Object s, Object option)
+    static void doMulticastSocketTests() throws Exception {
+        try (MulticastSocket ms = new MulticastSocket(0)) {
+            for (Test<?> test : multicastSocketTests) {
+                test(ms, test);
+            }
+        }
+    }
 
-        throws Exception
-    {
+    static Object legacyGetOption(Class<?> type, Object s, Object option) throws Exception {
         if (type.equals(Socket.class)) {
             Socket socket = (Socket)s;
             Set<SocketOption<?>> options = socket.supportedOptions();
@@ -280,8 +290,8 @@
     public static void main(String args[]) throws Exception {
         doSocketTests();
         doServerSocketTests();
-        doDgSocketTests();
-        doMcSocketTests();
+        doDatagramSocketTests();
+        doMulticastSocketTests();
     }
 
     // Reflectively access jdk.net.Sockets.getOption so that the test can run