8217429: WebSocket over authenticating proxy fails to send Upgrade headers
authorchegar
Mon, 28 Jan 2019 13:51:16 +0000
changeset 53521 41fa3e6f2785
parent 53520 5178e4b58b17
child 53522 40eb23e0a8c5
8217429: WebSocket over authenticating proxy fails to send Upgrade headers Reviewed-by: dfuchs, prappo
src/java.net.http/share/classes/jdk/internal/net/http/HttpRequestImpl.java
src/java.net.http/share/classes/jdk/internal/net/http/common/Utils.java
src/java.net.http/share/classes/jdk/internal/net/http/websocket/OpeningHandshake.java
test/jdk/java/net/httpclient/ProxyServer.java
test/jdk/java/net/httpclient/websocket/DummyWebSocketServer.java
test/jdk/java/net/httpclient/websocket/Support.java
test/jdk/java/net/httpclient/websocket/WebSocketProxyTest.java
test/jdk/java/net/httpclient/websocket/WebSocketTest.java
--- a/src/java.net.http/share/classes/jdk/internal/net/http/HttpRequestImpl.java	Mon Jan 28 09:56:00 2019 +0100
+++ b/src/java.net.http/share/classes/jdk/internal/net/http/HttpRequestImpl.java	Mon Jan 28 13:51:16 2019 +0000
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2015, 2018, Oracle and/or its affiliates. All rights reserved.
+ * Copyright (c) 2015, 2019, Oracle and/or its affiliates. All rights reserved.
  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
  *
  * This code is free software; you can redistribute it and/or modify it
@@ -43,6 +43,7 @@
 import java.net.http.HttpRequest;
 import jdk.internal.net.http.common.HttpHeadersBuilder;
 import jdk.internal.net.http.common.Utils;
+import jdk.internal.net.http.websocket.OpeningHandshake;
 import jdk.internal.net.http.websocket.WebSocketRequest;
 
 import static jdk.internal.net.http.common.Utils.ALLOWED_HEADERS;
