# HG changeset patch # User asmotrak # Date 1437009639 -28800 # Node ID f6f1365b416fe9e93fa0db6c41309a9bdcf38c42 # Parent 2e16a59cc5cb7b45d4e0000fa2b35ce78c5e079d 8049814: Additional SASL client-server tests Reviewed-by: weijun diff -r 2e16a59cc5cb -r f6f1365b416f jdk/test/javax/security/sasl/Sasl/ClientServerTest.java --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/jdk/test/javax/security/sasl/Sasl/ClientServerTest.java Thu Jul 16 09:20:39 2015 +0800 @@ -0,0 +1,477 @@ +/* + * Copyright (c) 2015, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ + +import java.io.Closeable; +import java.io.IOException; +import java.io.ObjectInputStream; +import java.io.ObjectOutputStream; +import java.io.Serializable; +import java.net.ServerSocket; +import java.net.Socket; +import java.net.UnknownHostException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.HashMap; +import java.util.Map; +import java.util.StringJoiner; +import javax.security.auth.callback.Callback; +import javax.security.auth.callback.CallbackHandler; +import javax.security.auth.callback.NameCallback; +import javax.security.auth.callback.PasswordCallback; +import javax.security.auth.callback.UnsupportedCallbackException; +import javax.security.sasl.AuthorizeCallback; +import javax.security.sasl.RealmCallback; +import javax.security.sasl.RealmChoiceCallback; +import javax.security.sasl.Sasl; +import javax.security.sasl.SaslClient; +import javax.security.sasl.SaslException; +import javax.security.sasl.SaslServer; + +/* + * @test + * @bug 8049814 + * @summary JAVA SASL server and client tests with CRAM-MD5 and + * DIGEST-MD5 mechanisms. The tests try different QOP values on + * client and server side. + * @modules java.security.sasl/javax.security.sasl + */ +public class ClientServerTest { + + private static final int DELAY = 100; + private static final String LOCALHOST = "localhost"; + private static final String DIGEST_MD5 = "DIGEST-MD5"; + private static final String CRAM_MD5 = "CRAM-MD5"; + private static final String PROTOCOL = "saslservice"; + private static final String USER_ID = "sasltester"; + private static final String PASSWD = "password"; + private static final String QOP_AUTH = "auth"; + private static final String QOP_AUTH_CONF = "auth-conf"; + private static final String QOP_AUTH_INT = "auth-int"; + private static final String AUTHID_SASL_TESTER = "sasl_tester"; + private static final ArrayList SUPPORT_MECHS = new ArrayList<>(); + + static { + SUPPORT_MECHS.add(DIGEST_MD5); + SUPPORT_MECHS.add(CRAM_MD5); + } + + public static void main(String[] args) throws Exception { + String[] allQops = { QOP_AUTH_CONF, QOP_AUTH_INT, QOP_AUTH }; + String[] twoQops = { QOP_AUTH_INT, QOP_AUTH }; + String[] authQop = { QOP_AUTH }; + String[] authIntQop = { QOP_AUTH_INT }; + String[] authConfQop = { QOP_AUTH_CONF }; + String[] emptyQop = {}; + + boolean success = true; + + success &= runTest("", CRAM_MD5, new String[] { QOP_AUTH }, + new String[] { QOP_AUTH }, false); + success &= runTest("", DIGEST_MD5, new String[] { QOP_AUTH }, + new String[] { QOP_AUTH }, false); + success &= runTest(AUTHID_SASL_TESTER, DIGEST_MD5, + new String[] { QOP_AUTH }, new String[] { QOP_AUTH }, false); + success &= runTest("", DIGEST_MD5, allQops, authQop, false); + success &= runTest("", DIGEST_MD5, allQops, authIntQop, false); + success &= runTest("", DIGEST_MD5, allQops, authConfQop, false); + success &= runTest("", DIGEST_MD5, twoQops, authQop, false); + success &= runTest("", DIGEST_MD5, twoQops, authIntQop, false); + success &= runTest("", DIGEST_MD5, twoQops, authConfQop, true); + success &= runTest("", DIGEST_MD5, authIntQop, authQop, true); + success &= runTest("", DIGEST_MD5, authConfQop, authQop, true); + success &= runTest("", DIGEST_MD5, authConfQop, emptyQop, true); + success &= runTest("", DIGEST_MD5, authIntQop, emptyQop, true); + success &= runTest("", DIGEST_MD5, authQop, emptyQop, true); + + if (!success) { + throw new RuntimeException("At least one test case failed"); + } + + System.out.println("Test passed"); + } + + private static boolean runTest(String authId, String mech, + String[] clientQops, String[] serverQops, boolean expectException) + throws Exception { + + System.out.println("AuthId:" + authId + + " mechanism:" + mech + + " clientQops: " + Arrays.toString(clientQops) + + " serverQops: " + Arrays.toString(serverQops) + + " expect exception:" + expectException); + + try (Server server = Server.start(LOCALHOST, authId, serverQops)) { + new Client(LOCALHOST, server.getPort(), mech, authId, clientQops) + .run(); + if (expectException) { + System.out.println("Expected exception not thrown"); + return false; + } + } catch (SaslException e) { + if (!expectException) { + System.out.println("Unexpected exception: " + e); + return false; + } + System.out.println("Expected exception: " + e); + } + + return true; + } + + static enum SaslStatus { + SUCCESS, FAILURE, CONTINUE + } + + static class Message implements Serializable { + + private final SaslStatus status; + private final byte[] data; + + public Message(SaslStatus status, byte[] data) { + this.status = status; + this.data = data; + } + + public SaslStatus getStatus() { + return status; + } + + public byte[] getData() { + return data; + } + } + + static class SaslPeer { + + final String host; + final String mechanism; + final String qop; + final CallbackHandler callback; + + SaslPeer(String host, String authId, String... qops) { + this(host, null, authId, qops); + } + + SaslPeer(String host, String mechanism, String authId, String... qops) { + this.host = host; + this.mechanism = mechanism; + + StringJoiner sj = new StringJoiner(","); + for (String q : qops) { + sj.add(q); + } + qop = sj.toString(); + + callback = new TestCallbackHandler(USER_ID, PASSWD, host, authId); + } + + Message getMessage(Object ob) { + if (!(ob instanceof Message)) { + throw new RuntimeException("Expected an instance of Message"); + } + return (Message) ob; + } + } + + static class Server extends SaslPeer implements Runnable, Closeable { + + private volatile boolean ready = false; + private volatile ServerSocket ssocket; + + static Server start(String host, String authId, String[] serverQops) + throws UnknownHostException { + Server server = new Server(host, authId, serverQops); + Thread thread = new Thread(server); + thread.setDaemon(true); + thread.start(); + + while (!server.ready) { + try { + Thread.sleep(DELAY); + } catch (InterruptedException e) { + throw new RuntimeException(e); + } + } + + return server; + } + + Server(String host, String authId, String... qops) { + super(host, authId, qops); + } + + int getPort() { + return ssocket.getLocalPort(); + } + + private void processConnection(SaslEndpoint endpoint) + throws SaslException, IOException, ClassNotFoundException { + System.out.println("process connection"); + endpoint.send(SUPPORT_MECHS); + Object o = endpoint.receive(); + if (!(o instanceof String)) { + throw new RuntimeException("Received unexpected object: " + o); + } + String mech = (String) o; + SaslServer saslServer = createSaslServer(mech); + Message msg = getMessage(endpoint.receive()); + while (!saslServer.isComplete()) { + byte[] data = processData(msg.getData(), endpoint, + saslServer); + if (saslServer.isComplete()) { + System.out.println("server is complete"); + endpoint.send(new Message(SaslStatus.SUCCESS, data)); + } else { + System.out.println("server continues"); + endpoint.send(new Message(SaslStatus.CONTINUE, data)); + msg = getMessage(endpoint.receive()); + } + } + } + + private byte[] processData(byte[] data, SaslEndpoint endpoint, + SaslServer server) throws SaslException, IOException { + try { + return server.evaluateResponse(data); + } catch (SaslException e) { + endpoint.send(new Message(SaslStatus.FAILURE, null)); + System.out.println("Error while processing data"); + throw e; + } + } + + private SaslServer createSaslServer(String mechanism) + throws SaslException { + Map props = new HashMap<>(); + props.put(Sasl.QOP, qop); + return Sasl.createSaslServer(mechanism, PROTOCOL, host, props, + callback); + } + + @Override + public void run() { + try (ServerSocket ss = new ServerSocket(0)) { + ssocket = ss; + System.out.println("server started on port " + getPort()); + ready = true; + Socket socket = ss.accept(); + try (SaslEndpoint endpoint = new SaslEndpoint(socket)) { + System.out.println("server accepted connection"); + processConnection(endpoint); + } + } catch (Exception e) { + // ignore it for now, client will throw an exception + } + } + + @Override + public void close() throws IOException { + if (!ssocket.isClosed()) { + ssocket.close(); + } + } + } + + static class Client extends SaslPeer { + + private final int port; + + Client(String host, int port, String mech, String authId, + String... qops) { + super(host, mech, authId, qops); + this.port = port; + } + + public void run() throws Exception { + System.out.println("Host:" + host + " port: " + + port); + try (SaslEndpoint endpoint = SaslEndpoint.create(host, port)) { + negotiateMechanism(endpoint); + SaslClient client = createSaslClient(); + byte[] data = new byte[0]; + if (client.hasInitialResponse()) { + data = client.evaluateChallenge(data); + } + endpoint.send(new Message(SaslStatus.CONTINUE, data)); + Message msg = getMessage(endpoint.receive()); + while (!client.isComplete() + && msg.getStatus() != SaslStatus.FAILURE) { + switch (msg.getStatus()) { + case CONTINUE: + System.out.println("client continues"); + data = client.evaluateChallenge(msg.getData()); + endpoint.send(new Message(SaslStatus.CONTINUE, + data)); + msg = getMessage(endpoint.receive()); + break; + case SUCCESS: + System.out.println("client succeeded"); + data = client.evaluateChallenge(msg.getData()); + if (data != null) { + throw new SaslException("data should be null"); + } + break; + default: + throw new RuntimeException("Wrong status:" + + msg.getStatus()); + } + } + + if (msg.getStatus() == SaslStatus.FAILURE) { + throw new RuntimeException("Status is FAILURE"); + } + } + + System.out.println("Done"); + } + + private SaslClient createSaslClient() throws SaslException { + Map props = new HashMap<>(); + props.put(Sasl.QOP, qop); + return Sasl.createSaslClient(new String[] {mechanism}, USER_ID, + PROTOCOL, host, props, callback); + } + + private void negotiateMechanism(SaslEndpoint endpoint) + throws ClassNotFoundException, IOException { + Object o = endpoint.receive(); + if (o instanceof ArrayList) { + ArrayList list = (ArrayList) o; + if (!list.contains(mechanism)) { + throw new RuntimeException( + "Server does not support specified mechanism:" + + mechanism); + } + } else { + throw new RuntimeException( + "Expected an instance of ArrayList, but received " + o); + } + + endpoint.send(mechanism); + } + + } + + static class SaslEndpoint implements AutoCloseable { + + private final Socket socket; + private ObjectInputStream input; + private ObjectOutputStream output; + + static SaslEndpoint create(String host, int port) throws IOException { + return new SaslEndpoint(new Socket(host, port)); + } + + SaslEndpoint(Socket socket) throws IOException { + this.socket = socket; + } + + private ObjectInputStream getInput() throws IOException { + if (input == null && socket != null) { + input = new ObjectInputStream(socket.getInputStream()); + } + return input; + } + + private ObjectOutputStream getOutput() throws IOException { + if (output == null && socket != null) { + output = new ObjectOutputStream(socket.getOutputStream()); + } + return output; + } + + public Object receive() throws IOException, ClassNotFoundException { + return getInput().readObject(); + } + + public void send(Object obj) throws IOException { + getOutput().writeObject(obj); + getOutput().flush(); + } + + @Override + public void close() throws IOException { + if (socket != null && !socket.isClosed()) { + socket.close(); + } + } + + } + + static class TestCallbackHandler implements CallbackHandler { + + private final String userId; + private final char[] passwd; + private final String realm; + private String authId; + + TestCallbackHandler(String userId, String passwd, String realm, + String authId) { + this.userId = userId; + this.passwd = passwd.toCharArray(); + this.realm = realm; + this.authId = authId; + } + + @Override + public void handle(Callback[] callbacks) throws IOException, + UnsupportedCallbackException { + for (Callback callback : callbacks) { + if (callback instanceof NameCallback) { + System.out.println("NameCallback"); + ((NameCallback) callback).setName(userId); + } else if (callback instanceof PasswordCallback) { + System.out.println("PasswordCallback"); + ((PasswordCallback) callback).setPassword(passwd); + } else if (callback instanceof RealmCallback) { + System.out.println("RealmCallback"); + ((RealmCallback) callback).setText(realm); + } else if (callback instanceof RealmChoiceCallback) { + System.out.println("RealmChoiceCallback"); + RealmChoiceCallback choice = (RealmChoiceCallback) callback; + if (realm == null) { + choice.setSelectedIndex(choice.getDefaultChoice()); + } else { + String[] choices = choice.getChoices(); + for (int j = 0; j < choices.length; j++) { + if (realm.equals(choices[j])) { + choice.setSelectedIndex(j); + break; + } + } + } + } else if (callback instanceof AuthorizeCallback) { + System.out.println("AuthorizeCallback"); + ((AuthorizeCallback) callback).setAuthorized(true); + if (authId == null || authId.trim().length() == 0) { + authId = userId; + } + ((AuthorizeCallback) callback).setAuthorizedID(authId); + } else { + throw new UnsupportedCallbackException(callback); + } + } + } + } + +}