http-client-branch: (WebSocket) immutable builder copy; illegal headers; http-client-branch
authorprappo
Fri, 17 Nov 2017 13:55:41 +0300
branchhttp-client-branch
changeset 55824 b922df193260
parent 55823 bf59e29eff0c
child 55825 5928d92183d2
child 55826 66b5d6013d85
http-client-branch: (WebSocket) immutable builder copy; illegal headers;
src/jdk.incubator.httpclient/share/classes/jdk/incubator/http/internal/websocket/BuilderImpl.java
src/jdk.incubator.httpclient/share/classes/jdk/incubator/http/internal/websocket/OpeningHandshake.java
src/jdk.incubator.httpclient/share/classes/jdk/incubator/http/internal/websocket/WebSocketImpl.java
test/jdk/java/net/httpclient/websocket/jdk.incubator.httpclient/jdk/incubator/http/internal/websocket/BuildingWebSocketTest.java
test/jdk/java/net/httpclient/websocket/jdk.incubator.httpclient/jdk/incubator/http/internal/websocket/TestSupport.java
--- a/src/jdk.incubator.httpclient/share/classes/jdk/incubator/http/internal/websocket/BuilderImpl.java	Fri Nov 17 10:45:26 2017 +0000
+++ b/src/jdk.incubator.httpclient/share/classes/jdk/incubator/http/internal/websocket/BuilderImpl.java	Fri Nov 17 13:55:41 2017 +0300
@@ -49,19 +49,38 @@
     private final HttpClient client;
     private final URI uri;
     private final Listener listener;
-    private final Collection<Pair<String, String>> headers = new LinkedList<>();
-    private final Collection<String> subprotocols = new LinkedList<>();
     private final Optional<ProxySelector> proxySelector;
+    private final Collection<Pair<String, String>> headers;
+    private final Collection<String> subprotocols;
     private Duration timeout;
 
