start/beginHandshake and more post-handshake changes JDK-8145252-TLS13-branch
authorascarpino
Sun, 13 May 2018 08:52:25 -0700
branchJDK-8145252-TLS13-branch
changeset 56544 ad120e0dfcfb
parent 56543 2352538d2f6e
child 56546 f8a11b589cc5
start/beginHandshake and more post-handshake changes
src/java.base/share/classes/sun/security/ssl/HandshakeContext.java
src/java.base/share/classes/sun/security/ssl/KeyUpdate.java
src/java.base/share/classes/sun/security/ssl/PostHandshakeContext.java
src/java.base/share/classes/sun/security/ssl/SSLEngineImpl.java
src/java.base/share/classes/sun/security/ssl/TransportContext.java
test/jdk/javax/net/ssl/SSLSession/RenegotiateTLS13.java
test/jdk/sun/security/ssl/SSLEngineImpl/TLS13BeginHandshake.java
--- a/src/java.base/share/classes/sun/security/ssl/HandshakeContext.java	Fri May 11 16:07:27 2018 -0700
+++ b/src/java.base/share/classes/sun/security/ssl/HandshakeContext.java	Sun May 13 08:52:25 2018 -0700
@@ -70,7 +70,7 @@
                     "sun.security.ssl.allowLegacyHelloMessages", true);
 
     // registered handshake message actors
-    final LinkedHashMap<Byte, SSLConsumer>  handshakeConsumers;
+    LinkedHashMap<Byte, SSLConsumer>  handshakeConsumers;
     final HashMap<Byte, HandshakeProducer>  handshakeProducers;
 
     // context
@@ -115,7 +115,6 @@
     SecretKey                               baseWriteSecret;
 
     // protocol version being established
-    ProtocolVersion                         protocolVersion;
     int                                     clientHelloVersion;
     String                                  applicationProtocol;
 
@@ -207,6 +206,30 @@
         initialize();
     }
 
+    /**
+     * Constructor for PostHandshakeContext
+     */
+    HandshakeContext(TransportContext conContext) {
+        this.sslContext = conContext.sslContext;
+        this.conContext = conContext;
+        this.sslConfig = conContext.sslConfig;
+
+        this.negotiatedProtocol = conContext.protocolVersion;
+        this.negotiatedCipherSuite = conContext.cipherSuite;
+        this.handshakeOutput = new HandshakeOutStream(conContext.outputRecord);
+        this.delegatedActions = new LinkedList<>();
+
+        this.handshakeProducers = null;
+        this.handshakeHash = null;
+        this.activeProtocols = null;
+        this.activeCipherSuites = null;
+        this.algorithmConstraints = null;
+        this.maximumActiveProtocol = null;
+        this.handshakeExtensions = null;
+        this.handshakePossessions = null;
+        this.handshakeCredentials = null;
+    }
+
     // Initialize the non-final class variables.
     private void initialize() throws IOException {
         ProtocolVersion inputHelloVersion;
@@ -331,9 +354,10 @@
         return Collections.unmodifiableList(suites);
     }
 
