src/java.base/share/classes/sun/nio/ch/SocketAdaptor.java
changeset 54620 13b67c1420b8
parent 54246 f04e3492fd88
child 58614 29624901d8bc
--- a/src/java.base/share/classes/sun/nio/ch/SocketAdaptor.java	Thu Apr 25 09:12:40 2019 +0200
+++ b/src/java.base/share/classes/sun/nio/ch/SocketAdaptor.java	Thu Apr 25 10:41:49 2019 +0100
@@ -33,18 +33,12 @@
 import java.net.Socket;
 import java.net.SocketAddress;
 import java.net.SocketException;
-import java.net.SocketImpl;
 import java.net.SocketOption;
-import java.net.SocketTimeoutException;
 import java.net.StandardSocketOptions;
-import java.nio.ByteBuffer;
-import java.nio.channels.Channels;
-import java.nio.channels.ClosedChannelException;
-import java.nio.channels.IllegalBlockingModeException;
 import java.nio.channels.SocketChannel;
-import java.security.AccessController;
-import java.security.PrivilegedExceptionAction;
-import static java.util.concurrent.TimeUnit.*;
+import java.util.Set;
+
+import static java.util.concurrent.TimeUnit.MILLISECONDS;
 
 // Make a socket channel look like a socket.
 //
@@ -62,11 +56,11 @@
     private volatile int timeout;
 
     private SocketAdaptor(SocketChannelImpl sc) throws SocketException {
-        super((SocketImpl) null);
+        super(DummySocketImpl.create());
         this.sc = sc;
     }
 
