--- /dev/null Thu Jan 01 00:00:00 1970 +0000
+++ b/jdk/test/javax/security/sasl/Sasl/ClientServerTest.java Thu Jul 16 09:20:39 2015 +0800
@@ -0,0 +1,477 @@
+/*
+ * Copyright (c) 2015, Oracle and/or its affiliates. All rights reserved.
+ * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
+ *
+ * This code is free software; you can redistribute it and/or modify it
+ * under the terms of the GNU General Public License version 2 only, as
+ * published by the Free Software Foundation.
+ *
+ * This code is distributed in the hope that it will be useful, but WITHOUT
+ * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
+ * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
+ * version 2 for more details (a copy is included in the LICENSE file that
+ * accompanied this code).
+ *
+ * You should have received a copy of the GNU General Public License version
+ * 2 along with this work; if not, write to the Free Software Foundation,
+ * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
+ *
+ * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
+ * or visit www.oracle.com if you need additional information or have any
+ * questions.
+ */
+
+import java.io.Closeable;
+import java.io.IOException;
+import java.io.ObjectInputStream;
+import java.io.ObjectOutputStream;
+import java.io.Serializable;
+import java.net.ServerSocket;
+import java.net.Socket;
+import java.net.UnknownHostException;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.Map;
+import java.util.StringJoiner;
+import javax.security.auth.callback.Callback;
+import javax.security.auth.callback.CallbackHandler;
+import javax.security.auth.callback.NameCallback;
+import javax.security.auth.callback.PasswordCallback;
+import javax.security.auth.callback.UnsupportedCallbackException;
+import javax.security.sasl.AuthorizeCallback;
+import javax.security.sasl.RealmCallback;
+import javax.security.sasl.RealmChoiceCallback;
+import javax.security.sasl.Sasl;
+import javax.security.sasl.SaslClient;
+import javax.security.sasl.SaslException;
+import javax.security.sasl.SaslServer;
+
+/*
+ * @test
+ * @bug 8049814
+ * @summary JAVA SASL server and client tests with CRAM-MD5 and
+ * DIGEST-MD5 mechanisms. The tests try different QOP values on
+ * client and server side.
+ * @modules java.security.sasl/javax.security.sasl
+ */
+public class ClientServerTest {
+
+ private static final int DELAY = 100;
+ private static final String LOCALHOST = "localhost";
+ private static final String DIGEST_MD5 = "DIGEST-MD5";
+ private static final String CRAM_MD5 = "CRAM-MD5";
+ private static final String PROTOCOL = "saslservice";
+ private static final String USER_ID = "sasltester";
+ private static final String PASSWD = "password";
+ private static final String QOP_AUTH = "auth";
+ private static final String QOP_AUTH_CONF = "auth-conf";
+ private static final String QOP_AUTH_INT = "auth-int";
+ private static final String AUTHID_SASL_TESTER = "sasl_tester";
+ private static final ArrayList<String> SUPPORT_MECHS = new ArrayList<>();
+
+ static {
+ SUPPORT_MECHS.add(DIGEST_MD5);
+ SUPPORT_MECHS.add(CRAM_MD5);
+ }
+
+ public static void main(String[] args) throws Exception {
+ String[] allQops = { QOP_AUTH_CONF, QOP_AUTH_INT, QOP_AUTH };
+ String[] twoQops = { QOP_AUTH_INT, QOP_AUTH };
+ String[] authQop = { QOP_AUTH };
+ String[] authIntQop = { QOP_AUTH_INT };
+ String[] authConfQop = { QOP_AUTH_CONF };
+ String[] emptyQop = {};
+
+ boolean success = true;
+
+ success &= runTest("", CRAM_MD5, new String[] { QOP_AUTH },
+ new String[] { QOP_AUTH }, false);
+ success &= runTest("", DIGEST_MD5, new String[] { QOP_AUTH },
+ new String[] { QOP_AUTH }, false);
+ success &= runTest(AUTHID_SASL_TESTER, DIGEST_MD5,
+ new String[] { QOP_AUTH }, new String[] { QOP_AUTH }, false);
+ success &= runTest("", DIGEST_MD5, allQops, authQop, false);
+ success &= runTest("", DIGEST_MD5, allQops, authIntQop, false);
+ success &= runTest("", DIGEST_MD5, allQops, authConfQop, false);
+ success &= runTest("", DIGEST_MD5, twoQops, authQop, false);
+ success &= runTest("", DIGEST_MD5, twoQops, authIntQop, false);
+ success &= runTest("", DIGEST_MD5, twoQops, authConfQop, true);
+ success &= runTest("", DIGEST_MD5, authIntQop, authQop, true);
+ success &= runTest("", DIGEST_MD5, authConfQop, authQop, true);
+ success &= runTest("", DIGEST_MD5, authConfQop, emptyQop, true);
+ success &= runTest("", DIGEST_MD5, authIntQop, emptyQop, true);
+ success &= runTest("", DIGEST_MD5, authQop, emptyQop, true);
+
+ if (!success) {
+ throw new RuntimeException("At least one test case failed");
+ }
+
+ System.out.println("Test passed");
+ }
+
+ private static boolean runTest(String authId, String mech,
+ String[] clientQops, String[] serverQops, boolean expectException)
+ throws Exception {
+
+ System.out.println("AuthId:" + authId
+ + " mechanism:" + mech
+ + " clientQops: " + Arrays.toString(clientQops)
+ + " serverQops: " + Arrays.toString(serverQops)
+ + " expect exception:" + expectException);
+
+ try (Server server = Server.start(LOCALHOST, authId, serverQops)) {
+ new Client(LOCALHOST, server.getPort(), mech, authId, clientQops)
+ .run();
+ if (expectException) {
+ System.out.println("Expected exception not thrown");
+ return false;
+ }
+ } catch (SaslException e) {
+ if (!expectException) {
+ System.out.println("Unexpected exception: " + e);
+ return false;
+ }
+ System.out.println("Expected exception: " + e);
+ }
+
+ return true;
+ }
+
+ static enum SaslStatus {
+ SUCCESS, FAILURE, CONTINUE
+ }
+
+ static class Message implements Serializable {
+
+ private final SaslStatus status;
+ private final byte[] data;
+
+ public Message(SaslStatus status, byte[] data) {
+ this.status = status;
+ this.data = data;
+ }
+
+ public SaslStatus getStatus() {
+ return status;
+ }
+
+ public byte[] getData() {
+ return data;
+ }
+ }
+
+ static class SaslPeer {
+
+ final String host;
+ final String mechanism;
+ final String qop;
+ final CallbackHandler callback;
+
+ SaslPeer(String host, String authId, String... qops) {
+ this(host, null, authId, qops);
+ }
+
+ SaslPeer(String host, String mechanism, String authId, String... qops) {
+ this.host = host;
+ this.mechanism = mechanism;
+
+ StringJoiner sj = new StringJoiner(",");
+ for (String q : qops) {
+ sj.add(q);
+ }
+ qop = sj.toString();
+
+ callback = new TestCallbackHandler(USER_ID, PASSWD, host, authId);
+ }
+
+ Message getMessage(Object ob) {
+ if (!(ob instanceof Message)) {
+ throw new RuntimeException("Expected an instance of Message");
+ }
+ return (Message) ob;
+ }
+ }
+
+ static class Server extends SaslPeer implements Runnable, Closeable {
+
+ private volatile boolean ready = false;
+ private volatile ServerSocket ssocket;
+
+ static Server start(String host, String authId, String[] serverQops)
+ throws UnknownHostException {
+ Server server = new Server(host, authId, serverQops);
+ Thread thread = new Thread(server);
+ thread.setDaemon(true);
+ thread.start();
+
+ while (!server.ready) {
+ try {
+ Thread.sleep(DELAY);
+ } catch (InterruptedException e) {
+ throw new RuntimeException(e);
+ }
+ }
+
+ return server;
+ }
+
+ Server(String host, String authId, String... qops) {
+ super(host, authId, qops);
+ }
+
+ int getPort() {
+ return ssocket.getLocalPort();
+ }
+
+ private void processConnection(SaslEndpoint endpoint)
+ throws SaslException, IOException, ClassNotFoundException {
+ System.out.println("process connection");
+ endpoint.send(SUPPORT_MECHS);
+ Object o = endpoint.receive();
+ if (!(o instanceof String)) {
+ throw new RuntimeException("Received unexpected object: " + o);
+ }
+ String mech = (String) o;
+ SaslServer saslServer = createSaslServer(mech);
+ Message msg = getMessage(endpoint.receive());
+ while (!saslServer.isComplete()) {
+ byte[] data = processData(msg.getData(), endpoint,
+ saslServer);
+ if (saslServer.isComplete()) {
+ System.out.println("server is complete");
+ endpoint.send(new Message(SaslStatus.SUCCESS, data));
+ } else {
+ System.out.println("server continues");
+ endpoint.send(new Message(SaslStatus.CONTINUE, data));
+ msg = getMessage(endpoint.receive());
+ }
+ }
+ }
+
+ private byte[] processData(byte[] data, SaslEndpoint endpoint,
+ SaslServer server) throws SaslException, IOException {
+ try {
+ return server.evaluateResponse(data);
+ } catch (SaslException e) {
+ endpoint.send(new Message(SaslStatus.FAILURE, null));
+ System.out.println("Error while processing data");
+ throw e;
+ }
+ }
+
+ private SaslServer createSaslServer(String mechanism)
+ throws SaslException {
+ Map<String, String> props = new HashMap<>();
+ props.put(Sasl.QOP, qop);
+ return Sasl.createSaslServer(mechanism, PROTOCOL, host, props,
+ callback);
+ }
+
+ @Override
+ public void run() {
+ try (ServerSocket ss = new ServerSocket(0)) {
+ ssocket = ss;
+ System.out.println("server started on port " + getPort());
+ ready = true;
+ Socket socket = ss.accept();
+ try (SaslEndpoint endpoint = new SaslEndpoint(socket)) {
+ System.out.println("server accepted connection");
+ processConnection(endpoint);
+ }
+ } catch (Exception e) {
+ // ignore it for now, client will throw an exception
+ }
+ }
+
+ @Override
+ public void close() throws IOException {
+ if (!ssocket.isClosed()) {
+ ssocket.close();
+ }
+ }
+ }
+
+ static class Client extends SaslPeer {
+
+ private final int port;
+
+ Client(String host, int port, String mech, String authId,
+ String... qops) {
+ super(host, mech, authId, qops);
+ this.port = port;
+ }
+
+ public void run() throws Exception {
+ System.out.println("Host:" + host + " port: "
+ + port);
+ try (SaslEndpoint endpoint = SaslEndpoint.create(host, port)) {
+ negotiateMechanism(endpoint);
+ SaslClient client = createSaslClient();
+ byte[] data = new byte[0];
+ if (client.hasInitialResponse()) {
+ data = client.evaluateChallenge(data);
+ }
+ endpoint.send(new Message(SaslStatus.CONTINUE, data));
+ Message msg = getMessage(endpoint.receive());
+ while (!client.isComplete()
+ && msg.getStatus() != SaslStatus.FAILURE) {
+ switch (msg.getStatus()) {
+ case CONTINUE:
+ System.out.println("client continues");
+ data = client.evaluateChallenge(msg.getData());
+ endpoint.send(new Message(SaslStatus.CONTINUE,
+ data));
+ msg = getMessage(endpoint.receive());
+ break;
+ case SUCCESS:
+ System.out.println("client succeeded");
+ data = client.evaluateChallenge(msg.getData());
+ if (data != null) {
+ throw new SaslException("data should be null");
+ }
+ break;
+ default:
+ throw new RuntimeException("Wrong status:"
+ + msg.getStatus());
+ }
+ }
+
+ if (msg.getStatus() == SaslStatus.FAILURE) {
+ throw new RuntimeException("Status is FAILURE");
+ }
+ }
+
+ System.out.println("Done");
+ }
+
+ private SaslClient createSaslClient() throws SaslException {
+ Map<String, String> props = new HashMap<>();
+ props.put(Sasl.QOP, qop);
+ return Sasl.createSaslClient(new String[] {mechanism}, USER_ID,
+ PROTOCOL, host, props, callback);
+ }
+
+ private void negotiateMechanism(SaslEndpoint endpoint)
+ throws ClassNotFoundException, IOException {
+ Object o = endpoint.receive();
+ if (o instanceof ArrayList) {
+ ArrayList list = (ArrayList) o;
+ if (!list.contains(mechanism)) {
+ throw new RuntimeException(
+ "Server does not support specified mechanism:"
+ + mechanism);
+ }
+ } else {
+ throw new RuntimeException(
+ "Expected an instance of ArrayList, but received " + o);
+ }
+
+ endpoint.send(mechanism);
+ }
+
+ }
+
+ static class SaslEndpoint implements AutoCloseable {
+
+ private final Socket socket;
+ private ObjectInputStream input;
+ private ObjectOutputStream output;
+
+ static SaslEndpoint create(String host, int port) throws IOException {
+ return new SaslEndpoint(new Socket(host, port));
+ }
+
+ SaslEndpoint(Socket socket) throws IOException {
+ this.socket = socket;
+ }
+
+ private ObjectInputStream getInput() throws IOException {
+ if (input == null && socket != null) {
+ input = new ObjectInputStream(socket.getInputStream());
+ }
+ return input;
+ }
+
+ private ObjectOutputStream getOutput() throws IOException {
+ if (output == null && socket != null) {
+ output = new ObjectOutputStream(socket.getOutputStream());
+ }
+ return output;
+ }
+
+ public Object receive() throws IOException, ClassNotFoundException {
+ return getInput().readObject();
+ }
+
+ public void send(Object obj) throws IOException {
+ getOutput().writeObject(obj);
+ getOutput().flush();
+ }
+
+ @Override
+ public void close() throws IOException {
+ if (socket != null && !socket.isClosed()) {
+ socket.close();
+ }
+ }
+
+ }
+
+ static class TestCallbackHandler implements CallbackHandler {
+
+ private final String userId;
+ private final char[] passwd;
+ private final String realm;
+ private String authId;
+
+ TestCallbackHandler(String userId, String passwd, String realm,
+ String authId) {
+ this.userId = userId;
+ this.passwd = passwd.toCharArray();
+ this.realm = realm;
+ this.authId = authId;
+ }
+
+ @Override
+ public void handle(Callback[] callbacks) throws IOException,
+ UnsupportedCallbackException {
+ for (Callback callback : callbacks) {
+ if (callback instanceof NameCallback) {
+ System.out.println("NameCallback");
+ ((NameCallback) callback).setName(userId);
+ } else if (callback instanceof PasswordCallback) {
+ System.out.println("PasswordCallback");
+ ((PasswordCallback) callback).setPassword(passwd);
+ } else if (callback instanceof RealmCallback) {
+ System.out.println("RealmCallback");
+ ((RealmCallback) callback).setText(realm);
+ } else if (callback instanceof RealmChoiceCallback) {
+ System.out.println("RealmChoiceCallback");
+ RealmChoiceCallback choice = (RealmChoiceCallback) callback;
+ if (realm == null) {
+ choice.setSelectedIndex(choice.getDefaultChoice());
+ } else {
+ String[] choices = choice.getChoices();
+ for (int j = 0; j < choices.length; j++) {
+ if (realm.equals(choices[j])) {
+ choice.setSelectedIndex(j);
+ break;
+ }
+ }
+ }
+ } else if (callback instanceof AuthorizeCallback) {
+ System.out.println("AuthorizeCallback");
+ ((AuthorizeCallback) callback).setAuthorized(true);
+ if (authId == null || authId.trim().length() == 0) {
+ authId = userId;
+ }
+ ((AuthorizeCallback) callback).setAuthorizedID(authId);
+ } else {
+ throw new UnsupportedCallbackException(callback);
+ }
+ }
+ }
+ }
+
+}