src/java.base/share/classes/sun/security/ssl/PreSharedKeyExtension.java
author apetcher
Fri, 25 May 2018 13:20:01 -0400
branchJDK-8145252-TLS13-branch
changeset 56608 34f33526b9a5
parent 56558 4a3deb6759b1
child 56661 2a820e434f17
permissions -rw-r--r--
A couple of minor session resumption fixes

/*
 * Copyright (c) 2015, 2017, Oracle and/or its affiliates. All rights reserved.
 * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
 *
 * This code is free software; you can redistribute it and/or modify it
 * under the terms of the GNU General Public License version 2 only, as
 * published by the Free Software Foundation.  Oracle designates this
 * particular file as subject to the "Classpath" exception as provided
 * by Oracle in the LICENSE file that accompanied this code.
 *
 * This code is distributed in the hope that it will be useful, but WITHOUT
 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
 * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
 * version 2 for more details (a copy is included in the LICENSE file that
 * accompanied this code).
 *
 * You should have received a copy of the GNU General Public License version
 * 2 along with this work; if not, write to the Free Software Foundation,
 * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
 *
 * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
 * or visit www.oracle.com if you need additional information or have any
 * questions.
 */
package sun.security.ssl;

import java.io.IOException;
import java.nio.ByteBuffer;
import java.security.*;
import java.text.MessageFormat;
import java.util.Map;
import java.util.List;
import java.util.ArrayList;
import java.util.Locale;
import java.util.Arrays;
import java.util.Collections;
import java.util.Optional;
import sun.security.ssl.SSLExtension.ExtensionConsumer;

import sun.security.ssl.SSLExtension.SSLExtensionSpec;
import sun.security.ssl.SSLHandshake.HandshakeMessage;

import javax.crypto.Mac;
import javax.crypto.SecretKey;

import static sun.security.ssl.SSLExtension.*;

/**
 * Pack of the "pre_shared_key" extension.
 */
final class PreSharedKeyExtension {
    static final HandshakeProducer chNetworkProducer =
            new CHPreSharedKeyProducer();
    static final ExtensionConsumer chOnLoadConsumer =
            new CHPreSharedKeyConsumer();
    static final HandshakeAbsence chOnLoadAbsence =
            new CHPreSharedKeyAbsence();
    static final HandshakeConsumer chOnTradeConsumer=
            new CHPreSharedKeyUpdate();

    static final HandshakeProducer shNetworkProducer =
            new SHPreSharedKeyProducer();
    static final ExtensionConsumer shOnLoadConsumer =
            new SHPreSharedKeyConsumer();
    static final HandshakeAbsence shOnLoadAbsence =
            new SHPreSharedKeyAbsence();

    static final class PskIdentity {
        final byte[] identity;
        final int obfuscatedAge;

        public PskIdentity(byte[] identity, int obfuscatedAge) {
            this.identity = identity;
            this.obfuscatedAge = obfuscatedAge;
        }

        public PskIdentity(ByteBuffer m)
            throws IllegalParameterException, IOException {

            identity = Record.getBytes16(m);
            if (identity.length == 0) {
                throw new IllegalParameterException("identity has length 0");
            }
            obfuscatedAge = Record.getInt32(m);
        }

        int getEncodedLength() {
            return 2 + identity.length + 4;
        }

        public void writeEncoded(ByteBuffer m) throws IOException {
            Record.putBytes16(m, identity);
            Record.putInt32(m, obfuscatedAge);
        }
        @Override
        public String toString() {
            return "{" + Utilities.toHexString(identity) + "," +
                obfuscatedAge + "}";
        }
    }

    static final class CHPreSharedKeySpec implements SSLExtensionSpec {
        final List<PskIdentity> identities;
        final List<byte[]> binders;

        CHPreSharedKeySpec(List<PskIdentity> identities, List<byte[]> binders) {
            this.identities = identities;
            this.binders = binders;
        }

