8199645: javax/net/ssl/SSLSession/TestEnabledProtocols.java failed with Connection reset
authorjjiang
Wed, 11 Jul 2018 10:39:58 +0800
changeset 51031 a40b75d39ecd
parent 51030 33be1da67b11
child 51032 43ee4f1c333b
8199645: javax/net/ssl/SSLSession/TestEnabledProtocols.java failed with Connection reset Summary: Refactor this test with SSLSocketTemplate Reviewed-by: xuelei
test/jdk/javax/net/ssl/SSLSession/TestEnabledProtocols.java
--- a/test/jdk/javax/net/ssl/SSLSession/TestEnabledProtocols.java	Tue Jul 10 19:42:48 2018 -0700
+++ b/test/jdk/javax/net/ssl/SSLSession/TestEnabledProtocols.java	Wed Jul 11 10:39:58 2018 +0800
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2001, 2014, Oracle and/or its affiliates. All rights reserved.
+ * Copyright (c) 2001, 2018, Oracle and/or its affiliates. All rights reserved.
  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
  *
  * This code is free software; you can redistribute it and/or modify it
@@ -33,313 +33,251 @@
  *                  session
  *          4701722 protocol mismatch exceptions should be consistent between
  *                  SSLv3 and TLSv1
+ * @library /javax/net/ssl/templates
  * @run main/othervm TestEnabledProtocols
  * @author Ram Marti
  */
 
