pre_shared_key extensio code cleanup, and a test update JDK-8145252-TLS13-branch
authorxuelei
Sun, 03 Jun 2018 19:43:10 -0700
branchJDK-8145252-TLS13-branch
changeset 56661 2a820e434f17
parent 56660 66c803c3ce32
child 56662 126e167bb2cb
pre_shared_key extensio code cleanup, and a test update
src/java.base/share/classes/sun/security/ssl/PreSharedKeyExtension.java
test/jdk/sun/security/ssl/SSLSocketImpl/AsyncSSLSocketClose.java
--- a/src/java.base/share/classes/sun/security/ssl/PreSharedKeyExtension.java	Sat Jun 02 21:23:02 2018 -0700
+++ b/src/java.base/share/classes/sun/security/ssl/PreSharedKeyExtension.java	Sun Jun 03 19:43:10 2018 -0700
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2015, 2017, Oracle and/or its affiliates. All rights reserved.
+ * Copyright (c) 2015, 2018, Oracle and/or its affiliates. All rights reserved.
  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
  *
  * This code is free software; you can redistribute it and/or modify it
@@ -28,21 +28,17 @@
 import java.nio.ByteBuffer;
 import java.security.*;
 import java.text.MessageFormat;
-import java.util.Map;
 import java.util.List;
 import java.util.ArrayList;
 import java.util.Locale;
 import java.util.Arrays;
-import java.util.Collections;
 import java.util.Optional;
+import javax.crypto.Mac;
+import javax.crypto.SecretKey;
+import sun.security.ssl.ClientHello.ClientHelloMessage;
 import sun.security.ssl.SSLExtension.ExtensionConsumer;
-
 import sun.security.ssl.SSLExtension.SSLExtensionSpec;
 import sun.security.ssl.SSLHandshake.HandshakeMessage;