        CHPreSharedKeySpec(ByteBuffer m)
            throws IllegalParameterException, IOException {

            identities = new ArrayList<>();
            int idEncodedLength = Record.getInt16(m);
            int idReadLength = 0;
            while (idReadLength < idEncodedLength) {
                PskIdentity id = new PskIdentity(m);
                identities.add(id);
                idReadLength += id.getEncodedLength();
            }

            binders = new ArrayList<>();
            int bindersEncodedLength = Record.getInt16(m);
            int bindersReadLength = 0;
            while (bindersReadLength < bindersEncodedLength) {
                byte[] binder = Record.getBytes8(m);
                if (binder.length < 32) {
                    throw new IllegalParameterException(
                        "binder has length < 32");
                }
                binders.add(binder);
                bindersReadLength += 1 + binder.length;
            }
        }

        int getIdsEncodedLength() {
            int idEncodedLength = 0;
            for(PskIdentity curId : identities) {
                idEncodedLength += curId.getEncodedLength();
            }
            return idEncodedLength;
        }

        int getBindersEncodedLength() {
            return getBindersEncodedLength(binders);
        }
        static int getBindersEncodedLength(Iterable<byte[]> binders) {
            int binderEncodedLength = 0;
            for (byte[] curBinder : binders) {
                binderEncodedLength += 1 + curBinder.length;
            }
            return binderEncodedLength;
        }

        byte[] getEncoded() throws IOException {

            int idsEncodedLength = getIdsEncodedLength();
            int bindersEncodedLength = getBindersEncodedLength();
            int encodedLength = 4 + idsEncodedLength + bindersEncodedLength;
            byte[] buffer = new byte[encodedLength];
            ByteBuffer m = ByteBuffer.wrap(buffer);
            Record.putInt16(m, idsEncodedLength);
            for(PskIdentity curId : identities) {
                curId.writeEncoded(m);
            }
            Record.putInt16(m, bindersEncodedLength);
            for (byte[] curBinder : binders) {
                Record.putBytes8(m, curBinder);
            }

            return buffer;
        }

        @Override
        public String toString() {
            MessageFormat messageFormat = new MessageFormat(
                "\"PreSharedKey\": '{'\n" +
                "  \"identities\"      : \"{0}\",\n" +
                "  \"binders\"       : \"{1}\",\n" +
                "'}'",
                Locale.ENGLISH);

            Object[] messageFields = {
                Utilities.indent(identitiesString()),
                Utilities.indent(bindersString())
            };

            return messageFormat.format(messageFields);
        }

        String identitiesString() {
            StringBuilder result = new StringBuilder();
            for(PskIdentity curId : identities) {
                result.append(curId.toString() + "\n");
            }

            return result.toString();
        }

        String bindersString() {
            StringBuilder result = new StringBuilder();
            for(byte[] curBinder : binders) {
                result.append("{" + Utilities.toHexString(curBinder) + "}\n");
            }

            return result.toString();
        }
    }

    static final class SHPreSharedKeySpec implements SSLExtensionSpec {
        final int selectedIdentity;

        SHPreSharedKeySpec(int selectedIdentity) {
            this.selectedIdentity = selectedIdentity;
        }

        SHPreSharedKeySpec(ByteBuffer m) throws IOException {
            this.selectedIdentity = Record.getInt16(m);
        }

        byte[] getEncoded() throws IOException {

            byte[] buffer = new byte[2];
            ByteBuffer m = ByteBuffer.wrap(buffer);
            Record.putInt16(m, selectedIdentity);

            return buffer;
        }

        @Override
        public String toString() {
            MessageFormat messageFormat = new MessageFormat(
                "\"PreSharedKey\": '{'\n" +
                "  \"selected_identity\"      : \"{0}\",\n" +
                "'}'",
                Locale.ENGLISH);

            Object[] messageFields = {
                selectedIdentity
            };

            return messageFormat.format(messageFields);
        }

    }


    private static class IllegalParameterException extends Exception {

        private static final long serialVersionUID = 0;

        private final String message;

        private IllegalParameterException(String message) {
            this.message = message;
        }
    }

    private static final class CHPreSharedKeyConsumer implements ExtensionConsumer {
        // Prevent instantiation of this class.
        private CHPreSharedKeyConsumer() {
            // blank
        }

