Change Socket.getInputStream/getOutputStream to return the same stream each time niosocketimpl-branch
authoralanb
Sun, 17 Feb 2019 10:00:26 +0000
branchniosocketimpl-branch
changeset 57188 1f2101ee432d
parent 57187 056911ad3ee7
child 57189 c56554b46dec
Change Socket.getInputStream/getOutputStream to return the same stream each time
src/java.base/share/classes/java/net/Socket.java
src/java.base/share/classes/sun/nio/ch/NioSocketImpl.java
--- a/src/java.base/share/classes/java/net/Socket.java	Sat Feb 16 19:53:43 2019 +0000
+++ b/src/java.base/share/classes/java/net/Socket.java	Sun Feb 17 10:00:26 2019 +0000
@@ -28,6 +28,8 @@
 import java.io.InputStream;
 import java.io.OutputStream;
 import java.io.IOException;
+import java.lang.invoke.MethodHandles;
+import java.lang.invoke.VarHandle;
 import java.nio.channels.SocketChannel;
 import java.security.AccessController;
 import java.security.PrivilegedAction;
@@ -75,6 +77,22 @@
     private boolean oldImpl = false;
 
     /**
+     * Socket input/output streams
+     */
+    private volatile InputStream in;
+    private volatile OutputStream out;
+    private static final VarHandle IN, OUT;
+    static {
+        try {
+            MethodHandles.Lookup l = MethodHandles.lookup();
+            IN = l.findVarHandle(Socket.class, "in", InputStream.class);
+            OUT = l.findVarHandle(Socket.class, "out", OutputStream.class);
+        } catch (Exception e) {
+            throw new InternalError(e);
+        }
+    }
+
+    /**
      * Creates an unconnected socket, with the
      * system-default type of SocketImpl.
      *
@@ -920,8 +938,15 @@
             throw new SocketException("Socket is not connected");
         if (isInputShutdown())
             throw new SocketException("Socket input is shutdown");
-        // wrap the input stream so that the close method closes this socket
-        return new SocketInputStream(this, impl.getInputStream());
+        InputStream in = this.in;
+        if (in == null) {
+            // wrap the input stream so that the close method closes this socket
+            in = new SocketInputStream(this, impl.getInputStream());
+            if (!IN.compareAndSet(this, null, in)) {
+                in = this.in;
+            }
+        }
+        return in;
     }
 
     private static class SocketInputStream extends InputStream {
@@ -976,8 +1001,15 @@
             throw new SocketException("Socket is not connected");
         if (isOutputShutdown())
             throw new SocketException("Socket output is shutdown");
-        // wrap the output stream so that the close method closes this socket
-        return new SocketOutputStream(this, impl.getOutputStream());
+        OutputStream out = this.out;
+        if (out == null) {
+            // wrap the output stream so that the close method closes this socket
+            out = new SocketOutputStream(this, impl.getOutputStream());
+            if (!OUT.compareAndSet(this, null, out)) {
+                out = this.out;
+            }
+        }
+        return out;
     }
 
     private static class SocketOutputStream extends OutputStream {
--- a/src/java.base/share/classes/sun/nio/ch/NioSocketImpl.java	Sat Feb 16 19:53:43 2019 +0000
+++ b/src/java.base/share/classes/sun/nio/ch/NioSocketImpl.java	Sun Feb 17 10:00:26 2019 +0000
@@ -128,20 +128,6 @@
     private volatile boolean isInputClosed;
     private volatile boolean isOutputClosed;
 
-    // socket input/output streams
-    private volatile InputStream in;
-    private volatile OutputStream out;
-    private static final VarHandle IN, OUT;
-    static {
-        try {
-            MethodHandles.Lookup l = MethodHandles.lookup();
-            IN = l.findVarHandle(NioSocketImpl.class, "in", InputStream.class);
-            OUT = l.findVarHandle(NioSocketImpl.class, "out", OutputStream.class);
-        } catch (Exception e) {
-            throw new InternalError(e);
-        }
-    }
-
     /**
      * Creates a instance of this SocketImpl.
      * @param server true if this is a SocketImpl for a ServerSocket
@@ -755,110 +741,85 @@
 
     @Override
     protected InputStream getInputStream() {
-        InputStream in = this.in;
-        if (in == null) {
-            in = new SocketInputStream(this);
-            if (!IN.compareAndSet(this, null, in)) {
-                in = this.in;
+        return new InputStream() {
+            // EOF or connection reset detected, not thread safe
+            private boolean eof, reset;
+            @Override
+            public int read() throws IOException {
+                byte[] a = new byte[1];
+                int n = read(a, 0, 1);
+                return (n > 0) ? (a[0] & 0xff) : -1;
             }
-        }
-        return in;
-    }
-
-    private static class SocketInputStream extends InputStream {
-        private final NioSocketImpl impl;
-        // EOF or connection reset detected, not thread safe
-        private boolean eof, reset;
-        SocketInputStream(NioSocketImpl impl) {
-            this.impl = impl;
-        }
-        @Override
-        public int read() throws IOException {
-            byte[] a = new byte[1];
-            int n = read(a, 0, 1);
-            return (n > 0) ? (a[0] & 0xff) : -1;
-        }
-        @Override
-        public int read(byte[] b, int off, int len) throws IOException {
-            Objects.checkFromIndexSize(off, len, b.length);
-            if (eof) {
-                return -1;
-            } else if (reset) {
-                throw new SocketException("Connection reset");
-            } else if (len == 0) {
-                return 0;
-            } else {
-                try {
-                    // read up to MAX_BUFFER_SIZE bytes
-                    int size = Math.min(len, MAX_BUFFER_SIZE);
-                    int n = impl.read(b, off, size);
-                    if (n == -1)
-                        eof = true;
-                    return n;
-                } catch (ConnectionResetException e) {
-                    reset = true;
+            @Override
+            public int read(byte[] b, int off, int len) throws IOException {
+                Objects.checkFromIndexSize(off, len, b.length);
+                if (eof) {
+                    return -1;
+                } else if (reset) {
                     throw new SocketException("Connection reset");
-                } catch (SocketTimeoutException e) {
-                    throw e;
-                } catch (IOException ioe) {
-                    throw new SocketException(ioe.getMessage());
+                } else if (len == 0) {
+                    return 0;
+                } else {
+                    try {
+                        // read up to MAX_BUFFER_SIZE bytes
+                        int size = Math.min(len, MAX_BUFFER_SIZE);
+                        int n = NioSocketImpl.this.read(b, off, size);
+                        if (n == -1)
+                            eof = true;
+                        return n;
+                    } catch (ConnectionResetException e) {
+                        reset = true;
+                        throw new SocketException("Connection reset");
+                    } catch (SocketTimeoutException e) {
+                        throw e;
+                    } catch (IOException ioe) {
+                        throw new SocketException(ioe.getMessage());
+                    }
                 }
             }
-        }
-        @Override
-        public int available() throws IOException {
-            return impl.available();
-        }
-        @Override
-        public void close() throws IOException {
-            impl.close();
-        }
+            @Override
+            public int available() throws IOException {
+                return NioSocketImpl.this.available();
+            }
+            @Override
+            public void close() throws IOException {
+                NioSocketImpl.this.close();
+            }
+        };
     }
 
     @Override
     protected OutputStream getOutputStream() {
-        OutputStream out = this.out;
-        if (out == null) {
-            out = new SocketOutputStream(this);
-            if (!OUT.compareAndSet(this, null, out)) {
-                out = this.out;
+        return new OutputStream() {
+            @Override
+            public void write(int b) throws IOException {
+                byte[] a = new byte[]{(byte) b};
+                write(a, 0, 1);
             }
-        }
-        return out;
-    }
-
-    private static class SocketOutputStream extends OutputStream {
-        private final NioSocketImpl impl;
-        SocketOutputStream(NioSocketImpl impl) {
-            this.impl = impl;
-        }
-        @Override
-        public void write(int b) throws IOException {
-            byte[] a = new byte[]{(byte) b};
-            write(a, 0, 1);
-        }
-        @Override
-        public void write(byte[] b, int off, int len) throws IOException {
-            Objects.checkFromIndexSize(off, len, b.length);
-            if (len > 0) {
-                try {
-                    int pos = off;
-                    int end = off + len;
-                    while (pos < end) {
-                        // write up to MAX_BUFFER_SIZE bytes
-                        int size = Math.min((end - pos), MAX_BUFFER_SIZE);
-                        int n = impl.write(b, pos, size);
-                        pos += n;
+            @Override
+            public void write(byte[] b, int off, int len) throws IOException {
+                Objects.checkFromIndexSize(off, len, b.length);
+                if (len > 0) {
+                    try {
+                        int pos = off;
+                        int end = off + len;
+                        while (pos < end) {
+                            // write up to MAX_BUFFER_SIZE bytes
+                            int size = Math.min((end - pos), MAX_BUFFER_SIZE);
+                            int n = NioSocketImpl.this.write(b, pos, size);
+                            pos += n;
+                        }
+                    } catch (IOException ioe) {
+                        throw new SocketException(ioe.getMessage());
                     }
-                } catch (IOException ioe) {
-                    throw new SocketException(ioe.getMessage());
                 }
             }
-        }
-        @Override
-        public void close() throws IOException {
-            impl.close();
-        }
+
+            @Override
+            public void close() throws IOException {
+                NioSocketImpl.this.close();
+            }
+        };
     }
 
     @Override
@@ -990,8 +951,7 @@
                 }
                 case IP_TOS: {
                     int i = intValue(value, "IP_TOS");
-                    var IP_TOS = StandardSocketOptions.IP_TOS;
-                    Net.setSocketOption(fd, family(), IP_TOS, i);
+                    Net.setSocketOption(fd, family(), StandardSocketOptions.IP_TOS, i);
                     trafficClass = i;
                     break;
                 }