-    public static Socket create(SocketChannelImpl sc) {
+    static Socket create(SocketChannelImpl sc) {
         try {
             return new SocketAdaptor(sc);
         } catch (SocketException e) {
@@ -74,70 +68,30 @@
         }
     }
 
-    public SocketChannel getChannel() {
-        return sc;
-    }
-
-    // Override this method just to protect against changes in the superclass
-    //
+    @Override
     public void connect(SocketAddress remote) throws IOException {
         connect(remote, 0);
     }
 
+    @Override
     public void connect(SocketAddress remote, int timeout) throws IOException {
         if (remote == null)
             throw new IllegalArgumentException("connect: The address can't be null");
         if (timeout < 0)
             throw new IllegalArgumentException("connect: timeout can't be negative");
-
-        synchronized (sc.blockingLock()) {
-            if (!sc.isBlocking())
-                throw new IllegalBlockingModeException();
-
-            try {
-                // no timeout
-                if (timeout == 0) {
-                    sc.connect(remote);
-                    return;
-                }
-
-                // timed connect
-                sc.configureBlocking(false);
-                try {
-                    if (sc.connect(remote))
-                        return;
-                } finally {
-                    try {
-                        sc.configureBlocking(true);
-                    } catch (ClosedChannelException e) { }
-                }
-
-                long timeoutNanos = NANOSECONDS.convert(timeout, MILLISECONDS);
-                long to = timeout;
-                for (;;) {
-                    long startTime = System.nanoTime();
-                    if (sc.pollConnected(to)) {
-                        boolean connected = sc.finishConnect();
-                        assert connected;
-                        break;
-                    }
-                    timeoutNanos -= System.nanoTime() - startTime;
-                    if (timeoutNanos <= 0) {
-                        try {
-                            sc.close();
-                        } catch (IOException x) { }
-                        throw new SocketTimeoutException();
-                    }
-                    to = MILLISECONDS.convert(timeoutNanos, NANOSECONDS);
-                }
-
-            } catch (Exception x) {
-                Net.translateException(x, true);
+        try {
+            if (timeout > 0) {
+                long nanos = MILLISECONDS.toNanos(timeout);
+                sc.blockingConnect(remote, nanos);
+            } else {
+                sc.blockingConnect(remote, Long.MAX_VALUE);
             }
+        } catch (Exception e) {
+            Net.translateException(e, true);
         }
-
     }
 
+    @Override
     public void bind(SocketAddress local) throws IOException {
         try {
             sc.bind(local);
@@ -146,6 +100,7 @@
         }
     }
 
+    @Override
     public InetAddress getInetAddress() {
         InetSocketAddress remote = sc.remoteAddress();
         if (remote == null) {
@@ -155,6 +110,7 @@
         }
     }
 
+    @Override
     public InetAddress getLocalAddress() {
         if (sc.isOpen()) {
             InetSocketAddress local = sc.localAddress();
@@ -165,6 +121,7 @@
         return new InetSocketAddress(0).getAddress();
     }
 
+    @Override
     public int getPort() {
         InetSocketAddress remote = sc.remoteAddress();
         if (remote == null) {
@@ -174,6 +131,7 @@
         }
     }
 
+    @Override
     public int getLocalPort() {
         InetSocketAddress local = sc.localAddress();
         if (local == null) {
@@ -183,48 +141,27 @@
         }
     }
 
-    private class SocketInputStream
-        extends ChannelInputStream
-    {
-        private SocketInputStream() {
-            super(sc);
-        }
-
-        protected int read(ByteBuffer bb)
-            throws IOException
-        {
-            synchronized (sc.blockingLock()) {
-                if (!sc.isBlocking())
-                    throw new IllegalBlockingModeException();
-
-                // no timeout
-                long to = SocketAdaptor.this.timeout;
-                if (to == 0)
-                    return sc.read(bb);
+    @Override
+    public SocketAddress getRemoteSocketAddress() {
+        return sc.remoteAddress();
+    }
 
-                // timed read
-                long timeoutNanos = NANOSECONDS.convert(to, MILLISECONDS);
-                for (;;) {
-                    long startTime = System.nanoTime();
-                    if (sc.pollRead(to)) {
-                        return sc.read(bb);
-                    }
-                    timeoutNanos -= System.nanoTime() - startTime;
-                    if (timeoutNanos <= 0)
-                        throw new SocketTimeoutException();
-                    to = MILLISECONDS.convert(timeoutNanos, NANOSECONDS);
-                }
-            }
-        }
-
-        @Override
-        public int available() throws IOException {
-            return sc.available();
+    @Override
+    public SocketAddress getLocalSocketAddress() {
+        InetSocketAddress local = sc.localAddress();
+        if (local != null) {
+            return Net.getRevealedLocalAddress(local);
+        } else {
+            return null;
         }
     }
 
-    private InputStream socketInputStream = null;
+    @Override
+    public SocketChannel getChannel() {
+        return sc;
+    }
 
+    @Override
     public InputStream getInputStream() throws IOException {
         if (!sc.isOpen())
             throw new SocketException("Socket is closed");
@@ -232,21 +169,35 @@
             throw new SocketException("Socket is not connected");
         if (!sc.isInputOpen())
             throw new SocketException("Socket input is shutdown");
-        if (socketInputStream == null) {
-            try {
-                socketInputStream = AccessController.doPrivileged(
-                    new PrivilegedExceptionAction<InputStream>() {
-                        public InputStream run() throws IOException {
-                            return new SocketInputStream();
-                        }
-                    });
-            } catch (java.security.PrivilegedActionException e) {
-                throw (IOException)e.getException();
+        return new InputStream() {
+            @Override
+            public int read() throws IOException {
+                byte[] a = new byte[1];
+                int n = read(a, 0, 1);
+                return (n > 0) ? (a[0] & 0xff) : -1;
             }
-        }
-        return socketInputStream;
+            @Override
+            public int read(byte[] b, int off, int len) throws IOException {
+                int timeout = SocketAdaptor.this.timeout;
+                if (timeout > 0) {
+                    long nanos = MILLISECONDS.toNanos(timeout);
+                    return sc.blockingRead(b, off, len, nanos);
+                } else {
+                    return sc.blockingRead(b, off, len, 0);
+                }
+            }
+            @Override
+            public int available() throws IOException {
+                return sc.available();
+            }
+            @Override
+            public void close() throws IOException {
+                sc.close();
+            }
+        };
     }
 
+    @Override
     public OutputStream getOutputStream() throws IOException {
         if (!sc.isOpen())
             throw new SocketException("Socket is closed");
@@ -254,18 +205,21 @@
             throw new SocketException("Socket is not connected");
         if (!sc.isOutputOpen())
             throw new SocketException("Socket output is shutdown");
-        OutputStream os = null;
-        try {
-            os = AccessController.doPrivileged(
-                new PrivilegedExceptionAction<OutputStream>() {
-                    public OutputStream run() throws IOException {
-                        return Channels.newOutputStream(sc);
-                    }
-                });
-        } catch (java.security.PrivilegedActionException e) {
-            throw (IOException)e.getException();
-        }
-        return os;
+        return new OutputStream() {
+            @Override
+            public void write(int b) throws IOException {
+                byte[] a = new byte[]{(byte) b};
+                write(a, 0, 1);
+            }
+            @Override
+            public void write(byte[] b, int off, int len) throws IOException {
+                sc.blockingWriteFully(b, off, len);
+            }
+            @Override
+            public void close() throws IOException {
+                sc.close();
+            }
+        };
     }
 
     private void setBooleanOption(SocketOption<Boolean> name, boolean value)
@@ -306,48 +260,62 @@
         }
     }
 
+    @Override
     public void setTcpNoDelay(boolean on) throws SocketException {
         setBooleanOption(StandardSocketOptions.TCP_NODELAY, on);
     }
 
+    @Override
     public boolean getTcpNoDelay() throws SocketException {
         return getBooleanOption(StandardSocketOptions.TCP_NODELAY);
     }
 
+    @Override
     public void setSoLinger(boolean on, int linger) throws SocketException {
         if (!on)
             linger = -1;
         setIntOption(StandardSocketOptions.SO_LINGER, linger);
     }
 
+    @Override
     public int getSoLinger() throws SocketException {
         return getIntOption(StandardSocketOptions.SO_LINGER);
     }
 
+    @Override
     public void sendUrgentData(int data) throws IOException {
         int n = sc.sendOutOfBandData((byte) data);
         if (n == 0)
             throw new IOException("Socket buffer full");
     }
 
+    @Override
     public void setOOBInline(boolean on) throws SocketException {
         setBooleanOption(ExtendedSocketOption.SO_OOBINLINE, on);
     }
 
+    @Override
     public boolean getOOBInline() throws SocketException {
         return getBooleanOption(ExtendedSocketOption.SO_OOBINLINE);
     }
 
+    @Override
     public void setSoTimeout(int timeout) throws SocketException {
+        if (!sc.isOpen())
+            throw new SocketException("Socket is closed");
         if (timeout < 0)
-            throw new IllegalArgumentException("timeout can't be negative");
+            throw new IllegalArgumentException("timeout < 0");
         this.timeout = timeout;
     }
 
+    @Override
     public int getSoTimeout() throws SocketException {
+        if (!sc.isOpen())
+            throw new SocketException("Socket is closed");
         return timeout;
     }
 
+    @Override
     public void setSendBufferSize(int size) throws SocketException {
         // size 0 valid for SocketChannel, invalid for Socket
         if (size <= 0)
@@ -355,10 +323,12 @@
         setIntOption(StandardSocketOptions.SO_SNDBUF, size);
     }
 
+    @Override
     public int getSendBufferSize() throws SocketException {
         return getIntOption(StandardSocketOptions.SO_SNDBUF);
     }
 
+    @Override
     public void setReceiveBufferSize(int size) throws SocketException {
         // size 0 valid for SocketChannel, invalid for Socket
         if (size <= 0)
@@ -366,38 +336,47 @@
         setIntOption(StandardSocketOptions.SO_RCVBUF, size);
     }
 
+    @Override
     public int getReceiveBufferSize() throws SocketException {
         return getIntOption(StandardSocketOptions.SO_RCVBUF);
     }
 
+    @Override
     public void setKeepAlive(boolean on) throws SocketException {
         setBooleanOption(StandardSocketOptions.SO_KEEPALIVE, on);
     }
 
+    @Override
     public boolean getKeepAlive() throws SocketException {
         return getBooleanOption(StandardSocketOptions.SO_KEEPALIVE);
     }
 
+    @Override
     public void setTrafficClass(int tc) throws SocketException {
         setIntOption(StandardSocketOptions.IP_TOS, tc);
     }
 
+    @Override
     public int getTrafficClass() throws SocketException {
         return getIntOption(StandardSocketOptions.IP_TOS);
     }
 
+    @Override
     public void setReuseAddress(boolean on) throws SocketException {
         setBooleanOption(StandardSocketOptions.SO_REUSEADDR, on);
     }
 
+    @Override
     public boolean getReuseAddress() throws SocketException {
         return getBooleanOption(StandardSocketOptions.SO_REUSEADDR);
     }
 
+    @Override
     public void close() throws IOException {
         sc.close();
     }
 
+    @Override
     public void shutdownInput() throws IOException {
         try {
             sc.shutdownInput();
@@ -406,6 +385,7 @@
         }
     }
 
+    @Override
     public void shutdownOutput() throws IOException {
         try {
             sc.shutdownOutput();
@@ -414,6 +394,7 @@
         }
     }
 
+    @Override
     public String toString() {
         if (sc.isConnected())
             return "Socket[addr=" + getInetAddress() +
@@ -422,23 +403,44 @@
         return "Socket[unconnected]";
     }
 
+    @Override
     public boolean isConnected() {
         return sc.isConnected();
     }
 
+    @Override
     public boolean isBound() {
         return sc.localAddress() != null;
     }
 
+    @Override
     public boolean isClosed() {
         return !sc.isOpen();
     }
 
+    @Override
     public boolean isInputShutdown() {
         return !sc.isInputOpen();
     }
 
+    @Override
     public boolean isOutputShutdown() {
         return !sc.isOutputOpen();
     }
+
+    @Override
+    public <T> Socket setOption(SocketOption<T> name, T value) throws IOException {
+        sc.setOption(name, value);
+        return this;
+    }
+
+    @Override
+    public <T> T getOption(SocketOption<T> name) throws IOException {
+        return sc.getOption(name);
+    }
+
+    @Override
+    public Set<SocketOption<?>> supportedOptions() {
+        return sc.supportedOptions();
+    }
 }