        @Override
        public void consume(ConnectionContext context,
                            HandshakeMessage message,
                            ByteBuffer buffer) throws IOException {

            ServerHandshakeContext shc = (ServerHandshakeContext) message.handshakeContext;
            // Is it a supported and enabled extension?
            if (!shc.sslConfig.isAvailable(SSLExtension.CH_PRE_SHARED_KEY)) {
                if (SSLLogger.isOn && SSLLogger.isOn("ssl,handshake")) {
                    SSLLogger.fine(
                    "Ignore unavailable pre_shared_key extension");
                }
                return;     // ignore the extension
            }

            CHPreSharedKeySpec pskSpec = null;
            try {
                pskSpec = new CHPreSharedKeySpec(buffer);
            } catch (IOException ioe) {
                shc.conContext.fatal(Alert.UNEXPECTED_MESSAGE, ioe);
                return;     // fatal() always throws, make the compiler happy.
            } catch (IllegalParameterException ex) {
                shc.conContext.fatal(Alert.ILLEGAL_PARAMETER, ex.message);
            }

            if (SSLLogger.isOn && SSLLogger.isOn("ssl,handshake")) {
                SSLLogger.fine(
                "Received PSK extension: ", pskSpec);
            }

            if (shc.pskKeyExchangeModes.isEmpty()) {
                shc.conContext.fatal(Alert.ILLEGAL_PARAMETER,
                "Client sent PSK but does not support PSK modes");
            }

            // error if id and binder lists are not the same length
            if (pskSpec.identities.size() != pskSpec.binders.size()) {
                shc.conContext.fatal(Alert.ILLEGAL_PARAMETER,
                "PSK extension has incorrect number of binders");
            }

            shc.handshakeExtensions.put(SSLExtension.CH_PRE_SHARED_KEY, pskSpec);

            SSLSessionContextImpl sessionCache = (SSLSessionContextImpl)
            message.handshakeContext.sslContext.engineGetServerSessionContext();

            // The session to resume will be decided below.
            // It could have been set by previous actions (e.g. PSK received
            // earlier), and it must be recalculated.
            shc.isResumption = false;
            shc.resumingSession = null;

            int idIndex = 0;
            for (PskIdentity requestedId : pskSpec.identities) {
                SSLSessionImpl s = sessionCache.get(requestedId.identity);
                if (s != null && s.isRejoinable() &&
                    s.getPreSharedKey().isPresent()) {

                    resumeSession(shc, s, idIndex);
                    break;
                }

                ++idIndex;
            }
        }
    }

    private static final class CHPreSharedKeyUpdate implements HandshakeConsumer {
        // Prevent instantiation of this class.
        private CHPreSharedKeyUpdate() {
            // blank
        }

        @Override
        public void consume(ConnectionContext context,
                            HandshakeMessage message) throws IOException {

            ServerHandshakeContext shc = (ServerHandshakeContext) message.handshakeContext;

            if (!shc.isResumption || shc.resumingSession == null) {
                // not resuming---nothing to do
                return;
            }

            CHPreSharedKeySpec chPsk = (CHPreSharedKeySpec)shc.handshakeExtensions.get(SSLExtension.CH_PRE_SHARED_KEY);
            SHPreSharedKeySpec shPsk = (SHPreSharedKeySpec)shc.handshakeExtensions.get(SSLExtension.SH_PRE_SHARED_KEY);

            if (chPsk == null || shPsk == null) {
                shc.conContext.fatal(Alert.INTERNAL_ERROR,
                "Required extensions are unavailable");
            }

            byte[] binder = chPsk.binders.get(shPsk.selectedIdentity);

            // set up PSK binder hash
            HandshakeHash pskBinderHash = shc.handshakeHash.copy();
            byte[] lastMessage = pskBinderHash.removeLastReceived();
            ByteBuffer messageBuf = ByteBuffer.wrap(lastMessage);
            // skip the type and length
            messageBuf.position(4);
            // read to find the beginning of the binders
            ClientHello.ClientHelloMessage.readPartial(shc.conContext, messageBuf);
            int length = messageBuf.position();
            messageBuf.position(0);
            pskBinderHash.receive(messageBuf, length);

            checkBinder(shc, shc.resumingSession, pskBinderHash, binder);
        }
    }

    private static void resumeSession(ServerHandshakeContext shc,
                                      SSLSessionImpl session,
                                      int index)
        throws IOException {

        if (SSLLogger.isOn && SSLLogger.isOn("ssl,handshake")) {
            SSLLogger.fine(
            "Resuming session: ", session);
        }

        // binder will be checked later

        shc.isResumption = true;
        shc.resumingSession = session;

        SHPreSharedKeySpec pskMsg = new SHPreSharedKeySpec(index);
        shc.handshakeExtensions.put(SH_PRE_SHARED_KEY, pskMsg);
    }

