test/jdk/java/net/httpclient/websocket/WebSocketTest.java
changeset 53521 41fa3e6f2785
parent 50681 4254bed3c09d
child 58289 3a79d4cccbcb
--- a/test/jdk/java/net/httpclient/websocket/WebSocketTest.java	Mon Jan 28 09:56:00 2019 +0100
+++ b/test/jdk/java/net/httpclient/websocket/WebSocketTest.java	Mon Jan 28 13:51:16 2019 +0000
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2018, Oracle and/or its affiliates. All rights reserved.
+ * Copyright (c) 2018, 2019, Oracle and/or its affiliates. All rights reserved.
  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
  *
  * This code is free software; you can redistribute it and/or modify it
@@ -23,6 +23,7 @@
 
 /*
  * @test
+ * @bug 8217429
  * @build DummyWebSocketServer
  * @run testng/othervm
  *       WebSocketTest
@@ -33,23 +34,32 @@
 import org.testng.annotations.Test;
 
 import java.io.IOException;
+import java.net.Authenticator;
+import java.net.PasswordAuthentication;
+import java.net.http.HttpResponse;
 import java.net.http.WebSocket;
+import java.net.http.WebSocketHandshakeException;
 import java.nio.ByteBuffer;
 import java.nio.charset.StandardCharsets;
 import java.util.ArrayList;
+import java.util.Base64;
 import java.util.List;
 import java.util.concurrent.CompletableFuture;
+import java.util.concurrent.CompletionException;
 import java.util.concurrent.CompletionStage;
 import java.util.concurrent.TimeUnit;
 import java.util.concurrent.atomic.AtomicBoolean;
+import java.util.function.Function;
 import java.util.function.Supplier;
 import java.util.stream.Collectors;
 
 import static java.net.http.HttpClient.Builder.NO_PROXY;
 import static java.net.http.HttpClient.newBuilder;
 import static java.net.http.WebSocket.NORMAL_CLOSURE;
+import static java.nio.charset.StandardCharsets.UTF_8;
 import static org.testng.Assert.assertEquals;
 import static org.testng.Assert.assertThrows;
+import static org.testng.Assert.fail;
 
 public class WebSocketTest {
 
@@ -68,8 +78,11 @@
 
     @AfterTest
     public void cleanup() {
-        server.close();
-        webSocket.abort();
+        System.out.println("AFTER TEST");
+        if (server != null)
+            server.close();
+        if (webSocket != null)
+            webSocket.abort();
     }
 
     @Test
@@ -134,6 +147,8 @@
         assertThrows(IAE, () -> webSocket.request(Long.MIN_VALUE));
         assertThrows(IAE, () -> webSocket.request(-1));
         assertThrows(IAE, () -> webSocket.request(0));
+
+        server.close();
     }
 
     @Test
@@ -149,6 +164,7 @@
         // Pings & Pongs are fine
         webSocket.sendPing(ByteBuffer.allocate(125)).join();
         webSocket.sendPong(ByteBuffer.allocate(125)).join();
+        server.close();
     }
 
     @Test
@@ -165,6 +181,7 @@
         // Pings & Pongs are fine
         webSocket.sendPing(ByteBuffer.allocate(125)).join();
         webSocket.sendPong(ByteBuffer.allocate(125)).join();
+        server.close();
     }
 
     @Test
@@ -198,6 +215,8 @@
         assertFails(IOE, webSocket.sendPong(ByteBuffer.allocate(124)));
         assertFails(IOE, webSocket.sendPong(ByteBuffer.allocate(1)));
         assertFails(IOE, webSocket.sendPong(ByteBuffer.allocate(0)));
+
+        server.close();
     }
 
     @DataProvider(name = "sequence")
@@ -318,6 +337,8 @@
         listener.invocations();
         violation.complete(null); // won't affect if completed exceptionally
         violation.join();
+
+        server.close();
     }
 
     @Test
@@ -372,10 +393,48 @@
         assertFails(IOE, webSocket.sendPong(ByteBuffer.allocate(124)));
         assertFails(IOE, webSocket.sendPong(ByteBuffer.allocate(1)));
         assertFails(IOE, webSocket.sendPong(ByteBuffer.allocate(0)));
+
+        server.close();
+    }
+
+    // Used to verify a server requiring Authentication
+    private static final String USERNAME = "chegar";
+    private static final String PASSWORD = "a1b2c3";
+
+    static class WSAuthenticator extends Authenticator {
+        @Override
+        protected PasswordAuthentication getPasswordAuthentication() {
+            return new PasswordAuthentication(USERNAME, PASSWORD.toCharArray());
+        }
     }
 
-    @Test
-    public void simpleAggregatingBinaryMessages() throws IOException {
+    static final Function<int[],DummyWebSocketServer> SERVER_WITH_CANNED_DATA =
+        new Function<>() {
+            @Override public DummyWebSocketServer apply(int[] data) {
+                return Support.serverWithCannedData(data); }
+            @Override public String toString() { return "SERVER_WITH_CANNED_DATA"; }
+        };
+
+    static final Function<int[],DummyWebSocketServer> AUTH_SERVER_WITH_CANNED_DATA =
+        new Function<>() {
+            @Override public DummyWebSocketServer apply(int[] data) {
+                return Support.serverWithCannedDataAndAuthentication(USERNAME, PASSWORD, data); }
+            @Override public String toString() { return "AUTH_SERVER_WITH_CANNED_DATA"; }
+        };
+
+    @DataProvider(name = "servers")
+    public Object[][] servers() {
+        return new Object[][] {
+            { SERVER_WITH_CANNED_DATA },
+            { AUTH_SERVER_WITH_CANNED_DATA },
+        };
+    }
+
+    @Test(dataProvider = "servers")
+    public void simpleAggregatingBinaryMessages
+            (Function<int[],DummyWebSocketServer> serverSupplier)
+        throws IOException
+    {
         List<byte[]> expected = List.of("alpha", "beta", "gamma", "delta")
                 .stream()
                 .map(s -> s.getBytes(StandardCharsets.US_ASCII))
@@ -399,7 +458,7 @@
         };
         CompletableFuture<List<byte[]>> actual = new CompletableFuture<>();
 
-        server = Support.serverWithCannedData(binary);
+        server = serverSupplier.apply(binary);
         server.open();
 
         WebSocket.Listener listener = new WebSocket.Listener() {
@@ -437,7 +496,7 @@
             }
 
             private void processWholeBinary(byte[] bytes) {
-                String stringBytes = new String(bytes, StandardCharsets.UTF_8);
+                String stringBytes = new String(bytes, UTF_8);
                 System.out.println("processWholeBinary: " + stringBytes);
                 collectedBytes.add(bytes);
             }
@@ -456,17 +515,24 @@
             }
         };
 
-        webSocket = newBuilder().proxy(NO_PROXY).build().newWebSocketBuilder()
+        webSocket = newBuilder()
+                .proxy(NO_PROXY)
+                .authenticator(new WSAuthenticator())
+                .build().newWebSocketBuilder()
                 .buildAsync(server.getURI(), listener)
                 .join();
 
         List<byte[]> a = actual.join();
         assertEquals(a, expected);
+
+        server.close();
     }
 
-    @Test
-    public void simpleAggregatingTextMessages() throws IOException {
-
+    @Test(dataProvider = "servers")
+    public void simpleAggregatingTextMessages
+            (Function<int[],DummyWebSocketServer> serverSupplier)
+        throws IOException
+    {
         List<String> expected = List.of("alpha", "beta", "gamma", "delta");
 
         int[] binary = new int[]{
@@ -488,7 +554,7 @@
         };
         CompletableFuture<List<String>> actual = new CompletableFuture<>();
 
-        server = Support.serverWithCannedData(binary);
+        server = serverSupplier.apply(binary);
         server.open();
 
         WebSocket.Listener listener = new WebSocket.Listener() {
@@ -530,21 +596,28 @@
             }
         };
 
-        webSocket = newBuilder().proxy(NO_PROXY).build().newWebSocketBuilder()
+        webSocket = newBuilder()
+                .proxy(NO_PROXY)
+                .authenticator(new WSAuthenticator())
+                .build().newWebSocketBuilder()
                 .buildAsync(server.getURI(), listener)
                 .join();
 
         List<String> a = actual.join();
         assertEquals(a, expected);
+
+        server.close();
     }
 
     /*
      * Exercises the scenario where requests for more messages are made prior to
      * completing the returned CompletionStage instances.
      */
