test/jdk/java/net/Socks/SocksServer.java
author aeubanks
Tue, 16 Apr 2019 13:06:23 -0700
changeset 54770 62b6e7587b1f
parent 54634 59c01214e478
permissions -rw-r--r--
8220673: Add test library support for determining platform IP support Reviewed-by: dfuchs, chegar Contributed-by: aeubanks@google.com

/*
 * Copyright (c) 2002, 2019, Oracle and/or its affiliates. All rights reserved.
 * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
 *
 * This code is free software; you can redistribute it and/or modify it
 * under the terms of the GNU General Public License version 2 only, as
 * published by the Free Software Foundation.
 *
 * This code is distributed in the hope that it will be useful, but WITHOUT
 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
 * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
 * version 2 for more details (a copy is included in the LICENSE file that
 * accompanied this code).
 *
 * You should have received a copy of the GNU General Public License version
 * 2 along with this work; if not, write to the Free Software Foundation,
 * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
 *
 * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
 * or visit www.oracle.com if you need additional information or have any
 * questions.
 */
import java.net.*;
import java.io.*;
import java.util.HashMap;

public class SocksServer extends Thread implements Closeable {
    // Some useful SOCKS constant

    static final int PROTO_VERS4        = 4;
    static final int PROTO_VERS         = 5;
    static final int DEFAULT_PORT       = 1080;

    static final int NO_AUTH            = 0;
    static final int GSSAPI             = 1;
    static final int USER_PASSW         = 2;
    static final int NO_METHODS         = -1;

    static final int CONNECT            = 1;
    static final int BIND               = 2;
    static final int UDP_ASSOC          = 3;

    static final int IPV4               = 1;
    static final int DOMAIN_NAME        = 3;
    static final int IPV6               = 4;

    static final int REQUEST_OK         = 0;
    static final int GENERAL_FAILURE    = 1;
    static final int NOT_ALLOWED        = 2;
    static final int NET_UNREACHABLE    = 3;
    static final int HOST_UNREACHABLE   = 4;
    static final int CONN_REFUSED       = 5;
    static final int TTL_EXPIRED        = 6;
    static final int CMD_NOT_SUPPORTED  = 7;
    static final int ADDR_TYPE_NOT_SUP  = 8;

    private int port;
    private ServerSocket server;
    private boolean useV4 = false;
    private HashMap<String,String> users = new HashMap<>();
    private volatile boolean done = false;
    // Inner class to handle protocol with client
    // This is the bulk of the work (protocol handler)
    class ClientHandler extends Thread {
        private InputStream in;
        private OutputStream out;
        private Socket client;
        private Socket dest;

        // Simple tunneling class, moving bits from one stream to another

        class Tunnel extends Thread {
            private InputStream tin;
            private OutputStream tout;

            Tunnel(InputStream in, OutputStream out) {
                tin = in;
                tout = out;
            }

            public void run() {
                int b;
                while (true) {
                    try {
                        b = tin.read();
                        if (b == -1) {
                            tin.close();
                            tout.close();
                            return;
                        }
                        tout.write(b);
                        tout.flush();
                    } catch (IOException e) {
                        // actually exit from the thread
                        return;
                    }
                }
            }
        }

        ClientHandler(Socket s) throws IOException {
            client = s;
            in = new BufferedInputStream(client.getInputStream());
            out = new BufferedOutputStream(client.getOutputStream());
        }

        private void readBuf(InputStream is, byte[] buf) throws IOException {
            int l = buf.length;
            int count = 0;
            int i;
            do {
                i = is.read(buf, count, l - count);
                if (i == -1)
                    throw new IOException("unexpected EOF");
                count += i;
            } while (count < l);
        }


        private boolean userPassAuth() throws IOException {
            int ver = in.read();
            int ulen = in.read();
            if (ulen <= 0)
                throw new SocketException("SOCKS protocol error");
            byte[] buf = new byte[ulen];
            readBuf(in, buf);
            String uname = new String(buf);
            String password = null;
            ulen = in.read();
            if (ulen < 0)
                throw new SocketException("SOCKS protocol error");
            if (ulen > 0) {
                buf = new byte[ulen];
                readBuf(in, buf);
                password = new String(buf);
            }
            // Check username/password validity here
            System.err.println("User: '" + uname);
            System.err.println("PSWD: '" + password);
            if (users.containsKey(uname)) {
                String p1 = users.get(uname);
                System.err.println("p1 = " + p1);
                if (p1.equals(password)) {
                    out.write(PROTO_VERS);
                    out.write(REQUEST_OK);
                    out.flush();
                    return true;
                }
            }
            out.write(PROTO_VERS);
            out.write(NOT_ALLOWED);
            out.flush();
            return false;
        }