    private static void checkBinder(ServerHandshakeContext shc, SSLSessionImpl session,
                                    HandshakeHash pskBinderHash, byte[] binder) throws IOException {

        Optional<SecretKey> pskOpt = session.getPreSharedKey();
        if (!pskOpt.isPresent()) {
            shc.conContext.fatal(Alert.INTERNAL_ERROR,
            "Session has no PSK");
        }
        SecretKey psk = pskOpt.get();

        SecretKey binderKey = deriveBinderKey(psk, session);
        byte[] computedBinder = computeBinder(binderKey, session, pskBinderHash);
        if (!Arrays.equals(binder, computedBinder)) {
            shc.conContext.fatal(Alert.ILLEGAL_PARAMETER,
            "Incorect PSK binder value");
        }
    }

    // Class that produces partial messages used to compute binder hash
    static final class PartialClientHelloMessage extends HandshakeMessage {

        private final ClientHello.ClientHelloMessage msg;
        private final CHPreSharedKeySpec psk;

        PartialClientHelloMessage(HandshakeContext ctx,
                                  ClientHello.ClientHelloMessage msg,
                                  CHPreSharedKeySpec psk) {
            super(ctx);

            this.msg = msg;
            this.psk = psk;
        }

        @Override
        SSLHandshake handshakeType() {
            return msg.handshakeType();
        }

        private int pskTotalLength() {
            return psk.getIdsEncodedLength() +
                psk.getBindersEncodedLength() + 8;
        }

        @Override
        int messageLength() {

            if (msg.extensions.get(SSLExtension.CH_PRE_SHARED_KEY) != null) {
                return msg.messageLength();
            } else {
                return msg.messageLength() + pskTotalLength();
            }
        }

        @Override
        void send(HandshakeOutStream hos) throws IOException {
            msg.sendCore(hos);

            // complete extensions
            int extsLen = msg.extensions.length();
            if (msg.extensions.get(SSLExtension.CH_PRE_SHARED_KEY) == null) {
                extsLen += pskTotalLength();
            }
            hos.putInt16(extsLen - 2);
            // write the complete extensions
            for (SSLExtension ext : SSLExtension.values()) {
                byte[] extData = msg.extensions.get(ext);
                if (extData == null) {
                    continue;
                }
                // the PSK could be there from an earlier round
                if (ext == SSLExtension.CH_PRE_SHARED_KEY) {
                    continue;
                }
                int extID = ext.id;
                hos.putInt16(extID);
                hos.putBytes16(extData);
            }

            // partial PSK extension
            int extID = SSLExtension.CH_PRE_SHARED_KEY.id;
            hos.putInt16(extID);
            byte[] encodedPsk = psk.getEncoded();
            hos.putInt16(encodedPsk.length);
            hos.write(encodedPsk, 0, psk.getIdsEncodedLength() + 2);
        }
    }

    private static final class CHPreSharedKeyProducer implements HandshakeProducer {

        // Prevent instantiation of this class.
        private CHPreSharedKeyProducer() {
            // blank
        }

