8212885: TLS 1.3 resumed session does not retain peer certificate chain
authorjnimeh
Tue, 13 Nov 2018 18:22:52 -0800
changeset 52512 1838347a803b
parent 52511 ddcbc20e8c6a
child 52513 d4f3e37d1fda
8212885: TLS 1.3 resumed session does not retain peer certificate chain Reviewed-by: xuelei, wetmore
src/java.base/share/classes/sun/security/ssl/NewSessionTicket.java
src/java.base/share/classes/sun/security/ssl/PostHandshakeContext.java
src/java.base/share/classes/sun/security/ssl/PreSharedKeyExtension.java
src/java.base/share/classes/sun/security/ssl/SSLSessionImpl.java
test/jdk/sun/security/ssl/SSLSessionImpl/ResumeChecksClient.java
--- a/src/java.base/share/classes/sun/security/ssl/NewSessionTicket.java	Wed Nov 14 08:46:25 2018 +0800
+++ b/src/java.base/share/classes/sun/security/ssl/NewSessionTicket.java	Tue Nov 13 18:22:52 2018 -0800
@@ -260,9 +260,8 @@
             // create and cache the new session
             // The new session must be a child of the existing session so
             // they will be invalidated together, etc.
-            SSLSessionImpl sessionCopy = new SSLSessionImpl(shc,
-                    shc.handshakeSession.getSuite(), newId,
-                    shc.handshakeSession.getCreationTime());
+            SSLSessionImpl sessionCopy =
+                    new SSLSessionImpl(shc.handshakeSession, newId);
             shc.handshakeSession.addChild(sessionCopy);
             sessionCopy.setPreSharedKey(psk);
             sessionCopy.setPskIdentity(newId.getId());
@@ -375,9 +374,8 @@
             // they will be invalidated together, etc.
             SessionId newId =
                 new SessionId(true, hc.sslContext.getSecureRandom());
-            SSLSessionImpl sessionCopy = new SSLSessionImpl(
-                    hc, sessionToSave.getSuite(), newId,
-                    sessionToSave.getCreationTime());
+            SSLSessionImpl sessionCopy = new SSLSessionImpl(sessionToSave,
+                    newId);
             sessionToSave.addChild(sessionCopy);
             sessionCopy.setPreSharedKey(psk);
             sessionCopy.setTicketAgeAdd(nstm.ticketAgeAdd);
--- a/src/java.base/share/classes/sun/security/ssl/PostHandshakeContext.java	Wed Nov 14 08:46:25 2018 +0800
+++ b/src/java.base/share/classes/sun/security/ssl/PostHandshakeContext.java	Tue Nov 13 18:22:52 2018 -0800
@@ -50,9 +50,6 @@
         this.localSupportedSignAlgs = new ArrayList<SignatureScheme>(
             context.conSession.getLocalSupportedSignatureSchemes());
 
-        this.requestedServerNames =
-                context.conSession.getRequestedServerNames();
-
         handshakeConsumers = new LinkedHashMap<>(consumers);
         handshakeFinished = true;
     }
--- a/src/java.base/share/classes/sun/security/ssl/PreSharedKeyExtension.java	Wed Nov 14 08:46:25 2018 +0800
+++ b/src/java.base/share/classes/sun/security/ssl/PreSharedKeyExtension.java	Tue Nov 13 18:22:52 2018 -0800
@@ -415,6 +415,16 @@
             result = false;
         }
 
