Throw SocketException consistently after connection reset detected niosocketimpl-branch
authoralanb
Fri, 15 Feb 2019 18:07:17 +0000
branchniosocketimpl-branch
changeset 57186 997178749c87
parent 57185 e0e1493fa166
child 57187 056911ad3ee7
Throw SocketException consistently after connection reset detected
src/java.base/share/classes/sun/nio/ch/NioSocketImpl.java
src/java.base/unix/classes/sun/nio/ch/SocketDispatcher.java
src/java.base/unix/native/libnio/ch/SocketDispatcher.c
src/java.base/windows/classes/sun/nio/ch/SocketDispatcher.java
--- a/src/java.base/share/classes/sun/nio/ch/NioSocketImpl.java	Fri Feb 15 10:39:45 2019 +0000
+++ b/src/java.base/share/classes/sun/nio/ch/NioSocketImpl.java	Fri Feb 15 18:07:17 2019 +0000
@@ -56,6 +56,7 @@
 
 import jdk.internal.access.SharedSecrets;
 import jdk.internal.ref.CleanerFactory;
+import sun.net.ConnectionResetException;
 import sun.net.NetHooks;
 import sun.net.PlatformSocketImpl;
 import sun.net.ResourceManager;
@@ -75,14 +76,10 @@
  * blocking. If a connect, accept or read is attempted with a timeout then the
  * socket is changed to non-blocking mode. When in non-blocking mode, operations
  * that don't complete immediately will poll the socket.
