6370908: Add support for HTTP_CONNECT proxy in Socket class
authorchegar
Thu, 07 Mar 2013 10:07:13 +0000
changeset 16061 c80133bafef0
parent 16060 08b9a416a770
child 16062 c64ef2b01401
6370908: Add support for HTTP_CONNECT proxy in Socket class Reviewed-by: chegar Contributed-by: Damjan Jovanovic <damjan.jov@gmail.com>, Chris Hegarty <chris.hegarty@oracle.com>
jdk/src/share/classes/java/net/HttpConnectSocketImpl.java
jdk/src/share/classes/java/net/Socket.java
jdk/test/java/net/Socket/HttpProxy.java
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/jdk/src/share/classes/java/net/HttpConnectSocketImpl.java	Thu Mar 07 10:07:13 2013 +0000
@@ -0,0 +1,210 @@
+/*
+ * Copyright (c) 2010, 2013, Oracle and/or its affiliates. All rights reserved.
+ * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
+ *
+ * This code is free software; you can redistribute it and/or modify it
+ * under the terms of the GNU General Public License version 2 only, as
+ * published by the Free Software Foundation.  Oracle designates this
+ * particular file as subject to the "Classpath" exception as provided
+ * by Oracle in the LICENSE file that accompanied this code.
+ *
+ * This code is distributed in the hope that it will be useful, but WITHOUT
+ * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
+ * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
+ * version 2 for more details (a copy is included in the LICENSE file that
+ * accompanied this code).
+ *
+ * You should have received a copy of the GNU General Public License version
+ * 2 along with this work; if not, write to the Free Software Foundation,
+ * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
+ *
+ * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
+ * or visit www.oracle.com if you need additional information or have any
+ * questions.
+ */
+
+package java.net;
+
+import java.io.IOException;
+import java.lang.reflect.Field;
+import java.lang.reflect.Method;
+import java.util.HashMap;
+import java.util.Map;
+import java.util.Set;
+
+/**
+ * Basic SocketImpl that relies on the internal HTTP protocol handler
+ * implementation to perform the HTTP tunneling and authentication. The
+ * sockets impl is swapped out and replaced with the socket from the HTTP
+ * handler after the tunnel is successfully setup.
+ *
+ * @since 1.8
+ */
+
+/*package*/ class HttpConnectSocketImpl extends PlainSocketImpl {
+
+    private static final String httpURLClazzStr =
+                                  "sun.net.www.protocol.http.HttpURLConnection";
+    private static final String netClientClazzStr = "sun.net.NetworkClient";
+    private static final String doTunnelingStr = "doTunneling";
+    private static final Field httpField;
+    private static final Field serverSocketField;
+    private static final Method doTunneling;
+
+    private final String server;
+    private InetSocketAddress external_address;
+    private HashMap<Integer, Object> optionsMap = new HashMap<>();
+
+    static  {
+        try {
+            Class<?> httpClazz = Class.forName(httpURLClazzStr, true, null);
+            httpField = httpClazz.getDeclaredField("http");
+            doTunneling = httpClazz.getDeclaredMethod(doTunnelingStr);
+            Class<?> netClientClazz = Class.forName(netClientClazzStr, true, null);
+            serverSocketField = netClientClazz.getDeclaredField("serverSocket");
+
+            java.security.AccessController.doPrivileged(
+                new java.security.PrivilegedAction<Void>() {
+                    public Void run() {
+                        httpField.setAccessible(true);
+                        serverSocketField.setAccessible(true);
+                        return null;
+                }
+            });
+        } catch (ReflectiveOperationException x) {
+            throw new InternalError("Should not reach here", x);
+        }
+    }
+
+    HttpConnectSocketImpl(String server, int port) {
+        this.server = server;
+        this.port = port;
+    }
+
+    HttpConnectSocketImpl(Proxy proxy) {
+        SocketAddress a = proxy.address();
+        if ( !(a instanceof InetSocketAddress) )
+            throw new IllegalArgumentException("Unsupported address type");
+
+        InetSocketAddress ad = (InetSocketAddress) a;
+        server = ad.getHostString();
+        port = ad.getPort();
+    }
+
+    @Override
+    protected void connect(SocketAddress endpoint, int timeout)
+        throws IOException
+    {
+        if (endpoint == null || !(endpoint instanceof InetSocketAddress))
+            throw new IllegalArgumentException("Unsupported address type");
+        final InetSocketAddress epoint = (InetSocketAddress)endpoint;
+        final String destHost = epoint.isUnresolved() ? epoint.getHostName()
+                                                      : epoint.getAddress().getHostAddress();
+        final int destPort = epoint.getPort();
+
+        SecurityManager security = System.getSecurityManager();
+        if (security != null)
+            security.checkConnect(destHost, destPort);
+
+        // Connect to the HTTP proxy server
+        String urlString = "http://" + destHost + ":" + destPort;
+        Socket httpSocket = privilegedDoTunnel(urlString, timeout);
+
+        // Success!
+        external_address = epoint;
+
+        // close the original socket impl and release its descriptor
+        close();
+
+        // update the Sockets impl to the impl from the http Socket
+        AbstractPlainSocketImpl psi = (AbstractPlainSocketImpl) httpSocket.impl;
+        this.getSocket().impl = psi;
+
+        // best effort is made to try and reset options previously set
+        Set<Map.Entry<Integer,Object>> options = optionsMap.entrySet();
+        try {
+            for(Map.Entry<Integer,Object> entry : options) {
+                psi.setOption(entry.getKey(), entry.getValue());
+            }
+        } catch (IOException x) {  /* gulp! */  }
+    }
+
+    @Override
+    public void setOption(int opt, Object val) throws SocketException {
+        super.setOption(opt, val);
+
+        if (external_address != null)
+            return;  // we're connected, just return
+
+        // store options so that they can be re-applied to the impl after connect
+        optionsMap.put(opt, val);
+    }
+
+    private Socket privilegedDoTunnel(final String urlString,
+                                      final int timeout)
+        throws IOException
+    {
+        try {
+            return java.security.AccessController.doPrivileged(
+                new java.security.PrivilegedExceptionAction<Socket>() {
+                    public Socket run() throws IOException {
+                        return doTunnel(urlString, timeout);
+                }
+            });
+        } catch (java.security.PrivilegedActionException pae) {
+            throw (IOException) pae.getException();
+        }
+    }
+
+    private Socket doTunnel(String urlString, int connectTimeout)
+        throws IOException
+    {
+        Proxy proxy = new Proxy(Proxy.Type.HTTP, new InetSocketAddress(server, port));
+        URL destURL = new URL(urlString);
+        HttpURLConnection conn = (HttpURLConnection) destURL.openConnection(proxy);
+        conn.setConnectTimeout(connectTimeout);
+        conn.setReadTimeout(this.timeout);
+        conn.connect();
+        doTunneling(conn);
+        try {
+            Object httpClient = httpField.get(conn);
+            return (Socket) serverSocketField.get(httpClient);
+        } catch (IllegalAccessException x) {
+            throw new InternalError("Should not reach here", x);
+        }
+    }
+
+    private void doTunneling(HttpURLConnection conn) {
+        try {
+            doTunneling.invoke(conn);
+        } catch (ReflectiveOperationException x) {
+            throw new InternalError("Should not reach here", x);
+        }
+    }
+
+    @Override
+    protected InetAddress getInetAddress() {
+        if (external_address != null)
+            return external_address.getAddress();
+        else
+            return super.getInetAddress();
+    }
+
+    @Override
+    protected int getPort() {
+        if (external_address != null)
+            return external_address.getPort();
+        else
+            return super.getPort();
+    }
+
+    @Override
+    protected int getLocalPort() {
+        if (socket != null)
+            return super.getLocalPort();
+        if (external_address != null)
+            return external_address.getPort();
+        else
+            return super.getLocalPort();
+    }
+}
--- a/jdk/src/share/classes/java/net/Socket.java	Thu Mar 07 11:32:14 2013 +0800
+++ b/jdk/src/share/classes/java/net/Socket.java	Thu Mar 07 10:07:13 2013 +0000
@@ -117,8 +117,10 @@
         if (proxy == null) {
             throw new IllegalArgumentException("Invalid Proxy");
         }
