src/java.base/share/classes/sun/security/ssl/SSLExtensions.java
branchJDK-8145252-TLS13-branch
changeset 56542 56aaa6cb3693
parent 48225 718669e6b375
child 56584 a0f3377c58c7
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/src/java.base/share/classes/sun/security/ssl/SSLExtensions.java	Fri May 11 15:53:12 2018 -0700
@@ -0,0 +1,360 @@
+/*
+ * Copyright (c) 2018, Oracle and/or its affiliates. All rights reserved.
+ * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
+ *
+ * This code is free software; you can redistribute it and/or modify it
+ * under the terms of the GNU General Public License version 2 only, as
+ * published by the Free Software Foundation.  Oracle designates this
+ * particular file as subject to the "Classpath" exception as provided
+ * by Oracle in the LICENSE file that accompanied this code.
+ *
+ * This code is distributed in the hope that it will be useful, but WITHOUT
+ * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
+ * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
+ * version 2 for more details (a copy is included in the LICENSE file that
+ * accompanied this code).
+ *
+ * You should have received a copy of the GNU General Public License version
+ * 2 along with this work; if not, write to the Free Software Foundation,
+ * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
+ *
+ * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
+ * or visit www.oracle.com if you need additional information or have any
+ * questions.
+ */
+
+package sun.security.ssl;
+
+import java.io.IOException;
+import java.nio.ByteBuffer;
+import java.text.MessageFormat;
+import java.util.*;
+
+import sun.security.ssl.SSLHandshake.HandshakeMessage;
+import sun.security.util.HexDumpEncoder;
+
+/**
+ * SSL/(D)TLS extensions in a handshake message.
+ */
+final class SSLExtensions {
+    private final HandshakeMessage handshakeMessage;
+    private Map<SSLExtension, byte[]> extMap = new LinkedHashMap<>();
+    private int encodedLength;
+
+    // Extension map for debug logging
+    private final Map<Integer, byte[]> logMap =
+            SSLLogger.isOn ? null : new LinkedHashMap<>();
+
+    SSLExtensions(HandshakeMessage handshakeMessage) {
+        this.handshakeMessage = handshakeMessage;
+        this.encodedLength = 2;         // 2: the length of the extensions.
+    }
+
+    SSLExtensions(HandshakeMessage hm,
+            ByteBuffer m, SSLExtension[] extensions) throws IOException {
+        this.handshakeMessage = hm;
+
+        int len = Record.getInt16(m);
+        encodedLength = len + 2;        // 2: the length of the extensions.
+        while (len > 0) {
+            int extId = Record.getInt16(m);
+            int extLen = Record.getInt16(m);
+            if (extLen > m.remaining()) {
+                hm.handshakeContext.conContext.fatal(Alert.ILLEGAL_PARAMETER,
+                        "Error parsing extension (" + extId +
+                        "): no sufficient data");
+            }
+
+            SSLHandshake handshakeType = hm.handshakeType();
+            if (SSLExtension.isConsumable(extId) &&
+                    SSLExtension.valueOf(handshakeType, extId) == null) {
+                hm.handshakeContext.conContext.fatal(
+                        Alert.UNSUPPORTED_EXTENSION,
+                        "extension (" + extId +
+                        ") should not be presented in " + handshakeType.name);
+            }
+
+            boolean isSupported = false;
+            for (SSLExtension extension : extensions) {
+                if ((extension.id != extId) ||
+                        (extension.onLoadConcumer == null)) {
+                    continue;
+                }
+
+                if (extension.handshakeType != handshakeType) {
+                    hm.handshakeContext.conContext.fatal(
+                            Alert.UNSUPPORTED_EXTENSION,
+                            "extension (" + extId + ") should not be " +
+                            "presented in " + handshakeType.name);
+                }
+
+                byte[] extData = new byte[extLen];
+                m.get(extData);
+                extMap.put(extension, extData);
+                if (logMap != null) {
+                    logMap.put(extId, extData);
+                }
+
+                isSupported = true;
+                break;
+            }
+
+            if (!isSupported) {
+                if (logMap != null) {
+                    // cache the extension for debug logging
+                    byte[] extData = new byte[extLen];
+                    m.get(extData);
+                    logMap.put(extId, extData);
+
+                    if (SSLLogger.isOn && SSLLogger.isOn("ssl,handshake")) {
+                        SSLLogger.fine(
+                                "Ignore unknown or unsupported extension",
+                                toString(extId, extData));
+                    }
+                } else {
+                    // ignore the extension
+                    int pos = m.position() + extLen;
+                    m.position(pos);
+                }
+            }
+
+            len -= extLen + 4;
+        }
+    }
+
+    byte[] get(SSLExtension ext) {
+        return extMap.get(ext);
+    }
+
+    /**
+     * Consume the specified extensions.
+     */
+    void consumeOnLoad(HandshakeContext context,
+            SSLExtension[] extensions) throws IOException {
+        for (SSLExtension extension : extensions) {
+            if (context.negotiatedProtocol != null &&
+                    !extension.isAvailable(context.negotiatedProtocol)) {
+                if (SSLLogger.isOn && SSLLogger.isOn("ssl,handshake")) {
+                    SSLLogger.fine(
+                        "Ignore unsupported extension: " + extension.name);
+                }
+                continue;
+                // context.conContext.fatal(Alert.UNSUPPORTED_EXTENSION,
+                //         context.negotiatedProtocol + " does not support " +
+                //         extension + " extension");
+            }
+
+            if (!extMap.containsKey(extension)) {
+                if (extension.onLoadAbsence != null) {
+                    extension.absent(context, handshakeMessage);
+                } else if (SSLLogger.isOn && SSLLogger.isOn("ssl,handshake")) {
+                    SSLLogger.fine(
+                        "Ignore unavailable extension: " + extension.name);
+                }
+                continue;
+            }
+
+
+            if (extension.onLoadConcumer == null) {
+                if (SSLLogger.isOn && SSLLogger.isOn("ssl,handshake")) {
+                    SSLLogger.warning(
+                        "Ignore unsupported extension: " + extension.name);
+                }
+                continue;
+            }
+
+            ByteBuffer m = ByteBuffer.wrap(extMap.get(extension));
+            extension.consumeOnLoad(context, handshakeMessage, m);
+
+            if (SSLLogger.isOn && SSLLogger.isOn("ssl,handshake")) {
+                SSLLogger.fine("Consumed extension: " + extension.name);
+            }
+        }
+    }
+
+    /**
+     * Consider impact of the specified extensions.
+     */
+    void consumeOnTrade(HandshakeContext context,
+            SSLExtension[] extensions) throws IOException {
+        for (SSLExtension extension : extensions) {
+            if (!extMap.containsKey(extension)) {
+                // No impact could be expected, so just ignore the absence.
+                continue;
+            }
+
+            if (extension.onTradeConsumer == null) {
+                if (SSLLogger.isOn && SSLLogger.isOn("ssl,handshake")) {
+                    SSLLogger.warning(
+                            "Ignore impact of unsupported extension: " +
+                            extension.name);
+                }
+                continue;
+            }
+
+            extension.consumeOnTrade(context, handshakeMessage);
+            if (SSLLogger.isOn && SSLLogger.isOn("ssl,handshake")) {
+                SSLLogger.fine("Populated with extension: " + extension.name);
+            }
+        }
+    }
+
+    /**
+     * Produce extension values for the specified extensions.
+     */
+    void produce(HandshakeContext context,
+            SSLExtension[] extensions) throws IOException {
+        for (SSLExtension extension : extensions) {
+            if (extMap.containsKey(extension)) {
+                if (SSLLogger.isOn && SSLLogger.isOn("ssl,handshake")) {
+                    SSLLogger.fine(
+                            "Ignore, duplicated extension: " +
+                            extension.name);
+                }
+                continue;
+            }
+
+            if (extension.networkProducer == null) {
+                if (SSLLogger.isOn && SSLLogger.isOn("ssl,handshake")) {
+                    SSLLogger.warning(
+                            "Ignore, no extension producer defined: " +
+                            extension.name);
+                }
+                continue;
+            }
+
+            byte[] encoded = extension.produce(context, handshakeMessage);
+            if (encoded != null) {
+                extMap.put(extension, encoded);
+                encodedLength += encoded.length + 4; // extension_type (2)
+                                                     // extension_data length(2)
+            } else if (SSLLogger.isOn && SSLLogger.isOn("ssl,handshake")) {
+                // The extension is not available in the context.
+                SSLLogger.fine(
+                        "Ignore, context unavailable extension: " +
+                        extension.name);
+            }
+        }
+    }
+
+    /**
+     * Produce extension values for the specified extensions, replacing if
+     * there is an existing extension value for a specified extension.
+     */
+    void reproduce(HandshakeContext context,
+            SSLExtension[] extensions) throws IOException {
+        for (SSLExtension extension : extensions) {
+            if (extension.networkProducer == null) {
+                if (SSLLogger.isOn && SSLLogger.isOn("ssl,handshake")) {
+                    SSLLogger.warning(
+                            "Ignore, no extension producer defined: " +
+                            extension.name);
+                }
+                continue;
+            }
+
+            byte[] encoded = extension.produce(context, handshakeMessage);
+            if (encoded != null) {
+                if (extMap.containsKey(extension)) {
+                    byte[] old = extMap.replace(extension, encoded);
+                    if (old != null) {
+                        encodedLength -= old.length + 4;
+                    }
+                    encodedLength += encoded.length + 4;
+                } else {
+                    extMap.put(extension, encoded);
+                    encodedLength += encoded.length + 4;
+                                                    // extension_type (2)
+                                                    // extension_data length(2)
+                }
+            } else if (SSLLogger.isOn && SSLLogger.isOn("ssl,handshake")) {
+                // The extension is not available in the context.
+                SSLLogger.fine(
+                        "Ignore, context unavailable extension: " +
+                        extension.name);
+            }
+        }
+    }
+
+    // Note that TLS 1.3 may use empty extensions.  Please consider it while
+    // using this method.
+    int length() {
+        if (extMap.isEmpty()) {
+            return 0;
+        } else {
+            return encodedLength;
+        }
+    }
+
+    // Note that TLS 1.3 may use empty extensions.  Please consider it while
+    // using this method.
+    void send(HandshakeOutStream hos) throws IOException {
+        int extsLen = length();
+        if (extsLen == 0) {
+            return;
+        }
+        hos.putInt16(extsLen - 2);
+        // extensions must be sent in the order they appear in the enum
+        for (SSLExtension ext : SSLExtension.values()) {
+            byte[] extData = extMap.get(ext);
+            if (extData != null) {
+                hos.putInt16(ext.id);
+                hos.putBytes16(extData);
+            }
+        }
+    }
+
+    @Override
+    public String toString() {
+        if (extMap.isEmpty() && (logMap == null || logMap.isEmpty())) {
+            return "<no extension>";
+        } else {
+            StringBuilder builder = new StringBuilder(512);
+            if (logMap != null) {
+                for (Map.Entry<Integer, byte[]> en : logMap.entrySet()) {
+                    SSLExtension ext = SSLExtension.valueOf(
+                            handshakeMessage.handshakeType(), en.getKey());
+                    if (builder.length() != 0) {
+                        builder.append(",\n");
+                    }
+                    if (ext != null) {
+                        builder.append(
+                                ext.toString(ByteBuffer.wrap(en.getValue())));
+                    } else {
+                        builder.append(toString(en.getKey(), en.getValue()));
+                    }
+                }
+
+                return builder.toString();
+            } else {
+                for (Map.Entry<SSLExtension, byte[]> en : extMap.entrySet()) {
+                    if (builder.length() != 0) {
+                        builder.append(",\n");
+                    }
+                    builder.append(
+                        en.getKey().toString(ByteBuffer.wrap(en.getValue())));
+                }
+
+                return builder.toString();
+            }
+        }
+    }
+
+    private static String toString(int extId, byte[] extData) {
+        MessageFormat messageFormat = new MessageFormat(
+            "\"unknown extension ({0})\": '{'\n" +
+            "{1}\n" +
+            "'}'",
+            Locale.ENGLISH);
+
+        HexDumpEncoder hexEncoder = new HexDumpEncoder();
+        String encoded = hexEncoder.encodeBuffer(extData);
+
+        Object[] messageFields = {
+            extId,
+            Utilities.indent(encoded)
+        };
+
+        return messageFormat.format(messageFields);
+    }
+}