--- a/src/java.base/share/classes/sun/security/ssl/HandshakeContext.java Fri May 11 16:07:27 2018 -0700
+++ b/src/java.base/share/classes/sun/security/ssl/HandshakeContext.java Sun May 13 08:52:25 2018 -0700
@@ -70,7 +70,7 @@
"sun.security.ssl.allowLegacyHelloMessages", true);
// registered handshake message actors
- final LinkedHashMap<Byte, SSLConsumer> handshakeConsumers;
+ LinkedHashMap<Byte, SSLConsumer> handshakeConsumers;
final HashMap<Byte, HandshakeProducer> handshakeProducers;
// context
@@ -115,7 +115,6 @@
SecretKey baseWriteSecret;
// protocol version being established
- ProtocolVersion protocolVersion;
int clientHelloVersion;
String applicationProtocol;
@@ -207,6 +206,30 @@
initialize();
}
+ /**
+ * Constructor for PostHandshakeContext
+ */
+ HandshakeContext(TransportContext conContext) {
+ this.sslContext = conContext.sslContext;
+ this.conContext = conContext;
+ this.sslConfig = conContext.sslConfig;
+
+ this.negotiatedProtocol = conContext.protocolVersion;
+ this.negotiatedCipherSuite = conContext.cipherSuite;
+ this.handshakeOutput = new HandshakeOutStream(conContext.outputRecord);
+ this.delegatedActions = new LinkedList<>();
+
+ this.handshakeProducers = null;
+ this.handshakeHash = null;
+ this.activeProtocols = null;
+ this.activeCipherSuites = null;
+ this.algorithmConstraints = null;
+ this.maximumActiveProtocol = null;
+ this.handshakeExtensions = null;
+ this.handshakePossessions = null;
+ this.handshakeCredentials = null;
+ }
+
// Initialize the non-final class variables.
private void initialize() throws IOException {
ProtocolVersion inputHelloVersion;
@@ -331,9 +354,10 @@
return Collections.unmodifiableList(suites);
}
- void dispatch(Plaintext plaintext) throws IOException {
- // parse the handshake record.
- //
+ /**
+ * Parse the handshake record and return the contentType
+ */
+ static byte getHandshakeType(TransportContext conContext, Plaintext plaintext) throws IOException {
// struct {
// HandshakeType msg_type; /* handshake type */
// uint24 length; /* bytes in message */
@@ -341,18 +365,19 @@
// ...
// } body;
// } Handshake;
+
if (plaintext.contentType != ContentType.HANDSHAKE.id) {
conContext.fatal(Alert.INTERNAL_ERROR,
"Unexpected operation for record: " + plaintext.contentType);
- return; // make the compiler happy
+ return 0;
}
if (plaintext.fragment == null || plaintext.fragment.remaining() < 4) {
conContext.fatal(Alert.UNEXPECTED_MESSAGE,
"Invalid handshake message: insufficient data");
- return; // make the compiler happy
+ return 0;
}
byte handshakeType = (byte)Record.getInt8(plaintext.fragment);
@@ -361,9 +386,14 @@
conContext.fatal(Alert.UNEXPECTED_MESSAGE,
"Invalid handshake message: insufficient handshake body");
- return; // make the compiler happy
+ return 0;
}
+ return handshakeType;
+ }
+
+ void dispatch(byte handshakeType, Plaintext plaintext) throws IOException {
+
if (conContext.transport.useDelegatedTask()) {
boolean hasDelegated = !delegatedActions.isEmpty();
if (hasDelegated || handshakeType != SSLHandshake.FINISHED.id) {
@@ -406,8 +436,6 @@
} else if (handshakeType == SSLHandshake.NEW_SESSION_TICKET.id) {
// new session ticket may be sent any time after server finished
consumer = SSLHandshake.NEW_SESSION_TICKET;
- } else if (handshakeType == SSLHandshake.KEY_UPDATE.id) {
- consumer = SSLHandshake.KEY_UPDATE;
} else {
consumer = handshakeConsumers.get(handshakeType);
}
@@ -416,7 +444,7 @@
conContext.fatal(Alert.UNEXPECTED_MESSAGE,
"Unexpected handshake message: " +
SSLHandshake.nameOf(handshakeType));
- return; // make the compiler happy
+ return;
}
try {
@@ -490,7 +518,6 @@
* and ServerHandshaker with the negotiated protocol version.
*/
void setVersion(ProtocolVersion protocolVersion) {
- this.protocolVersion = protocolVersion;
this.conContext.protocolVersion = protocolVersion;
}
--- a/src/java.base/share/classes/sun/security/ssl/KeyUpdate.java Fri May 11 16:07:27 2018 -0700
+++ b/src/java.base/share/classes/sun/security/ssl/KeyUpdate.java Sun May 13 08:52:25 2018 -0700
@@ -66,16 +66,15 @@
static final byte NOTREQUSTED = 0;
static final byte REQUSTED = 1;
private byte status;
- HandshakeOutStream hos;
- KeyUpdateMessage(HandshakeContext context, byte status) {
+ KeyUpdateMessage(PostHandshakeContext context, byte status) {
super(context);
this.status = status;
if (status > 1) {
new IOException("KeyUpdate message value invalid: " + status);
}
}
- KeyUpdateMessage(HandshakeContext context, ByteBuffer m)
+ KeyUpdateMessage(PostHandshakeContext context, ByteBuffer m)
throws IOException{
super(context);
@@ -90,14 +89,6 @@
}
}
- KeyUpdateMessage(PostHandshakeContext context, byte status) {
- super(context);
- this.status = status;
- if (status > 1) {
- new IOException("KeyUpdate message value invalid: " + status);
- }
- }
-
@Override
public SSLHandshake handshakeType() {
return SSLHandshake.KEY_UPDATE;
@@ -134,7 +125,7 @@
// Produce kickstart handshake message.
@Override
public byte[] produce(ConnectionContext context) throws IOException {
- HandshakeContext hc = (HandshakeContext)context;
+ PostHandshakeContext hc = (PostHandshakeContext)context;
handshakeProducer.produce(hc,
new KeyUpdateMessage(hc, KeyUpdateMessage.REQUSTED));
return null;
@@ -154,7 +145,7 @@
public void consume(ConnectionContext context,
ByteBuffer message) throws IOException {
// The consuming happens in client side only.
- HandshakeContext hc = (HandshakeContext)context;
+ PostHandshakeContext hc = (PostHandshakeContext)context;
KeyUpdateMessage km = new KeyUpdateMessage(hc, message);
if (km.getStatus() == KeyUpdateMessage.NOTREQUSTED) {
@@ -221,9 +212,8 @@
public byte[] produce(ConnectionContext context,
HandshakeMessage message) throws IOException {
// The producing happens in server side only.
- HandshakeContext hc = (HandshakeContext)context;
+ PostHandshakeContext hc = (PostHandshakeContext)context;
KeyUpdateMessage km = (KeyUpdateMessage)message;
-
SecretKey secret;
if (km.getStatus() == KeyUpdateMessage.REQUSTED) {
@@ -278,6 +268,7 @@
if (SSLLogger.isOn && SSLLogger.isOn("ssl")) {
SSLLogger.fine("KeyUpdate: write key updated");
}
+ hc.free();
return null;
}
}
--- a/src/java.base/share/classes/sun/security/ssl/PostHandshakeContext.java Fri May 11 16:07:27 2018 -0700
+++ b/src/java.base/share/classes/sun/security/ssl/PostHandshakeContext.java Sun May 13 08:52:25 2018 -0700
@@ -26,18 +26,64 @@
package sun.security.ssl;
import java.io.IOException;
+import java.nio.ByteBuffer;
+import java.util.LinkedHashMap;
/**
- * A clean implementation of HandshakeContext for post-handshake messages
+ * A compact implementation of HandshakeContext for post-handshake messages
*/
public class PostHandshakeContext extends HandshakeContext {
+ final static LinkedHashMap<Byte, SSLConsumer> consumers;
+ static {
+ consumers = new LinkedHashMap<>() {{
+ put(SSLHandshake.KEY_UPDATE.id,
+ SSLHandshake.KEY_UPDATE);
+ }};
+ }
+
+
PostHandshakeContext(TransportContext context) throws IOException {
- super(context.sslContext, context);
+ super(context);
+
+ if (!negotiatedProtocol.useTLS13PlusSpec()) {
+ conContext.fatal(Alert.UNEXPECTED_MESSAGE, "Post-handshake not " +
+ "supported in " + negotiatedProtocol.name);
+ }
+
+ handshakeConsumers = consumers;
+ handshakeFinished = true;
}
@Override
void kickstart() throws IOException {
+ SSLHandshake.kickstart(this);
+ }
+
+ @Override
+ void dispatch(byte handshakeType, ByteBuffer fragment) throws IOException {
+
+ SSLConsumer consumer = handshakeConsumers.get(handshakeType);
+ if (consumer == null) {
+ conContext.fatal(Alert.UNEXPECTED_MESSAGE,
+ "Unexpected post-handshake message: " +
+ SSLHandshake.nameOf(handshakeType));
+ return;
+ }
+
+ try {
+ consumer.consume(this, fragment);
+ } catch (UnsupportedOperationException unsoe) {
+ conContext.fatal(Alert.UNEXPECTED_MESSAGE,
+ "Unsupported post-handshake message: " +
+ SSLHandshake.nameOf(handshakeType), unsoe);
+ }
+ }
+
+ void free() {
+ if (delegatedActions.isEmpty()) {
+ conContext.handshakeContext = null;
+ }
}
}
--- a/src/java.base/share/classes/sun/security/ssl/SSLEngineImpl.java Fri May 11 16:07:27 2018 -0700
+++ b/src/java.base/share/classes/sun/security/ssl/SSLEngineImpl.java Sun May 13 08:52:25 2018 -0700
@@ -287,7 +287,7 @@
// Is the handshake completed?
boolean needRetransmission =
conContext.sslContext.isDTLS() &&
- conContext.handshakeContext != null &&
+ conContext.getHandshakeContext(TransportContext.PRE) != null &&
conContext.handshakeContext.sslConfig.enableRetransmissions;
HandshakeStatus hsStatus =
tryToFinishHandshake(ciphertext.contentType);
@@ -331,6 +331,8 @@
conContext.outputRecord.isEmpty()) {
if (conContext.handshakeContext == null) {
hsStatus = HandshakeStatus.FINISHED;
+ } else if (conContext.getHandshakeContext(TransportContext.POST) != null) {
+ return null;
} else if (conContext.handshakeContext.handshakeFinished) {
hsStatus = conContext.finishHandshake();
}
@@ -684,7 +686,7 @@
@Override
public synchronized Runnable getDelegatedTask() {
- if (conContext.handshakeContext != null &&
+ if (conContext.handshakeContext != null && // PRE or POST handshake
!conContext.handshakeContext.taskDelegated &&
!conContext.handshakeContext.delegatedActions.isEmpty()) {
conContext.handshakeContext.taskDelegated = true;
--- a/src/java.base/share/classes/sun/security/ssl/TransportContext.java Fri May 11 16:07:27 2018 -0700
+++ b/src/java.base/share/classes/sun/security/ssl/TransportContext.java Sun May 13 08:52:25 2018 -0700
@@ -175,18 +175,24 @@
if (ct == null) {
fatal(Alert.UNEXPECTED_MESSAGE,
"Unknown content type: " + plaintext.contentType);
- return; // make compiler happy
+ return;
}
switch (ct) {
case HANDSHAKE:
+ byte type = HandshakeContext.getHandshakeType(this,
+ plaintext);
if (handshakeContext == null) {
- handshakeContext = sslConfig.isClientMode ?
- new ClientHandshakeContext(sslContext, this) :
- new ServerHandshakeContext(sslContext, this);
- outputRecord.initHandshaker();
+ if (type == SSLHandshake.KEY_UPDATE.id) {
+ handshakeContext = new PostHandshakeContext(this);
+ } else {
+ handshakeContext = sslConfig.isClientMode ?
+ new ClientHandshakeContext(sslContext, this) :
+ new ServerHandshakeContext(sslContext, this);
+ outputRecord.initHandshaker();
+ }
}
- handshakeContext.dispatch(plaintext);
+ handshakeContext.dispatch(type, plaintext);
break;
case ALERT:
Alert.alertConsumer.consume(this, plaintext.fragment);
@@ -209,29 +215,54 @@
// initialize the handshaker if necessary
if (handshakeContext == null) {
- handshakeContext = sslConfig.isClientMode ?
- new ClientHandshakeContext(sslContext, this) :
- new ServerHandshakeContext(sslContext, this);
- outputRecord.initHandshaker();
+ // TLS1.3 post-handshake
+ if (isNegotiated && protocolVersion.useTLS13PlusSpec()) {
+ handshakeContext = new PostHandshakeContext(this);
+ } else {
+ handshakeContext = sslConfig.isClientMode ?
+ new ClientHandshakeContext(sslContext, this) :
+ new ServerHandshakeContext(sslContext, this);
+ outputRecord.initHandshaker();
+ }
}
// kickstart the handshake if needed
//
// Need no kickstart message on server side unless the connection
- // has been estabilished.
+ // has been established.
if(isNegotiated || sslConfig.isClientMode) {
handshakeContext.kickstart();
}
}
void keyUpdate() throws IOException {
- // TODO: TLS 1.3
kickstart();
}
- // Note: close_notify is delivered as awarning alert.
+ final static byte PRE = 1;
+ final static byte POST = 2;
+
+ HandshakeContext getHandshakeContext(byte type) {
+ if (handshakeContext == null) {
+ return null;
+ }
+
+ if (type == PRE &&
+ (handshakeContext instanceof ClientHandshakeContext ||
+ handshakeContext instanceof ServerHandshakeContext)) {
+ return handshakeContext;
+ }
+
+ if (type == POST && handshakeContext instanceof PostHandshakeContext) {
+ return handshakeContext;
+ }
+
+ return null;
+ }
+
+ // Note: close_notify is delivered as a warning alert.
void warning(Alert alert) {
- // For initial handshaking, don't send awarning alert message to peer
+ // For initial handshaking, don't send a warning alert message to peer
// if handshaker has not started.
if (isNegotiated || handshakeContext != null) {
try {
--- /dev/null Thu Jan 01 00:00:00 1970 +0000
+++ b/test/jdk/javax/net/ssl/SSLSession/RenegotiateTLS13.java Sun May 13 08:52:25 2018 -0700
@@ -0,0 +1,292 @@
+/*
+ * Copyright (c) 2018, Oracle and/or its affiliates. All rights reserved.
+ * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
+ *
+ * This code is free software; you can redistribute it and/or modify it
+ * under the terms of the GNU General Public License version 2 only, as
+ * published by the Free Software Foundation.
+ *
+ * This code is distributed in the hope that it will be useful, but WITHOUT
+ * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
+ * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
+ * version 2 for more details (a copy is included in the LICENSE file that
+ * accompanied this code).
+ *
+ * You should have received a copy of the GNU General Public License version
+ * 2 along with this work; if not, write to the Free Software Foundation,
+ * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
+ *
+ * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
+ * or visit www.oracle.com if you need additional information or have any
+ * questions.
+ */
+
+/*
+ * @test
+ * @run main/othervm -Djavax.net.debug=ssl RenegotiateTLS13
+ */
+
+import javax.net.ssl.KeyManagerFactory;
+import javax.net.ssl.SSLContext;
+import javax.net.ssl.SSLServerSocket;
+import javax.net.ssl.SSLServerSocketFactory;
+import javax.net.ssl.SSLSocket;
+import javax.net.ssl.SSLSocketFactory;
+import javax.net.ssl.TrustManagerFactory;
+import java.io.DataInputStream;
+import java.io.DataOutputStream;
+import java.io.File;
+import java.io.IOException;
+import java.security.KeyStore;
+import java.security.SecureRandom;
+
+public class RenegotiateTLS13 {
+
+ static final String dataString = "This is a test";
+
+ // Run the server as a thread instead of the client
+ static boolean separateServerThread = false;
+
+ static String pathToStores = "../etc";
+ static String keyStoreFile = "keystore";
+ static String trustStoreFile = "truststore";
+ static String passwd = "passphrase";
+
+ // Server ready flag
+ volatile static boolean serverReady = false;
+ // Turn on SSL debugging
+ static boolean debug = false;
+ // Server done flag
+ static boolean done = false;
+
+ // Main server code
+
+ void doServerSide() throws Exception {
+ SSLServerSocketFactory sslssf;
+ sslssf = initContext().getServerSocketFactory();
+ SSLServerSocket sslServerSocket =
+ (SSLServerSocket) sslssf.createServerSocket(serverPort);
+ serverPort = sslServerSocket.getLocalPort();
+
+ serverReady = true;
+
+ SSLSocket sslSocket = (SSLSocket) sslServerSocket.accept();
+
+ DataInputStream sslIS =
+ new DataInputStream(sslSocket.getInputStream());
+ String s = "";
+ while (s.compareTo("done") != 0) {
+ try {
+ s = sslIS.readUTF();
+ System.out.println("Received: " + s);
+ } catch (IOException e) {
+ throw e;
+ }
+ }
+ done = true;
+ sslSocket.close();
+ }
+
+ // Main client code
+ void doClientSide() throws Exception {
+
+ while (!serverReady) {
+ Thread.sleep(5);
+ }
+
+ SSLSocketFactory sslsf;
+ sslsf = initContext().getSocketFactory();
+
+ SSLSocket sslSocket = (SSLSocket)
+ sslsf.createSocket("localhost", serverPort);
+
+ DataOutputStream sslOS =
+ new DataOutputStream(sslSocket.getOutputStream());
+
+ sslOS.writeUTF("With " + dataString);
+ sslOS.writeUTF("With " + dataString);
+ sslOS.writeUTF("With " + dataString);
+
+ sslSocket.startHandshake();
+
+ sslOS.writeUTF("With " + dataString);
+ sslOS.writeUTF("With " + dataString);
+ sslOS.writeUTF("With " + dataString);
+
+ sslSocket.startHandshake();
+
+ sslOS.writeUTF("With " + dataString);
+ sslOS.writeUTF("With " + dataString);
+ sslOS.writeUTF("With " + dataString);
+ sslOS.writeUTF("done");
+
+ while (!done) {
+ Thread.sleep(5);
+ }
+ sslSocket.close();
+ }
+
+ volatile int serverPort = 0;
+
+ volatile Exception serverException = null;
+ volatile Exception clientException = null;
+
+ public static void main(String[] args) throws Exception {
+ String keyFilename =
+ System.getProperty("test.src", "./") + "/" + pathToStores +
+ "/" + keyStoreFile;
+ String trustFilename =
+ System.getProperty("test.src", "./") + "/" + pathToStores +
+ "/" + trustStoreFile;
+
+ System.setProperty("javax.net.ssl.keyStore", keyFilename);
+ System.setProperty("javax.net.ssl.keyStorePassword", passwd);
+ System.setProperty("javax.net.ssl.trustStore", trustFilename);
+ System.setProperty("javax.net.ssl.trustStorePassword", passwd);
+
+ if (debug)
+ System.setProperty("javax.net.debug", "ssl");
+
+ new RenegotiateTLS13();
+ }
+
+ Thread clientThread = null;
+ Thread serverThread = null;
+
+ /*
+ * Primary constructor, used to drive remainder of the test.
+ *
+ * Fork off the other side, then do your work.
+ */
+ RenegotiateTLS13() throws Exception {
+ try {
+ if (separateServerThread) {
+ startServer(true);
+ startClient(false);
+ } else {
+ startClient(true);
+ startServer(false);
+ }
+ } catch (Exception e) {
+ // swallow for now. Show later
+ }
+
+ /*
+ * Wait for other side to close down.
+ */
+ if (separateServerThread) {
+ serverThread.join();
+ } else {
+ clientThread.join();
+ }
+
+ /*
+ * When we get here, the test is pretty much over.
+ * Which side threw the error?
+ */
+ Exception local;
+ Exception remote;
+ String whichRemote;
+
+ if (separateServerThread) {
+ remote = serverException;
+ local = clientException;
+ whichRemote = "server";
+ } else {
+ remote = clientException;
+ local = serverException;
+ whichRemote = "client";
+ }
+
+ /*
+ * If both failed, return the curthread's exception, but also
+ * print the remote side Exception
+ */
+ if ((local != null) && (remote != null)) {
+ System.out.println(whichRemote + " also threw:");
+ remote.printStackTrace();
+ System.out.println();
+ throw local;
+ }
+
+ if (remote != null) {
+ throw remote;
+ }
+
+ if (local != null) {
+ throw local;
+ }
+ }
+
+ void startServer(boolean newThread) throws Exception {
+ if (newThread) {
+ serverThread = new Thread() {
+ public void run() {
+ try {
+ doServerSide();
+ } catch (Exception e) {
+ /*
+ * Our server thread just died.
+ *
+ * Release the client, if not active already...
+ */
+ System.err.println("Server died...");
+ serverReady = true;
+ serverException = e;
+ }
+ }
+ };
+ serverThread.start();
+ } else {
+ try {
+ doServerSide();
+ } catch (Exception e) {
+ serverException = e;
+ } finally {
+ serverReady = true;
+ }
+ }
+ }
+
+ void startClient(boolean newThread) throws Exception {
+ if (newThread) {
+ clientThread = new Thread() {
+ public void run() {
+ try {
+ doClientSide();
+ } catch (Exception e) {
+ /*
+ * Our client thread just died.
+ */
+ System.err.println("Client died...");
+ clientException = e;
+ }
+ }
+ };
+ clientThread.start();
+ } else {
+ try {
+ doClientSide();
+ } catch (Exception e) {
+ clientException = e;
+ }
+ }
+ }
+
+ // Initialize context for TLS 1.3
+ SSLContext initContext() throws Exception {
+ System.out.println("Using TLS13");
+ SSLContext sc = SSLContext.getInstance("TLSv1.3");
+ KeyStore ks = KeyStore.getInstance(
+ new File(System.getProperty("javax.net.ssl.keyStore")),
+ passwd.toCharArray());
+ KeyManagerFactory kmf = KeyManagerFactory.getInstance(
+ KeyManagerFactory.getDefaultAlgorithm());
+ kmf.init(ks, passwd.toCharArray());
+ TrustManagerFactory tmf = TrustManagerFactory.getInstance(
+ TrustManagerFactory.getDefaultAlgorithm());
+ tmf.init(ks);
+ sc.init(kmf.getKeyManagers(), tmf.getTrustManagers(), new SecureRandom());
+ return sc;
+ }
+}
--- /dev/null Thu Jan 01 00:00:00 1970 +0000
+++ b/test/jdk/sun/security/ssl/SSLEngineImpl/TLS13BeginHandshake.java Sun May 13 08:52:25 2018 -0700
@@ -0,0 +1,191 @@
+/*
+ * Copyright (c) 2018, Oracle and/or its affiliates. All rights reserved.
+ * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
+ *
+ * This code is free software; you can redistribute it and/or modify it
+ * under the terms of the GNU General Public License version 2 only, as
+ * published by the Free Software Foundation.
+ *
+ * This code is distributed in the hope that it will be useful, but WITHOUT
+ * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
+ * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
+ * version 2 for more details (a copy is included in the LICENSE file that
+ * accompanied this code).
+ *
+ * You should have received a copy of the GNU General Public License version
+ * 2 along with this work; if not, write to the Free Software Foundation,
+ * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
+ *
+ * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
+ * or visit www.oracle.com if you need additional information or have any
+ * questions.
+ */
+
+/*
+ * @test
+ * @summary Test SSLEngine.begineHandshake() triggers a KeyUpdate handshake
+ * in TLSv1.3
+ * @run main/othervm TLS13BeginHandshake
+ */
+
+import javax.net.ssl.KeyManagerFactory;
+import javax.net.ssl.SSLContext;
+import javax.net.ssl.SSLEngine;
+import javax.net.ssl.SSLEngineResult;
+import javax.net.ssl.SSLEngineResult.HandshakeStatus;
+import javax.net.ssl.SSLSession;
+import javax.net.ssl.TrustManagerFactory;
+import java.io.File;
+import java.nio.ByteBuffer;
+import java.security.KeyStore;
+import java.security.SecureRandom;
+
+public class TLS13BeginHandshake {
+ static String pathToStores =
+ System.getProperty("test.src") + "/../../../../javax/net/ssl/etc/";
+ static String keyStoreFile = "keystore";
+ static String passwd = "passphrase";
+
+ private SSLEngine serverEngine, clientEngine;
+ SSLEngineResult clientResult, serverResult;
+ private ByteBuffer clientOut, clientIn;
+ private ByteBuffer serverOut, serverIn;
+ private ByteBuffer cTOs,sTOc;
+
+ public static void main(String args[]) throws Exception{
+ new TLS13BeginHandshake().runDemo();
+ }
+
+ private void runDemo() throws Exception {
+ int done = 0;
+
+ createSSLEngines();
+ createBuffers();
+
+ while (!isEngineClosed(clientEngine) || !isEngineClosed(serverEngine)) {
+
+ System.out.println("================");
+ clientResult = clientEngine.wrap(clientOut, cTOs);
+ System.out.println("client wrap: " + clientResult);
+ runDelegatedTasks(clientResult, clientEngine);
+ serverResult = serverEngine.wrap(serverOut, sTOc);
+ System.out.println("server wrap: " + serverResult);
+ runDelegatedTasks(serverResult, serverEngine);
+
+ cTOs.flip();
+ sTOc.flip();
+
+ System.out.println("----");
+ clientResult = clientEngine.unwrap(sTOc, clientIn);
+ System.out.println("client unwrap: " + clientResult);
+ if (clientResult.getStatus() == SSLEngineResult.Status.CLOSED) {
+ break;
+ } runDelegatedTasks(clientResult, clientEngine);
+ serverResult = serverEngine.unwrap(cTOs, serverIn);
+ System.out.println("server unwrap: " + serverResult);
+ runDelegatedTasks(serverResult, serverEngine);
+
+ cTOs.compact();
+ sTOc.compact();
+
+ //System.err.println("so limit="+serverOut.limit()+" so pos="+serverOut.position());
+ //System.out.println("bf ctos limit="+cTOs.limit()+" pos="+cTOs.position()+" cap="+cTOs.capacity());
+ //System.out.println("bf stoc limit="+sTOc.limit()+" pos="+sTOc.position()+" cap="+sTOc.capacity());
+ if (done < 2 && (clientOut.limit() == serverIn.position()) &&
+ (serverOut.limit() == clientIn.position())) {
+
+ if (done == 0) {
+ checkTransfer(serverOut, clientIn);
+ checkTransfer(clientOut, serverIn);
+ clientEngine.beginHandshake();
+ done++;
+ continue;
+ }
+
+ checkTransfer(serverOut, clientIn);
+ checkTransfer(clientOut, serverIn);
+ System.out.println("\tClosing...");
+ clientEngine.closeOutbound();
+ done++;
+ continue;
+ }
+ }
+
+ }
+
+ private static boolean isEngineClosed(SSLEngine engine) {
+ if (engine.isInboundDone())
+ System.out.println("inbound closed");
+ if (engine.isOutboundDone())
+ System.out.println("outbound closed");
+ return (engine.isOutboundDone() && engine.isInboundDone());
+ }
+
+ private static void checkTransfer(ByteBuffer a, ByteBuffer b)
+ throws Exception {
+ a.flip();
+ b.flip();
+
+ if (!a.equals(b)) {
+ throw new Exception("Data didn't transfer cleanly");
+ } else {
+ System.out.println("\tData transferred cleanly");
+ }
+
+ a.compact();
+ b.compact();
+
+ }
+ private void createBuffers() {
+ SSLSession session = clientEngine.getSession();
+ int appBufferMax = session.getApplicationBufferSize();
+ int netBufferMax = session.getPacketBufferSize();
+
+ clientIn = ByteBuffer.allocate(appBufferMax + 50);
+ serverIn = ByteBuffer.allocate(appBufferMax + 50);
+
+ cTOs = ByteBuffer.allocateDirect(netBufferMax);
+ sTOc = ByteBuffer.allocateDirect(netBufferMax);
+
+ clientOut = ByteBuffer.wrap("client".getBytes());
+ serverOut = ByteBuffer.wrap("server".getBytes());
+ }
+
+ private void createSSLEngines() throws Exception {
+ serverEngine = initContext().createSSLEngine();
+ serverEngine.setUseClientMode(false);
+ serverEngine.setNeedClientAuth(true);
+
+ clientEngine = initContext().createSSLEngine("client", 80);
+ clientEngine.setUseClientMode(true);
+ }
+
+ private SSLContext initContext() throws Exception {
+ SSLContext sc = SSLContext.getInstance("TLSv1.3");
+ KeyStore ks = KeyStore.getInstance(new File(pathToStores + keyStoreFile),
+ passwd.toCharArray());
+ KeyManagerFactory kmf =
+ KeyManagerFactory.getInstance(KeyManagerFactory.getDefaultAlgorithm());
+ kmf.init(ks, passwd.toCharArray());
+ TrustManagerFactory tmf = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm());
+ tmf.init(ks);
+ sc.init(kmf.getKeyManagers(), tmf.getTrustManagers(), new SecureRandom());
+ return sc;
+ }
+
+ private static void runDelegatedTasks(SSLEngineResult result,
+ SSLEngine engine) throws Exception {
+
+ if (result.getHandshakeStatus() == HandshakeStatus.NEED_TASK) {
+ Runnable runnable;
+ while ((runnable = engine.getDelegatedTask()) != null) {
+ runnable.run();
+ }
+ HandshakeStatus hsStatus = engine.getHandshakeStatus();
+ if (hsStatus == HandshakeStatus.NEED_TASK) {
+ throw new Exception(
+ "handshake shouldn't need additional tasks");
+ }
+ }
+ }
+}