        private void purge() throws IOException {
            boolean done = false;
            int i = 0;
            client.setSoTimeout(1000);
            while(!done && i != -1) {
                try {
                    i = in.read();
                } catch(IOException e) {
                    done = true;
                }
            }
        }


        // Handle the SOCKS version 4 protocl

        private void getRequestV4() throws IOException {
            int ver = in.read();
            int cmd = in.read();
            if (ver == -1 || cmd == -1) {
                // EOF
                in.close();
                out.close();
                return;
            }

            if (ver != 0 && ver != 4) {
                out.write(PROTO_VERS4);
                out.write(91); // Bad Request
                out.write(0);
                out.write(0);
                out.write(0);
                out.write(0);
                out.write(0);
                out.write(0);
                out.write(0);
                out.flush();
                purge();
                out.close();
                in.close();
                return;
            }

            if (cmd == CONNECT) {
                int port = ((in.read() & 0xff) << 8);
                port += (in.read() & 0xff);
                byte[] buf = new byte[4];
                readBuf(in, buf);
                InetAddress addr = InetAddress.getByAddress(buf);
                // We don't use the username...
                int c;
                do {
                    c = (in.read() & 0xff);
                } while (c!=0);
                boolean ok = true;
                try {
                    dest = new Socket(addr, port);
                } catch (IOException e) {
                    ok = false;
                }
                if (!ok) {
                    out.write(PROTO_VERS4);
                    out.write(91);
                    out.write(0);
                    out.write(0);
                    out.write(buf);
                    out.flush();
                    purge();
                    out.close();
                    in.close();
                    return;
                }
                out.write(PROTO_VERS4);
                out.write(90); // Success
                out.write((port >> 8) & 0xff);
                out.write(port & 0xff);
                out.write(buf);
                out.flush();
                InputStream in2 = new BufferedInputStream(dest.getInputStream());
                OutputStream out2 = new BufferedOutputStream(dest.getOutputStream());

                Tunnel tunnel = new Tunnel(in2, out);
                tunnel.start();

                int b = 0;
                do {
                    try {
                        b = in.read();
                        if (b == -1) {
                            in.close();
                            out2.close();
                            return;
                        }
                        out2.write(b);
                        out2.flush();
                    } catch (IOException ex) {
                    }
                } while (!client.isClosed());
            }
        }


        // Negociate the authentication scheme with the client
        private void negociate() throws IOException {
            int ver = in.read();
            int n = in.read();
            byte[] buf = null;
            if (n > 0) {
                buf = new byte[n];
                readBuf(in, buf);
            }
            int scheme = NO_AUTH;
            for (int i = 0; i < n; i++)
                if (buf[i] == USER_PASSW)
                    scheme = USER_PASSW;
            out.write(PROTO_VERS);
            out.write(scheme);
            out.flush();
            if (scheme == USER_PASSW)
                userPassAuth();
        }

        // Send error message then close the streams
        private void sendError(int code) {
            try {
                out.write(PROTO_VERS);
                out.write(code);
                out.write(0);
                out.write(IPV4);
                for (int i=0; i<6; i++)
                    out.write(0);
                out.flush();
                out.close();
            } catch (IOException ex) {
            }
        }

        // Actually connect the proxy to the destination then initiate tunneling

        private void doConnect(InetSocketAddress addr) throws IOException {
            dest = new Socket();
            try {
                dest.connect(addr, 10000);
            } catch (SocketTimeoutException ex) {
                sendError(HOST_UNREACHABLE);
                return;
            } catch (ConnectException cex) {
                sendError(CONN_REFUSED);
                return;
            }
            // Success
            InetAddress iadd = addr.getAddress();
            if (iadd instanceof Inet4Address) {
                out.write(PROTO_VERS);
                out.write(REQUEST_OK);
                out.write(0);
                out.write(IPV4);
                out.write(iadd.getAddress());
            } else if (iadd instanceof Inet6Address) {
                out.write(PROTO_VERS);
                out.write(REQUEST_OK);
                out.write(0);
                out.write(IPV6);
                out.write(iadd.getAddress());
            } else {
                sendError(GENERAL_FAILURE);
                return;
            }
            out.write((addr.getPort() >> 8) & 0xff);
            out.write((addr.getPort() >> 0) & 0xff);
            out.flush();

            InputStream in2 = new BufferedInputStream(dest.getInputStream());
            OutputStream out2 = new BufferedOutputStream(dest.getOutputStream());

            Tunnel tunnel = new Tunnel(in2, out);
            tunnel.start();

            int b = 0;
            do {
                // Note that the socket might be closed from another thread (the tunnel)
                try {
                    b = in.read();
                    if (b == -1) {
                        in.close();
                        out2.close();
                        return;
                    }
                    out2.write(b);
                    out2.flush();
                } catch(IOException ioe) {
                }
            } while (!client.isClosed());
        }

