jdk/test/javax/security/sasl/Sasl/ClientServerTest.java
author asmotrak
Thu, 16 Jul 2015 09:20:39 +0800
changeset 31724 f6f1365b416f
permissions -rw-r--r--
8049814: Additional SASL client-server tests Reviewed-by: weijun

/*
 * Copyright (c) 2015, Oracle and/or its affiliates. All rights reserved.
 * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
 *
 * This code is free software; you can redistribute it and/or modify it
 * under the terms of the GNU General Public License version 2 only, as
 * published by the Free Software Foundation.
 *
 * This code is distributed in the hope that it will be useful, but WITHOUT
 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
 * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
 * version 2 for more details (a copy is included in the LICENSE file that
 * accompanied this code).
 *
 * You should have received a copy of the GNU General Public License version
 * 2 along with this work; if not, write to the Free Software Foundation,
 * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
 *
 * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
 * or visit www.oracle.com if you need additional information or have any
 * questions.
 */

import java.io.Closeable;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.io.Serializable;
import java.net.ServerSocket;
import java.net.Socket;
import java.net.UnknownHostException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Map;
import java.util.StringJoiner;
import javax.security.auth.callback.Callback;
import javax.security.auth.callback.CallbackHandler;
import javax.security.auth.callback.NameCallback;
import javax.security.auth.callback.PasswordCallback;
import javax.security.auth.callback.UnsupportedCallbackException;
import javax.security.sasl.AuthorizeCallback;
import javax.security.sasl.RealmCallback;
import javax.security.sasl.RealmChoiceCallback;
import javax.security.sasl.Sasl;
import javax.security.sasl.SaslClient;
import javax.security.sasl.SaslException;
import javax.security.sasl.SaslServer;

/*
 * @test
 * @bug 8049814
 * @summary JAVA SASL server and client tests with CRAM-MD5 and
 *          DIGEST-MD5 mechanisms. The tests try different QOP values on
 *          client and server side.
 * @modules java.security.sasl/javax.security.sasl
 */
public class ClientServerTest {

    private static final int DELAY = 100;
    private static final String LOCALHOST = "localhost";
    private static final String DIGEST_MD5 = "DIGEST-MD5";
    private static final String CRAM_MD5 = "CRAM-MD5";
    private static final String PROTOCOL = "saslservice";
    private static final String USER_ID = "sasltester";
    private static final String PASSWD = "password";
    private static final String QOP_AUTH = "auth";
    private static final String QOP_AUTH_CONF = "auth-conf";
    private static final String QOP_AUTH_INT = "auth-int";
    private static final String AUTHID_SASL_TESTER = "sasl_tester";
    private static final ArrayList<String> SUPPORT_MECHS = new ArrayList<>();

    static {
        SUPPORT_MECHS.add(DIGEST_MD5);
        SUPPORT_MECHS.add(CRAM_MD5);
    }

    public static void main(String[] args) throws Exception {
        String[] allQops = { QOP_AUTH_CONF, QOP_AUTH_INT, QOP_AUTH };
        String[] twoQops = { QOP_AUTH_INT, QOP_AUTH };
        String[] authQop = { QOP_AUTH };
        String[] authIntQop = { QOP_AUTH_INT };
        String[] authConfQop = { QOP_AUTH_CONF };
        String[] emptyQop = {};

        boolean success = true;

        success &= runTest("", CRAM_MD5, new String[] { QOP_AUTH },
                new String[] { QOP_AUTH }, false);
        success &= runTest("", DIGEST_MD5, new String[] { QOP_AUTH },
                new String[] { QOP_AUTH }, false);
        success &= runTest(AUTHID_SASL_TESTER, DIGEST_MD5,
                new String[] { QOP_AUTH }, new String[] { QOP_AUTH }, false);
        success &= runTest("", DIGEST_MD5, allQops, authQop, false);
        success &= runTest("", DIGEST_MD5, allQops, authIntQop, false);
        success &= runTest("", DIGEST_MD5, allQops, authConfQop, false);
        success &= runTest("", DIGEST_MD5, twoQops, authQop, false);
        success &= runTest("", DIGEST_MD5, twoQops, authIntQop, false);
        success &= runTest("", DIGEST_MD5, twoQops, authConfQop, true);
        success &= runTest("", DIGEST_MD5, authIntQop, authQop, true);
        success &= runTest("", DIGEST_MD5, authConfQop, authQop, true);
        success &= runTest("", DIGEST_MD5, authConfQop, emptyQop, true);
        success &= runTest("", DIGEST_MD5, authIntQop, emptyQop, true);
        success &= runTest("", DIGEST_MD5, authQop, emptyQop, true);

        if (!success) {
            throw new RuntimeException("At least one test case failed");
        }

        System.out.println("Test passed");
    }