-        Proxy p = proxy == Proxy.NO_PROXY ? Proxy.NO_PROXY : sun.net.ApplicationProxy.create(proxy);
-        if (p.type() == Proxy.Type.SOCKS) {
+        Proxy p = proxy == Proxy.NO_PROXY ? Proxy.NO_PROXY
+                                          : sun.net.ApplicationProxy.create(proxy);
+        Proxy.Type type = p.type();
+        if (type == Proxy.Type.SOCKS || type == Proxy.Type.HTTP) {
             SecurityManager security = System.getSecurityManager();
             InetSocketAddress epoint = (InetSocketAddress) p.address();
             if (epoint.getAddress() != null) {
@@ -133,7 +135,8 @@
                     security.checkConnect(epoint.getAddress().getHostAddress(),
                                   epoint.getPort());
             }
-            impl = new SocksSocketImpl(p);
+            impl = type == Proxy.Type.SOCKS ? new SocksSocketImpl(p)
+                                            : new HttpConnectSocketImpl(p);
             impl.setSocket(this);
         } else {
             if (p == Proxy.NO_PROXY) {
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/jdk/test/java/net/Socket/HttpProxy.java	Thu Mar 07 10:07:13 2013 +0000
@@ -0,0 +1,281 @@
+/*
+ * Copyright (c) 2010, 2013, Oracle and/or its affiliates. All rights reserved.
+ * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
+ *
+ * This code is free software; you can redistribute it and/or modify it
+ * under the terms of the GNU General Public License version 2 only, as
+ * published by the Free Software Foundation.
+ *
+ * This code is distributed in the hope that it will be useful, but WITHOUT
+ * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
+ * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
+ * version 2 for more details (a copy is included in the LICENSE file that
+ * accompanied this code).
+ *
+ * You should have received a copy of the GNU General Public License version
+ * 2 along with this work; if not, write to the Free Software Foundation,
+ * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
+ *
+ * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
+ * or visit www.oracle.com if you need additional information or have any
+ * questions.
+ */
+
+/*
+ * @test
+ * @bug 6370908
+ * @summary Add support for HTTP_CONNECT proxy in Socket class
+ */
+
+import java.io.IOException;
+import java.io.InputStream;
+import java.io.OutputStream;
+import java.io.PrintWriter;
+import static java.lang.System.out;
+import java.net.InetAddress;
+import java.net.InetSocketAddress;
+import java.net.Proxy;
+import java.net.ServerSocket;
+import java.net.Socket;
+import sun.net.www.MessageHeader;
+
+public class HttpProxy {
+    final String proxyHost;
+    final int proxyPort;
+    static final int SO_TIMEOUT = 15000;
+
+    public static void main(String[] args) throws Exception {
+        String host;
+        int port;
+        if (args.length == 0) {
+            // Start internal proxy
+            ConnectProxyTunnelServer proxy = new ConnectProxyTunnelServer();
+            proxy.start();
+            host = "localhost";
+            port = proxy.getLocalPort();
+            out.println("Running with internal proxy: " + host + ":" + port);
+        } else if (args.length == 2) {
+            host = args[0];
+            port = Integer.valueOf(args[1]);
+            out.println("Running against specified proxy server: " + host + ":" + port);
+        } else {
+            System.err.println("Usage: java HttpProxy [<proxy host> <proxy port>]");
+            return;
+        }
+
+        HttpProxy p = new HttpProxy(host, port);
+        p.test();
+    }
+
+    public HttpProxy(String proxyHost, int proxyPort) {
+        this.proxyHost = proxyHost;
+        this.proxyPort = proxyPort;
+    }
+
+    void test() throws Exception {
+        InetSocketAddress proxyAddress = new InetSocketAddress(proxyHost, proxyPort);
+        Proxy httpProxy = new Proxy(Proxy.Type.HTTP, proxyAddress);
+
+        try (ServerSocket ss = new ServerSocket(0);
+             Socket sock = new Socket(httpProxy)) {
+            sock.setSoTimeout(SO_TIMEOUT);
+            sock.setTcpNoDelay(false);
+
+            InetSocketAddress externalAddress =
+                new InetSocketAddress(InetAddress.getLocalHost(), ss.getLocalPort());
+
+            out.println("Trying to connect to server socket on " + externalAddress);
+            sock.connect(externalAddress);
+            try (Socket externalSock = ss.accept()) {
+                // perform some simple checks
+                check(sock.isBound(), "Socket is not bound");
+                check(sock.isConnected(), "Socket is not connected");
+                check(!sock.isClosed(), "Socket should not be closed");
+                check(sock.getSoTimeout() == SO_TIMEOUT,
+                        "Socket should have a previously set timeout");
+                check(sock.getTcpNoDelay() ==  false, "NODELAY should be false");
+
+                simpleDataExchange(sock, externalSock);
+            }
+        }
+    }
+
+    static void check(boolean condition, String message) {
+        if (!condition) out.println(message);
+    }
+
+    static Exception unexpected(Exception e) {
+        out.println("Unexcepted Exception: " + e);
+        e.printStackTrace();
+        return e;
+    }
+
+    // performs a simple exchange of data between the two sockets
+    // and throws an exception if there is any problem.
+    void simpleDataExchange(Socket s1, Socket s2) throws Exception {
+        try (final InputStream i1 = s1.getInputStream();
+             final InputStream i2 = s2.getInputStream();
+             final OutputStream o1 = s1.getOutputStream();
+             final OutputStream o2 = s2.getOutputStream()) {
+            startSimpleWriter("simpleWriter1", o1, 100);
+            startSimpleWriter("simpleWriter2", o2, 200);
+            simpleRead(i2, 100);
+            simpleRead(i1, 200);
+        }
+    }
+
+    void startSimpleWriter(String threadName, final OutputStream os, final int start) {
+        (new Thread(new Runnable() {
+            public void run() {
+                try { simpleWrite(os, start); }
+                catch (Exception e) {unexpected(e); }
+            }}, threadName)).start();
+    }
+
+    void simpleWrite(OutputStream os, int start) throws Exception {
+        byte b[] = new byte[2];
+        for (int i=start; i<start+100; i++) {
+            b[0] = (byte) (i / 256);
+            b[1] = (byte) (i % 256);
+            os.write(b);
+        }
+    }
+
+    void simpleRead(InputStream is, int start) throws Exception {
+        byte b[] = new byte [2];
+        for (int i=start; i<start+100; i++) {
+            int x = is.read(b);
+            if (x == 1)
+                x += is.read(b,1,1);
+            if (x!=2)
+                throw new Exception("read error");
+            int r = bytes(b[0], b[1]);
+            if (r != i)
+                throw new Exception("read " + r + " expected " +i);
+        }
+    }
+
+    int bytes(byte b1, byte b2) {
+        int i1 = (int)b1 & 0xFF;
+        int i2 = (int)b2 & 0xFF;
+        return i1 * 256 + i2;
+    }
+
+    static class ConnectProxyTunnelServer extends Thread {
+
+        private final ServerSocket ss;
+
+        public ConnectProxyTunnelServer() throws IOException {
+            ss = new ServerSocket(0);
+        }
+
+        @Override
+        public void run() {
+            try (Socket clientSocket = ss.accept()) {
+                processRequest(clientSocket);
+            } catch (Exception e) {
+                out.println("Proxy Failed: " + e);
+                e.printStackTrace();
+            } finally {
+                try { ss.close(); } catch (IOException x) { unexpected(x); }
+            }
+        }
+
+        /**
+         * Returns the port on which the proxy is accepting connections.
+         */
+        public int getLocalPort() {
+            return ss.getLocalPort();
+        }
+
+        /*
+         * Processes the CONNECT request
+         */
+        private void processRequest(Socket clientSocket) throws Exception {
+            MessageHeader mheader = new MessageHeader(clientSocket.getInputStream());
+            String statusLine = mheader.getValue(0);
+
+            if (!statusLine.startsWith("CONNECT")) {
+                out.println("proxy server: processes only "
+                                  + "CONNECT method requests, recieved: "
+                                  + statusLine);
+                return;
+            }
+
+            // retrieve the host and port info from the status-line
+            InetSocketAddress serverAddr = getConnectInfo(statusLine);
+
+            //open socket to the server
+            try (Socket serverSocket = new Socket(serverAddr.getAddress(),
+                                                  serverAddr.getPort())) {
+                Forwarder clientFW = new Forwarder(clientSocket.getInputStream(),
+                                                   serverSocket.getOutputStream());
+                Thread clientForwarderThread = new Thread(clientFW, "ClientForwarder");
+                clientForwarderThread.start();
+                send200(clientSocket);
+                Forwarder serverFW = new Forwarder(serverSocket.getInputStream(),
+                                                   clientSocket.getOutputStream());
+                serverFW.run();
+                clientForwarderThread.join();
+            }
+        }
+
+        private void send200(Socket clientSocket) throws IOException {
+            OutputStream out = clientSocket.getOutputStream();
+            PrintWriter pout = new PrintWriter(out);
+
+            pout.println("HTTP/1.1 200 OK");
+            pout.println();
+            pout.flush();
+        }
+
+        /*
+         * This method retrieves the hostname and port of the tunnel destination
+         * from the request line.
+         * @param connectStr
+         *        of the form: <i>CONNECT server-name:server-port HTTP/1.x</i>
+         */
+        static InetSocketAddress getConnectInfo(String connectStr)
+            throws Exception
+        {
+            try {
+                int starti = connectStr.indexOf(' ');
+                int endi = connectStr.lastIndexOf(' ');
+                String connectInfo = connectStr.substring(starti+1, endi).trim();
+                // retrieve server name and port
+                endi = connectInfo.indexOf(':');
+                String name = connectInfo.substring(0, endi);
+                int port = Integer.parseInt(connectInfo.substring(endi+1));
+                return new InetSocketAddress(name, port);
+            } catch (Exception e) {
+                out.println("Proxy recieved a request: " + connectStr);
+                throw unexpected(e);
+            }
+        }
+    }
+
+    /* Reads from the given InputStream and writes to the given OutputStream */
+    static class Forwarder implements Runnable
+    {
+        private final InputStream in;
+        private final OutputStream os;
+
+        Forwarder(InputStream in, OutputStream os) {
+            this.in = in;
+            this.os = os;
+        }
+
+        @Override
+        public void run() {
+            try {
+                byte[] ba = new byte[1024];
+                int count;
+                while ((count = in.read(ba)) != -1) {
+                    os.write(ba, 0, count);
+                }
+            } catch (IOException e) {
+                unexpected(e);
+            }
+        }
+    }
+}