-
-import javax.crypto.Mac;
-import javax.crypto.SecretKey;
-
 import static sun.security.ssl.SSLExtension.*;
 
 /**
@@ -55,7 +51,7 @@
             new CHPreSharedKeyConsumer();
     static final HandshakeAbsence chOnLoadAbsence =
             new CHPreSharedKeyAbsence();
-    static final HandshakeConsumer chOnTradeConsumer=
+    static final HandshakeConsumer chOnTradeConsumer =
             new CHPreSharedKeyUpdate();
 
     static final HandshakeProducer shNetworkProducer =
@@ -65,33 +61,24 @@
     static final HandshakeAbsence shOnLoadAbsence =
             new SHPreSharedKeyAbsence();
 
-    static final class PskIdentity {
+    private static final class PskIdentity {
         final byte[] identity;
         final int obfuscatedAge;
 
-        public PskIdentity(byte[] identity, int obfuscatedAge) {
+        PskIdentity(byte[] identity, int obfuscatedAge) {
             this.identity = identity;
             this.obfuscatedAge = obfuscatedAge;
         }
 
-        public PskIdentity(ByteBuffer m)
-            throws IllegalParameterException, IOException {
-
-            identity = Record.getBytes16(m);
-            if (identity.length == 0) {
-                throw new IllegalParameterException("identity has length 0");
-            }
-            obfuscatedAge = Record.getInt32(m);
-        }
-
         int getEncodedLength() {
             return 2 + identity.length + 4;
         }
 
-        public void writeEncoded(ByteBuffer m) throws IOException {
+        void writeEncoded(ByteBuffer m) throws IOException {
             Record.putBytes16(m, identity);
             Record.putInt32(m, obfuscatedAge);
         }
+
         @Override
         public String toString() {
             return "{" + Utilities.toHexString(identity) + "," +
@@ -99,7 +86,8 @@
         }
     }
 
-    static final class CHPreSharedKeySpec implements SSLExtensionSpec {
+    private static final
+            class CHPreSharedKeySpec implements SSLExtensionSpec {
         final List<PskIdentity> identities;
         final List<byte[]> binders;
 
@@ -108,26 +96,65 @@
             this.binders = binders;
         }
 
-        CHPreSharedKeySpec(ByteBuffer m)
-            throws IllegalParameterException, IOException {
+        CHPreSharedKeySpec(HandshakeContext context,
+                ByteBuffer m) throws IOException {
+            // struct {
+            //     PskIdentity identities<7..2^16-1>;
+            //     PskBinderEntry binders<33..2^16-1>;
+            // } OfferedPsks;
+            if (m.remaining() < 44) {
+                context.conContext.fatal(Alert.ILLEGAL_PARAMETER,
+                    "Invalid pre_shared_key extension: " +
+                    "insufficient data (length=" + m.remaining() + ")");
+            }
+
+            int idEncodedLength = Record.getInt16(m);
+            if (idEncodedLength < 7) {
+                context.conContext.fatal(Alert.ILLEGAL_PARAMETER,
+                    "Invalid pre_shared_key extension: " +
+                    "insufficient identities (length=" + idEncodedLength + ")");
+            }
 
             identities = new ArrayList<>();
-            int idEncodedLength = Record.getInt16(m);
             int idReadLength = 0;
             while (idReadLength < idEncodedLength) {
-                PskIdentity id = new PskIdentity(m);
-                identities.add(id);
-                idReadLength += id.getEncodedLength();
+                byte[] id = Record.getBytes16(m);
+                if (id.length < 1) {
+                    context.conContext.fatal(Alert.ILLEGAL_PARAMETER,
+                        "Invalid pre_shared_key extension: " +
+                        "insufficient identity (length=" + id.length + ")");
+                }
+                int obfuscatedTicketAge = Record.getInt32(m);
+
+                PskIdentity pskId = new PskIdentity(id, obfuscatedTicketAge);
+                identities.add(pskId);
+                idReadLength += pskId.getEncodedLength();
+            }
+
+            if (m.remaining() < 35) {
+                context.conContext.fatal(Alert.ILLEGAL_PARAMETER,
+                        "Invalid pre_shared_key extension: " +
+                        "insufficient binders data (length=" +
+                        m.remaining() + ")");
+            }
+
+            int bindersEncodedLen = Record.getInt16(m);
+            if (bindersEncodedLen < 33) {
+                context.conContext.fatal(Alert.ILLEGAL_PARAMETER,
+                        "Invalid pre_shared_key extension: " +
+                        "insufficient binders (length=" +
+                        bindersEncodedLen + ")");
             }
 
             binders = new ArrayList<>();
-            int bindersEncodedLength = Record.getInt16(m);
             int bindersReadLength = 0;
-            while (bindersReadLength < bindersEncodedLength) {
+            while (bindersReadLength < bindersEncodedLen) {
                 byte[] binder = Record.getBytes8(m);
                 if (binder.length < 32) {
-                    throw new IllegalParameterException(
-                        "binder has length < 32");
+                    context.conContext.fatal(Alert.ILLEGAL_PARAMETER,
+                            "Invalid pre_shared_key extension: " +
+                            "insufficient binder entry (length=" +
+                            binder.length + ")");
                 }
                 binders.add(binder);
                 bindersReadLength += 1 + binder.length;
@@ -139,22 +166,20 @@
             for(PskIdentity curId : identities) {
                 idEncodedLength += curId.getEncodedLength();
             }
+
             return idEncodedLength;
         }
 
         int getBindersEncodedLength() {
-            return getBindersEncodedLength(binders);
-        }
-        static int getBindersEncodedLength(Iterable<byte[]> binders) {
             int binderEncodedLength = 0;
             for (byte[] curBinder : binders) {
                 binderEncodedLength += 1 + curBinder.length;
             }
+
             return binderEncodedLength;
         }
 
         byte[] getEncoded() throws IOException {
-
             int idsEncodedLength = getIdsEncodedLength();
             int bindersEncodedLength = getBindersEncodedLength();
             int encodedLength = 4 + idsEncodedLength + bindersEncodedLength;
@@ -176,7 +201,7 @@
         public String toString() {
             MessageFormat messageFormat = new MessageFormat(
                 "\"PreSharedKey\": '{'\n" +
-                "  \"identities\"      : \"{0}\",\n" +
+                "  \"identities\"    : \"{0}\",\n" +
                 "  \"binders\"       : \"{1}\",\n" +
                 "'}'",
                 Locale.ENGLISH);
@@ -208,24 +233,30 @@
         }
     }
 
-    static final class SHPreSharedKeySpec implements SSLExtensionSpec {
+    private static final
+            class SHPreSharedKeySpec implements SSLExtensionSpec {
         final int selectedIdentity;
 
         SHPreSharedKeySpec(int selectedIdentity) {
             this.selectedIdentity = selectedIdentity;
         }
 
-        SHPreSharedKeySpec(ByteBuffer m) throws IOException {
+        SHPreSharedKeySpec(HandshakeContext context,
+                ByteBuffer m) throws IOException {
+            if (m.remaining() < 2) {
+                context.conContext.fatal(Alert.ILLEGAL_PARAMETER,
+                        "Invalid pre_shared_key extension: " +
+                        "insufficient selected_identity (length=" +
+                        m.remaining() + ")");
+            }
             this.selectedIdentity = Record.getInt16(m);
         }
 
         byte[] getEncoded() throws IOException {
-
-            byte[] buffer = new byte[2];
-            ByteBuffer m = ByteBuffer.wrap(buffer);
-            Record.putInt16(m, selectedIdentity);
-
-            return buffer;
+            return new byte[] {
+                (byte)((selectedIdentity >> 8) & 0xFF),
+                (byte)(selectedIdentity & 0xFF)
+            };
         }
 
         @Override
@@ -237,27 +268,15 @@
                 Locale.ENGLISH);
 
             Object[] messageFields = {
-                selectedIdentity
+                Utilities.byte16HexString(selectedIdentity)
             };
 
             return messageFormat.format(messageFields);
         }
-
     }
 
-
-    private static class IllegalParameterException extends Exception {
-
-        private static final long serialVersionUID = 0;
-
-        private final String message;
-
-        private IllegalParameterException(String message) {
-            this.message = message;
-        }
-    }
-
-    private static final class CHPreSharedKeyConsumer implements ExtensionConsumer {
+    private static final
+            class CHPreSharedKeyConsumer implements ExtensionConsumer {
         // Prevent instantiation of this class.
         private CHPreSharedKeyConsumer() {
             // blank
@@ -267,70 +286,77 @@
         public void consume(ConnectionContext context,
                             HandshakeMessage message,
                             ByteBuffer buffer) throws IOException {
-
-            ServerHandshakeContext shc = (ServerHandshakeContext) message.handshakeContext;
+            ServerHandshakeContext shc = (ServerHandshakeContext)context;
             // Is it a supported and enabled extension?
             if (!shc.sslConfig.isAvailable(SSLExtension.CH_PRE_SHARED_KEY)) {
                 if (SSLLogger.isOn && SSLLogger.isOn("ssl,handshake")) {
                     SSLLogger.fine(
-                    "Ignore unavailable pre_shared_key extension");
+                            "Ignore unavailable pre_shared_key extension");
                 }
                 return;     // ignore the extension
             }
 
+            // Parse the extension.
             CHPreSharedKeySpec pskSpec = null;
             try {
-                pskSpec = new CHPreSharedKeySpec(buffer);
+                pskSpec = new CHPreSharedKeySpec(shc, buffer);
             } catch (IOException ioe) {
                 shc.conContext.fatal(Alert.UNEXPECTED_MESSAGE, ioe);
                 return;     // fatal() always throws, make the compiler happy.
-            } catch (IllegalParameterException ex) {
-                shc.conContext.fatal(Alert.ILLEGAL_PARAMETER, ex.message);
             }
 
             if (SSLLogger.isOn && SSLLogger.isOn("ssl,handshake")) {
-                SSLLogger.fine(
-                "Received PSK extension: ", pskSpec);
+                SSLLogger.fine("Received PSK extension: ", pskSpec);
             }
 
             if (shc.pskKeyExchangeModes.isEmpty()) {
                 shc.conContext.fatal(Alert.ILLEGAL_PARAMETER,
-                "Client sent PSK but does not support PSK modes");
+                        "Client sent PSK but does not support PSK modes");
             }
 
             // error if id and binder lists are not the same length
             if (pskSpec.identities.size() != pskSpec.binders.size()) {
                 shc.conContext.fatal(Alert.ILLEGAL_PARAMETER,
-                "PSK extension has incorrect number of binders");
+                        "PSK extension has incorrect number of binders");
             }
 
-            shc.handshakeExtensions.put(SSLExtension.CH_PRE_SHARED_KEY, pskSpec);
-
-            SSLSessionContextImpl sessionCache = (SSLSessionContextImpl)
-            message.handshakeContext.sslContext.engineGetServerSessionContext();
+            if (shc.isResumption && shc.resumingSession != null) {
+                SSLSessionContextImpl sessionCache = (SSLSessionContextImpl)
+                        shc.sslContext.engineGetServerSessionContext();
+                int idIndex = 0;
+                for (PskIdentity requestedId : pskSpec.identities) {
+                    SSLSessionImpl s = sessionCache.get(requestedId.identity);
+                    if (s != null && s.isRejoinable() &&
+                            s.getPreSharedKey().isPresent()) {
+                        if (SSLLogger.isOn && SSLLogger.isOn("ssl,handshake")) {
+                            SSLLogger.fine("Resuming session: ", s);
+                        }
 
-            // The session to resume will be decided below.
-            // It could have been set by previous actions (e.g. PSK received
-            // earlier), and it must be recalculated.
-            shc.isResumption = false;
-            shc.resumingSession = null;
+                        // binder will be checked later
+                        shc.resumingSession = s;
+                        shc.handshakeExtensions.put(SH_PRE_SHARED_KEY,
+                            new SHPreSharedKeySpec(idIndex));   // for the index
+                        break;
+                    }
 
-            int idIndex = 0;
-            for (PskIdentity requestedId : pskSpec.identities) {
-                SSLSessionImpl s = sessionCache.get(requestedId.identity);
-                if (s != null && s.isRejoinable() &&
-                    s.getPreSharedKey().isPresent()) {
-
-                    resumeSession(shc, s, idIndex);
-                    break;
+                    ++idIndex;
                 }
 
-                ++idIndex;
+                if (idIndex == pskSpec.identities.size()) {
+                    // no resumable session
+                    shc.isResumption = false;
+                    shc.resumingSession = null;
+                }
             }
+
+            // update the context
+            shc.handshakeExtensions.put(
+                    SSLExtension.CH_PRE_SHARED_KEY, pskSpec);
         }
     }
 
-    private static final class CHPreSharedKeyUpdate implements HandshakeConsumer {
+    private static final
+            class CHPreSharedKeyUpdate implements HandshakeConsumer {
         // Prevent instantiation of this class.
         private CHPreSharedKeyUpdate() {
             // blank
@@ -338,21 +364,20 @@
 
         @Override
         public void consume(ConnectionContext context,
-                            HandshakeMessage message) throws IOException {
-
-            ServerHandshakeContext shc = (ServerHandshakeContext) message.handshakeContext;
-
+                HandshakeMessage message) throws IOException {
+            ServerHandshakeContext shc = (ServerHandshakeContext)context;
             if (!shc.isResumption || shc.resumingSession == null) {
                 // not resuming---nothing to do
                 return;
             }
 
-            CHPreSharedKeySpec chPsk = (CHPreSharedKeySpec)shc.handshakeExtensions.get(SSLExtension.CH_PRE_SHARED_KEY);
-            SHPreSharedKeySpec shPsk = (SHPreSharedKeySpec)shc.handshakeExtensions.get(SSLExtension.SH_PRE_SHARED_KEY);
-
+            CHPreSharedKeySpec chPsk = (CHPreSharedKeySpec)
+                    shc.handshakeExtensions.get(SSLExtension.CH_PRE_SHARED_KEY);
+            SHPreSharedKeySpec shPsk = (SHPreSharedKeySpec)
+                    shc.handshakeExtensions.get(SSLExtension.SH_PRE_SHARED_KEY);
             if (chPsk == null || shPsk == null) {
                 shc.conContext.fatal(Alert.INTERNAL_ERROR,
-                "Required extensions are unavailable");
+                        "Required extensions are unavailable");
             }
 
             byte[] binder = chPsk.binders.get(shPsk.selectedIdentity);
@@ -364,7 +389,7 @@
             // skip the type and length
             messageBuf.position(4);
             // read to find the beginning of the binders
-            ClientHello.ClientHelloMessage.readPartial(shc.conContext, messageBuf);
+            ClientHelloMessage.readPartial(shc.conContext, messageBuf);
             int length = messageBuf.position();
             messageBuf.position(0);
             pskBinderHash.receive(messageBuf, length);
@@ -373,37 +398,19 @@
         }
     }
 
-    private static void resumeSession(ServerHandshakeContext shc,
-                                      SSLSessionImpl session,
-                                      int index)
-        throws IOException {
-
-        if (SSLLogger.isOn && SSLLogger.isOn("ssl,handshake")) {
-            SSLLogger.fine(
-            "Resuming session: ", session);
-        }
-
-        // binder will be checked later
-
-        shc.isResumption = true;
-        shc.resumingSession = session;
-
-        SHPreSharedKeySpec pskMsg = new SHPreSharedKeySpec(index);
-        shc.handshakeExtensions.put(SH_PRE_SHARED_KEY, pskMsg);
-    }
-
-    private static void checkBinder(ServerHandshakeContext shc, SSLSessionImpl session,
-                                    HandshakeHash pskBinderHash, byte[] binder) throws IOException {
-
+    private static void checkBinder(ServerHandshakeContext shc,
+            SSLSessionImpl session,
+            HandshakeHash pskBinderHash, byte[] binder) throws IOException {
         Optional<SecretKey> pskOpt = session.getPreSharedKey();
         if (!pskOpt.isPresent()) {
             shc.conContext.fatal(Alert.INTERNAL_ERROR,
-            "Session has no PSK");
+                    "Session has no PSK");
         }
         SecretKey psk = pskOpt.get();
 
         SecretKey binderKey = deriveBinderKey(psk, session);
-        byte[] computedBinder = computeBinder(binderKey, session, pskBinderHash);
+        byte[] computedBinder =
+                computeBinder(binderKey, session, pskBinderHash);
         if (!Arrays.equals(binder, computedBinder)) {
             shc.conContext.fatal(Alert.ILLEGAL_PARAMETER,
             "Incorect PSK binder value");
@@ -479,8 +486,8 @@
         }
     }
 
-    private static final class CHPreSharedKeyProducer implements HandshakeProducer {
-
+    private static final
+            class CHPreSharedKeyProducer implements HandshakeProducer {
         // Prevent instantiation of this class.
         private CHPreSharedKeyProducer() {
             // blank
@@ -494,8 +501,7 @@
             ClientHandshakeContext chc = (ClientHandshakeContext)context;
             if (!chc.isResumption || chc.resumingSession == null) {
                 if (SSLLogger.isOn && SSLLogger.isOn("ssl,handshake")) {
-                    SSLLogger.fine(
-                    "No session to resume.");
+                    SSLLogger.fine("No session to resume.");
                 }
                 return null;
             }
@@ -503,8 +509,7 @@
             Optional<SecretKey> pskOpt = chc.resumingSession.getPreSharedKey();
             if (!pskOpt.isPresent()) {
                 if (SSLLogger.isOn && SSLLogger.isOn("ssl,handshake")) {
-                    SSLLogger.fine(
-                    "Existing session has no PSK.");
+                    SSLLogger.fine("Existing session has no PSK.");
                 }
                 return null;
             }
@@ -513,7 +518,7 @@
             if (!pskIdOpt.isPresent()) {
                 if (SSLLogger.isOn && SSLLogger.isOn("ssl,handshake")) {
                     SSLLogger.fine(
-                    "PSK has no identity, or identity was already used");
+                        "PSK has no identity, or identity was already used");
                 }
                 return null;
             }
@@ -521,30 +526,36 @@
 
             if (SSLLogger.isOn && SSLLogger.isOn("ssl,handshake")) {
                 SSLLogger.fine(
-                "Found resumable session. Preparing PSK message.");
+                    "Found resumable session. Preparing PSK message.");
             }
 
             List<PskIdentity> identities = new ArrayList<>();
-            int ageMillis = (int)(System.currentTimeMillis() - chc.resumingSession.getTicketCreationTime());
-            int obfuscatedAge = ageMillis + chc.resumingSession.getTicketAgeAdd();
+            int ageMillis = (int)(System.currentTimeMillis() -
+                    chc.resumingSession.getTicketCreationTime());
+            int obfuscatedAge =
+                    ageMillis + chc.resumingSession.getTicketAgeAdd();
             identities.add(new PskIdentity(pskId, obfuscatedAge));
 
             SecretKey binderKey = deriveBinderKey(psk, chc.resumingSession);
-            ClientHello.ClientHelloMessage clientHello = (ClientHello.ClientHelloMessage) message;
-            CHPreSharedKeySpec pskPrototype = createPskPrototype(chc.resumingSession.getSuite().hashAlg.hashLength, identities);
+            ClientHelloMessage clientHello = (ClientHelloMessage)message;
+            CHPreSharedKeySpec pskPrototype = createPskPrototype(
+                chc.resumingSession.getSuite().hashAlg.hashLength, identities);
             HandshakeHash pskBinderHash = chc.handshakeHash.copy();
 
-            byte[] binder = computeBinder(binderKey, pskBinderHash, chc.resumingSession, chc, clientHello, pskPrototype);
+            byte[] binder = computeBinder(binderKey, pskBinderHash,
+                    chc.resumingSession, chc, clientHello, pskPrototype);
 
             List<byte[]> binders = new ArrayList<>();
             binders.add(binder);
 
-            CHPreSharedKeySpec pskMessage = new CHPreSharedKeySpec(identities, binders);
+            CHPreSharedKeySpec pskMessage =
+                    new CHPreSharedKeySpec(identities, binders);
             chc.handshakeExtensions.put(CH_PRE_SHARED_KEY, pskMessage);
             return pskMessage.getEncoded();
         }
 
-        private CHPreSharedKeySpec createPskPrototype(int hashLength, List<PskIdentity> identities) {
+        private CHPreSharedKeySpec createPskPrototype(
+                int hashLength, List<PskIdentity> identities) {
             List<byte[]> binders = new ArrayList<>();
             byte[] binderProto = new byte[hashLength];
             for (PskIdentity curId : identities) {
@@ -555,20 +566,25 @@
         }
     }
 
-    private static byte[] computeBinder(SecretKey binderKey, SSLSessionImpl session, HandshakeHash pskBinderHash) throws IOException {
+    private static byte[] computeBinder(SecretKey binderKey,
+            SSLSessionImpl session,
+            HandshakeHash pskBinderHash) throws IOException {
 
-        pskBinderHash.determine(session.getProtocolVersion(), session.getSuite());
+        pskBinderHash.determine(
+                session.getProtocolVersion(), session.getSuite());
         pskBinderHash.update();
         byte[] digest = pskBinderHash.digest();
 
         return computeBinder(binderKey, session, digest);
     }
 
-    private static byte[] computeBinder(SecretKey binderKey, HandshakeHash hash, SSLSessionImpl session,
-                                        HandshakeContext ctx, ClientHello.ClientHelloMessage hello,
-                                        CHPreSharedKeySpec pskPrototype) throws IOException {
+    private static byte[] computeBinder(SecretKey binderKey,
+            HandshakeHash hash, SSLSessionImpl session,
+            HandshakeContext ctx, ClientHello.ClientHelloMessage hello,
+            CHPreSharedKeySpec pskPrototype) throws IOException {
 
-        PartialClientHelloMessage partialMsg = new PartialClientHelloMessage(ctx, hello, pskPrototype);
+        PartialClientHelloMessage partialMsg =
+                new PartialClientHelloMessage(ctx, hello, pskPrototype);
 
         SSLEngineOutputRecord record = new SSLEngineOutputRecord(hash);
         HandshakeOutStream hos = new HandshakeOutStream(record);
@@ -581,15 +597,16 @@
         return computeBinder(binderKey, session, digest);
     }
 
-    private static byte[] computeBinder(SecretKey binderKey, SSLSessionImpl session,
-                                        byte[] digest) throws IOException {
-
+    private static byte[] computeBinder(SecretKey binderKey,
+            SSLSessionImpl session, byte[] digest) throws IOException {
         try {
             CipherSuite.HashAlg hashAlg = session.getSuite().hashAlg;
             HKDF hkdf = new HKDF(hashAlg.name);
             byte[] label = ("tls13 finished").getBytes();
-            byte[] hkdfInfo = SSLSecretDerivation.createHkdfInfo(label, new byte[0], hashAlg.hashLength);
-            SecretKey finishedKey = hkdf.expand(binderKey, hkdfInfo, hashAlg.hashLength, "TlsBinderKey");
+            byte[] hkdfInfo = SSLSecretDerivation.createHkdfInfo(
+                    label, new byte[0], hashAlg.hashLength);
+            SecretKey finishedKey = hkdf.expand(
+                    binderKey, hkdfInfo, hashAlg.hashLength, "TlsBinderKey");
 
             String hmacAlg =
                 "Hmac" + hashAlg.name.replace("-", "");
@@ -606,9 +623,7 @@
     }
 
     private static SecretKey deriveBinderKey(SecretKey psk,
-                                             SSLSessionImpl session)
-        throws IOException {
-
+            SSLSessionImpl session) throws IOException {
         try {
             CipherSuite.HashAlg hashAlg = session.getSuite().hashAlg;
             HKDF hkdf = new HKDF(hashAlg.name);
@@ -618,16 +633,16 @@
             byte[] label = ("tls13 res binder").getBytes();
             MessageDigest md = MessageDigest.getInstance(hashAlg.toString());;
             byte[] hkdfInfo = SSLSecretDerivation.createHkdfInfo(
-                label, md.digest(new byte[0]), hashAlg.hashLength);
-            return hkdf.expand(earlySecret, hkdfInfo, hashAlg.hashLength,
-                "TlsBinderKey");
-
+                    label, md.digest(new byte[0]), hashAlg.hashLength);
+            return hkdf.expand(earlySecret,
+                    hkdfInfo, hashAlg.hashLength, "TlsBinderKey");
         } catch (GeneralSecurityException ex) {
             throw new IOException(ex);
         }
     }
 
-    private static final class CHPreSharedKeyAbsence implements HandshakeAbsence {
+    private static final
+            class CHPreSharedKeyAbsence implements HandshakeAbsence {
         @Override
         public void absent(ConnectionContext context,
                            HandshakeMessage message) throws IOException {
@@ -645,27 +660,30 @@
         }
     }
 
-    private static final class SHPreSharedKeyConsumer implements ExtensionConsumer {
+    private static final
+            class SHPreSharedKeyConsumer implements ExtensionConsumer {
         // Prevent instantiation of this class.
         private SHPreSharedKeyConsumer() {
-
+            // blank
         }
 
         @Override
         public void consume(ConnectionContext context,
-                HandshakeMessage message, ByteBuffer buffer) throws IOException {
+            HandshakeMessage message, ByteBuffer buffer) throws IOException {
+            // The consuming happens in client side only.
+            ClientHandshakeContext chc = (ClientHandshakeContext)context;
 
-            ClientHandshakeContext chc = (ClientHandshakeContext) message.handshakeContext;
+            // Is it a response of the specific request?
+            if (!chc.handshakeExtensions.containsKey(
+                    SSLExtension.CH_PRE_SHARED_KEY)) {
+                chc.conContext.fatal(Alert.UNEXPECTED_MESSAGE,
+                    "Server sent unexpected pre_shared_key extension");
+            }
 
-            SHPreSharedKeySpec shPsk = new SHPreSharedKeySpec(buffer);
+            SHPreSharedKeySpec shPsk = new SHPreSharedKeySpec(chc, buffer);
             if (SSLLogger.isOn && SSLLogger.isOn("ssl,handshake")) {
                 SSLLogger.fine(
-                "Received pre_shared_key extension: ", shPsk);
-            }
-
-            if (!chc.handshakeExtensions.containsKey(SSLExtension.CH_PRE_SHARED_KEY)) {
-                chc.conContext.fatal(Alert.UNEXPECTED_MESSAGE,
-                "Server sent unexpected pre_shared_key extension");
+                    "Received pre_shared_key extension: ", shPsk);
             }
 
             // The PSK identity should not be reused, even if it is
@@ -674,7 +692,7 @@
 
             if (shPsk.selectedIdentity != 0) {
                 chc.conContext.fatal(Alert.ILLEGAL_PARAMETER,
-                "Selected identity index is not in correct range.");
+                    "Selected identity index is not in correct range.");
             }
 
             if (SSLLogger.isOn && SSLLogger.isOn("ssl,handshake")) {
@@ -684,37 +702,38 @@
 
             // remove the session from the cache
             SSLSessionContextImpl sessionCache = (SSLSessionContextImpl)
-                chc.sslContext.engineGetClientSessionContext();
+                    chc.sslContext.engineGetClientSessionContext();
             sessionCache.remove(chc.resumingSession.getSessionId());
         }
     }
 
-    private static final class SHPreSharedKeyAbsence implements HandshakeAbsence {
+    private static final
+            class SHPreSharedKeyAbsence implements HandshakeAbsence {
         @Override
         public void absent(ConnectionContext context,
-                           HandshakeMessage message) throws IOException {
+                HandshakeMessage message) throws IOException {
+            ClientHandshakeContext chc = (ClientHandshakeContext)context;
 
             if (SSLLogger.isOn && SSLLogger.isOn("ssl,handshake")) {
-                SSLLogger.fine(
-                "Handling pre_shared_key absence.");
+                SSLLogger.fine("Handling pre_shared_key absence.");
             }
 
-            ClientHandshakeContext chc = (ClientHandshakeContext)context;
-
-            if (chc.handshakeExtensions.containsKey(SSLExtension.CH_PRE_SHARED_KEY)) {
+            if (chc.handshakeExtensions.containsKey(
+                    SSLExtension.CH_PRE_SHARED_KEY)) {
                 // The PSK identity should not be reused, even if it is
                 // not selected.
                 chc.resumingSession.consumePskIdentity();
             }
 
-            // the server refused to resume, or the client did not request 1.3 resumption
+            // The server refused to resume, or the client did not
+            // request 1.3 resumption.
             chc.resumingSession = null;
             chc.isResumption = false;
         }
     }
 
-    private static final class SHPreSharedKeyProducer implements HandshakeProducer {
-
+    private static final
+            class SHPreSharedKeyProducer implements HandshakeProducer {
         // Prevent instantiation of this class.
         private SHPreSharedKeyProducer() {
             // blank
@@ -723,11 +742,9 @@
         @Override
         public byte[] produce(ConnectionContext context,
                 HandshakeMessage message) throws IOException {
-
-            ServerHandshakeContext shc = (ServerHandshakeContext)
-                message.handshakeContext;
+            ServerHandshakeContext shc = (ServerHandshakeContext)context;
             SHPreSharedKeySpec psk = (SHPreSharedKeySpec)
-                shc.handshakeExtensions.get(SH_PRE_SHARED_KEY);
+                    shc.handshakeExtensions.get(SH_PRE_SHARED_KEY);
             if (psk == null) {
                 return null;
             }
--- a/test/jdk/sun/security/ssl/SSLSocketImpl/AsyncSSLSocketClose.java	Sat Jun 02 21:23:02 2018 -0700
+++ b/test/jdk/sun/security/ssl/SSLSocketImpl/AsyncSSLSocketClose.java	Sun Jun 03 19:43:10 2018 -0700
@@ -36,15 +36,15 @@
 
 import javax.net.ssl.*;
 import java.io.*;
-import java.util.concurrent.locks.*;
+import java.util.concurrent.CountDownLatch;
 import java.util.concurrent.TimeUnit;
 
 public class AsyncSSLSocketClose implements Runnable {
     SSLSocket socket;
     SSLServerSocket ss;
 
-    final Lock lock = new ReentrantLock();
-    final Condition isRunning = lock.newCondition(); 
+    // Is the socket ready to close?
+    private final CountDownLatch closeCondition = new CountDownLatch(1);
 
     // Where do we find the keystores?
     static String pathToStores = "../../../../javax/net/ssl/etc";
@@ -81,25 +81,19 @@
 
         (new Thread(this)).start();
         serverSoc.startHandshake();
-        if (lock.tryLock() || lock.tryLock(5000, TimeUnit.MILLISECONDS)) {
-            boolean started = false;
-            try {
-                started = isRunning.await(5000, TimeUnit.MILLISECONDS);
-            } finally {
-                lock.unlock();
-            }
-            if (started) {
-                socket.setSoLinger(true, 10);
-                System.out.println("Calling Socket.close");
-                socket.close();
-                System.out.println("ssl socket get closed");
-                System.out.flush();
-            } else {
-                throw new Exception("Did not get the signal in main thread");
-            }
-        } else {
-            throw new Exception("Unable get the lock in main thread");
+
+        boolean closeIsReady = closeCondition.await(90L, TimeUnit.SECONDS);
+        if (!closeIsReady) {
+            System.out.println(
+                    "Ignore, the closure is not ready yet in 90 seconds.");
+            return;
         }
+
+        socket.setSoLinger(true, 10);
+        System.out.println("Calling Socket.close");
+        socket.close();
+        System.out.println("ssl socket get closed");
+        System.out.flush();
     }
 
     // block in write
@@ -119,16 +113,8 @@
             os.write(ba);
             System.out.println(count + " bytes written");
 
-            if (lock.tryLock() || lock.tryLock(5000, TimeUnit.MILLISECONDS)) {
-                try {
-                    isRunning.signal();
-                } finally {
-                    lock.unlock();
-                }
-            } else {
-                throw new RuntimeException(
-                    "Unable get the lock in write thread");
-            }
+            // Signal, ready to close.
+            closeCondition.countDown();
 
             // write more
             while (true) {
@@ -138,7 +124,7 @@
                 System.out.println(count + " bytes written");
             }
         } catch (Exception e) {
-            if (socket.isClosed()) {
+            if (socket.isClosed() || socket.isOutputShutdown()) {
                 System.out.println("interrupted, the socket is closed");
             } else {
                 throw new RuntimeException("interrupted?", e);
@@ -147,4 +133,3 @@
     }
 }
 
-