    private static boolean runTest(String authId, String mech,
            String[] clientQops, String[] serverQops, boolean expectException)
            throws Exception {

        System.out.println("AuthId:" + authId
                + " mechanism:" + mech
                + " clientQops: " + Arrays.toString(clientQops)
                + " serverQops: " + Arrays.toString(serverQops)
                + " expect exception:" + expectException);

        try (Server server = Server.start(LOCALHOST, authId, serverQops)) {
            new Client(LOCALHOST, server.getPort(), mech, authId, clientQops)
                    .run();
            if (expectException) {
                System.out.println("Expected exception not thrown");
                return false;
            }
        } catch (SaslException e) {
            if (!expectException) {
                System.out.println("Unexpected exception: " + e);
                return false;
            }
            System.out.println("Expected exception: " + e);
        }

        return true;
    }

    static enum SaslStatus {
        SUCCESS, FAILURE, CONTINUE
    }

    static class Message implements Serializable {

        private final SaslStatus status;
        private final byte[] data;

        public Message(SaslStatus status, byte[] data) {
            this.status = status;
            this.data = data;
        }

        public SaslStatus getStatus() {
            return status;
        }

        public byte[] getData() {
            return data;
        }
    }

    static class SaslPeer {

        final String host;
        final String mechanism;
        final String qop;
        final CallbackHandler callback;

        SaslPeer(String host, String authId, String... qops) {
            this(host, null, authId, qops);
        }

        SaslPeer(String host, String mechanism, String authId, String... qops) {
            this.host = host;
            this.mechanism = mechanism;

            StringJoiner sj = new StringJoiner(",");
            for (String q : qops) {
                sj.add(q);
            }
            qop = sj.toString();

            callback = new TestCallbackHandler(USER_ID, PASSWD, host, authId);
        }

        Message getMessage(Object ob) {
            if (!(ob instanceof Message)) {
                throw new RuntimeException("Expected an instance of Message");
            }
            return (Message) ob;
        }
    }

    static class Server extends SaslPeer implements Runnable, Closeable {

        private volatile boolean ready = false;
        private volatile ServerSocket ssocket;

        static Server start(String host, String authId, String[] serverQops)
                throws UnknownHostException {
            Server server = new Server(host, authId, serverQops);
            Thread thread = new Thread(server);
            thread.setDaemon(true);
            thread.start();

            while (!server.ready) {
                try {
                    Thread.sleep(DELAY);
                } catch (InterruptedException e) {
                    throw new RuntimeException(e);
                }
            }

            return server;
        }

        Server(String host, String authId, String... qops) {
            super(host, authId, qops);
        }

        int getPort() {
            return ssocket.getLocalPort();
        }

        private void processConnection(SaslEndpoint endpoint)
                throws SaslException, IOException, ClassNotFoundException {
            System.out.println("process connection");
            endpoint.send(SUPPORT_MECHS);
            Object o = endpoint.receive();
            if (!(o instanceof String)) {
                throw new RuntimeException("Received unexpected object: " + o);
            }
            String mech = (String) o;
            SaslServer saslServer = createSaslServer(mech);
            Message msg = getMessage(endpoint.receive());
            while (!saslServer.isComplete()) {
                byte[] data = processData(msg.getData(), endpoint,
                        saslServer);
                if (saslServer.isComplete()) {
                    System.out.println("server is complete");
                    endpoint.send(new Message(SaslStatus.SUCCESS, data));
                } else {
                    System.out.println("server continues");
                    endpoint.send(new Message(SaslStatus.CONTINUE, data));
                    msg = getMessage(endpoint.receive());
                }
            }
        }

        private byte[] processData(byte[] data, SaslEndpoint endpoint,
                SaslServer server) throws SaslException, IOException {
            try {
                return server.evaluateResponse(data);
            } catch (SaslException e) {
                endpoint.send(new Message(SaslStatus.FAILURE, null));
                System.out.println("Error while processing data");
                throw e;
            }
        }

        private SaslServer createSaslServer(String mechanism)
                throws SaslException {
            Map<String, String> props = new HashMap<>();
            props.put(Sasl.QOP, qop);
            return Sasl.createSaslServer(mechanism, PROTOCOL, host, props,
                    callback);
        }

        @Override
        public void run() {
            try (ServerSocket ss = new ServerSocket(0)) {
                ssocket = ss;
                System.out.println("server started on port " + getPort());
                ready = true;
                Socket socket = ss.accept();
                try (SaslEndpoint endpoint = new SaslEndpoint(socket)) {
                    System.out.println("server accepted connection");
                    processConnection(endpoint);
                }
            } catch (Exception e) {
                // ignore it for now, client will throw an exception
            }
        }

        @Override
        public void close() throws IOException {
            if (!ssocket.isClosed()) {
                ssocket.close();
            }
        }
    }

    static class Client extends SaslPeer {

        private final int port;

        Client(String host, int port, String mech, String authId,
                String... qops) {
            super(host, mech, authId, qops);
            this.port = port;
        }