-    @Test
-    public void aggregatingTextMessages() throws IOException {
-
+    @Test(dataProvider = "servers")
+    public void aggregatingTextMessages
+        (Function<int[],DummyWebSocketServer> serverSupplier)
+        throws IOException
+    {
         List<String> expected = List.of("alpha", "beta", "gamma", "delta");
 
         int[] binary = new int[]{
@@ -566,8 +639,7 @@
         };
         CompletableFuture<List<String>> actual = new CompletableFuture<>();
 
-
-        server = Support.serverWithCannedData(binary);
+        server = serverSupplier.apply(binary);
         server.open();
 
         WebSocket.Listener listener = new WebSocket.Listener() {
@@ -623,11 +695,111 @@
             }
         };
 
-        webSocket = newBuilder().proxy(NO_PROXY).build().newWebSocketBuilder()
+        webSocket = newBuilder()
+                .proxy(NO_PROXY)
+                .authenticator(new WSAuthenticator())
+                .build().newWebSocketBuilder()
                 .buildAsync(server.getURI(), listener)
                 .join();
 
         List<String> a = actual.join();
         assertEquals(a, expected);
+
+        server.close();
+    }
+
+    // -- authentication specific tests
+
+    /*
+     * Ensures authentication succeeds when an Authenticator set on client builder.
+     */
+    @Test
+    public void clientAuthenticate() throws IOException  {
+        try (var server = new DummyWebSocketServer(USERNAME, PASSWORD)){
+            server.open();
+
+            var webSocket = newBuilder()
+                    .proxy(NO_PROXY)
+                    .authenticator(new WSAuthenticator())
+                    .build()
+                    .newWebSocketBuilder()
+                    .buildAsync(server.getURI(), new WebSocket.Listener() { })
+                    .join();
+        }
+    }
+
+    /*
+     * Ensures authentication succeeds when an `Authorization` header is explicitly set.
+     */
+    @Test
+    public void explicitAuthenticate() throws IOException  {
+        try (var server = new DummyWebSocketServer(USERNAME, PASSWORD)) {
+            server.open();
+
+            String hv = "Basic " + Base64.getEncoder().encodeToString(
+                    (USERNAME + ":" + PASSWORD).getBytes(UTF_8));
+
+            var webSocket = newBuilder()
+                    .proxy(NO_PROXY).build()
+                    .newWebSocketBuilder()
+                    .header("Authorization", hv)
+                    .buildAsync(server.getURI(), new WebSocket.Listener() { })
+                    .join();
+        }
+    }
+
+    /*
+     * Ensures authentication does not succeed when no authenticator is present.
+     */
+    @Test
+    public void failNoAuthenticator() throws IOException  {
+        try (var server = new DummyWebSocketServer(USERNAME, PASSWORD)) {
+            server.open();
+
+            CompletableFuture<WebSocket> cf = newBuilder()
+                    .proxy(NO_PROXY).build()
+                    .newWebSocketBuilder()
+                    .buildAsync(server.getURI(), new WebSocket.Listener() { });
+
+            try {
+                var webSocket = cf.join();
+                fail("Expected exception not thrown");
+            } catch (CompletionException expected) {
+                WebSocketHandshakeException e = (WebSocketHandshakeException)expected.getCause();
+                HttpResponse<?> response = e.getResponse();
+                assertEquals(response.statusCode(), 401);
+            }
+        }
+    }
+
+    /*
+     * Ensures authentication does not succeed when the authenticator presents
+     * unauthorized credentials.
+     */
+    @Test
+    public void failBadCredentials() throws IOException  {
+        try (var server = new DummyWebSocketServer(USERNAME, PASSWORD)) {
+            server.open();
+
+            Authenticator authenticator = new Authenticator() {
+                @Override protected PasswordAuthentication getPasswordAuthentication() {
+                    return new PasswordAuthentication("BAD"+USERNAME, "".toCharArray());
+                }
+            };
+
+            CompletableFuture<WebSocket> cf = newBuilder()
+                    .proxy(NO_PROXY)
+                    .authenticator(authenticator)
+                    .build()
+                    .newWebSocketBuilder()
+                    .buildAsync(server.getURI(), new WebSocket.Listener() { });
+
+            try {
+                var webSocket = cf.join();
+                fail("Expected exception not thrown");
+            } catch (CompletionException expected) {
+                System.out.println("caught expected exception:" + expected);
+            }
+        }
     }
 }