@@ -157,7 +158,11 @@
 
     /** Returns a new instance suitable for authentication. */
     public static HttpRequestImpl newInstanceForAuthentication(HttpRequestImpl other) {
-        return new HttpRequestImpl(other.uri(), other.method(), other);
+        HttpRequestImpl request = new HttpRequestImpl(other.uri(), other.method(), other);
+        if (request.isWebSocket()) {
+            Utils.setWebSocketUpgradeHeaders(request);
+        }
+        return request;
     }
 
     /**
--- a/src/java.net.http/share/classes/jdk/internal/net/http/common/Utils.java	Mon Jan 28 09:56:00 2019 +0100
+++ b/src/java.net.http/share/classes/jdk/internal/net/http/common/Utils.java	Mon Jan 28 13:51:16 2019 +0000
@@ -263,6 +263,15 @@
                       : ! PROXY_AUTH_DISABLED_SCHEMES.isEmpty();
     }
 
+    // WebSocket connection Upgrade headers
+    private static final String HEADER_CONNECTION = "Connection";
+    private static final String HEADER_UPGRADE    = "Upgrade";
+
+    public static final void setWebSocketUpgradeHeaders(HttpRequestImpl request) {
+        request.setSystemHeader(HEADER_UPGRADE, "websocket");
+        request.setSystemHeader(HEADER_CONNECTION, "Upgrade");
+    }
+
     public static IllegalArgumentException newIAE(String message, Object... args) {
         return new IllegalArgumentException(format(message, args));
     }
--- a/src/java.net.http/share/classes/jdk/internal/net/http/websocket/OpeningHandshake.java	Mon Jan 28 09:56:00 2019 +0100
+++ b/src/java.net.http/share/classes/jdk/internal/net/http/websocket/OpeningHandshake.java	Mon Jan 28 13:51:16 2019 +0000
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2015, 2018, Oracle and/or its affiliates. All rights reserved.
+ * Copyright (c) 2015, 2019, Oracle and/or its affiliates. All rights reserved.
  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
  *
  * This code is free software; you can redistribute it and/or modify it
@@ -143,8 +143,7 @@
         requestBuilder.version(Version.HTTP_1_1).GET();
         request = requestBuilder.buildForWebSocket();
         request.isWebSocket(true);
-        request.setSystemHeader(HEADER_UPGRADE, "websocket");
-        request.setSystemHeader(HEADER_CONNECTION, "Upgrade");
+        Utils.setWebSocketUpgradeHeaders(request);
         request.setProxy(proxy);
     }
 
--- a/test/jdk/java/net/httpclient/ProxyServer.java	Mon Jan 28 09:56:00 2019 +0100
+++ b/test/jdk/java/net/httpclient/ProxyServer.java	Mon Jan 28 13:51:16 2019 +0000
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2015, 2018, Oracle and/or its affiliates. All rights reserved.
+ * Copyright (c) 2015, 2019, Oracle and/or its affiliates. All rights reserved.
  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
  *
  * This code is free software; you can redistribute it and/or modify it
@@ -25,6 +25,9 @@
 import java.io.*;
 import java.util.*;
 import java.security.*;
+import static java.nio.charset.StandardCharsets.UTF_8;
+import static java.util.Arrays.asList;
+import static java.util.stream.Collectors.toList;
 
 /**
  * A minimal proxy server that supports CONNECT tunneling. It does not do
@@ -37,6 +40,18 @@
     ServerSocket listener;
     int port;
     volatile boolean debug;
+    private final Credentials credentials;  // may be null
+
+    private static class Credentials {
+        private final String name;
+        private final String password;
+        private Credentials(String name, String password) {
+            this.name = name;
+            this.password = password;
+        }
+        public String name() { return name; }
+        public String password() { return password; }
+    }
 
     /**
      * Create proxy on port (zero means don't care). Call getPort()
@@ -46,19 +61,42 @@
         this(port, false);
     }
 
-    public ProxyServer(Integer port, Boolean debug) throws IOException {
+    public ProxyServer(Integer port,
+                       Boolean debug,
+                       String username,
+                       String password)
+        throws IOException
+    {
+        this(port, debug, new Credentials(username, password));
+    }
+
+    public ProxyServer(Integer port,
+                       Boolean debug)
+        throws IOException
+    {
+        this(port, debug, null);
+    }
+
+    public ProxyServer(Integer port,
+                       Boolean debug,
+                       Credentials credentials)
+        throws IOException
+    {
         this.debug = debug;
         listener = new ServerSocket();
         listener.setReuseAddress(false);
         listener.bind(new InetSocketAddress(InetAddress.getLoopbackAddress(), port));
         this.port = listener.getLocalPort();
+        this.credentials = credentials;
         setName("ProxyListener");
         setDaemon(true);
         connections = new LinkedList<>();
         start();
     }
 
-    public ProxyServer(String s) {  }
+    public ProxyServer(String s) {
+        credentials = null;
+    }
 
     /**
      * Returns the port number this proxy is listening on
@@ -194,16 +232,69 @@
             return -1;
         }
 
+        // Checks credentials in the request against those allowable by the proxy.
+        private boolean authorized(Credentials credentials,
+                                   List<String> requestHeaders) {
+            List<String> authorization = requestHeaders.stream()
+                    .filter(n -> n.toLowerCase(Locale.US).startsWith("proxy-authorization"))
+                    .collect(toList());
+
+            if (authorization.isEmpty())
+                return false;
+
+            if (authorization.size() != 1) {
+                throw new IllegalStateException("Authorization unexpected count:" + authorization);
+            }
+            String value = authorization.get(0).substring("proxy-authorization".length()).trim();
+            if (!value.startsWith(":"))
+                throw new IllegalStateException("Authorization malformed: " + value);
+            value = value.substring(1).trim();
+
+            if (!value.startsWith("Basic "))
+                throw new IllegalStateException("Authorization not Basic: " + value);
+
+            value = value.substring("Basic ".length());
+            String values = new String(Base64.getDecoder().decode(value), UTF_8);
+            int sep = values.indexOf(':');
+            if (sep < 1) {
+                throw new IllegalStateException("Authorization no colon: " +  values);
+            }
+            String name = values.substring(0, sep);
+            String password = values.substring(sep + 1);
+
+            if (name.equals(credentials.name()) && password.equals(credentials.password()))
+                return true;
+
+            return false;
+        }
+
         public void init() {
             try {
-                byte[] buf = readHeaders(clientIn);
+                byte[] buf;
+                while (true) {
+                    buf = readHeaders(clientIn);
+                    if (findCRLF(buf) == -1) {
+                        close();
+                        return;
+                    }
+
+                    List<String> headers = asList(new String(buf, UTF_8).split("\r\n"));
+                    // check authorization credentials, if required by the server
+                    if (credentials != null && !authorized(credentials, headers)) {
+                        String resp = "HTTP/1.1 407 Proxy Authentication Required\r\n" +
+                                      "Content-Length: 0\r\n" +
+                                      "Proxy-Authenticate: Basic realm=\"proxy realm\"\r\n\r\n";
+
+                        clientOut.write(resp.getBytes(UTF_8));
+                    } else {
+                        break;
+                    }
+                }
+
                 int p = findCRLF(buf);
-                if (p == -1) {
-                    close();
-                    return;
-                }
                 String cmd = new String(buf, 0, p, "US-ASCII");
                 String[] params = cmd.split(" ");
+
                 if (params[0].equals("CONNECT")) {
                     doTunnel(params[1]);
                 } else {
--- a/test/jdk/java/net/httpclient/websocket/DummyWebSocketServer.java	Mon Jan 28 09:56:00 2019 +0100
+++ b/test/jdk/java/net/httpclient/websocket/DummyWebSocketServer.java	Mon Jan 28 13:51:16 2019 +0000
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2016, 2018, Oracle and/or its affiliates. All rights reserved.
+ * Copyright (c) 2016, 2019, Oracle and/or its affiliates. All rights reserved.
  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
  *
  * This code is free software; you can redistribute it and/or modify it
@@ -46,13 +46,14 @@
 import java.util.Map;
 import java.util.concurrent.CountDownLatch;
 import java.util.concurrent.atomic.AtomicBoolean;
-import java.util.function.Function;
+import java.util.function.BiFunction;
 import java.util.regex.Pattern;
 import java.util.stream.Collectors;
 
 import static java.lang.String.format;
 import static java.lang.System.err;
 import static java.nio.charset.StandardCharsets.ISO_8859_1;
+import static java.nio.charset.StandardCharsets.UTF_8;
 import static java.util.Arrays.asList;
 import static java.util.Objects.requireNonNull;
 
@@ -92,12 +93,32 @@
     private ByteBuffer read = ByteBuffer.allocate(16384);
     private final CountDownLatch readReady = new CountDownLatch(1);
 
-    public DummyWebSocketServer() {
-        this(defaultMapping());
+    private static class Credentials {
+        private final String name;
+        private final String password;
+        private Credentials(String name, String password) {
+            this.name = name;
+            this.password = password;
+        }
+        public String name() { return name; }
+        public String password() { return password; }
     }
 
-    public DummyWebSocketServer(Function<List<String>, List<String>> mapping) {
+    public DummyWebSocketServer() {
+        this(defaultMapping(), null, null);
+    }
+
+    public DummyWebSocketServer(String username, String password) {
+        this(defaultMapping(), username, password);
+    }
+
+    public DummyWebSocketServer(BiFunction<List<String>,Credentials,List<String>> mapping,
+                                String username,
+                                String password) {
         requireNonNull(mapping);
+        Credentials credentials = username != null ?
+                new Credentials(username, password) : null;
+
         thread = new Thread(() -> {
             try {
                 while (!Thread.currentThread().isInterrupted()) {
@@ -107,14 +128,23 @@
                     try {
                         channel.setOption(StandardSocketOptions.TCP_NODELAY, true);
                         channel.configureBlocking(true);
-                        StringBuilder request = new StringBuilder();
-                        if (!readRequest(channel, request)) {
-                            throw new IOException("Bad request:" + request);
+                        while (true) {
+                            StringBuilder request = new StringBuilder();
+                            if (!readRequest(channel, request)) {
+                                throw new IOException("Bad request:[" + request + "]");
+                            }
+                            List<String> strings = asList(request.toString().split("\r\n"));
+                            List<String> response = mapping.apply(strings, credentials);
+                            writeResponse(channel, response);
+
+                            if (response.get(0).startsWith("HTTP/1.1 401")) {
+                                err.println("Sent 401 Authentication response " + channel);
+                                continue;
+                            } else {
+                                serve(channel);
+                                break;
+                            }
                         }
-                        List<String> strings = asList(request.toString().split("\r\n"));
-                        List<String> response = mapping.apply(strings);
-                        writeResponse(channel, response);
-                        serve(channel);
                     } catch (IOException e) {
                         err.println("Error in connection: " + channel + ", " + e);
                     } finally {
@@ -125,7 +155,7 @@
                 }
             } catch (ClosedByInterruptException ignored) {
             } catch (Exception e) {
-                err.println(e);
+                e.printStackTrace(err);
             } finally {
                 close(ssc);
                 err.println("Stopped at: " + getURI());
@@ -256,8 +286,8 @@
         }
     }
 
-    private static Function<List<String>, List<String>> defaultMapping() {
-        return request -> {
+    private static BiFunction<List<String>,Credentials,List<String>> defaultMapping() {
+        return (request, credentials) -> {
             List<String> response = new LinkedList<>();
             Iterator<String> iterator = request.iterator();
             if (!iterator.hasNext()) {
@@ -309,14 +339,57 @@
             sha1.update(x.getBytes(ISO_8859_1));
             String v = Base64.getEncoder().encodeToString(sha1.digest());
             response.add("Sec-WebSocket-Accept: " + v);
+
+            // check authorization credentials, if required by the server
+            if (credentials != null && !authorized(credentials, requestHeaders)) {
+                response.clear();
+                response.add("HTTP/1.1 401 Unauthorized");
+                response.add("Content-Length: 0");
+                response.add("WWW-Authenticate: Basic realm=\"dummy server realm\"");
+            }
+
             return response;
         };
     }
 
+    // Checks credentials in the request against those allowable by the server.
+    private static boolean authorized(Credentials credentials,
+                                      Map<String,List<String>> requestHeaders) {
+        List<String> authorization = requestHeaders.get("Authorization");
+        if (authorization == null)
+            return false;
+
+        if (authorization.size() != 1) {
+            throw new IllegalStateException("Authorization unexpected count:" + authorization);
+        }
+        String header = authorization.get(0);
+        if (!header.startsWith("Basic "))
+            throw new IllegalStateException("Authorization not Basic: " + header);
+
+        header = header.substring("Basic ".length());
+        String values = new String(Base64.getDecoder().decode(header), UTF_8);
+        int sep = values.indexOf(':');
+        if (sep < 1) {
+            throw new IllegalStateException("Authorization not colon: " +  values);
+        }
+        String name = values.substring(0, sep);
+        String password = values.substring(sep + 1);
+
+        if (name.equals(credentials.name()) && password.equals(credentials.password()))
+            return true;
+
+        return false;
+    }
+
     protected static String expectHeader(Map<String, List<String>> headers,
                                          String name,
                                          String value) {
         List<String> v = headers.get(name);
+        if (v == null) {
+            throw new IllegalStateException(
+                    format("Expected '%s' header, not present in %s",
+                           name, headers));
+        }
         if (!v.contains(value)) {
             throw new IllegalStateException(
                     format("Expected '%s: %s', actual: '%s: %s'",
--- a/test/jdk/java/net/httpclient/websocket/Support.java	Mon Jan 28 09:56:00 2019 +0100
+++ b/test/jdk/java/net/httpclient/websocket/Support.java	Mon Jan 28 13:51:16 2019 +0000
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2018, Oracle and/or its affiliates. All rights reserved.
+ * Copyright (c) 2018, 2019, Oracle and/or its affiliates. All rights reserved.
  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
  *
  * This code is free software; you can redistribute it and/or modify it
@@ -79,16 +79,32 @@
     }
 
     public static DummyWebSocketServer serverWithCannedData(int... data) {
+        return serverWithCannedDataAndAuthentication(null, null, data);
+    }
+
+    public static DummyWebSocketServer serverWithCannedDataAndAuthentication(
+            String username,
+            String password,
+            int... data)
+    {
         byte[] copy = new byte[data.length];
         for (int i = 0; i < data.length; i++) {
             copy[i] = (byte) data[i];
         }
-        return serverWithCannedData(copy);
+        return serverWithCannedDataAndAuthentication(username, password, copy);
     }
 
     public static DummyWebSocketServer serverWithCannedData(byte... data) {
+       return serverWithCannedDataAndAuthentication(null, null, data);
+    }
+
+    public static DummyWebSocketServer serverWithCannedDataAndAuthentication(
+            String username,
+            String password,
+            byte... data)
+    {
         byte[] copy = Arrays.copyOf(data, data.length);
-        return new DummyWebSocketServer() {
+        return new DummyWebSocketServer(username, password) {
             @Override
             protected void write(SocketChannel ch) throws IOException {
                 int off = 0; int n = 1; // 1 byte at a time
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/test/jdk/java/net/httpclient/websocket/WebSocketProxyTest.java	Mon Jan 28 13:51:16 2019 +0000
@@ -0,0 +1,309 @@
+/*
+ * Copyright (c) 2019, Oracle and/or its affiliates. All rights reserved.
+ * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
+ *
+ * This code is free software; you can redistribute it and/or modify it
+ * under the terms of the GNU General Public License version 2 only, as
+ * published by the Free Software Foundation.
+ *
+ * This code is distributed in the hope that it will be useful, but WITHOUT
+ * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
+ * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
+ * version 2 for more details (a copy is included in the LICENSE file that
+ * accompanied this code).
+ *
+ * You should have received a copy of the GNU General Public License version
+ * 2 along with this work; if not, write to the Free Software Foundation,
+ * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
+ *
+ * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
+ * or visit www.oracle.com if you need additional information or have any
+ * questions.
+ */
+
+/*
+ * @test
+ * @bug 8217429
+ * @summary WebSocket proxy tunneling tests
+ * @compile DummyWebSocketServer.java ../ProxyServer.java
+ * @run testng/othervm
+ *         -Djdk.http.auth.tunneling.disabledSchemes=
+ *         WebSocketProxyTest
+ */
+
+import java.io.IOException;
+import java.io.UncheckedIOException;
+import java.net.Authenticator;
+import java.net.InetAddress;
+import java.net.InetSocketAddress;
+import java.net.PasswordAuthentication;
+import java.net.ProxySelector;
+import java.net.http.HttpResponse;
+import java.net.http.WebSocket;
+import java.net.http.WebSocketHandshakeException;
+import java.nio.ByteBuffer;
+import java.nio.charset.StandardCharsets;
+import java.util.ArrayList;
+import java.util.Base64;
+import java.util.List;
+import java.util.concurrent.CompletableFuture;
+import java.util.concurrent.CompletionException;
+import java.util.concurrent.CompletionStage;
+import java.util.function.Function;
+import java.util.function.Supplier;
+import java.util.stream.Collectors;
+import org.testng.annotations.DataProvider;
+import org.testng.annotations.Test;
+import static java.net.http.HttpClient.newBuilder;
+import static java.nio.charset.StandardCharsets.UTF_8;
+import static org.testng.Assert.assertEquals;
+import static org.testng.FileAssert.fail;
+
+public class WebSocketProxyTest {
+
+    // Used to verify a proxy/websocket server requiring Authentication
+    private static final String USERNAME = "wally";
+    private static final String PASSWORD = "xyz987";
+
+    static class WSAuthenticator extends Authenticator {
+        @Override
+        protected PasswordAuthentication getPasswordAuthentication() {
+            return new PasswordAuthentication(USERNAME, PASSWORD.toCharArray());
+        }
+    }
+
+    static final Function<int[],DummyWebSocketServer> SERVER_WITH_CANNED_DATA =
+        new Function<>() {
+            @Override public DummyWebSocketServer apply(int[] data) {
+                return Support.serverWithCannedData(data); }
+            @Override public String toString() { return "SERVER_WITH_CANNED_DATA"; }
+        };
+
+    static final Function<int[],DummyWebSocketServer> AUTH_SERVER_WITH_CANNED_DATA =
+        new Function<>() {
+            @Override public DummyWebSocketServer apply(int[] data) {
+                return Support.serverWithCannedDataAndAuthentication(USERNAME, PASSWORD, data); }
+            @Override public String toString() { return "AUTH_SERVER_WITH_CANNED_DATA"; }
+        };
+
+    static final Supplier<ProxyServer> TUNNELING_PROXY_SERVER =
+        new Supplier<>() {
+            @Override public ProxyServer get() {
+                try { return new ProxyServer(0, true);}
+                catch(IOException e) { throw new UncheckedIOException(e); } }
+            @Override public String toString() { return "TUNNELING_PROXY_SERVER"; }
+        };
+    static final Supplier<ProxyServer> AUTH_TUNNELING_PROXY_SERVER =
+        new Supplier<>() {
+            @Override public ProxyServer get() {
+                try { return new ProxyServer(0, true, USERNAME, PASSWORD);}
+                catch(IOException e) { throw new UncheckedIOException(e); } }
+            @Override public String toString() { return "AUTH_TUNNELING_PROXY_SERVER"; }
+        };
+
+    @DataProvider(name = "servers")
+    public Object[][] servers() {
+        return new Object[][] {
+            { SERVER_WITH_CANNED_DATA,      TUNNELING_PROXY_SERVER      },
+            { SERVER_WITH_CANNED_DATA,      AUTH_TUNNELING_PROXY_SERVER },
+            { AUTH_SERVER_WITH_CANNED_DATA, TUNNELING_PROXY_SERVER      },
+        };
+    }
+
+    @Test(dataProvider = "servers")
+    public void simpleAggregatingBinaryMessages
+            (Function<int[],DummyWebSocketServer> serverSupplier,
+             Supplier<ProxyServer> proxyServerSupplier)
+        throws IOException
+    {
+        List<byte[]> expected = List.of("hello", "chegar")
+                .stream()
+                .map(s -> s.getBytes(StandardCharsets.US_ASCII))
+                .collect(Collectors.toList());
+        int[] binary = new int[]{
+                0x82, 0x05, 0x68, 0x65, 0x6C, 0x6C, 0x6F,       // hello
+                0x82, 0x06, 0x63, 0x68, 0x65, 0x67, 0x61, 0x72, // chegar
+                0x88, 0x00                                      // <CLOSE>
+        };
+        CompletableFuture<List<byte[]>> actual = new CompletableFuture<>();
+
+        try (var proxyServer = proxyServerSupplier.get();
+             var server = serverSupplier.apply(binary)) {
+
+            InetSocketAddress proxyAddress = new InetSocketAddress(
+                    InetAddress.getLoopbackAddress(), proxyServer.getPort());
+            server.open();
+
+            WebSocket.Listener listener = new WebSocket.Listener() {
+
+                List<byte[]> collectedBytes = new ArrayList<>();
+                ByteBuffer buffer = ByteBuffer.allocate(1024);
+
+                @Override
+                public CompletionStage<?> onBinary(WebSocket webSocket,
+                                                   ByteBuffer message,
+                                                   boolean last) {
+                    System.out.printf("onBinary(%s, %s)%n", message, last);
+                    webSocket.request(1);
+
+                    append(message);
+                    if (last) {
+                        buffer.flip();
+                        byte[] bytes = new byte[buffer.remaining()];
+                        buffer.get(bytes);
+                        buffer.clear();
+                        processWholeBinary(bytes);
+                    }
+                    return null;
+                }
+
+                private void append(ByteBuffer message) {
+                    if (buffer.remaining() < message.remaining()) {
+                        assert message.remaining() > 0;
+                        int cap = (buffer.capacity() + message.remaining()) * 2;
+                        ByteBuffer b = ByteBuffer.allocate(cap);
+                        b.put(buffer.flip());
+                        buffer = b;
+                    }
+                    buffer.put(message);
+                }
+
+                private void processWholeBinary(byte[] bytes) {
+                    String stringBytes = new String(bytes, UTF_8);
+                    System.out.println("processWholeBinary: " + stringBytes);
+                    collectedBytes.add(bytes);
+                }
+
+                @Override
+                public CompletionStage<?> onClose(WebSocket webSocket,
+                                                  int statusCode,
+                                                  String reason) {
+                    actual.complete(collectedBytes);
+                    return null;
+                }
+
+                @Override
+                public void onError(WebSocket webSocket, Throwable error) {
+                    actual.completeExceptionally(error);
+                }
+            };
+
+            var webSocket = newBuilder()
+                    .proxy(ProxySelector.of(proxyAddress))
+                    .authenticator(new WSAuthenticator())
+                    .build().newWebSocketBuilder()
+                    .buildAsync(server.getURI(), listener)
+                    .join();
+
+            List<byte[]> a = actual.join();
+            assertEquals(a, expected);
+        }
+    }
+
+    // -- authentication specific tests
+
+    /*
+     * Ensures authentication succeeds when an Authenticator set on client builder.
+     */
+    @Test
+    public void clientAuthenticate() throws IOException  {
+        try (var proxyServer = AUTH_TUNNELING_PROXY_SERVER.get();
+             var server = new DummyWebSocketServer()){
+            server.open();
+            InetSocketAddress proxyAddress = new InetSocketAddress(
+                    InetAddress.getLoopbackAddress(), proxyServer.getPort());
+
+            var webSocket = newBuilder()
+                    .proxy(ProxySelector.of(proxyAddress))
+                    .authenticator(new WSAuthenticator())
+                    .build()
+                    .newWebSocketBuilder()
+                    .buildAsync(server.getURI(), new WebSocket.Listener() { })
+                    .join();
+        }
+    }
+
+    /*
+     * Ensures authentication succeeds when an `Authorization` header is explicitly set.
+     */
+    @Test
+    public void explicitAuthenticate() throws IOException  {
+        try (var proxyServer = AUTH_TUNNELING_PROXY_SERVER.get();
+             var server = new DummyWebSocketServer()) {
+            server.open();
+            InetSocketAddress proxyAddress = new InetSocketAddress(
+                    InetAddress.getLoopbackAddress(), proxyServer.getPort());
+
+            String hv = "Basic " + Base64.getEncoder().encodeToString(
+                    (USERNAME + ":" + PASSWORD).getBytes(UTF_8));
+
+            var webSocket = newBuilder()
+                    .proxy(ProxySelector.of(proxyAddress)).build()
+                    .newWebSocketBuilder()
+                    .header("Proxy-Authorization", hv)
+                    .buildAsync(server.getURI(), new WebSocket.Listener() { })
+                    .join();
+        }
+    }
+
+    /*
+     * Ensures authentication does not succeed when no authenticator is present.
+     */
+    @Test
+    public void failNoAuthenticator() throws IOException  {
+        try (var proxyServer = AUTH_TUNNELING_PROXY_SERVER.get();
+             var server = new DummyWebSocketServer(USERNAME, PASSWORD)) {
+            server.open();
+            InetSocketAddress proxyAddress = new InetSocketAddress(
+                    InetAddress.getLoopbackAddress(), proxyServer.getPort());
+
+            CompletableFuture<WebSocket> cf = newBuilder()
+                    .proxy(ProxySelector.of(proxyAddress)).build()
+                    .newWebSocketBuilder()
+                    .buildAsync(server.getURI(), new WebSocket.Listener() { });
+
+            try {
+                var webSocket = cf.join();
+                fail("Expected exception not thrown");
+            } catch (CompletionException expected) {
+                WebSocketHandshakeException e = (WebSocketHandshakeException)expected.getCause();
+                HttpResponse<?> response = e.getResponse();
+                assertEquals(response.statusCode(), 407);
+            }
+        }
+    }
+
+    /*
+     * Ensures authentication does not succeed when the authenticator presents
+     * unauthorized credentials.
+     */
+    @Test
+    public void failBadCredentials() throws IOException  {
+        try (var proxyServer = AUTH_TUNNELING_PROXY_SERVER.get();
+             var server = new DummyWebSocketServer(USERNAME, PASSWORD)) {
+            server.open();
+            InetSocketAddress proxyAddress = new InetSocketAddress(
+                    InetAddress.getLoopbackAddress(), proxyServer.getPort());
+
+            Authenticator authenticator = new Authenticator() {
+                @Override protected PasswordAuthentication getPasswordAuthentication() {
+                    return new PasswordAuthentication("BAD"+USERNAME, "".toCharArray());
+                }
+            };
+
+            CompletableFuture<WebSocket> cf = newBuilder()
+                    .proxy(ProxySelector.of(proxyAddress))
+                    .authenticator(authenticator)
+                    .build()
+                    .newWebSocketBuilder()
+                    .buildAsync(server.getURI(), new WebSocket.Listener() { });
+
+            try {
+                var webSocket = cf.join();
+                fail("Expected exception not thrown");
+            } catch (CompletionException expected) {
+                System.out.println("caught expected exception:" + expected);
+            }
+        }
+    }
+}
--- a/test/jdk/java/net/httpclient/websocket/WebSocketTest.java	Mon Jan 28 09:56:00 2019 +0100
+++ b/test/jdk/java/net/httpclient/websocket/WebSocketTest.java	Mon Jan 28 13:51:16 2019 +0000
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2018, Oracle and/or its affiliates. All rights reserved.
+ * Copyright (c) 2018, 2019, Oracle and/or its affiliates. All rights reserved.
  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
  *
  * This code is free software; you can redistribute it and/or modify it
