src/java.base/share/classes/sun/security/ssl/SSLExtensions.java
author wetmore
Fri, 11 May 2018 15:53:12 -0700
branchJDK-8145252-TLS13-branch
changeset 56542 56aaa6cb3693
parent 48225 src/java.base/share/classes/sun/security/ssl/HelloExtensions.java@718669e6b375
child 56584 a0f3377c58c7
permissions -rw-r--r--
Initial TLSv1.3 Implementation

/*
 * Copyright (c) 2018, Oracle and/or its affiliates. All rights reserved.
 * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
 *
 * This code is free software; you can redistribute it and/or modify it
 * under the terms of the GNU General Public License version 2 only, as
 * published by the Free Software Foundation.  Oracle designates this
 * particular file as subject to the "Classpath" exception as provided
 * by Oracle in the LICENSE file that accompanied this code.
 *
 * This code is distributed in the hope that it will be useful, but WITHOUT
 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
 * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
 * version 2 for more details (a copy is included in the LICENSE file that
 * accompanied this code).
 *
 * You should have received a copy of the GNU General Public License version
 * 2 along with this work; if not, write to the Free Software Foundation,
 * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
 *
 * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
 * or visit www.oracle.com if you need additional information or have any
 * questions.
 */

package sun.security.ssl;

import java.io.IOException;
import java.nio.ByteBuffer;
import java.text.MessageFormat;
import java.util.*;

import sun.security.ssl.SSLHandshake.HandshakeMessage;
import sun.security.util.HexDumpEncoder;

/**
 * SSL/(D)TLS extensions in a handshake message.
 */
final class SSLExtensions {
    private final HandshakeMessage handshakeMessage;
    private Map<SSLExtension, byte[]> extMap = new LinkedHashMap<>();
    private int encodedLength;

    // Extension map for debug logging
    private final Map<Integer, byte[]> logMap =
            SSLLogger.isOn ? null : new LinkedHashMap<>();

    SSLExtensions(HandshakeMessage handshakeMessage) {
        this.handshakeMessage = handshakeMessage;
        this.encodedLength = 2;         // 2: the length of the extensions.
    }

    SSLExtensions(HandshakeMessage hm,
            ByteBuffer m, SSLExtension[] extensions) throws IOException {
        this.handshakeMessage = hm;

        int len = Record.getInt16(m);
        encodedLength = len + 2;        // 2: the length of the extensions.
        while (len > 0) {
            int extId = Record.getInt16(m);
            int extLen = Record.getInt16(m);
            if (extLen > m.remaining()) {
                hm.handshakeContext.conContext.fatal(Alert.ILLEGAL_PARAMETER,
                        "Error parsing extension (" + extId +
                        "): no sufficient data");
            }

            SSLHandshake handshakeType = hm.handshakeType();
            if (SSLExtension.isConsumable(extId) &&
                    SSLExtension.valueOf(handshakeType, extId) == null) {
                hm.handshakeContext.conContext.fatal(
                        Alert.UNSUPPORTED_EXTENSION,
                        "extension (" + extId +
                        ") should not be presented in " + handshakeType.name);
            }

            boolean isSupported = false;
            for (SSLExtension extension : extensions) {
                if ((extension.id != extId) ||
                        (extension.onLoadConcumer == null)) {
                    continue;
                }

                if (extension.handshakeType != handshakeType) {
                    hm.handshakeContext.conContext.fatal(
                            Alert.UNSUPPORTED_EXTENSION,
                            "extension (" + extId + ") should not be " +
                            "presented in " + handshakeType.name);
                }

                byte[] extData = new byte[extLen];
                m.get(extData);
                extMap.put(extension, extData);
                if (logMap != null) {
                    logMap.put(extId, extData);
                }

                isSupported = true;
                break;
            }

            if (!isSupported) {
                if (logMap != null) {
                    // cache the extension for debug logging
                    byte[] extData = new byte[extLen];
                    m.get(extData);
                    logMap.put(extId, extData);

                    if (SSLLogger.isOn && SSLLogger.isOn("ssl,handshake")) {
                        SSLLogger.fine(
                                "Ignore unknown or unsupported extension",
                                toString(extId, extData));
                    }
                } else {
                    // ignore the extension
                    int pos = m.position() + extLen;
                    m.position(pos);
                }
            }

            len -= extLen + 4;
        }
    }