-    public BuilderImpl(HttpClient client, URI uri, Listener listener, ProxySelector proxySelector) {
+    public BuilderImpl(HttpClient client,
+                       URI uri,
+                       Listener listener,
+                       ProxySelector proxySelector)
+    {
+        this(client, uri, listener, Optional.ofNullable(proxySelector),
+             new LinkedList<>(), new LinkedList<>(), null);
+    }
+
+    private BuilderImpl(HttpClient client,
+                        URI uri,
+                        Listener listener,
+                        Optional<ProxySelector> proxySelector,
+                        Collection<Pair<String, String>> headers,
+                        Collection<String> subprotocols,
+                        Duration timeout) {
         this.client = requireNonNull(client, "client");
         this.uri = checkURI(requireNonNull(uri, "uri"));
         this.listener = requireNonNull(listener, "listener");
-        this.proxySelector = Optional.ofNullable(proxySelector);
-        // if the proxy selector was supplied by the user, it should be present
-        // on the client and should be the same than what we get as argument.
-        assert !client.proxy().isPresent() || client.proxy().get() == proxySelector;
+        this.proxySelector = proxySelector;
+        // If a proxy selector was supplied by the user, it should be present
+        // on the client and should be the same that what we got as an argument
+        assert !client.proxy().isPresent()
+                || client.proxy().equals(proxySelector);
+        this.headers = requireNonNull(headers);
+        this.subprotocols = requireNonNull(subprotocols);
+        this.timeout = timeout;
     }
 
     private static IllegalArgumentException newIAE(String message, Object... args) {
@@ -116,7 +135,8 @@
 
     @Override
     public CompletableFuture<WebSocket> buildAsync() {
-        return WebSocketImpl.newInstanceAsync(this);
+        BuilderImpl copy = immutableCopy();
+        return WebSocketImpl.newInstanceAsync(copy);
     }
 
     HttpClient getClient() { return client; }
@@ -131,5 +151,18 @@
 
     Duration getConnectTimeout() { return timeout; }
 
-    Optional<ProxySelector> proxySelector() { return proxySelector; }
+    Optional<ProxySelector> getProxySelector() { return proxySelector; }
+
+    BuilderImpl immutableCopy() {
+        @SuppressWarnings({"unchecked", "rawtypes"})
+        BuilderImpl copy = new BuilderImpl(
+                client,
+                uri,
+                listener,
+                proxySelector,
+                List.of(this.headers.toArray(new Pair[0])),
+                List.of(this.subprotocols.toArray(new String[0])),
+                timeout);
+        return copy;
+    }
 }
--- a/src/jdk.incubator.httpclient/share/classes/jdk/incubator/http/internal/websocket/OpeningHandshake.java	Fri Nov 17 10:45:26 2017 +0000
+++ b/src/jdk.incubator.httpclient/share/classes/jdk/incubator/http/internal/websocket/OpeningHandshake.java	Fri Nov 17 13:55:41 2017 +0300
@@ -72,15 +72,15 @@
     private static final String HEADER_PROTOCOL   = "Sec-WebSocket-Protocol";
     private static final String HEADER_VERSION    = "Sec-WebSocket-Version";
 
-    private static final Set<String> FORBIDDEN_HEADERS;
+    private static final Set<String> ILLEGAL_HEADERS;
 
     static {
-        FORBIDDEN_HEADERS = new TreeSet<>(String.CASE_INSENSITIVE_ORDER);
-        FORBIDDEN_HEADERS.addAll(List.of(HEADER_ACCEPT,
-                                         HEADER_EXTENSIONS,
-                                         HEADER_KEY,
-                                         HEADER_PROTOCOL,
-                                         HEADER_VERSION));
+        ILLEGAL_HEADERS = new TreeSet<>(String.CASE_INSENSITIVE_ORDER);
+        ILLEGAL_HEADERS.addAll(List.of(HEADER_ACCEPT,
+                                       HEADER_EXTENSIONS,
+                                       HEADER_KEY,
+                                       HEADER_PROTOCOL,
+                                       HEADER_VERSION));
     }
 
     private static final SecureRandom srandom = new SecureRandom();
@@ -111,7 +111,7 @@
             requestBuilder.timeout(connectTimeout);
         }
         for (Pair<String, String> p : b.getHeaders()) {
-            if (FORBIDDEN_HEADERS.contains(p.first)) {
+            if (ILLEGAL_HEADERS.contains(p.first)) {
                 throw illegal("Illegal header: " + p.first);
             }
             requestBuilder.header(p.first, p.second);
--- a/src/jdk.incubator.httpclient/share/classes/jdk/incubator/http/internal/websocket/WebSocketImpl.java	Fri Nov 17 10:45:26 2017 +0000
+++ b/src/jdk.incubator.httpclient/share/classes/jdk/incubator/http/internal/websocket/WebSocketImpl.java	Fri Nov 17 13:55:41 2017 +0300
@@ -34,7 +34,6 @@
 import java.net.URI;
 import java.net.URLPermission;
 import java.nio.ByteBuffer;
-import java.util.Collection;
 import java.util.List;
 import java.util.Optional;
 import java.util.Queue;
@@ -45,7 +44,6 @@
 import java.util.function.Consumer;
 import java.util.function.Function;
 
-import jdk.incubator.http.HttpClient;
 import jdk.incubator.http.WebSocket;
 import jdk.incubator.http.internal.common.Log;
 import jdk.incubator.http.internal.common.MinimalFuture;
@@ -162,8 +160,8 @@
     }
 
     static CompletableFuture<WebSocket> newInstanceAsync(BuilderImpl b) {
-        // TODO: a security issue? TOCTOU: two accesses to b.getURI
-        Proxy proxy = proxyFor(b.proxySelector(), b.getUri());
+        URI uri = b.getUri();
+        Proxy proxy = proxyFor(b.getProxySelector(), uri);
         try {
             checkPermissions(b, proxy);
         } catch (Throwable throwable) {
@@ -171,7 +169,7 @@
         }
 
         Function<Result, WebSocket> newWebSocket = r -> {
-            WebSocketImpl ws = new WebSocketImpl(b.getUri(),
+            WebSocketImpl ws = new WebSocketImpl(uri,
                                                  r.subprotocol,
                                                  r.channel,
                                                  b.getListener());
--- a/test/jdk/java/net/httpclient/websocket/jdk.incubator.httpclient/jdk/incubator/http/internal/websocket/BuildingWebSocketTest.java	Fri Nov 17 10:45:26 2017 +0000
+++ b/test/jdk/java/net/httpclient/websocket/jdk.incubator.httpclient/jdk/incubator/http/internal/websocket/BuildingWebSocketTest.java	Fri Nov 17 13:55:41 2017 +0300
@@ -96,22 +96,7 @@
 
     @Test
     public void illegalHeaders() {
-        List<String> headers = List.of("Authorization",
-                                       "Connection",
-                                       "Cookie",
-                                       "Content-Length",
-                                       "Date",
-                                       "Expect",
-                                       "From",
-                                       "Host",
-                                       "Origin",
-                                       "Proxy-Authorization",
-                                       "Referer",
-                                       "User-agent",
-                                       "Upgrade",
-                                       "Via",
-                                       "Warning",
-                                       "Sec-WebSocket-Accept",
+        List<String> headers = List.of("Sec-WebSocket-Accept",
                                        "Sec-WebSocket-Extensions",
                                        "Sec-WebSocket-Key",
                                        "Sec-WebSocket-Protocol",
@@ -123,6 +108,7 @@
                         .newHttpClient()
                         .newWebSocketBuilder(URI.create("ws://websocket.example.com"),
                                              listener())
+                        .header(header, "value")
                         .buildAsync();
 
         headers.forEach(h -> assertCompletesExceptionally(IllegalArgumentException.class, f.apply(h)));
--- a/test/jdk/java/net/httpclient/websocket/jdk.incubator.httpclient/jdk/incubator/http/internal/websocket/TestSupport.java	Fri Nov 17 10:45:26 2017 +0000
+++ b/test/jdk/java/net/httpclient/websocket/jdk.incubator.httpclient/jdk/incubator/http/internal/websocket/TestSupport.java	Fri Nov 17 13:55:41 2017 +0300
@@ -480,7 +480,7 @@
                                                   CompletionStage<?> stage) {
         CompletableFuture<?> cf =
                 CompletableFuture.completedFuture(null).thenCompose(x -> stage);
-        return assertThrows(t -> clazz.isInstance(t.getCause()), cf::get);
+        return assertThrows(t -> clazz == t.getCause().getClass(), cf::get);
     }
 
     interface ThrowingProcedure {