        @Override
        public byte[] produce(ConnectionContext context,
                HandshakeMessage message) throws IOException {

            // The producing happens in client side only.
            ClientHandshakeContext chc = (ClientHandshakeContext)context;
            if (!chc.isResumption || chc.resumingSession == null) {
                if (SSLLogger.isOn && SSLLogger.isOn("ssl,handshake")) {
                    SSLLogger.fine(
                    "No session to resume.");
                }
                return null;
            }

            Optional<SecretKey> pskOpt = chc.resumingSession.getPreSharedKey();
            if (!pskOpt.isPresent()) {
                if (SSLLogger.isOn && SSLLogger.isOn("ssl,handshake")) {
                    SSLLogger.fine(
                    "Existing session has no PSK.");
                }
                return null;
            }
            SecretKey psk = pskOpt.get();
            Optional<byte[]> pskIdOpt = chc.resumingSession.getPskIdentity();
            if (!pskIdOpt.isPresent()) {
                if (SSLLogger.isOn && SSLLogger.isOn("ssl,handshake")) {
                    SSLLogger.fine(
                    "PSK has no identity, or identity was already used");
                }
                return null;
            }
            byte[] pskId = pskIdOpt.get();

            if (SSLLogger.isOn && SSLLogger.isOn("ssl,handshake")) {
                SSLLogger.fine(
                "Found resumable session. Preparing PSK message.");
            }

            List<PskIdentity> identities = new ArrayList<>();
            int ageMillis = (int)(System.currentTimeMillis() - chc.resumingSession.getTicketCreationTime());
            int obfuscatedAge = ageMillis + chc.resumingSession.getTicketAgeAdd();
            identities.add(new PskIdentity(pskId, obfuscatedAge));

            SecretKey binderKey = deriveBinderKey(psk, chc.resumingSession);
            ClientHello.ClientHelloMessage clientHello = (ClientHello.ClientHelloMessage) message;
            CHPreSharedKeySpec pskPrototype = createPskPrototype(chc.resumingSession.getSuite().hashAlg.hashLength, identities);
            HandshakeHash pskBinderHash = chc.handshakeHash.copy();

            byte[] binder = computeBinder(binderKey, pskBinderHash, chc.resumingSession, chc, clientHello, pskPrototype);

            List<byte[]> binders = new ArrayList<>();
            binders.add(binder);

            CHPreSharedKeySpec pskMessage = new CHPreSharedKeySpec(identities, binders);
            chc.handshakeExtensions.put(CH_PRE_SHARED_KEY, pskMessage);
            return pskMessage.getEncoded();
        }

        private CHPreSharedKeySpec createPskPrototype(int hashLength, List<PskIdentity> identities) {
            List<byte[]> binders = new ArrayList<>();
            byte[] binderProto = new byte[hashLength];
            for (PskIdentity curId : identities) {
                binders.add(binderProto);
            }

            return new CHPreSharedKeySpec(identities, binders);
        }
    }

    private static byte[] computeBinder(SecretKey binderKey, SSLSessionImpl session, HandshakeHash pskBinderHash) throws IOException {

        pskBinderHash.determine(session.getProtocolVersion(), session.getSuite());
        pskBinderHash.update();
        byte[] digest = pskBinderHash.digest();

        return computeBinder(binderKey, session, digest);
    }

    private static byte[] computeBinder(SecretKey binderKey, HandshakeHash hash, SSLSessionImpl session,
                                        HandshakeContext ctx, ClientHello.ClientHelloMessage hello,
                                        CHPreSharedKeySpec pskPrototype) throws IOException {

        PartialClientHelloMessage partialMsg = new PartialClientHelloMessage(ctx, hello, pskPrototype);

        SSLEngineOutputRecord record = new SSLEngineOutputRecord(hash);
        HandshakeOutStream hos = new HandshakeOutStream(record);
        partialMsg.write(hos);

        hash.determine(session.getProtocolVersion(), session.getSuite());
        hash.update();
        byte[] digest = hash.digest();

        return computeBinder(binderKey, session, digest);
    }

    private static byte[] computeBinder(SecretKey binderKey, SSLSessionImpl session,
                                        byte[] digest) throws IOException {

        try {
            CipherSuite.HashAlg hashAlg = session.getSuite().hashAlg;
            HKDF hkdf = new HKDF(hashAlg.name);
            byte[] label = ("tls13 finished").getBytes();
            byte[] hkdfInfo = SSLSecretDerivation.createHkdfInfo(label, new byte[0], hashAlg.hashLength);
            SecretKey finishedKey = hkdf.expand(binderKey, hkdfInfo, hashAlg.hashLength, "TlsBinderKey");

            String hmacAlg =
                "Hmac" + hashAlg.name.replace("-", "");
            try {
                Mac hmac = JsseJce.getMac(hmacAlg);
                hmac.init(finishedKey);
                return hmac.doFinal(digest);
            } catch (NoSuchAlgorithmException | InvalidKeyException ex) {
                throw new IOException(ex);
            }
        } catch(GeneralSecurityException ex) {
            throw new IOException(ex);
        }
    }