@@ -23,6 +23,7 @@
 
 /*
  * @test
+ * @bug 8217429
  * @build DummyWebSocketServer
  * @run testng/othervm
  *       WebSocketTest
@@ -33,23 +34,32 @@
 import org.testng.annotations.Test;
 
 import java.io.IOException;
+import java.net.Authenticator;
+import java.net.PasswordAuthentication;
+import java.net.http.HttpResponse;
 import java.net.http.WebSocket;
+import java.net.http.WebSocketHandshakeException;
 import java.nio.ByteBuffer;
 import java.nio.charset.StandardCharsets;
 import java.util.ArrayList;
+import java.util.Base64;
 import java.util.List;
 import java.util.concurrent.CompletableFuture;
+import java.util.concurrent.CompletionException;
 import java.util.concurrent.CompletionStage;
 import java.util.concurrent.TimeUnit;
 import java.util.concurrent.atomic.AtomicBoolean;
+import java.util.function.Function;
 import java.util.function.Supplier;
 import java.util.stream.Collectors;
 
 import static java.net.http.HttpClient.Builder.NO_PROXY;
 import static java.net.http.HttpClient.newBuilder;
 import static java.net.http.WebSocket.NORMAL_CLOSURE;
+import static java.nio.charset.StandardCharsets.UTF_8;
 import static org.testng.Assert.assertEquals;
 import static org.testng.Assert.assertThrows;
+import static org.testng.Assert.fail;
 
 public class WebSocketTest {
 
@@ -68,8 +78,11 @@
 
     @AfterTest
     public void cleanup() {
-        server.close();
-        webSocket.abort();
+        System.out.println("AFTER TEST");
+        if (server != null)
+            server.close();
+        if (webSocket != null)
+            webSocket.abort();
     }
 
     @Test
@@ -134,6 +147,8 @@
         assertThrows(IAE, () -> webSocket.request(Long.MIN_VALUE));
         assertThrows(IAE, () -> webSocket.request(-1));
         assertThrows(IAE, () -> webSocket.request(0));
+
+        server.close();
     }
 
     @Test
@@ -149,6 +164,7 @@
         // Pings & Pongs are fine
         webSocket.sendPing(ByteBuffer.allocate(125)).join();
         webSocket.sendPong(ByteBuffer.allocate(125)).join();
+        server.close();
     }
 
     @Test
@@ -165,6 +181,7 @@
         // Pings & Pongs are fine
         webSocket.sendPing(ByteBuffer.allocate(125)).join();
         webSocket.sendPong(ByteBuffer.allocate(125)).join();
+        server.close();
     }
 
     @Test
@@ -198,6 +215,8 @@
         assertFails(IOE, webSocket.sendPong(ByteBuffer.allocate(124)));
         assertFails(IOE, webSocket.sendPong(ByteBuffer.allocate(1)));
         assertFails(IOE, webSocket.sendPong(ByteBuffer.allocate(0)));
+
+        server.close();
     }
 
     @DataProvider(name = "sequence")
@@ -318,6 +337,8 @@
         listener.invocations();
         violation.complete(null); // won't affect if completed exceptionally
         violation.join();
+
+        server.close();
     }
 
     @Test
@@ -372,10 +393,48 @@
         assertFails(IOE, webSocket.sendPong(ByteBuffer.allocate(124)));
         assertFails(IOE, webSocket.sendPong(ByteBuffer.allocate(1)));
         assertFails(IOE, webSocket.sendPong(ByteBuffer.allocate(0)));
+
+        server.close();
+    }
+
+    // Used to verify a server requiring Authentication
+    private static final String USERNAME = "chegar";
+    private static final String PASSWORD = "a1b2c3";
+
+    static class WSAuthenticator extends Authenticator {
+        @Override
+        protected PasswordAuthentication getPasswordAuthentication() {
+            return new PasswordAuthentication(USERNAME, PASSWORD.toCharArray());
+        }
     }
 
-    @Test
-    public void simpleAggregatingBinaryMessages() throws IOException {
+    static final Function<int[],DummyWebSocketServer> SERVER_WITH_CANNED_DATA =
+        new Function<>() {
+            @Override public DummyWebSocketServer apply(int[] data) {
+                return Support.serverWithCannedData(data); }
+            @Override public String toString() { return "SERVER_WITH_CANNED_DATA"; }
+        };
+
+    static final Function<int[],DummyWebSocketServer> AUTH_SERVER_WITH_CANNED_DATA =
+        new Function<>() {
+            @Override public DummyWebSocketServer apply(int[] data) {
+                return Support.serverWithCannedDataAndAuthentication(USERNAME, PASSWORD, data); }
+            @Override public String toString() { return "AUTH_SERVER_WITH_CANNED_DATA"; }
+        };
+
+    @DataProvider(name = "servers")
+    public Object[][] servers() {
+        return new Object[][] {
+            { SERVER_WITH_CANNED_DATA },
+            { AUTH_SERVER_WITH_CANNED_DATA },
+        };
+    }
+
+    @Test(dataProvider = "servers")
+    public void simpleAggregatingBinaryMessages
+            (Function<int[],DummyWebSocketServer> serverSupplier)
+        throws IOException
+    {
         List<byte[]> expected = List.of("alpha", "beta", "gamma", "delta")
                 .stream()
                 .map(s -> s.getBytes(StandardCharsets.US_ASCII))
@@ -399,7 +458,7 @@
         };
         CompletableFuture<List<byte[]>> actual = new CompletableFuture<>();
 
-        server = Support.serverWithCannedData(binary);
+        server = serverSupplier.apply(binary);
         server.open();
 
         WebSocket.Listener listener = new WebSocket.Listener() {
@@ -437,7 +496,7 @@
             }
 
             private void processWholeBinary(byte[] bytes) {
-                String stringBytes = new String(bytes, StandardCharsets.UTF_8);
+                String stringBytes = new String(bytes, UTF_8);
                 System.out.println("processWholeBinary: " + stringBytes);
                 collectedBytes.add(bytes);
             }
@@ -456,17 +515,24 @@
             }
         };
 
-        webSocket = newBuilder().proxy(NO_PROXY).build().newWebSocketBuilder()
+        webSocket = newBuilder()
+                .proxy(NO_PROXY)
+                .authenticator(new WSAuthenticator())
+                .build().newWebSocketBuilder()
                 .buildAsync(server.getURI(), listener)
                 .join();
 
         List<byte[]> a = actual.join();
         assertEquals(a, expected);
+
+        server.close();
     }
 
-    @Test
-    public void simpleAggregatingTextMessages() throws IOException {
-
+    @Test(dataProvider = "servers")
+    public void simpleAggregatingTextMessages
+            (Function<int[],DummyWebSocketServer> serverSupplier)
+        throws IOException
+    {
         List<String> expected = List.of("alpha", "beta", "gamma", "delta");
 
         int[] binary = new int[]{
@@ -488,7 +554,7 @@
         };
         CompletableFuture<List<String>> actual = new CompletableFuture<>();
 
-        server = Support.serverWithCannedData(binary);
+        server = serverSupplier.apply(binary);
         server.open();
 
         WebSocket.Listener listener = new WebSocket.Listener() {
@@ -530,21 +596,28 @@
             }
         };
 
-        webSocket = newBuilder().proxy(NO_PROXY).build().newWebSocketBuilder()
+        webSocket = newBuilder()
+                .proxy(NO_PROXY)
+                .authenticator(new WSAuthenticator())
+                .build().newWebSocketBuilder()
                 .buildAsync(server.getURI(), listener)
                 .join();
 
         List<String> a = actual.join();
         assertEquals(a, expected);
+
+        server.close();
     }
 
     /*
      * Exercises the scenario where requests for more messages are made prior to
      * completing the returned CompletionStage instances.
      */