    byte[] get(SSLExtension ext) {
        return extMap.get(ext);
    }

    /**
     * Consume the specified extensions.
     */
    void consumeOnLoad(HandshakeContext context,
            SSLExtension[] extensions) throws IOException {
        for (SSLExtension extension : extensions) {
            if (context.negotiatedProtocol != null &&
                    !extension.isAvailable(context.negotiatedProtocol)) {
                if (SSLLogger.isOn && SSLLogger.isOn("ssl,handshake")) {
                    SSLLogger.fine(
                        "Ignore unsupported extension: " + extension.name);
                }
                continue;
                // context.conContext.fatal(Alert.UNSUPPORTED_EXTENSION,
                //         context.negotiatedProtocol + " does not support " +
                //         extension + " extension");
            }

            if (!extMap.containsKey(extension)) {
                if (extension.onLoadAbsence != null) {
                    extension.absent(context, handshakeMessage);
                } else if (SSLLogger.isOn && SSLLogger.isOn("ssl,handshake")) {
                    SSLLogger.fine(
                        "Ignore unavailable extension: " + extension.name);
                }
                continue;
            }


            if (extension.onLoadConcumer == null) {
                if (SSLLogger.isOn && SSLLogger.isOn("ssl,handshake")) {
                    SSLLogger.warning(
                        "Ignore unsupported extension: " + extension.name);
                }
                continue;
            }

            ByteBuffer m = ByteBuffer.wrap(extMap.get(extension));
            extension.consumeOnLoad(context, handshakeMessage, m);

            if (SSLLogger.isOn && SSLLogger.isOn("ssl,handshake")) {
                SSLLogger.fine("Consumed extension: " + extension.name);
            }
        }
    }

    /**
     * Consider impact of the specified extensions.
     */
    void consumeOnTrade(HandshakeContext context,
            SSLExtension[] extensions) throws IOException {
        for (SSLExtension extension : extensions) {
            if (!extMap.containsKey(extension)) {
                // No impact could be expected, so just ignore the absence.
                continue;
            }

            if (extension.onTradeConsumer == null) {
                if (SSLLogger.isOn && SSLLogger.isOn("ssl,handshake")) {
                    SSLLogger.warning(
                            "Ignore impact of unsupported extension: " +
                            extension.name);
                }
                continue;
            }

            extension.consumeOnTrade(context, handshakeMessage);
            if (SSLLogger.isOn && SSLLogger.isOn("ssl,handshake")) {
                SSLLogger.fine("Populated with extension: " + extension.name);
            }
        }
    }

    /**
     * Produce extension values for the specified extensions.
     */
    void produce(HandshakeContext context,
            SSLExtension[] extensions) throws IOException {
        for (SSLExtension extension : extensions) {
            if (extMap.containsKey(extension)) {
                if (SSLLogger.isOn && SSLLogger.isOn("ssl,handshake")) {
                    SSLLogger.fine(
                            "Ignore, duplicated extension: " +
                            extension.name);
                }
                continue;
            }

            if (extension.networkProducer == null) {
                if (SSLLogger.isOn && SSLLogger.isOn("ssl,handshake")) {
                    SSLLogger.warning(
                            "Ignore, no extension producer defined: " +
                            extension.name);
                }
                continue;
            }

            byte[] encoded = extension.produce(context, handshakeMessage);
            if (encoded != null) {
                extMap.put(extension, encoded);
                encodedLength += encoded.length + 4; // extension_type (2)
                                                     // extension_data length(2)
            } else if (SSLLogger.isOn && SSLLogger.isOn("ssl,handshake")) {
                // The extension is not available in the context.
                SSLLogger.fine(
                        "Ignore, context unavailable extension: " +
                        extension.name);
            }
        }
    }