        private void doBind(InetSocketAddress addr) throws IOException {
            ServerSocket svr = new ServerSocket();
            svr.bind(null);
            InetSocketAddress bad = (InetSocketAddress) svr.getLocalSocketAddress();
            out.write(PROTO_VERS);
            out.write(REQUEST_OK);
            out.write(0);
            out.write(IPV4);
            out.write(bad.getAddress().getAddress());
            out.write((bad.getPort() >> 8) & 0xff);
            out.write((bad.getPort() & 0xff));
            out.flush();
            dest = svr.accept();
            bad = (InetSocketAddress) dest.getRemoteSocketAddress();
            out.write(PROTO_VERS);
            out.write(REQUEST_OK);
            out.write(0);
            out.write(IPV4);
            out.write(bad.getAddress().getAddress());
            out.write((bad.getPort() >> 8) & 0xff);
            out.write((bad.getPort() & 0xff));
            out.flush();
            InputStream in2 = dest.getInputStream();
            OutputStream out2 = dest.getOutputStream();

            Tunnel tunnel = new Tunnel(in2, out);
            tunnel.start();

            int b = 0;
            do {
                // Note that the socket might be close from another thread (the tunnel)
                try {
                    b = in.read();
                    if (b == -1) {
                        in.close();
                        out2.close();
                        return;
                    }
                    out2.write(b);
                    out2.flush();
                } catch(IOException ioe) {
                }
            } while (!client.isClosed());

        }

        // Handle the SOCKS v5 requests

        private void getRequest() throws IOException {
            int ver = in.read();
            int cmd = in.read();
            if (ver == -1 || cmd == -1) {
                in.close();
                out.close();
                return;
            }
            int rsv = in.read();
            int atyp = in.read();
            String addr = null;
            int port = 0;

            switch(atyp) {
            case IPV4:
                {
                byte[] buf = new byte[4];
                readBuf(in, buf);
                addr = InetAddress.getByAddress(buf).getHostAddress();
                }
                break;
            case DOMAIN_NAME:
                {
                int i = in.read();
                byte[] buf = new byte[i];
                readBuf(in, buf);
                addr = new String(buf);
                }
                break;
            case IPV6:
                {
                byte[] buf = new byte[16];
                readBuf(in, buf);
                addr = InetAddress.getByAddress(buf).getHostAddress();
                }
                break;
            }

            port = ((in.read()&0xff) << 8);
            port += (in.read()&0xff);

            InetSocketAddress socAddr = new InetSocketAddress(addr, port);
            switch(cmd) {
            case CONNECT:
                doConnect(socAddr);
                break;
            case BIND:
                doBind(socAddr);
                break;
            case UDP_ASSOC:
                // doUDP(socAddr);
                break;
            }
        }

        public void run() {
            String line = null;
            try {
                if (useV4) {
                    getRequestV4();
                } else {
                    negociate();
                    getRequest();
                }
            } catch (IOException ex) {
                try {
                    sendError(GENERAL_FAILURE);
                } catch (Exception e) {
                }
            } finally {
                try {
                    client.close();
                } catch (IOException e2) {
                }
            }
        }

    }

    public SocksServer(int port, boolean v4) throws IOException {
        this(port);
        this.useV4 = v4;
    }

    public SocksServer(int port) throws IOException {
        this.port = port;
        server = new ServerSocket();
        if (port == 0) {
            server.bind(null);
            this.port = server.getLocalPort();
        } else {
            server.bind(new InetSocketAddress(port));
        }
    }

    public SocksServer(InetAddress addr, int port, boolean useV4) throws IOException {
        this.port = port;
        this.useV4 = useV4;
        server = new ServerSocket();
        if (port == 0 && addr == null) {
            server.bind(null);
            this.port = server.getLocalPort();
        } else if (port == 0 && addr != null) {
            server.bind(new InetSocketAddress(addr, 0));
            this.port = server.getLocalPort();
        } else if (addr == null) {
            assert port != 0;
            server.bind(new InetSocketAddress(port));
        } else {
            assert port != 0;
            server.bind(new InetSocketAddress(addr, port));
        }
    }

    public SocksServer() throws IOException {
        this (DEFAULT_PORT);
    }

    public void addUser(String user, String passwd) {
        users.put(user, passwd);
    }

    public int getPort() {
        return port;
    }

    public void close() {
        done = true;
        try { server.close(); } catch (IOException unused) {}
    }

    public void run() {
        ClientHandler cl = null;
        while (!done) {
            try {
                Socket s = server.accept();
                cl = new ClientHandler(s);
                cl.start();
            } catch (IOException ex) {
                if (cl != null)
                    cl.interrupt();
            }
        }
    }
}