    private static SecretKey deriveBinderKey(SecretKey psk,
                                             SSLSessionImpl session)
        throws IOException {

        try {
            CipherSuite.HashAlg hashAlg = session.getSuite().hashAlg;
            HKDF hkdf = new HKDF(hashAlg.name);
            byte[] zeros = new byte[hashAlg.hashLength];
            SecretKey earlySecret = hkdf.extract(zeros, psk, "TlsEarlySecret");

            byte[] label = ("tls13 res binder").getBytes();
            MessageDigest md = MessageDigest.getInstance(hashAlg.toString());;
            byte[] hkdfInfo = SSLSecretDerivation.createHkdfInfo(
                label, md.digest(new byte[0]), hashAlg.hashLength);
            return hkdf.expand(earlySecret, hkdfInfo, hashAlg.hashLength,
                "TlsBinderKey");

        } catch (GeneralSecurityException ex) {
            throw new IOException(ex);
        }
    }

    private static final class CHPreSharedKeyAbsence implements HandshakeAbsence {
        @Override
        public void absent(ConnectionContext context,
                           HandshakeMessage message) throws IOException {

            if (SSLLogger.isOn && SSLLogger.isOn("ssl,handshake")) {
                SSLLogger.fine(
                "Handling pre_shared_key absence.");
            }

            ServerHandshakeContext shc = (ServerHandshakeContext)context;

            // Resumption is only determined by PSK, when enabled
            shc.resumingSession = null;
            shc.isResumption = false;
        }
    }

    private static final class SHPreSharedKeyConsumer implements ExtensionConsumer {
        // Prevent instantiation of this class.
        private SHPreSharedKeyConsumer() {

        }

        @Override
        public void consume(ConnectionContext context,
                HandshakeMessage message, ByteBuffer buffer) throws IOException {

            ClientHandshakeContext chc = (ClientHandshakeContext) message.handshakeContext;

            SHPreSharedKeySpec shPsk = new SHPreSharedKeySpec(buffer);
            if (SSLLogger.isOn && SSLLogger.isOn("ssl,handshake")) {
                SSLLogger.fine(
                "Received pre_shared_key extension: ", shPsk);
            }

            if (!chc.handshakeExtensions.containsKey(SSLExtension.CH_PRE_SHARED_KEY)) {
                chc.conContext.fatal(Alert.UNEXPECTED_MESSAGE,
                "Server sent unexpected pre_shared_key extension");
            }

            // The PSK identity should not be reused, even if it is
            // not selected.
            chc.resumingSession.consumePskIdentity();

            if (shPsk.selectedIdentity != 0) {
                chc.conContext.fatal(Alert.ILLEGAL_PARAMETER,
                "Selected identity index is not in correct range.");
            }

            if (SSLLogger.isOn && SSLLogger.isOn("ssl,handshake")) {
                SSLLogger.fine(
                "Resuming session: ", chc.resumingSession);
            }

            // remove the session from the cache
            SSLSessionContextImpl sessionCache = (SSLSessionContextImpl)
                chc.sslContext.engineGetClientSessionContext();
            sessionCache.remove(chc.resumingSession.getSessionId());
        }
    }

    private static final class SHPreSharedKeyAbsence implements HandshakeAbsence {
        @Override
        public void absent(ConnectionContext context,
                           HandshakeMessage message) throws IOException {

            if (SSLLogger.isOn && SSLLogger.isOn("ssl,handshake")) {
                SSLLogger.fine(
                "Handling pre_shared_key absence.");
            }

            ClientHandshakeContext chc = (ClientHandshakeContext)context;

            if (chc.handshakeExtensions.containsKey(SSLExtension.CH_PRE_SHARED_KEY)) {
                // The PSK identity should not be reused, even if it is
                // not selected.
                chc.resumingSession.consumePskIdentity();
            }

            // the server refused to resume, or the client did not request 1.3 resumption
            chc.resumingSession = null;
            chc.isResumption = false;
        }
    }

    private static final class SHPreSharedKeyProducer implements HandshakeProducer {

        // Prevent instantiation of this class.
        private SHPreSharedKeyProducer() {
            // blank
        }

        @Override
        public byte[] produce(ConnectionContext context,
                HandshakeMessage message) throws IOException {

            ServerHandshakeContext shc = (ServerHandshakeContext)
                message.handshakeContext;
            SHPreSharedKeySpec psk = (SHPreSharedKeySpec)
                shc.handshakeExtensions.get(SH_PRE_SHARED_KEY);
            if (psk == null) {
                return null;
            }

            return psk.getEncoded();
        }
    }
}