    /**
     * Produce extension values for the specified extensions, replacing if
     * there is an existing extension value for a specified extension.
     */
    void reproduce(HandshakeContext context,
            SSLExtension[] extensions) throws IOException {
        for (SSLExtension extension : extensions) {
            if (extension.networkProducer == null) {
                if (SSLLogger.isOn && SSLLogger.isOn("ssl,handshake")) {
                    SSLLogger.warning(
                            "Ignore, no extension producer defined: " +
                            extension.name);
                }
                continue;
            }

            byte[] encoded = extension.produce(context, handshakeMessage);
            if (encoded != null) {
                if (extMap.containsKey(extension)) {
                    byte[] old = extMap.replace(extension, encoded);
                    if (old != null) {
                        encodedLength -= old.length + 4;
                    }
                    encodedLength += encoded.length + 4;
                } else {
                    extMap.put(extension, encoded);
                    encodedLength += encoded.length + 4;
                                                    // extension_type (2)
                                                    // extension_data length(2)
                }
            } else if (SSLLogger.isOn && SSLLogger.isOn("ssl,handshake")) {
                // The extension is not available in the context.
                SSLLogger.fine(
                        "Ignore, context unavailable extension: " +
                        extension.name);
            }
        }
    }

    // Note that TLS 1.3 may use empty extensions.  Please consider it while
    // using this method.
    int length() {
        if (extMap.isEmpty()) {
            return 0;
        } else {
            return encodedLength;
        }
    }

    // Note that TLS 1.3 may use empty extensions.  Please consider it while
    // using this method.
    void send(HandshakeOutStream hos) throws IOException {
        int extsLen = length();
        if (extsLen == 0) {
            return;
        }
        hos.putInt16(extsLen - 2);
        // extensions must be sent in the order they appear in the enum
        for (SSLExtension ext : SSLExtension.values()) {
            byte[] extData = extMap.get(ext);
            if (extData != null) {
                hos.putInt16(ext.id);
                hos.putBytes16(extData);
            }
        }
    }

    @Override
    public String toString() {
        if (extMap.isEmpty() && (logMap == null || logMap.isEmpty())) {
            return "<no extension>";
        } else {
            StringBuilder builder = new StringBuilder(512);
            if (logMap != null) {
                for (Map.Entry<Integer, byte[]> en : logMap.entrySet()) {
                    SSLExtension ext = SSLExtension.valueOf(
                            handshakeMessage.handshakeType(), en.getKey());
                    if (builder.length() != 0) {
                        builder.append(",\n");
                    }
                    if (ext != null) {
                        builder.append(
                                ext.toString(ByteBuffer.wrap(en.getValue())));
                    } else {
                        builder.append(toString(en.getKey(), en.getValue()));
                    }
                }

                return builder.toString();
            } else {
                for (Map.Entry<SSLExtension, byte[]> en : extMap.entrySet()) {
                    if (builder.length() != 0) {
                        builder.append(",\n");
                    }
                    builder.append(
                        en.getKey().toString(ByteBuffer.wrap(en.getValue())));
                }

                return builder.toString();
            }
        }
    }

    private static String toString(int extId, byte[] extData) {
        MessageFormat messageFormat = new MessageFormat(
            "\"unknown extension ({0})\": '{'\n" +
            "{1}\n" +
            "'}'",
            Locale.ENGLISH);

        HexDumpEncoder hexEncoder = new HexDumpEncoder();
        String encoded = hexEncoder.encodeBuffer(extData);

        Object[] messageFields = {
            extId,
            Utilities.indent(encoded)
        };

        return messageFormat.format(messageFields);
    }
}