-    @Test
-    public void aggregatingTextMessages() throws IOException {
-
+    @Test(dataProvider = "servers")
+    public void aggregatingTextMessages
+        (Function<int[],DummyWebSocketServer> serverSupplier)
+        throws IOException
+    {
         List<String> expected = List.of("alpha", "beta", "gamma", "delta");
 
         int[] binary = new int[]{
@@ -566,8 +639,7 @@
         };
         CompletableFuture<List<String>> actual = new CompletableFuture<>();
 
-
-        server = Support.serverWithCannedData(binary);
+        server = serverSupplier.apply(binary);
         server.open();
 
         WebSocket.Listener listener = new WebSocket.Listener() {
@@ -623,11 +695,111 @@
             }
         };
 
-        webSocket = newBuilder().proxy(NO_PROXY).build().newWebSocketBuilder()
+        webSocket = newBuilder()
+                .proxy(NO_PROXY)
+                .authenticator(new WSAuthenticator())
+                .build().newWebSocketBuilder()
                 .buildAsync(server.getURI(), listener)
                 .join();
 
         List<String> a = actual.join();
         assertEquals(a, expected);
+
+        server.close();
+    }
+
+    // -- authentication specific tests
+
+    /*
+     * Ensures authentication succeeds when an Authenticator set on client builder.
+     */
+    @Test
+    public void clientAuthenticate() throws IOException  {
+        try (var server = new DummyWebSocketServer(USERNAME, PASSWORD)){
+            server.open();
+
+            var webSocket = newBuilder()
+                    .proxy(NO_PROXY)
+                    .authenticator(new WSAuthenticator())
+                    .build()
+                    .newWebSocketBuilder()
+                    .buildAsync(server.getURI(), new WebSocket.Listener() { })
+                    .join();
+        }
+    }
+
+    /*
+     * Ensures authentication succeeds when an `Authorization` header is explicitly set.
+     */
+    @Test
+    public void explicitAuthenticate() throws IOException  {
+        try (var server = new DummyWebSocketServer(USERNAME, PASSWORD)) {
+            server.open();
+
+            String hv = "Basic " + Base64.getEncoder().encodeToString(
+                    (USERNAME + ":" + PASSWORD).getBytes(UTF_8));
+
+            var webSocket = newBuilder()
+                    .proxy(NO_PROXY).build()
+                    .newWebSocketBuilder()
+                    .header("Authorization", hv)
+                    .buildAsync(server.getURI(), new WebSocket.Listener() { })
+                    .join();
+        }
+    }
+
+    /*
+     * Ensures authentication does not succeed when no authenticator is present.
+     */
+    @Test
+    public void failNoAuthenticator() throws IOException  {
+        try (var server = new DummyWebSocketServer(USERNAME, PASSWORD)) {
+            server.open();
+
+            CompletableFuture<WebSocket> cf = newBuilder()
+                    .proxy(NO_PROXY).build()
+                    .newWebSocketBuilder()
+                    .buildAsync(server.getURI(), new WebSocket.Listener() { });
+
+            try {
+                var webSocket = cf.join();
+                fail("Expected exception not thrown");
+            } catch (CompletionException expected) {
+                WebSocketHandshakeException e = (WebSocketHandshakeException)expected.getCause();
+                HttpResponse<?> response = e.getResponse();
+                assertEquals(response.statusCode(), 401);
+            }
+        }
+    }
+
+    /*
+     * Ensures authentication does not succeed when the authenticator presents
+     * unauthorized credentials.
+     */
+    @Test
+    public void failBadCredentials() throws IOException  {
+        try (var server = new DummyWebSocketServer(USERNAME, PASSWORD)) {
+            server.open();
+
+            Authenticator authenticator = new Authenticator() {
+                @Override protected PasswordAuthentication getPasswordAuthentication() {
+                    return new PasswordAuthentication("BAD"+USERNAME, "".toCharArray());
+                }
+            };
+
+            CompletableFuture<WebSocket> cf = newBuilder()
+                    .proxy(NO_PROXY)
+                    .authenticator(authenticator)
+                    .build()
+                    .newWebSocketBuilder()
+                    .buildAsync(server.getURI(), new WebSocket.Listener() { });
+
+            try {
+                var webSocket = cf.join();
+                fail("Expected exception not thrown");
+            } catch (CompletionException expected) {
+                System.out.println("caught expected exception:" + expected);
+            }
+        }
     }
 }