        public void run() throws Exception {
            System.out.println("Host:" + host + " port: "
                    + port);
            try (SaslEndpoint endpoint = SaslEndpoint.create(host, port)) {
                negotiateMechanism(endpoint);
                SaslClient client = createSaslClient();
                byte[] data = new byte[0];
                if (client.hasInitialResponse()) {
                    data = client.evaluateChallenge(data);
                }
                endpoint.send(new Message(SaslStatus.CONTINUE, data));
                Message msg = getMessage(endpoint.receive());
                while (!client.isComplete()
                        && msg.getStatus() != SaslStatus.FAILURE) {
                    switch (msg.getStatus()) {
                        case CONTINUE:
                            System.out.println("client continues");
                            data = client.evaluateChallenge(msg.getData());
                            endpoint.send(new Message(SaslStatus.CONTINUE,
                                    data));
                            msg = getMessage(endpoint.receive());
                            break;
                        case SUCCESS:
                            System.out.println("client succeeded");
                            data = client.evaluateChallenge(msg.getData());
                            if (data != null) {
                                throw new SaslException("data should be null");
                            }
                            break;
                        default:
                            throw new RuntimeException("Wrong status:"
                                    + msg.getStatus());
                    }
                }

                if (msg.getStatus() == SaslStatus.FAILURE) {
                    throw new RuntimeException("Status is FAILURE");
                }
            }

            System.out.println("Done");
        }

        private SaslClient createSaslClient() throws SaslException {
            Map<String, String> props = new HashMap<>();
            props.put(Sasl.QOP, qop);
            return Sasl.createSaslClient(new String[] {mechanism}, USER_ID,
                    PROTOCOL, host, props, callback);
        }

        private void negotiateMechanism(SaslEndpoint endpoint)
                throws ClassNotFoundException, IOException {
            Object o = endpoint.receive();
            if (o instanceof ArrayList) {
                ArrayList list = (ArrayList) o;
                if (!list.contains(mechanism)) {
                    throw new RuntimeException(
                            "Server does not support specified mechanism:"
                                    + mechanism);
                }
            } else {
                throw new RuntimeException(
                        "Expected an instance of ArrayList, but received " + o);
            }

            endpoint.send(mechanism);
        }

    }

    static class SaslEndpoint implements AutoCloseable {

        private final Socket socket;
        private ObjectInputStream input;
        private ObjectOutputStream output;

        static SaslEndpoint create(String host, int port) throws IOException {
            return new SaslEndpoint(new Socket(host, port));
        }

        SaslEndpoint(Socket socket) throws IOException {
            this.socket = socket;
        }

        private ObjectInputStream getInput() throws IOException {
            if (input == null && socket != null) {
                input = new ObjectInputStream(socket.getInputStream());
            }
            return input;
        }

        private ObjectOutputStream getOutput() throws IOException {
            if (output == null && socket != null) {
                output = new ObjectOutputStream(socket.getOutputStream());
            }
            return output;
        }

        public Object receive() throws IOException, ClassNotFoundException {
            return getInput().readObject();
        }

        public void send(Object obj) throws IOException {
            getOutput().writeObject(obj);
            getOutput().flush();
        }

        @Override
        public void close() throws IOException {
            if (socket != null && !socket.isClosed()) {
                socket.close();
            }
        }

    }

    static class TestCallbackHandler implements CallbackHandler {

        private final String userId;
        private final char[] passwd;
        private final String realm;
        private String authId;

        TestCallbackHandler(String userId, String passwd, String realm,
                String authId) {
            this.userId = userId;
            this.passwd = passwd.toCharArray();
            this.realm = realm;
            this.authId = authId;
        }

        @Override
        public void handle(Callback[] callbacks) throws IOException,
                UnsupportedCallbackException {
            for (Callback callback : callbacks) {
                if (callback instanceof NameCallback) {
                    System.out.println("NameCallback");
                    ((NameCallback) callback).setName(userId);
                } else if (callback instanceof PasswordCallback) {
                    System.out.println("PasswordCallback");
                    ((PasswordCallback) callback).setPassword(passwd);
                } else if (callback instanceof RealmCallback) {
                    System.out.println("RealmCallback");
                    ((RealmCallback) callback).setText(realm);
                } else if (callback instanceof RealmChoiceCallback) {
                    System.out.println("RealmChoiceCallback");
                    RealmChoiceCallback choice = (RealmChoiceCallback) callback;
                    if (realm == null) {
                        choice.setSelectedIndex(choice.getDefaultChoice());
                    } else {
                        String[] choices = choice.getChoices();
                        for (int j = 0; j < choices.length; j++) {
                            if (realm.equals(choices[j])) {
                                choice.setSelectedIndex(j);
                                break;
                            }
                        }
                    }
                } else if (callback instanceof AuthorizeCallback) {
                    System.out.println("AuthorizeCallback");
                    ((AuthorizeCallback) callback).setAuthorized(true);
                    if (authId == null || authId.trim().length() == 0) {
                        authId = userId;
                    }
                    ((AuthorizeCallback) callback).setAuthorizedID(authId);
                } else {
                    throw new UnsupportedCallbackException(callback);
                }
            }
        }
    }

}