-import java.io.*;
-import java.net.*;
-import java.util.*;
-import java.security.*;
-import javax.net.ssl.*;
-import java.security.cert.*;
-
-public class TestEnabledProtocols {
-
-    /*
-     * For each of the valid protocols combinations, start a server thread
-     * that sets up an SSLServerSocket supporting that protocol. Then run
-     * a client thread that attemps to open a connection with all
-     * possible protocol combinataion.  Verify that we get handshake
-     * exceptions correctly. Whenever the connection is established
-     * successfully, verify that the negotiated protocol was correct.
-     * See results file in this directory for complete results.
-     */
+import java.io.InputStream;
+import java.io.InterruptedIOException;
+import java.io.OutputStream;
+import java.security.Security;
+import java.util.Arrays;
 
-    static final String[][] protocolStrings = {
-                                {"TLSv1"},
-                                {"TLSv1", "SSLv2Hello"},
-                                {"TLSv1", "SSLv3"},
-                                {"SSLv3", "SSLv2Hello"},
-                                {"SSLv3"},
-                                {"TLSv1", "SSLv3", "SSLv2Hello"}
-                                };
+import javax.net.ssl.SSLException;
+import javax.net.ssl.SSLHandshakeException;
+import javax.net.ssl.SSLServerSocket;
+import javax.net.ssl.SSLSocket;
 
-    static final boolean [][] eXceptionArray = {
-        // Do we expect exception?       Protocols supported by the server
-        { false, true,  false, true,  true,  true }, // TLSv1
-        { false, false, false, true,  true,  false}, // TLSv1,SSLv2Hello
-        { false, true,  false, true,  false, true }, // TLSv1,SSLv3
-        { true,  true,  false, false, false, false}, // SSLv3, SSLv2Hello
-        { true,  true,  false, true,  false, true }, // SSLv3
-        { false, false, false, false, false, false } // TLSv1,SSLv3,SSLv2Hello
-        };
-
-    static final String[][] protocolSelected = {
-        // TLSv1
-        { "TLSv1",  null,   "TLSv1",  null,   null,     null },
-
-        // TLSv1,SSLv2Hello
-        { "TLSv1", "TLSv1", "TLSv1",  null,   null,    "TLSv1"},
-
-        // TLSv1,SSLv3
-        { "TLSv1",  null,   "TLSv1",  null,   "SSLv3",  null },
-
-        // SSLv3, SSLv2Hello
-        {  null,    null,   "SSLv3", "SSLv3", "SSLv3",  "SSLv3"},
+public class TestEnabledProtocols extends SSLSocketTemplate {
 
-        // SSLv3
-        {  null,    null,   "SSLv3",  null,   "SSLv3",  null },
-
-        // TLSv1,SSLv3,SSLv2Hello
-        { "TLSv1", "TLSv1", "TLSv1", "SSLv3", "SSLv3", "TLSv1" }
-
-    };
-
-    /*
-     * Where do we find the keystores?
-     */
-    final static String pathToStores = "../etc";
-    static String passwd = "passphrase";
-    static String keyStoreFile = "keystore";
-    static String trustStoreFile = "truststore";
-
-    /*
-     * Is the server ready to serve?
-     */
-    volatile static boolean serverReady = false;
-
-    /*
-     * Turn on SSL debugging?
-     */
-    final static boolean debug = false;
+    private final String[] serverProtocols;
+    private final String[] clientProtocols;
+    private final boolean exceptionExpected;
+    private final String selectedProtocol;
 
-    // use any free port by default
-    volatile int serverPort = 0;
-
-    volatile Exception clientException = null;
-
-    public static void main(String[] args) throws Exception {
-        // reset the security property to make sure that the algorithms
-        // and keys used in this test are not disabled.
-        Security.setProperty("jdk.tls.disabledAlgorithms", "");
+    public TestEnabledProtocols(String[] serverProtocols,
+            String[] clientProtocols, boolean exceptionExpected,
+            String selectedProtocol) {
+        this.serverProtocols = serverProtocols;
+        this.clientProtocols = clientProtocols;
+        this.exceptionExpected = exceptionExpected;
+        this.selectedProtocol = selectedProtocol;
+    }
 
-        String keyFilename =
-            System.getProperty("test.src", "./") + "/" + pathToStores +
-                "/" + keyStoreFile;
-        String trustFilename =
-            System.getProperty("test.src", "./") + "/" + pathToStores +
-                "/" + trustStoreFile;
-
-        System.setProperty("javax.net.ssl.keyStore", keyFilename);
-        System.setProperty("javax.net.ssl.keyStorePassword", passwd);
-        System.setProperty("javax.net.ssl.trustStore", trustFilename);
-        System.setProperty("javax.net.ssl.trustStorePassword", passwd);
-
-        if (debug)
-            System.setProperty("javax.net.debug", "all");
-
-        new TestEnabledProtocols();
+    @Override
+    protected void configureServerSocket(SSLServerSocket sslServerSocket) {
+        sslServerSocket.setEnabledProtocols(serverProtocols);
     }
 
-    TestEnabledProtocols() throws Exception  {
-        /*
-         * Start the tests.
-         */
-        SSLServerSocketFactory sslssf =
-            (SSLServerSocketFactory) SSLServerSocketFactory.getDefault();
-        SSLServerSocket sslServerSocket =
-            (SSLServerSocket) sslssf.createServerSocket(serverPort);
-        serverPort = sslServerSocket.getLocalPort();
-        // sslServerSocket.setNeedClientAuth(true);
+    @Override
+    protected void runServerApplication(SSLSocket socket) throws Exception {
+        try {
+            socket.startHandshake();
+
+            InputStream in = socket.getInputStream();
+            OutputStream out = socket.getOutputStream();
+            out.write(280);
+            in.read();
+        } catch (SSLHandshakeException se) {
+            // ignore it; this is part of the testing
+            // log it for debugging
+            System.out.println("Server SSLHandshakeException:");
+            se.printStackTrace(System.out);
+        } catch (InterruptedIOException ioe) {
+            // must have been interrupted, no harm
+        } catch (SSLException ssle) {
+            // The client side may have closed the socket.
+            System.out.println("Server SSLException:");
+            ssle.printStackTrace(System.out);
+        } catch (Exception e) {
+            System.out.println("Server exception:");
+            e.printStackTrace(System.out);
+            throw new RuntimeException(e);
+        }
+    }
+
+    @Override
+    protected void runClientApplication(SSLSocket sslSocket) throws Exception {
+        try {
+            System.out.println("=== Starting new test run ===");
+            showProtocols("server", serverProtocols);
+            showProtocols("client", clientProtocols);
 
-        for (int i = 0; i < protocolStrings.length; i++) {
-            String [] serverProtocols = protocolStrings[i];
-            startServer ss = new startServer(serverProtocols,
-                sslServerSocket, protocolStrings.length);
-            ss.setDaemon(true);
-            ss.start();
-            for (int j = 0; j < protocolStrings.length; j++) {
-                String [] clientProtocols = protocolStrings[j];
-                startClient sc = new startClient(
-                    clientProtocols, serverProtocols,
-                    eXceptionArray[i][j], protocolSelected[i][j]);
-                sc.start();
-                sc.join();
-                if (clientException != null) {
-                    ss.requestStop();
-                    throw clientException;
-                }
+            sslSocket.setEnabledProtocols(clientProtocols);
+            sslSocket.startHandshake();
+
+            String protocolName = sslSocket.getSession().getProtocol();
+            System.out.println("Protocol name after getSession is " +
+                protocolName);
+
+            if (protocolName.equals(selectedProtocol)) {
+                System.out.println("** Success **");
+            } else {
+                System.out.println("** FAILURE ** ");
+                throw new RuntimeException
+                    ("expected protocol " + selectedProtocol +
+                     " but using " + protocolName);
             }
-            ss.requestStop();
-            System.out.println("Waiting for the server to complete");
-            ss.join();
+
+            InputStream in = sslSocket.getInputStream();
+            OutputStream out = sslSocket.getOutputStream();
+            in.read();
+            out.write(280);
+        } catch (SSLHandshakeException e) {
+            if (!exceptionExpected) {
+                System.out.println(
+                        "Client got UNEXPECTED SSLHandshakeException:");
+                e.printStackTrace(System.out);
+                System.out.println("** FAILURE **");
+                throw new RuntimeException(e);
+            } else {
+                System.out.println(
+                        "Client got expected SSLHandshakeException:");
+                e.printStackTrace(System.out);
+                System.out.println("** Success **");
+            }
+        } catch (Exception e) {
+            System.out.println("Client got UNEXPECTED Exception:");
+            e.printStackTrace(System.out);
+            System.out.println("** FAILURE **");
+            throw new RuntimeException(e);
         }
     }
 
-    class startServer extends Thread  {
-        private String[] enabledP = null;
-        SSLServerSocket sslServerSocket = null;
-        int numExpConns;
-        volatile boolean stopRequested = false;
+    public static void main(String[] args) throws Exception {
+        Security.setProperty("jdk.tls.disabledAlgorithms", "");
 
-        public startServer(String[] enabledProtocols,
-                            SSLServerSocket sslServerSocket,
-                            int numExpConns) {
-            super("Server Thread");
-            serverReady = false;
-            enabledP = enabledProtocols;
-            this.sslServerSocket = sslServerSocket;
-            sslServerSocket.setEnabledProtocols(enabledP);
-            this.numExpConns = numExpConns;
-        }
+        runCase(new String[] { "TLSv1" },
+                new String[] { "TLSv1" },
+                false, "TLSv1");
+        runCase(new String[] { "TLSv1" },
+                new String[] { "TLSv1", "SSLv2Hello" },
+                true, null);
+        runCase(new String[] { "TLSv1" },
+                new String[] { "TLSv1", "SSLv3" },
+                false, "TLSv1");
+        runCase(new String[] { "TLSv1" },
+                new String[] { "SSLv3", "SSLv2Hello" },
+                true, null);
+        runCase(new String[] { "TLSv1" },
+                new String[] { "SSLv3" },
+                true, null);
+        runCase(new String[] { "TLSv1" },
+                new String[] { "TLSv1", "SSLv3", "SSLv2Hello" },
+                true, null);
 
-        public void requestStop() {
-            stopRequested = true;
-        }
+        runCase(new String[] { "TLSv1", "SSLv2Hello" },
+                new String[] { "TLSv1" },
+                false, "TLSv1");
+        runCase(new String[] { "TLSv1", "SSLv2Hello" },
+                new String[] { "TLSv1", "SSLv2Hello" },
+                false, "TLSv1");
+        runCase(new String[] { "TLSv1", "SSLv2Hello" },
+                new String[] { "TLSv1", "SSLv3" },
+                false, "TLSv1");
+        runCase(new String[] { "TLSv1", "SSLv2Hello" },
+                new String[] { "SSLv3", "SSLv2Hello" },
+                true, null);
+        runCase(new String[] { "TLSv1", "SSLv2Hello" },
+                new String[] { "SSLv3" },
+                true, null);
+        runCase(new String[] { "TLSv1", "SSLv2Hello" },
+                new String[] { "TLSv1", "SSLv3", "SSLv2Hello" },
+                false, "TLSv1");
 
-        public void run() {
-            int conns = 0;
-            while (!stopRequested) {
-                SSLSocket socket = null;
-                try {
-                    serverReady = true;
-                    socket = (SSLSocket)sslServerSocket.accept();
-                    conns++;
-
-                    // set ready to false. this is just to make the
-                    // client wait and synchronise exception messages
-                    serverReady = false;
-                    socket.startHandshake();
-                    SSLSession session = socket.getSession();
-                    session.invalidate();
+        runCase(new String[] { "TLSv1", "SSLv3" },
+                new String[] { "TLSv1" },
+                false, "TLSv1");
+        runCase(new String[] { "TLSv1", "SSLv3" },
+                new String[] { "TLSv1", "SSLv2Hello" },
+                true, null);
+        runCase(new String[] { "TLSv1", "SSLv3" },
+                new String[] { "TLSv1", "SSLv3" },
+                false, "TLSv1");
+        runCase(new String[] { "TLSv1", "SSLv3" },
+                new String[] { "SSLv3", "SSLv2Hello" },
+                true, null);
+        runCase(new String[] { "TLSv1", "SSLv3" },
+                new String[] { "SSLv3" },
+                false, "SSLv3");
+        runCase(new String[] { "TLSv1", "SSLv3" },
+                new String[] { "TLSv1", "SSLv3", "SSLv2Hello" },
+                true, null);
 
-                    InputStream in = socket.getInputStream();
-                    OutputStream out = socket.getOutputStream();
-                    out.write(280);
-                    in.read();
+        runCase(new String[] { "SSLv3", "SSLv2Hello" },
+                new String[] { "TLSv1" },
+                true, null);
+        runCase(new String[] { "SSLv3", "SSLv2Hello" },
+                new String[] { "TLSv1", "SSLv2Hello" },
+                true, null);
+        runCase(new String[] { "SSLv3", "SSLv2Hello" },
+                new String[] { "TLSv1", "SSLv3" },
+                false, "SSLv3");
+        runCase(new String[] { "SSLv3", "SSLv2Hello" },
+                new String[] { "SSLv3", "SSLv2Hello" },
+                false, "SSLv3");
+        runCase(new String[] { "SSLv3", "SSLv2Hello" },
+                new String[] { "SSLv3" },
+                false, "SSLv3");
+        runCase(new String[] { "SSLv3", "SSLv2Hello" },
+                new String[] { "TLSv1", "SSLv3", "SSLv2Hello" },
+                false, "SSLv3");
 
-                    socket.close();
-                    // sleep for a while so that the server thread can be
-                    // stopped
-                    Thread.sleep(30);
-                } catch (SSLHandshakeException se) {
-                    // ignore it; this is part of the testing
-                    // log it for debugging
-                    System.out.println("Server SSLHandshakeException:");
-                    se.printStackTrace(System.out);
-                } catch (java.io.InterruptedIOException ioe) {
-                    // must have been interrupted, no harm
-                    break;
-                } catch (java.lang.InterruptedException ie) {
-                    // must have been interrupted, no harm
-                    break;
-                } catch (SSLException ssle) {
-                    // The client side may have closed the socket.
-                    System.out.println("Server SSLException:");
-                    ssle.printStackTrace(System.out);
-                } catch (Exception e) {
-                    System.out.println("Server exception:");
-                    e.printStackTrace(System.out);
-                    throw new RuntimeException(e);
-                } finally {
-                    try {
-                        if (socket != null) {
-                            socket.close();
-                        }
-                    } catch (IOException e) {
-                        // ignore
-                    }
-                }
-                if (conns >= numExpConns) {
-                    break;
-                }
-            }
-        }
+        runCase(new String[] { "SSLv3" },
+                new String[] { "TLSv1" },
+                true, null);
+        runCase(new String[] { "SSLv3" },
+                new String[] { "TLSv1", "SSLv2Hello" },
+                true, null);
+        runCase(new String[] { "SSLv3" },
+                new String[] { "TLSv1", "SSLv3" },
+                false, "SSLv3");
+        runCase(new String[] { "SSLv3" },
+                new String[] { "SSLv3", "SSLv2Hello" },
+                true, null);
+        runCase(new String[] { "SSLv3" },
+                new String[] { "SSLv3" },
+                false, "SSLv3");
+        runCase(new String[] { "SSLv3" },
+                new String[] { "TLSv1", "SSLv3", "SSLv2Hello" },
+                true, null);
+
+        runCase(new String[] { "TLSv1", "SSLv3", "SSLv2Hello" },
+                new String[] { "TLSv1" },
+                false, "TLSv1");
+        runCase(new String[] { "TLSv1", "SSLv3", "SSLv2Hello" },
+                new String[] { "TLSv1", "SSLv2Hello" },
+                false, "TLSv1");
+        runCase(new String[] { "TLSv1", "SSLv3", "SSLv2Hello" },
+                new String[] { "TLSv1", "SSLv3" },
+                false, "TLSv1");
+        runCase(new String[] { "TLSv1", "SSLv3", "SSLv2Hello" },
+                new String[] { "SSLv3", "SSLv2Hello" },
+                false, "SSLv3");
+        runCase(new String[] { "TLSv1", "SSLv3", "SSLv2Hello" },
+                new String[] { "SSLv3" },
+                false, "SSLv3");
+        runCase(new String[] { "TLSv1", "SSLv3", "SSLv2Hello" },
+                new String[] { "TLSv1", "SSLv3", "SSLv2Hello" },
+                false, "TLSv1");
+    }
+
+    private static void runCase(
+            String[] serverProtocols,
+            String[] clientProtocols,
+            boolean exceptionExpected,
+            String selectedProtocol) throws Exception {
+        new TestEnabledProtocols(
+                serverProtocols,
+                clientProtocols,
+                exceptionExpected,
+                selectedProtocol).run();
     }
 
     private static void showProtocols(String name, String[] protocols) {
-        System.out.println("Enabled protocols on the " + name + " are: " + Arrays.asList(protocols));
+        System.out.printf("Enabled protocols on the %s are: %s%n",
+                name,
+                Arrays.asList(protocols));
     }
-
-    class startClient extends Thread {
-        boolean hsCompleted = false;
-        boolean exceptionExpected = false;
-        private String[] enabledP = null;
-        private String[] serverP = null; // used to print the result
-        private String protocolToUse = null;
-
-        startClient(String[] enabledProtocol,
-                    String[] serverP,
-                    boolean eXception,
-                    String protocol) throws Exception {
-            super("Client Thread");
-            this.enabledP = enabledProtocol;
-            this.serverP = serverP;
-            this.exceptionExpected = eXception;
-            this.protocolToUse = protocol;
-        }
-
-        public void run() {
-            SSLSocket sslSocket = null;
-            try {
-                while (!serverReady) {
-                    Thread.sleep(50);
-                }
-                System.out.flush();
-                System.out.println("=== Starting new test run ===");
-                showProtocols("server", serverP);
-                showProtocols("client", enabledP);
-
-                SSLSocketFactory sslsf =
-                    (SSLSocketFactory)SSLSocketFactory.getDefault();
-                sslSocket = (SSLSocket)
-                    sslsf.createSocket("localhost", serverPort);
-                sslSocket.setEnabledProtocols(enabledP);
-                sslSocket.startHandshake();
-
-                SSLSession session = sslSocket.getSession();
-                session.invalidate();
-                String protocolName = session.getProtocol();
-                System.out.println("Protocol name after getSession is " +
-                    protocolName);
-
-                if (protocolName.equals(protocolToUse)) {
-                    System.out.println("** Success **");
-                } else {
-                    System.out.println("** FAILURE ** ");
-                    throw new RuntimeException
-                        ("expected protocol " + protocolToUse +
-                         " but using " + protocolName);
-                }
-
-                InputStream in = sslSocket.getInputStream();
-                OutputStream out = sslSocket.getOutputStream();
-                in.read();
-                out.write(280);
-
-                sslSocket.close();
-
-            } catch (SSLHandshakeException e) {
-                if (!exceptionExpected) {
-                    System.out.println("Client got UNEXPECTED SSLHandshakeException:");
-                    e.printStackTrace(System.out);
-                    System.out.println("** FAILURE **");
-                    clientException = e;
-                } else {
-                    System.out.println("Client got expected SSLHandshakeException:");
-                    e.printStackTrace(System.out);
-                    System.out.println("** Success **");
-                }
-            } catch (RuntimeException e) {
-                clientException = e;
-            } catch (Exception e) {
-                System.out.println("Client got UNEXPECTED Exception:");
-                e.printStackTrace(System.out);
-                System.out.println("** FAILURE **");
-                clientException = e;
-            }
-        }
-    }
-
 }