+        // Make sure that the server handshake context's localSupportedSignAlgs
+        // field is populated.  This is particularly important when
+        // client authentication was used in an initial session and it is
+        // now being resumed.
+        if (shc.localSupportedSignAlgs == null) {
+            shc.localSupportedSignAlgs =
+                    SignatureScheme.getSupportedAlgorithms(
+                            shc.algorithmConstraints, shc.activeProtocols);
+        }
+
         // Validate the required client authentication.
         if (result &&
             (shc.sslConfig.clientAuthType == CLIENT_AUTH_REQUIRED)) {
@@ -763,7 +773,7 @@
             SecretKey earlySecret = hkdf.extract(zeros, psk, "TlsEarlySecret");
 
             byte[] label = ("tls13 res binder").getBytes();
-            MessageDigest md = MessageDigest.getInstance(hashAlg.toString());;
+            MessageDigest md = MessageDigest.getInstance(hashAlg.name);
             byte[] hkdfInfo = SSLSecretDerivation.createHkdfInfo(
                     label, md.digest(new byte[0]), hashAlg.hashLength);
             return hkdf.expand(earlySecret,
--- a/src/java.base/share/classes/sun/security/ssl/SSLSessionImpl.java	Wed Nov 14 08:46:25 2018 +0800
+++ b/src/java.base/share/classes/sun/security/ssl/SSLSessionImpl.java	Tue Nov 13 18:22:52 2018 -0800
@@ -154,6 +154,7 @@
         this.useExtendedMasterSecret = false;
         this.creationTime = System.currentTimeMillis();
         this.identificationProtocol = null;
+        this.boundValues = new ConcurrentHashMap<>();
     }
 
     /*
@@ -204,6 +205,41 @@
         }
         this.creationTime = creationTime;
         this.identificationProtocol = hc.sslConfig.identificationProtocol;
+        this.boundValues = new ConcurrentHashMap<>();
+
+        if (SSLLogger.isOn && SSLLogger.isOn("session")) {
+             SSLLogger.finest("Session initialized:  " + this);
+        }
+    }
+
+    SSLSessionImpl(SSLSessionImpl baseSession, SessionId newId) {
+        this.protocolVersion = baseSession.getProtocolVersion();
+        this.cipherSuite = baseSession.cipherSuite;
+        this.sessionId = newId;
+        this.host = baseSession.getPeerHost();
+        this.port = baseSession.getPeerPort();
+        this.localSupportedSignAlgs =
+            baseSession.localSupportedSignAlgs == null ?
+                Collections.emptySet() :
+                Collections.unmodifiableCollection(
+                        baseSession.localSupportedSignAlgs);
+        this.peerSupportedSignAlgs =
+                baseSession.getPeerSupportedSignatureAlgorithms();
+        this.serverNameIndication = baseSession.serverNameIndication;
+        this.requestedServerNames = baseSession.getRequestedServerNames();
+        this.masterSecret = baseSession.getMasterSecret();
+        this.useExtendedMasterSecret = baseSession.useExtendedMasterSecret;
+        this.creationTime = baseSession.getCreationTime();
+        this.lastUsedTime = System.currentTimeMillis();
+        this.identificationProtocol = baseSession.getIdentificationProtocol();
+        this.localCerts = baseSession.localCerts;
+        this.peerCerts = baseSession.peerCerts;
+        this.statusResponses = baseSession.statusResponses;
+        this.resumptionMasterSecret = baseSession.resumptionMasterSecret;
+        this.context = baseSession.context;
+        this.negotiatedMaxFragLen = baseSession.negotiatedMaxFragLen;
+        this.maximumPacketSize = baseSession.maximumPacketSize;
+        this.boundValues = baseSession.boundValues;
 
         if (SSLLogger.isOn && SSLLogger.isOn("session")) {
              SSLLogger.finest("Session initialized:  " + this);
@@ -772,8 +808,7 @@
      * key and the calling security context. This is important since
      * sessions can be shared across different protection domains.
      */
-    private final ConcurrentHashMap<SecureKey, Object> boundValues =
-            new ConcurrentHashMap<>();
+    private final ConcurrentHashMap<SecureKey, Object> boundValues;
 
     /**
      * Assigns a session value.  Session change events are given if
--- a/test/jdk/sun/security/ssl/SSLSessionImpl/ResumeChecksClient.java	Wed Nov 14 08:46:25 2018 +0800
+++ b/test/jdk/sun/security/ssl/SSLSessionImpl/ResumeChecksClient.java	Tue Nov 13 18:22:52 2018 -0800
@@ -23,7 +23,7 @@
 
 /*
  * @test
- * @bug 8206929
+ * @bug 8206929 8212885
  * @summary ensure that client only resumes a session if certain properties
  *    of the session are compatible with the new connection
  * @run main/othervm -Djdk.tls.client.protocols=TLSv1.2 ResumeChecksClient BASIC
@@ -80,7 +80,7 @@
         while (!server.started) {
             Thread.yield();
         }
-        connect(sslContext, server.port, mode, false);
+        SSLSession firstSession = connect(sslContext, server.port, mode, false);
 
         server.signal();
         long secondStartTime = System.currentTimeMillis();
@@ -93,9 +93,7 @@
         switch (mode) {
         case BASIC:
             // fail if session is not resumed
-            if (secondSession.getCreationTime() > secondStartTime) {
-                throw new RuntimeException("Session was not reused");
-            }
+            checkResumedSession(firstSession, secondSession);
             break;
         case VERSION_2_TO_3:
         case VERSION_3_TO_2:
@@ -124,14 +122,17 @@
             return !a.toLowerCase().contains(alg.toLowerCase());
         }
 
+        @Override
         public boolean permits(Set<CryptoPrimitive> primitives, Key key) {
             return true;
         }
+        @Override
         public boolean permits(Set<CryptoPrimitive> primitives,
             String algorithm, AlgorithmParameters parameters) {
 
             return test(algorithm);
         }
+        @Override
         public boolean permits(Set<CryptoPrimitive> primitives,
             String algorithm, Key key, AlgorithmParameters parameters) {
 
@@ -205,6 +206,81 @@
         }
     }
 
+    private static void checkResumedSession(SSLSession initSession,
+            SSLSession resSession) throws Exception {
+        StringBuilder diffLog = new StringBuilder();
+
+        // Initial and resumed SSLSessions should have the same creation
+        // times so they get invalidated together.
+        long initCt = initSession.getCreationTime();
+        long resumeCt = resSession.getCreationTime();
+        if (initCt != resumeCt) {
+            diffLog.append("Session creation time is different. Initial: ").
+                    append(initCt).append(", Resumed: ").append(resumeCt).
+                    append("\n");
+        }
+
+        // Ensure that peer and local certificate lists are preserved
+        if (!Arrays.equals(initSession.getLocalCertificates(),
+                resSession.getLocalCertificates())) {
+            diffLog.append("Local certificate mismatch between initial " +
+                    "and resumed sessions\n");
+        }
+
+        if (!Arrays.equals(initSession.getPeerCertificates(),
+                resSession.getPeerCertificates())) {
+            diffLog.append("Peer certificate mismatch between initial " +
+                    "and resumed sessions\n");
+        }
+
+        // Buffer sizes should also be the same
+        if (initSession.getApplicationBufferSize() !=
+                resSession.getApplicationBufferSize()) {
+            diffLog.append(String.format(
+                    "App Buffer sizes differ: Init: %d, Res: %d\n",
+                    initSession.getApplicationBufferSize(),
+                    resSession.getApplicationBufferSize()));
+        }
+
+        if (initSession.getPacketBufferSize() !=
+                resSession.getPacketBufferSize()) {
+            diffLog.append(String.format(
+                    "Packet Buffer sizes differ: Init: %d, Res: %d\n",
+                    initSession.getPacketBufferSize(),
+                    resSession.getPacketBufferSize()));
+        }
+
+        // Cipher suite should match
+        if (!initSession.getCipherSuite().equals(
+                resSession.getCipherSuite())) {
+            diffLog.append(String.format(
+                    "CipherSuite does not match - Init: %s, Res: %s\n",
+                    initSession.getCipherSuite(), resSession.getCipherSuite()));
+        }
+
+        // Peer host/port should match
+        if (!initSession.getPeerHost().equals(resSession.getPeerHost()) ||
+                initSession.getPeerPort() != resSession.getPeerPort()) {
+            diffLog.append(String.format(
+                    "Host/Port mismatch - Init: %s/%d, Res: %s/%d\n",
+                    initSession.getPeerHost(), initSession.getPeerPort(),
+                    resSession.getPeerHost(), resSession.getPeerPort()));
+        }
+
+        // Check protocol
+        if (!initSession.getProtocol().equals(resSession.getProtocol())) {
+            diffLog.append(String.format(
+                    "Protocol mismatch - Init: %s, Res: %s\n",
+                    initSession.getProtocol(), resSession.getProtocol()));
+        }
+
+        // If the StringBuilder has any data in it then one of the checks
+        // above failed and we should throw an exception.
+        if (diffLog.length() > 0) {
+            throw new RuntimeException(diffLog.toString());
+        }
+    }
+
     private static Server startServer() {
         Server server = new Server();
         new Thread(server).start();
@@ -233,6 +309,7 @@
             notify();
         }
 
+        @Override
         public void run() {
             try {