-    void dispatch(Plaintext plaintext) throws IOException {
-        // parse the handshake record.
-        //
+    /**
+     * Parse the handshake record and return the contentType
+     */
+    static byte getHandshakeType(TransportContext conContext, Plaintext plaintext) throws IOException {
         //     struct {
         //         HandshakeType msg_type;    /* handshake type */
         //         uint24 length;             /* bytes in message */
@@ -341,18 +365,19 @@
         //             ...
         //         } body;
         //     } Handshake;
+
         if (plaintext.contentType != ContentType.HANDSHAKE.id) {
             conContext.fatal(Alert.INTERNAL_ERROR,
                 "Unexpected operation for record: " + plaintext.contentType);
 
-            return;     // make the compiler happy
+            return 0;
         }
 
         if (plaintext.fragment == null || plaintext.fragment.remaining() < 4) {
             conContext.fatal(Alert.UNEXPECTED_MESSAGE,
                     "Invalid handshake message: insufficient data");
 
-            return;     // make the compiler happy
+            return 0;
         }
 
         byte handshakeType = (byte)Record.getInt8(plaintext.fragment);
@@ -361,9 +386,14 @@
             conContext.fatal(Alert.UNEXPECTED_MESSAGE,
                     "Invalid handshake message: insufficient handshake body");
 
-            return;     // make the compiler happy
+            return 0;
         }
 
+        return handshakeType;
+    }
+
+    void dispatch(byte handshakeType, Plaintext plaintext) throws IOException {
+
         if (conContext.transport.useDelegatedTask()) {
             boolean hasDelegated = !delegatedActions.isEmpty();
             if (hasDelegated || handshakeType != SSLHandshake.FINISHED.id) {
@@ -406,8 +436,6 @@
         } else if (handshakeType == SSLHandshake.NEW_SESSION_TICKET.id) {
             // new session ticket may be sent any time after server finished
             consumer = SSLHandshake.NEW_SESSION_TICKET;
-        } else if (handshakeType == SSLHandshake.KEY_UPDATE.id) {
-            consumer = SSLHandshake.KEY_UPDATE;
         } else {
             consumer = handshakeConsumers.get(handshakeType);
         }
@@ -416,7 +444,7 @@
             conContext.fatal(Alert.UNEXPECTED_MESSAGE,
                     "Unexpected handshake message: " +
                     SSLHandshake.nameOf(handshakeType));
-            return;     // make the compiler happy
+            return;
         }
 
         try {
@@ -490,7 +518,6 @@
      * and ServerHandshaker with the negotiated protocol version.
      */
     void setVersion(ProtocolVersion protocolVersion) {
-        this.protocolVersion = protocolVersion;
         this.conContext.protocolVersion = protocolVersion;
     }
 
--- a/src/java.base/share/classes/sun/security/ssl/KeyUpdate.java	Fri May 11 16:07:27 2018 -0700
+++ b/src/java.base/share/classes/sun/security/ssl/KeyUpdate.java	Sun May 13 08:52:25 2018 -0700
@@ -66,16 +66,15 @@
         static final byte NOTREQUSTED = 0;
         static final byte REQUSTED = 1;
         private byte status;
-        HandshakeOutStream hos;
 
-        KeyUpdateMessage(HandshakeContext context, byte status) {
+        KeyUpdateMessage(PostHandshakeContext context, byte status) {
             super(context);
             this.status = status;
             if (status > 1) {
                 new IOException("KeyUpdate message value invalid: " + status);
             }
         }
-        KeyUpdateMessage(HandshakeContext context, ByteBuffer m)
+        KeyUpdateMessage(PostHandshakeContext context, ByteBuffer m)
                 throws IOException{
             super(context);
 
@@ -90,14 +89,6 @@
             }
         }
 
-        KeyUpdateMessage(PostHandshakeContext context, byte status) {
-            super(context);
-            this.status = status;
-            if (status > 1) {
-                new IOException("KeyUpdate message value invalid: " + status);
-            }
-        }
-
         @Override
         public SSLHandshake handshakeType() {
             return SSLHandshake.KEY_UPDATE;
@@ -134,7 +125,7 @@
         // Produce kickstart handshake message.
         @Override
         public byte[] produce(ConnectionContext context) throws IOException {
-            HandshakeContext hc = (HandshakeContext)context;
+            PostHandshakeContext hc = (PostHandshakeContext)context;
             handshakeProducer.produce(hc,
                     new KeyUpdateMessage(hc, KeyUpdateMessage.REQUSTED));
             return null;
@@ -154,7 +145,7 @@
         public void consume(ConnectionContext context,
                 ByteBuffer message) throws IOException {
             // The consuming happens in client side only.
-            HandshakeContext hc = (HandshakeContext)context;
+            PostHandshakeContext hc = (PostHandshakeContext)context;
             KeyUpdateMessage km = new KeyUpdateMessage(hc, message);
 
             if (km.getStatus() == KeyUpdateMessage.NOTREQUSTED) {
@@ -221,9 +212,8 @@
         public byte[] produce(ConnectionContext context,
                 HandshakeMessage message) throws IOException {
             // The producing happens in server side only.
-            HandshakeContext hc = (HandshakeContext)context;
+            PostHandshakeContext hc = (PostHandshakeContext)context;
             KeyUpdateMessage km = (KeyUpdateMessage)message;
-
             SecretKey secret;
 
             if (km.getStatus() == KeyUpdateMessage.REQUSTED) {
@@ -278,6 +268,7 @@
             if (SSLLogger.isOn && SSLLogger.isOn("ssl")) {
                 SSLLogger.fine("KeyUpdate: write key updated");
             }
+            hc.free();
             return null;
         }
     }
--- a/src/java.base/share/classes/sun/security/ssl/PostHandshakeContext.java	Fri May 11 16:07:27 2018 -0700
+++ b/src/java.base/share/classes/sun/security/ssl/PostHandshakeContext.java	Sun May 13 08:52:25 2018 -0700
@@ -26,18 +26,64 @@
 package sun.security.ssl;
 
 import java.io.IOException;
+import java.nio.ByteBuffer;
+import java.util.LinkedHashMap;
 
 /**
- * A clean implementation of HandshakeContext for post-handshake messages
+ * A compact implementation of HandshakeContext for post-handshake messages
  */
 
 public class PostHandshakeContext extends HandshakeContext {
 
+    final static LinkedHashMap<Byte, SSLConsumer> consumers;
+    static {
+        consumers = new LinkedHashMap<>() {{
+            put(SSLHandshake.KEY_UPDATE.id,
+                    SSLHandshake.KEY_UPDATE);
+        }};
+    }
+
+
     PostHandshakeContext(TransportContext context) throws IOException {
-        super(context.sslContext, context);
+        super(context);
+
+        if (!negotiatedProtocol.useTLS13PlusSpec()) {
+            conContext.fatal(Alert.UNEXPECTED_MESSAGE, "Post-handshake not " +
+                    "supported in " + negotiatedProtocol.name);
+        }
+
+        handshakeConsumers = consumers;
+        handshakeFinished = true;
     }
 
     @Override
     void kickstart() throws IOException {
+        SSLHandshake.kickstart(this);
+    }
+
+    @Override
+    void dispatch(byte handshakeType, ByteBuffer fragment) throws IOException {
+
+        SSLConsumer consumer = handshakeConsumers.get(handshakeType);
+        if (consumer == null) {
+            conContext.fatal(Alert.UNEXPECTED_MESSAGE,
+                    "Unexpected post-handshake message: " +
+                            SSLHandshake.nameOf(handshakeType));
+            return;
+        }
+
+        try {
+            consumer.consume(this, fragment);
+        } catch (UnsupportedOperationException unsoe) {
+            conContext.fatal(Alert.UNEXPECTED_MESSAGE,
+                    "Unsupported post-handshake message: " +
+                            SSLHandshake.nameOf(handshakeType), unsoe);
+        }
+    }
+
+    void free() {
+        if (delegatedActions.isEmpty()) {
+            conContext.handshakeContext = null;
+        }
     }
 }
--- a/src/java.base/share/classes/sun/security/ssl/SSLEngineImpl.java	Fri May 11 16:07:27 2018 -0700
+++ b/src/java.base/share/classes/sun/security/ssl/SSLEngineImpl.java	Sun May 13 08:52:25 2018 -0700
@@ -287,7 +287,7 @@
         // Is the handshake completed?
         boolean needRetransmission =
                 conContext.sslContext.isDTLS() &&
-                conContext.handshakeContext != null &&
+                conContext.getHandshakeContext(TransportContext.PRE) != null &&
                 conContext.handshakeContext.sslConfig.enableRetransmissions;
         HandshakeStatus hsStatus =
                 tryToFinishHandshake(ciphertext.contentType);
@@ -331,6 +331,8 @@
                 conContext.outputRecord.isEmpty()) {
             if (conContext.handshakeContext == null) {
                 hsStatus = HandshakeStatus.FINISHED;
+            } else if (conContext.getHandshakeContext(TransportContext.POST) != null) {
+                return null;
             } else if (conContext.handshakeContext.handshakeFinished) {
                 hsStatus = conContext.finishHandshake();
             }
@@ -684,7 +686,7 @@
 
     @Override
     public synchronized Runnable getDelegatedTask() {
-        if (conContext.handshakeContext != null &&
+        if (conContext.handshakeContext != null && // PRE or POST handshake
                 !conContext.handshakeContext.taskDelegated &&
                 !conContext.handshakeContext.delegatedActions.isEmpty()) {
             conContext.handshakeContext.taskDelegated = true;
--- a/src/java.base/share/classes/sun/security/ssl/TransportContext.java	Fri May 11 16:07:27 2018 -0700
+++ b/src/java.base/share/classes/sun/security/ssl/TransportContext.java	Sun May 13 08:52:25 2018 -0700
@@ -175,18 +175,24 @@
         if (ct == null) {
             fatal(Alert.UNEXPECTED_MESSAGE,
                 "Unknown content type: " + plaintext.contentType);
-            return;     // make compiler happy
+            return;
         }
 
         switch (ct) {
             case HANDSHAKE:
+                byte type = HandshakeContext.getHandshakeType(this,
+                        plaintext);
                 if (handshakeContext == null) {
-                    handshakeContext = sslConfig.isClientMode ?
-                            new ClientHandshakeContext(sslContext, this) :
-                            new ServerHandshakeContext(sslContext, this);
-                    outputRecord.initHandshaker();
+                    if (type == SSLHandshake.KEY_UPDATE.id) {
+                        handshakeContext = new PostHandshakeContext(this);
+                    } else {
+                        handshakeContext = sslConfig.isClientMode ?
+                                new ClientHandshakeContext(sslContext, this) :
+                                new ServerHandshakeContext(sslContext, this);
+                        outputRecord.initHandshaker();
+                    }
                 }
-                handshakeContext.dispatch(plaintext);
+                handshakeContext.dispatch(type, plaintext);
                 break;
             case ALERT:
                 Alert.alertConsumer.consume(this, plaintext.fragment);
@@ -209,29 +215,54 @@
 
         // initialize the handshaker if necessary
         if (handshakeContext == null) {
-            handshakeContext = sslConfig.isClientMode ?
-                    new ClientHandshakeContext(sslContext, this) :
-                    new ServerHandshakeContext(sslContext, this);
-            outputRecord.initHandshaker();
+            //  TLS1.3 post-handshake
+            if (isNegotiated && protocolVersion.useTLS13PlusSpec()) {
+                handshakeContext = new PostHandshakeContext(this);
+            } else {
+                handshakeContext = sslConfig.isClientMode ?
+                        new ClientHandshakeContext(sslContext, this) :
+                        new ServerHandshakeContext(sslContext, this);
+                outputRecord.initHandshaker();
+            }
         }
 
         // kickstart the handshake if needed
         //
         // Need no kickstart message on server side unless the connection
-        // has been estabilished.
+        // has been established.
         if(isNegotiated || sslConfig.isClientMode) {
            handshakeContext.kickstart();
         }
     }
 
     void keyUpdate() throws IOException {
-        // TODO: TLS 1.3
         kickstart();
     }
 
-    // Note: close_notify is delivered as awarning alert.
+    final static byte PRE = 1;
+    final static byte POST = 2;
+
+    HandshakeContext getHandshakeContext(byte type) {
+        if (handshakeContext == null) {
+            return null;
+        }
+
+        if (type == PRE &&
+                (handshakeContext instanceof ClientHandshakeContext ||
+                        handshakeContext instanceof ServerHandshakeContext)) {
+            return handshakeContext;
+        }
+
+        if (type == POST && handshakeContext instanceof PostHandshakeContext) {
+            return handshakeContext;
+        }
+
+        return null;
+    }
+
+    // Note: close_notify is delivered as a warning alert.
     void warning(Alert alert) {
-        // For initial handshaking, don't send awarning alert message to peer
+        // For initial handshaking, don't send a warning alert message to peer
         // if handshaker has not started.
         if (isNegotiated || handshakeContext != null) {
             try {
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/test/jdk/javax/net/ssl/SSLSession/RenegotiateTLS13.java	Sun May 13 08:52:25 2018 -0700
@@ -0,0 +1,292 @@
+/*
+ * Copyright (c) 2018, Oracle and/or its affiliates. All rights reserved.
+ * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
+ *
+ * This code is free software; you can redistribute it and/or modify it
+ * under the terms of the GNU General Public License version 2 only, as
+ * published by the Free Software Foundation.
+ *
+ * This code is distributed in the hope that it will be useful, but WITHOUT
+ * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
+ * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
+ * version 2 for more details (a copy is included in the LICENSE file that
+ * accompanied this code).
+ *
+ * You should have received a copy of the GNU General Public License version
+ * 2 along with this work; if not, write to the Free Software Foundation,
+ * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
+ *
+ * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
+ * or visit www.oracle.com if you need additional information or have any
+ * questions.
+ */
+
+/*
+ * @test
+ * @run main/othervm -Djavax.net.debug=ssl RenegotiateTLS13
+ */
+
+import javax.net.ssl.KeyManagerFactory;
+import javax.net.ssl.SSLContext;
+import javax.net.ssl.SSLServerSocket;
+import javax.net.ssl.SSLServerSocketFactory;
+import javax.net.ssl.SSLSocket;
+import javax.net.ssl.SSLSocketFactory;
+import javax.net.ssl.TrustManagerFactory;
+import java.io.DataInputStream;
+import java.io.DataOutputStream;
+import java.io.File;
+import java.io.IOException;
+import java.security.KeyStore;
+import java.security.SecureRandom;
+
+public class RenegotiateTLS13 {
+
+    static final String dataString = "This is a test";
+
+    // Run the server as a thread instead of the client
+    static boolean separateServerThread = false;
+
+    static String pathToStores = "../etc";
+    static String keyStoreFile = "keystore";
+    static String trustStoreFile = "truststore";
+    static String passwd = "passphrase";
+
+    // Server ready flag
+    volatile static boolean serverReady = false;
+    // Turn on SSL debugging
+    static boolean debug = false;
+    // Server done flag
+    static boolean done = false;
+
+    // Main server code
+
+    void doServerSide() throws Exception {
+        SSLServerSocketFactory sslssf;
+            sslssf = initContext().getServerSocketFactory();
+        SSLServerSocket sslServerSocket =
+            (SSLServerSocket) sslssf.createServerSocket(serverPort);
+        serverPort = sslServerSocket.getLocalPort();
+
+        serverReady = true;
+
+        SSLSocket sslSocket = (SSLSocket) sslServerSocket.accept();
+
+        DataInputStream sslIS =
+            new DataInputStream(sslSocket.getInputStream());
+        String s = "";
+        while (s.compareTo("done") != 0) {
+            try {
+                s = sslIS.readUTF();
+                System.out.println("Received: " + s);
+            } catch (IOException e) {
+                throw e;
+            }
+        }
+        done = true;
+        sslSocket.close();
+    }
+
+    // Main client code
+    void doClientSide() throws Exception {
+
+        while (!serverReady) {
+            Thread.sleep(5);
+        }
+
+        SSLSocketFactory sslsf;
+        sslsf = initContext().getSocketFactory();
+
+        SSLSocket sslSocket = (SSLSocket)
+            sslsf.createSocket("localhost", serverPort);
+
+        DataOutputStream sslOS =
+            new DataOutputStream(sslSocket.getOutputStream());
+
+        sslOS.writeUTF("With " + dataString);
+        sslOS.writeUTF("With " + dataString);
+        sslOS.writeUTF("With " + dataString);
+
+        sslSocket.startHandshake();
+
+        sslOS.writeUTF("With " + dataString);
+        sslOS.writeUTF("With " + dataString);
+        sslOS.writeUTF("With " + dataString);
+
+        sslSocket.startHandshake();
+
+        sslOS.writeUTF("With " + dataString);
+        sslOS.writeUTF("With " + dataString);
+        sslOS.writeUTF("With " + dataString);
+        sslOS.writeUTF("done");
+
+        while (!done) {
+            Thread.sleep(5);
+        }
+        sslSocket.close();
+    }
+
+    volatile int serverPort = 0;
+
+    volatile Exception serverException = null;
+    volatile Exception clientException = null;
+
+    public static void main(String[] args) throws Exception {
+        String keyFilename =
+            System.getProperty("test.src", "./") + "/" + pathToStores +
+                "/" + keyStoreFile;
+        String trustFilename =
+            System.getProperty("test.src", "./") + "/" + pathToStores +
+                "/" + trustStoreFile;
+
+        System.setProperty("javax.net.ssl.keyStore", keyFilename);
+        System.setProperty("javax.net.ssl.keyStorePassword", passwd);
+        System.setProperty("javax.net.ssl.trustStore", trustFilename);
+        System.setProperty("javax.net.ssl.trustStorePassword", passwd);
+
+        if (debug)
+            System.setProperty("javax.net.debug", "ssl");
+
+        new RenegotiateTLS13();
+    }
+
+    Thread clientThread = null;
+    Thread serverThread = null;
+
+    /*
+     * Primary constructor, used to drive remainder of the test.
+     *
+     * Fork off the other side, then do your work.
+     */
+    RenegotiateTLS13() throws Exception {
+        try {
+            if (separateServerThread) {
+                startServer(true);
+                startClient(false);
+            } else {
+                startClient(true);
+                startServer(false);
+            }
+        } catch (Exception e) {
+            // swallow for now.  Show later
+        }
+
+        /*
+         * Wait for other side to close down.
+         */
+        if (separateServerThread) {
+            serverThread.join();
+        } else {
+            clientThread.join();
+        }
+
+        /*
+         * When we get here, the test is pretty much over.
+         * Which side threw the error?
+         */
+        Exception local;
+        Exception remote;
+        String whichRemote;
+
+        if (separateServerThread) {
+            remote = serverException;
+            local = clientException;
+            whichRemote = "server";
+        } else {
+            remote = clientException;
+            local = serverException;
+            whichRemote = "client";
+        }
+
+        /*
+         * If both failed, return the curthread's exception, but also
+         * print the remote side Exception
+         */
+        if ((local != null) && (remote != null)) {
+            System.out.println(whichRemote + " also threw:");
+            remote.printStackTrace();
+            System.out.println();
+            throw local;
+        }
+
+        if (remote != null) {
+            throw remote;
+        }
+
+        if (local != null) {
+            throw local;
+        }
+    }
+
+    void startServer(boolean newThread) throws Exception {
+        if (newThread) {
+            serverThread = new Thread() {
+                public void run() {
+                    try {
+                        doServerSide();
+                    } catch (Exception e) {
+                        /*
+                         * Our server thread just died.
+                         *
+                         * Release the client, if not active already...
+                         */
+                        System.err.println("Server died...");
+                        serverReady = true;
+                        serverException = e;
+                    }
+                }
+            };
+            serverThread.start();
+        } else {
+            try {
+                doServerSide();
+            } catch (Exception e) {
+                serverException = e;
+            } finally {
+                serverReady = true;
+            }
+        }
+    }
+
+    void startClient(boolean newThread) throws Exception {
+        if (newThread) {
+            clientThread = new Thread() {
+                public void run() {
+                    try {
+                        doClientSide();
+                    } catch (Exception e) {
+                        /*
+                         * Our client thread just died.
+                         */
+                        System.err.println("Client died...");
+                        clientException = e;
+                    }
+                }
+            };
+            clientThread.start();
+        } else {
+            try {
+                doClientSide();
+            } catch (Exception e) {
+                clientException = e;
+            }
+        }
+    }
+
+    // Initialize context for TLS 1.3
+    SSLContext initContext() throws Exception {
+        System.out.println("Using TLS13");
+        SSLContext sc = SSLContext.getInstance("TLSv1.3");
+        KeyStore ks = KeyStore.getInstance(
+                new File(System.getProperty("javax.net.ssl.keyStore")),
+                passwd.toCharArray());
+        KeyManagerFactory kmf = KeyManagerFactory.getInstance(
+                KeyManagerFactory.getDefaultAlgorithm());
+        kmf.init(ks, passwd.toCharArray());
+        TrustManagerFactory tmf = TrustManagerFactory.getInstance(
+                TrustManagerFactory.getDefaultAlgorithm());
+        tmf.init(ks);
+        sc.init(kmf.getKeyManagers(), tmf.getTrustManagers(), new SecureRandom());
+        return sc;
+    }
+}
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/test/jdk/sun/security/ssl/SSLEngineImpl/TLS13BeginHandshake.java	Sun May 13 08:52:25 2018 -0700
@@ -0,0 +1,191 @@
+/*
+ * Copyright (c) 2018, Oracle and/or its affiliates. All rights reserved.
+ * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
+ *
+ * This code is free software; you can redistribute it and/or modify it
+ * under the terms of the GNU General Public License version 2 only, as
+ * published by the Free Software Foundation.
+ *
+ * This code is distributed in the hope that it will be useful, but WITHOUT
+ * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
+ * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
+ * version 2 for more details (a copy is included in the LICENSE file that
+ * accompanied this code).
+ *
+ * You should have received a copy of the GNU General Public License version
+ * 2 along with this work; if not, write to the Free Software Foundation,
+ * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
+ *
+ * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
+ * or visit www.oracle.com if you need additional information or have any
+ * questions.
+ */
+
+/*
+ * @test
+ * @summary Test SSLEngine.begineHandshake() triggers a KeyUpdate handshake
+ * in TLSv1.3
+ * @run main/othervm TLS13BeginHandshake
+ */
+
+import javax.net.ssl.KeyManagerFactory;
+import javax.net.ssl.SSLContext;
+import javax.net.ssl.SSLEngine;
+import javax.net.ssl.SSLEngineResult;
+import javax.net.ssl.SSLEngineResult.HandshakeStatus;
+import javax.net.ssl.SSLSession;
+import javax.net.ssl.TrustManagerFactory;
+import java.io.File;
+import java.nio.ByteBuffer;
+import java.security.KeyStore;
+import java.security.SecureRandom;
+
+public class TLS13BeginHandshake {
+    static String pathToStores =
+            System.getProperty("test.src") + "/../../../../javax/net/ssl/etc/";
+    static String keyStoreFile = "keystore";
+    static String passwd = "passphrase";
+
+    private SSLEngine serverEngine, clientEngine;
+    SSLEngineResult clientResult, serverResult;
+    private ByteBuffer clientOut, clientIn;
+    private ByteBuffer serverOut, serverIn;
+    private ByteBuffer cTOs,sTOc;
+
+    public static void main(String args[]) throws Exception{
+        new TLS13BeginHandshake().runDemo();
+    }
+
+    private void runDemo() throws Exception {
+        int done = 0;
+
+        createSSLEngines();
+        createBuffers();
+
+        while (!isEngineClosed(clientEngine) || !isEngineClosed(serverEngine)) {
+
+            System.out.println("================");
+            clientResult = clientEngine.wrap(clientOut, cTOs);
+            System.out.println("client wrap: " + clientResult);
+            runDelegatedTasks(clientResult, clientEngine);
+            serverResult = serverEngine.wrap(serverOut, sTOc);
+            System.out.println("server wrap: " + serverResult);
+            runDelegatedTasks(serverResult, serverEngine);
+
+            cTOs.flip();
+            sTOc.flip();
+
+            System.out.println("----");
+            clientResult = clientEngine.unwrap(sTOc, clientIn);
+            System.out.println("client unwrap: " + clientResult);
+            if (clientResult.getStatus() == SSLEngineResult.Status.CLOSED) {
+                break;
+            }            runDelegatedTasks(clientResult, clientEngine);
+            serverResult = serverEngine.unwrap(cTOs, serverIn);
+            System.out.println("server unwrap: " + serverResult);
+            runDelegatedTasks(serverResult, serverEngine);
+
+            cTOs.compact();
+            sTOc.compact();
+
+            //System.err.println("so limit="+serverOut.limit()+" so pos="+serverOut.position());
+            //System.out.println("bf ctos limit="+cTOs.limit()+" pos="+cTOs.position()+" cap="+cTOs.capacity());
+            //System.out.println("bf stoc limit="+sTOc.limit()+" pos="+sTOc.position()+" cap="+sTOc.capacity());
+            if (done < 2  && (clientOut.limit() == serverIn.position()) &&
+                    (serverOut.limit() == clientIn.position())) {
+
+                if (done == 0) {
+                    checkTransfer(serverOut, clientIn);
+                    checkTransfer(clientOut, serverIn);
+                    clientEngine.beginHandshake();
+                    done++;
+                    continue;
+                }
+
+                checkTransfer(serverOut, clientIn);
+                checkTransfer(clientOut, serverIn);
+                System.out.println("\tClosing...");
+                clientEngine.closeOutbound();
+                done++;
+                continue;
+            }
+        }
+
+    }
+
+    private static boolean isEngineClosed(SSLEngine engine) {
+        if (engine.isInboundDone())
+            System.out.println("inbound closed");
+        if (engine.isOutboundDone())
+            System.out.println("outbound closed");
+        return (engine.isOutboundDone() && engine.isInboundDone());
+    }
+
+    private static void checkTransfer(ByteBuffer a, ByteBuffer b)
+            throws Exception {
+        a.flip();
+        b.flip();
+
+        if (!a.equals(b)) {
+            throw new Exception("Data didn't transfer cleanly");
+        } else {
+            System.out.println("\tData transferred cleanly");
+        }
+
+        a.compact();
+        b.compact();
+
+    }
+    private void createBuffers() {
+        SSLSession session = clientEngine.getSession();
+        int appBufferMax = session.getApplicationBufferSize();
+        int netBufferMax = session.getPacketBufferSize();
+
+        clientIn = ByteBuffer.allocate(appBufferMax + 50);
+        serverIn = ByteBuffer.allocate(appBufferMax + 50);
+
+        cTOs = ByteBuffer.allocateDirect(netBufferMax);
+        sTOc = ByteBuffer.allocateDirect(netBufferMax);
+
+        clientOut = ByteBuffer.wrap("client".getBytes());
+        serverOut = ByteBuffer.wrap("server".getBytes());
+    }
+
+    private void createSSLEngines() throws Exception {
+        serverEngine = initContext().createSSLEngine();
+        serverEngine.setUseClientMode(false);
+        serverEngine.setNeedClientAuth(true);
+
+        clientEngine = initContext().createSSLEngine("client", 80);
+        clientEngine.setUseClientMode(true);
+    }
+
+    private SSLContext initContext() throws Exception {
+        SSLContext sc = SSLContext.getInstance("TLSv1.3");
+        KeyStore ks = KeyStore.getInstance(new File(pathToStores + keyStoreFile),
+                passwd.toCharArray());
+        KeyManagerFactory kmf =
+                KeyManagerFactory.getInstance(KeyManagerFactory.getDefaultAlgorithm());
+        kmf.init(ks, passwd.toCharArray());
+        TrustManagerFactory tmf = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm());
+        tmf.init(ks);
+        sc.init(kmf.getKeyManagers(), tmf.getTrustManagers(), new SecureRandom());
+        return sc;
+    }
+
+    private static void runDelegatedTasks(SSLEngineResult result,
+            SSLEngine engine) throws Exception {
+
+        if (result.getHandshakeStatus() == HandshakeStatus.NEED_TASK) {
+            Runnable runnable;
+            while ((runnable = engine.getDelegatedTask()) != null) {
+                runnable.run();
+            }
+            HandshakeStatus hsStatus = engine.getHandshakeStatus();
+            if (hsStatus == HandshakeStatus.NEED_TASK) {
+                throw new Exception(
+                    "handshake shouldn't need additional tasks");
+            }
+        }
+    }
+}