src/java.net.http/share/classes/jdk/internal/net/http/websocket/OpeningHandshake.java
branchhttp-client-branch
changeset 56092 fd85b2bf2b0d
parent 56089 42208b2f224e
child 56167 96fa4f49a9ff
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/src/java.net.http/share/classes/jdk/internal/net/http/websocket/OpeningHandshake.java	Wed Feb 07 21:45:37 2018 +0000
@@ -0,0 +1,392 @@
+/*
+ * Copyright (c) 2015, 2018, Oracle and/or its affiliates. All rights reserved.
+ * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
+ *
+ * This code is free software; you can redistribute it and/or modify it
+ * under the terms of the GNU General Public License version 2 only, as
+ * published by the Free Software Foundation.  Oracle designates this
+ * particular file as subject to the "Classpath" exception as provided
+ * by Oracle in the LICENSE file that accompanied this code.
+ *
+ * This code is distributed in the hope that it will be useful, but WITHOUT
+ * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
+ * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
+ * version 2 for more details (a copy is included in the LICENSE file that
+ * accompanied this code).
+ *
+ * You should have received a copy of the GNU General Public License version
+ * 2 along with this work; if not, write to the Free Software Foundation,
+ * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
+ *
+ * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
+ * or visit www.oracle.com if you need additional information or have any
+ * questions.
+ */
+
+package jdk.internal.net.http.websocket;
+
+import java.net.http.HttpClient;
+import java.net.http.HttpClient.Version;
+import java.net.http.HttpHeaders;
+import java.net.http.HttpRequest;
+import java.net.http.HttpResponse;
+import java.net.http.HttpResponse.BodyHandler;
+import java.net.http.WebSocketHandshakeException;
+import jdk.internal.net.http.common.MinimalFuture;
+import jdk.internal.net.http.common.Pair;
+import jdk.internal.net.http.common.Utils;
+
+import java.io.IOException;
+import java.net.InetSocketAddress;
+import java.net.Proxy;
+import java.net.ProxySelector;
+import java.net.URI;
+import java.net.URISyntaxException;
+import java.net.URLPermission;
+import java.nio.charset.StandardCharsets;
+import java.security.AccessController;
+import java.security.MessageDigest;
+import java.security.NoSuchAlgorithmException;
+import java.security.PrivilegedAction;
+import java.security.SecureRandom;
+import java.time.Duration;
+import java.util.Base64;
+import java.util.Collection;
+import java.util.Collections;
+import java.util.LinkedHashSet;
+import java.util.List;
+import java.util.Optional;
+import java.util.Set;
+import java.util.TreeSet;
+import java.util.concurrent.CompletableFuture;
+import java.util.stream.Collectors;
+import java.util.stream.Stream;
+
+import static java.lang.String.format;
+import static jdk.internal.net.http.common.Utils.isValidName;
+import static jdk.internal.net.http.common.Utils.permissionForProxy;
+import static jdk.internal.net.http.common.Utils.stringOf;
+
+public class OpeningHandshake {
+
+    private static final String HEADER_CONNECTION = "Connection";
+    private static final String HEADER_UPGRADE    = "Upgrade";
+    private static final String HEADER_ACCEPT     = "Sec-WebSocket-Accept";
+    private static final String HEADER_EXTENSIONS = "Sec-WebSocket-Extensions";
+    private static final String HEADER_KEY        = "Sec-WebSocket-Key";
+    private static final String HEADER_PROTOCOL   = "Sec-WebSocket-Protocol";
+    private static final String HEADER_VERSION    = "Sec-WebSocket-Version";
+    private static final String VERSION           = "13";  // WebSocket's lucky number
+
+    private static final Set<String> ILLEGAL_HEADERS;
+
+    static {
+        ILLEGAL_HEADERS = new TreeSet<>(String.CASE_INSENSITIVE_ORDER);
+        ILLEGAL_HEADERS.addAll(List.of(HEADER_ACCEPT,
+                                       HEADER_EXTENSIONS,
+                                       HEADER_KEY,
+                                       HEADER_PROTOCOL,
+                                       HEADER_VERSION));
+    }
+
+    private static final SecureRandom random = new SecureRandom();
+
+    private final MessageDigest sha1;
+    private final HttpClient client;
+
+    {
+        try {
+            sha1 = MessageDigest.getInstance("SHA-1");
+        } catch (NoSuchAlgorithmException e) {
+            // Shouldn't happen: SHA-1 must be available in every Java platform
+            // implementation
+            throw new InternalError("Minimum requirements", e);
+        }
+    }
+
+    private final HttpRequest request;
+    private final Collection<String> subprotocols;
+    private final String nonce;
+
+    public OpeningHandshake(BuilderImpl b) {
+        checkURI(b.getUri());
+        Proxy proxy = proxyFor(b.getProxySelector(), b.getUri());
+        checkPermissions(b, proxy);
+        this.client = b.getClient();
+        URI httpURI = createRequestURI(b.getUri());
+        HttpRequest.Builder requestBuilder = HttpRequest.newBuilder(httpURI);
+        Duration connectTimeout = b.getConnectTimeout();
+        if (connectTimeout != null) {
+            requestBuilder.timeout(connectTimeout);
+        }
+        for (Pair<String, String> p : b.getHeaders()) {
+            if (ILLEGAL_HEADERS.contains(p.first)) {
+                throw illegal("Illegal header: " + p.first);
+            }
+            requestBuilder.header(p.first, p.second);
+        }
+        this.subprotocols = createRequestSubprotocols(b.getSubprotocols());
+        if (!this.subprotocols.isEmpty()) {
+            String p = this.subprotocols.stream().collect(Collectors.joining(", "));
+            requestBuilder.header(HEADER_PROTOCOL, p);
+        }
+        requestBuilder.header(HEADER_VERSION, VERSION);
+        this.nonce = createNonce();
+        requestBuilder.header(HEADER_KEY, this.nonce);
+        // Setting request version to HTTP/1.1 forcibly, since it's not possible
+        // to upgrade from HTTP/2 to WebSocket (as of August 2016):
+        //
+        //     https://tools.ietf.org/html/draft-hirano-httpbis-websocket-over-http2-00
+        this.request = requestBuilder.version(Version.HTTP_1_1).GET().build();
+        WebSocketRequest r = (WebSocketRequest) this.request;
+        r.isWebSocket(true);
+        r.setSystemHeader(HEADER_UPGRADE, "websocket");
+        r.setSystemHeader(HEADER_CONNECTION, "Upgrade");
+        r.setProxy(proxy);
+    }
+
+    private static Collection<String> createRequestSubprotocols(
+            Collection<String> subprotocols)
+    {
+        LinkedHashSet<String> sp = new LinkedHashSet<>(subprotocols.size(), 1);
+        for (String s : subprotocols) {
+            if (s.trim().isEmpty() || !isValidName(s)) {
+                throw illegal("Bad subprotocol syntax: " + s);
+            }
+            if (!sp.add(s)) {
+                throw illegal("Duplicating subprotocol: " + s);
+            }
+        }
+        return Collections.unmodifiableCollection(sp);
+    }
+
+    /*
+     * Checks the given URI for being a WebSocket URI and translates it into a
+     * target HTTP URI for the Opening Handshake.
+     *
+     * https://tools.ietf.org/html/rfc6455#section-3
+     */
+    static URI createRequestURI(URI uri) {
+        String s = uri.getScheme();
+        assert "ws".equalsIgnoreCase(s) || "wss".equalsIgnoreCase(s);
+        String scheme = "ws".equalsIgnoreCase(s) ? "http" : "https";
+        try {
+            return new URI(scheme,
+                           uri.getUserInfo(),
+                           uri.getHost(),
+                           uri.getPort(),
+                           uri.getPath(),
+                           uri.getQuery(),
+                           null); // No fragment
+        } catch (URISyntaxException e) {
+            // Shouldn't happen: URI invariant
+            throw new InternalError(e);
+        }
+    }
+
+    public CompletableFuture<Result> send() {
+        PrivilegedAction<CompletableFuture<Result>> pa = () ->
+                client.sendAsync(this.request, BodyHandler.discard())
+                      .thenCompose(this::resultFrom);
+        return AccessController.doPrivileged(pa);
+    }
+
+    /*
+     * The result of the opening handshake.
+     */
+    static final class Result {
+
+        final String subprotocol;
+        final TransportFactory transport;
+
+        private Result(String subprotocol, TransportFactory transport) {
+            this.subprotocol = subprotocol;
+            this.transport = transport;
+        }
+    }
+
+    private CompletableFuture<Result> resultFrom(HttpResponse<?> response) {
+        // Do we need a special treatment for SSLHandshakeException?
+        // Namely, invoking
+        //
+        //     Listener.onClose(StatusCodes.TLS_HANDSHAKE_FAILURE, "")
+        //
+        // See https://tools.ietf.org/html/rfc6455#section-7.4.1
+        Result result = null;
+        Exception exception = null;
+        try {
+            result = handleResponse(response);
+        } catch (IOException e) {
+            exception = e;
+        } catch (Exception e) {
+            exception = new WebSocketHandshakeException(response).initCause(e);
+        }
+        if (exception == null) {
+            return MinimalFuture.completedFuture(result);
+        }
+        try {
+            ((RawChannel.Provider) response).rawChannel().close();
+        } catch (IOException e) {
+            exception.addSuppressed(e);
+        }
+        return MinimalFuture.failedFuture(exception);
+    }
+
+    private Result handleResponse(HttpResponse<?> response) throws IOException {
+        // By this point all redirects, authentications, etc. (if any) MUST have
+        // been done by the HttpClient used by the WebSocket; so only 101 is
+        // expected
+        int c = response.statusCode();
+        if (c != 101) {
+            throw checkFailed("Unexpected HTTP response status code " + c);
+        }
+        HttpHeaders headers = response.headers();
+        String upgrade = requireSingle(headers, HEADER_UPGRADE);
+        if (!upgrade.equalsIgnoreCase("websocket")) {
+            throw checkFailed("Bad response field: " + HEADER_UPGRADE);
+        }
+        String connection = requireSingle(headers, HEADER_CONNECTION);
+        if (!connection.equalsIgnoreCase("Upgrade")) {
+            throw checkFailed("Bad response field: " + HEADER_CONNECTION);
+        }
+        Optional<String> version = requireAtMostOne(headers, HEADER_VERSION);
+        if (version.isPresent() && !version.get().equals(VERSION)) {
+            throw checkFailed("Bad response field: " + HEADER_VERSION);
+        }
+        requireAbsent(headers, HEADER_EXTENSIONS);
+        String x = this.nonce + "258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
+        this.sha1.update(x.getBytes(StandardCharsets.ISO_8859_1));
+        String expected = Base64.getEncoder().encodeToString(this.sha1.digest());
+        String actual = requireSingle(headers, HEADER_ACCEPT);
+        if (!actual.trim().equals(expected)) {
+            throw checkFailed("Bad " + HEADER_ACCEPT);
+        }
+        String subprotocol = checkAndReturnSubprotocol(headers);
+        RawChannel channel = ((RawChannel.Provider) response).rawChannel();
+        return new Result(subprotocol, new TransportFactoryImpl(channel));
+    }
+
+    private String checkAndReturnSubprotocol(HttpHeaders responseHeaders)
+            throws CheckFailedException
+    {
+        Optional<String> opt = responseHeaders.firstValue(HEADER_PROTOCOL);
+        if (!opt.isPresent()) {
+            // If there is no such header in the response, then the server
+            // doesn't want to use any subprotocol
+            return "";
+        }
+        String s = requireSingle(responseHeaders, HEADER_PROTOCOL);
+        // An empty string as a subprotocol's name is not allowed by the spec
+        // and the check below will detect such responses too
+        if (this.subprotocols.contains(s)) {
+            return s;
+        } else {
+            throw checkFailed("Unexpected subprotocol: " + s);
+        }
+    }
+
+    private static void requireAbsent(HttpHeaders responseHeaders,
+                                      String headerName)
+    {
+        List<String> values = responseHeaders.allValues(headerName);
+        if (!values.isEmpty()) {
+            throw checkFailed(format("Response field '%s' present: %s",
+                                     headerName,
+                                     stringOf(values)));
+        }
+    }
+
+    private static Optional<String> requireAtMostOne(HttpHeaders responseHeaders,
+                                                     String headerName)
+    {
+        List<String> values = responseHeaders.allValues(headerName);
+        if (values.size() > 1) {
+            throw checkFailed(format("Response field '%s' multivalued: %s",
+                                     headerName,
+                                     stringOf(values)));
+        }
+        return values.stream().findFirst();
+    }
+
+    private static String requireSingle(HttpHeaders responseHeaders,
+                                        String headerName)
+    {
+        List<String> values = responseHeaders.allValues(headerName);
+        if (values.isEmpty()) {
+            throw checkFailed("Response field missing: " + headerName);
+        } else if (values.size() > 1) {
+            throw checkFailed(format("Response field '%s' multivalued: %s",
+                                     headerName,
+                                     stringOf(values)));
+        }
+        return values.get(0);
+    }
+
+    private static String createNonce() {
+        byte[] bytes = new byte[16];
+        OpeningHandshake.random.nextBytes(bytes);
+        return Base64.getEncoder().encodeToString(bytes);
+    }
+
+    private static CheckFailedException checkFailed(String message) {
+        throw new CheckFailedException(message);
+    }
+
+    private static URI checkURI(URI uri) {
+        String scheme = uri.getScheme();
+        if (!("ws".equalsIgnoreCase(scheme) || "wss".equalsIgnoreCase(scheme)))
+            throw illegal("invalid URI scheme: " + scheme);
+        if (uri.getHost() == null)
+            throw illegal("URI must contain a host: " + uri);
+        if (uri.getFragment() != null)
+            throw illegal("URI must not contain a fragment: " + uri);
+        return uri;
+    }
+
+    private static IllegalArgumentException illegal(String message) {
+        return new IllegalArgumentException(message);
+    }
+
+    /**
+     * Returns the proxy for the given URI when sent through the given client,
+     * or {@code null} if none is required or applicable.
+     */
+    private static Proxy proxyFor(Optional<ProxySelector> selector, URI uri) {
+        if (!selector.isPresent()) {
+            return null;
+        }
+        URI requestURI = createRequestURI(uri); // Based on the HTTP scheme
+        List<Proxy> pl = selector.get().select(requestURI);
+        if (pl.isEmpty()) {
+            return null;
+        }
+        Proxy proxy = pl.get(0);
+        if (proxy.type() != Proxy.Type.HTTP) {
+            return null;
+        }
+        return proxy;
+    }
+
+    /**
+     * Performs the necessary security permissions checks to connect ( possibly
+     * through a proxy ) to the builders WebSocket URI.
+     *
+     * @throws SecurityException if the security manager denies access
+     */
+    static void checkPermissions(BuilderImpl b, Proxy proxy) {
+        SecurityManager sm = System.getSecurityManager();
+        if (sm == null) {
+            return;
+        }
+        Stream<String> headers = b.getHeaders().stream().map(p -> p.first).distinct();
+        URLPermission perm1 = Utils.permissionForServer(b.getUri(), "", headers);
+        sm.checkPermission(perm1);
+        if (proxy == null) {
+            return;
+        }
+        URLPermission perm2 = permissionForProxy((InetSocketAddress) proxy.address());
+        if (perm2 != null) {
+            sm.checkPermission(perm2);
+        }
+    }
+}