- *
- * Behavior differences to examine:
- * "Connection reset" handling differs to PlainSocketImpl for cases where
- * an application continues to call read or available after a reset.
  */
 
 public final class NioSocketImpl extends SocketImpl implements PlatformSocketImpl {
-    private static final NativeDispatcher nd = new SocketDispatcher();
+    private static final NativeDispatcher nd = new SocketDispatcher(true);
 
     // The maximum number of bytes to read/write per syscall to avoid needing
     // a huge buffer from the temporary buffer cache
@@ -128,6 +125,20 @@
     private volatile boolean isInputClosed;
     private volatile boolean isOutputClosed;
 
+    // socket input/output streams
+    private volatile InputStream in;
+    private volatile OutputStream out;
+    private static final VarHandle IN, OUT;
+    static {
+        try {
+            MethodHandles.Lookup l = MethodHandles.lookup();
+            IN = l.findVarHandle(NioSocketImpl.class, "in", InputStream.class);
+            OUT = l.findVarHandle(NioSocketImpl.class, "out", OutputStream.class);
+        } catch (Exception e) {
+            throw new InternalError(e);
+        }
+    }
+
     /**
      * Creates a instance of this SocketImpl.
      * @param server true if this is a SocketImpl for a ServerSocket
@@ -741,78 +752,110 @@
 
     @Override
     protected InputStream getInputStream() {
-        return new InputStream() {
-            private volatile boolean eof;  // to emulate legacy SocketInputStream
-            @Override
-            public int read() throws IOException {
-                byte[] a = new byte[1];
-                int n = read(a, 0, 1);
-                return (n > 0) ? (a[0] & 0xff) : -1;
+        InputStream in = this.in;
+        if (in == null) {
+            in = new SocketInputStream(this);
+            if (!IN.compareAndSet(this, null, in)) {
+                in = this.in;
             }
-            @Override
-            public int read(byte[] b, int off, int len) throws IOException {
-                Objects.checkFromIndexSize(off, len, b.length);
-                if (eof) {
-                    return -1; // return -1, even if socket is closed
-                } else if (len == 0) {
-                    return 0;  // return 0, even if socket is closed
-                } else {
-                    try {
-                        // read up to MAX_BUFFER_SIZE bytes
-                        int size = Math.min(len, MAX_BUFFER_SIZE);
-                        int n = NioSocketImpl.this.read(b, off, size);
-                        if (n == -1)
-                            eof = true;
-                        return n;
-                    } catch (SocketTimeoutException e) {
-                        throw e;
-                    } catch (IOException ioe) {
-                        throw new SocketException(ioe.getMessage());
-                    }
+        }
+        return in;
+    }
+
+    private static class SocketInputStream extends InputStream {
+        private final NioSocketImpl impl;
+        // EOF or connection reset detected, not thread safe
+        private boolean eof, reset;
+        SocketInputStream(NioSocketImpl impl) {
+            this.impl = impl;
+        }
+        @Override
+        public int read() throws IOException {
+            byte[] a = new byte[1];
+            int n = read(a, 0, 1);
+            return (n > 0) ? (a[0] & 0xff) : -1;
+        }
+        @Override
+        public int read(byte[] b, int off, int len) throws IOException {
+            Objects.checkFromIndexSize(off, len, b.length);
+            if (eof) {
+                return -1;
+            } else if (reset) {
+                throw new SocketException("Connection reset");
+            } else if (len == 0) {
+                return 0;
+            } else {
+                try {
+                    // read up to MAX_BUFFER_SIZE bytes
+                    int size = Math.min(len, MAX_BUFFER_SIZE);
+                    int n = impl.read(b, off, size);
+                    if (n == -1)
+                        eof = true;
+                    return n;
+                } catch (ConnectionResetException e) {
+                    reset = true;
+                    throw new SocketException("Connection reset");
+                } catch (SocketTimeoutException e) {
+                    throw e;
+                } catch (IOException ioe) {
+                    throw new SocketException(ioe.getMessage());
                 }
             }
-            @Override
-            public int available() throws IOException {
-                return NioSocketImpl.this.available();
-            }
-            @Override
-            public void close() throws IOException {
-                NioSocketImpl.this.close();
-            }
-        };
+        }
+        @Override
+        public int available() throws IOException {
+            return impl.available();
+        }
+        @Override
+        public void close() throws IOException {
+            impl.close();
+        }
     }
 
     @Override
     protected OutputStream getOutputStream() {
-        return new OutputStream() {
-            @Override
-            public void write(int b) throws IOException {
-                byte[] a = new byte[] { (byte) b };
-                write(a, 0, 1);
+        OutputStream out = this.out;
+        if (out == null) {
+            out = new SocketOutputStream(this);
+            if (!OUT.compareAndSet(this, null, out)) {
+                out = this.out;
             }
-            @Override
-            public void write(byte[] b, int off, int len) throws IOException {
-                Objects.checkFromIndexSize(off, len, b.length);
-                if (len > 0) {
-                    try {
-                        int pos = off;
-                        int end = off + len;
-                        while (pos < end) {
-                            // write up to MAX_BUFFER_SIZE bytes
-                            int size = Math.min((end - pos), MAX_BUFFER_SIZE);
-                            int n = NioSocketImpl.this.write(b, pos, size);
-                            pos += n;
-                        }
-                    } catch (IOException ioe) {
-                        throw new SocketException(ioe.getMessage());
+        }
+        return out;
+    }
+
+    private static class SocketOutputStream extends OutputStream {
+        private final NioSocketImpl impl;
+        SocketOutputStream(NioSocketImpl impl) {
+            this.impl = impl;
+        }
+        @Override
+        public void write(int b) throws IOException {
+            byte[] a = new byte[]{(byte) b};
+            write(a, 0, 1);
+        }
+        @Override
+        public void write(byte[] b, int off, int len) throws IOException {
+            Objects.checkFromIndexSize(off, len, b.length);
+            if (len > 0) {
+                try {
+                    int pos = off;
+                    int end = off + len;
+                    while (pos < end) {
+                        // write up to MAX_BUFFER_SIZE bytes
+                        int size = Math.min((end - pos), MAX_BUFFER_SIZE);
+                        int n = impl.write(b, pos, size);
+                        pos += n;
                     }
+                } catch (IOException ioe) {
+                    throw new SocketException(ioe.getMessage());
                 }
             }
-            @Override
-            public void close() throws IOException {
-                NioSocketImpl.this.close();
-            }
-        };
+        }
+        @Override
+        public void close() throws IOException {
+            impl.close();
+        }
     }
 
     @Override
--- a/src/java.base/unix/classes/sun/nio/ch/SocketDispatcher.java	Fri Feb 15 10:39:45 2019 +0000
+++ b/src/java.base/unix/classes/sun/nio/ch/SocketDispatcher.java	Fri Feb 15 18:07:17 2019 +0000
@@ -34,9 +34,22 @@
  */
 
 class SocketDispatcher extends NativeDispatcher {
+    private final boolean detectConnectionReset;
+
+    SocketDispatcher(boolean detectConnectionReset) {
+        this.detectConnectionReset = detectConnectionReset;
+    }
+
+    SocketDispatcher() {
+        this(false);
+    }
 
     int read(FileDescriptor fd, long address, int len) throws IOException {
-        return FileDispatcherImpl.read0(fd, address, len);
+        if (detectConnectionReset) {
+            return read0(fd, address, len);
+        } else {
+            return FileDispatcherImpl.read0(fd, address, len);
+        }
     }
 
     long readv(FileDescriptor fd, long address, int len) throws IOException {
@@ -58,4 +71,20 @@
     void preClose(FileDescriptor fd) throws IOException {
         FileDispatcherImpl.preClose0(fd);
     }
+
+    // -- Native methods --
+
+    /**
+     * Reads up to len bytes from a socket with special handling for "connection
+     * reset".
+     *
+     * @throws sun.net.ConnectionResetException if connection reset is detected
+     * @throws IOException if another I/O error occurs
+     */
+    private static native int read0(FileDescriptor fd, long address, int len)
+        throws IOException;
+
+    static {
+        IOUtil.load();
+    }
 }
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/src/java.base/unix/native/libnio/ch/SocketDispatcher.c	Fri Feb 15 18:07:17 2019 +0000
@@ -0,0 +1,49 @@
+/*
+ * Copyright (c) 2019, Oracle and/or its affiliates. All rights reserved.
+ * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
+ *
+ * This code is free software; you can redistribute it and/or modify it
+ * under the terms of the GNU General Public License version 2 only, as
+ * published by the Free Software Foundation.  Oracle designates this
+ * particular file as subject to the "Classpath" exception as provided
+ * by Oracle in the LICENSE file that accompanied this code.
+ *
+ * This code is distributed in the hope that it will be useful, but WITHOUT
+ * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
+ * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
+ * version 2 for more details (a copy is included in the LICENSE file that
+ * accompanied this code).
+ *
+ * You should have received a copy of the GNU General Public License version
+ * 2 along with this work; if not, write to the Free Software Foundation,
+ * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
+ *
+ * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
+ * or visit www.oracle.com if you need additional information or have any
+ * questions.
+ */
+
+ #include <sys/types.h>
+ #include <unistd.h>
+
+ #include "jni.h"
+ #include "jni_util.h"
+ #include "jlong.h"
+ #include "nio.h"
+ #include "nio_util.h"
+ #include "sun_nio_ch_SocketDispatcher.h"
+
+ JNIEXPORT jint JNICALL
+ Java_sun_nio_ch_SocketDispatcher_read0(JNIEnv *env, jclass clazz,
+                                        jobject fdo, jlong address, jint len)
+ {
+     jint fd = fdval(env, fdo);
+     void *buf = (void *)jlong_to_ptr(address);
+     jint n = read(fd, buf, len);
+     if ((n == -1) && (errno == ECONNRESET || errno == EPIPE)) {
+         JNU_ThrowByName(env, "sun/net/ConnectionResetException", "Connection reset");
+         return IOS_THROWN;
+     } else {
+         return convertReturnVal(env, n, JNI_TRUE);
+     }
+ }
--- a/src/java.base/windows/classes/sun/nio/ch/SocketDispatcher.java	Fri Feb 15 10:39:45 2019 +0000
+++ b/src/java.base/windows/classes/sun/nio/ch/SocketDispatcher.java	Fri Feb 15 18:07:17 2019 +0000
@@ -32,12 +32,11 @@
  * for read and write operations.
  */
 
-class SocketDispatcher extends NativeDispatcher
-{
+class SocketDispatcher extends NativeDispatcher {
 
-    static {
-        IOUtil.load();
-    }
+    SocketDispatcher() { }
+
+    SocketDispatcher(boolean ignore) { }
 
     int read(FileDescriptor fd, long address, int len) throws IOException {
         return read0(fd, address, len);
@@ -63,7 +62,8 @@
         close0(fd);
     }
 
-    //-- Native methods
+    // -- Native methods --
+
     static native int read0(FileDescriptor fd, long address, int len)
         throws IOException;
 
@@ -79,4 +79,8 @@
     static native void preClose0(FileDescriptor fd) throws IOException;
 
     static native void close0(FileDescriptor fd) throws IOException;
+
+    static {
+        IOUtil.load();
+    }
 }