test/jdk/java/net/httpclient/http2/BadHeadersTest.java
branchhttp-client-branch
changeset 56619 57f17e890a40
parent 56598 4c502e3991bf
child 56771 73a6534bce94
--- a/test/jdk/java/net/httpclient/http2/BadHeadersTest.java	Mon May 28 17:22:37 2018 +0100
+++ b/test/jdk/java/net/httpclient/http2/BadHeadersTest.java	Fri May 25 16:13:11 2018 +0100
@@ -33,8 +33,7 @@
  * @run testng/othervm -Djdk.internal.httpclient.debug=true BadHeadersTest
  */
 
-import jdk.internal.net.http.common.HttpHeadersImpl;
-import jdk.internal.net.http.common.Pair;
+import jdk.internal.net.http.common.HttpHeadersBuilder;
 import jdk.internal.net.http.frame.ContinuationFrame;
 import jdk.internal.net.http.frame.HeaderFrame;
 import jdk.internal.net.http.frame.HeadersFrame;
@@ -44,41 +43,38 @@
 import org.testng.annotations.BeforeTest;
 import org.testng.annotations.DataProvider;
 import org.testng.annotations.Test;
-
 import javax.net.ssl.SSLContext;
 import javax.net.ssl.SSLSession;
 import java.io.IOException;
 import java.io.InputStream;
 import java.io.OutputStream;
-import java.net.Socket;
 import java.net.URI;
 import java.net.http.HttpClient;
+import java.net.http.HttpHeaders;
 import java.net.http.HttpRequest;
 import java.net.http.HttpRequest.BodyPublishers;
+import java.net.http.HttpResponse;
 import java.net.http.HttpResponse.BodyHandlers;
 import java.nio.ByteBuffer;
 import java.util.ArrayList;
-import java.util.LinkedHashMap;
 import java.util.List;
-import java.util.Locale;
-import java.util.Map;
-import java.util.concurrent.CompletionException;
-import java.util.concurrent.atomic.AtomicInteger;
+import java.util.Map.Entry;
+import java.util.concurrent.ExecutionException;
 import java.util.function.BiFunction;
-
 import static java.util.List.of;
-import static jdk.internal.net.http.common.Pair.pair;
-import static org.testng.Assert.assertThrows;
+import static java.util.Map.entry;
+import static org.testng.Assert.assertTrue;
+import static org.testng.Assert.fail;
 
 // Code copied from ContinuationFrameTest
 public class BadHeadersTest {
 
-    private static final List<List<Pair<String, String>>> BAD_HEADERS = of(
-            of(pair(":status", "200"),  pair(":hello", "GET")),                      // Unknown pseudo-header
-            of(pair(":status", "200"),  pair("hell o", "value")),                    // Space in the name
-            of(pair(":status", "200"),  pair("hello", "line1\r\n  line2\r\n")),      // Multiline value
-            of(pair(":status", "200"),  pair("hello", "DE" + ((char) 0x7F) + "L")),  // Bad byte in value
-            of(pair("hello", "world!"), pair(":status", "200"))                      // Pseudo header is not the first one
+    private static final List<List<Entry<String, String>>> BAD_HEADERS = of(
+        of(entry(":status", "200"),  entry(":hello", "GET")),                      // Unknown pseudo-header
+        of(entry(":status", "200"),  entry("hell o", "value")),                    // Space in the name
+        of(entry(":status", "200"),  entry("hello", "line1\r\n  line2\r\n")),      // Multiline value
+        of(entry(":status", "200"),  entry("hello", "DE" + ((char) 0x7F) + "L")),  // Bad byte in value
+        of(entry("hello", "world!"), entry(":status", "200"))                      // Pseudo header is not the first one
     );
 
     SSLContext sslContext;
@@ -143,7 +139,35 @@
     void test(String uri,
               boolean sameClient,
               BiFunction<Integer,List<ByteBuffer>,List<Http2Frame>> headerFramesSupplier)
-            throws Exception
+        throws Exception
+    {
+        CFTHttp2TestExchange.setHeaderFrameSupplier(headerFramesSupplier);
+
+        HttpClient client = null;
+        for (int i=0; i< BAD_HEADERS.size(); i++) {
+            if (!sameClient || client == null)
+                client = HttpClient.newBuilder().sslContext(sslContext).build();
+
+            URI uriWithQuery = URI.create(uri +  "?BAD_HEADERS=" + i);
+            HttpRequest request = HttpRequest.newBuilder(uriWithQuery)
+                    .POST(BodyPublishers.ofString("Hello there!"))
+                    .build();
+            System.out.println("\nSending request:" + uriWithQuery);
+            final HttpClient cc = client;
+            try {
+                HttpResponse<String> response = cc.send(request, BodyHandlers.ofString());
+                fail("Expected exception, got :" + response + ", " + response.body());
+            } catch (IOException ioe) {
+                System.out.println("Got EXPECTED: " + ioe);
+                assertDetailMessage(ioe, i);
+            }
+        }
+    }
+
+    @Test(dataProvider = "variants")
+    void testAsync(String uri,
+                   boolean sameClient,
+                   BiFunction<Integer,List<ByteBuffer>,List<Http2Frame>> headerFramesSupplier)
     {
         CFTHttp2TestExchange.setHeaderFrameSupplier(headerFramesSupplier);
 
@@ -152,31 +176,45 @@
             if (!sameClient || client == null)
                 client = HttpClient.newBuilder().sslContext(sslContext).build();
 
-            HttpRequest request = HttpRequest.newBuilder(URI.create(uri))
+            URI uriWithQuery = URI.create(uri +  "?BAD_HEADERS=" + i);
+            HttpRequest request = HttpRequest.newBuilder(uriWithQuery)
                     .POST(BodyPublishers.ofString("Hello there!"))
                     .build();
+            System.out.println("\nSending request:" + uriWithQuery);
             final HttpClient cc = client;
-            if (i % 2 == 0) {
-                assertThrows(IOException.class, () -> cc.send(request, BodyHandlers.ofString()));
-            } else {
-                Throwable t = null;
-                try {
-                    cc.sendAsync(request, BodyHandlers.ofString()).join();
-                } catch (Throwable t0) {
-                    t = t0;
+
+            Throwable t = null;
+            try {
+                HttpResponse<String> response = cc.sendAsync(request, BodyHandlers.ofString()).get();
+                fail("Expected exception, got :" + response + ", " + response.body());
+            } catch (Throwable t0) {
+                System.out.println("Got EXPECTED: " + t0);
+                if (t0 instanceof ExecutionException) {
+                    t0 = t0.getCause();
                 }
-                if (t == null) {
-                    throw new AssertionError("An exception was expected");
-                }
-                if (t instanceof CompletionException) {
-                    Throwable c = t.getCause();
-                    if (!(c instanceof IOException)) {
-                        throw new AssertionError("Unexpected exception", c);
-                    }
-                } else if (!(t instanceof IOException)) {
-                    throw new AssertionError("Unexpected exception", t);
-                }
+                t = t0;
             }
+            assertDetailMessage(t, i);
+        }
+    }
+
+    // Assertions based on implementation specific detail messages. Keep in
+    // sync with implementation.
+    static void assertDetailMessage(Throwable throwable, int iterationIndex) {
+        assertTrue(throwable instanceof IOException,
+                   "Expected IOException, got, " + throwable);
+        assertTrue(throwable.getMessage().contains("protocol error"),
+                "Expected \"protocol error\" in: " + throwable.getMessage());
+
+        if (iterationIndex == 0) { // unknown
+            assertTrue(throwable.getMessage().contains("Unknown pseudo-header"),
+                    "Expected \"Unknown pseudo-header\" in: " + throwable.getMessage());
+        } else if (iterationIndex == 4) { // unexpected
+            assertTrue(throwable.getMessage().contains(" Unexpected pseudo-header"),
+                    "Expected \" Unexpected pseudo-header\" in: " + throwable.getMessage());
+        } else {
+            assertTrue(throwable.getMessage().contains("Bad header"),
+                    "Expected \"Bad header\" in: " + throwable.getMessage());
         }
     }
 
@@ -186,38 +224,12 @@
         if (sslContext == null)
             throw new AssertionError("Unexpected null sslContext");
 
-        http2TestServer = new Http2TestServer("localhost", false, 0) {
-            @Override
-            protected Http2TestServerConnection createConnection(Http2TestServer http2TestServer,
-                                                                 Socket socket,
-                                                                 Http2TestExchangeSupplier exchangeSupplier)
-                    throws IOException {
-                return new Http2TestServerConnection(http2TestServer, socket, exchangeSupplier, null) {
-                    @Override
-                    protected HttpHeadersImpl createNewResponseHeaders() {
-                        return new OrderedHttpHeaders();
-                    }
-                };
-            }
-        };
+        http2TestServer = new Http2TestServer("localhost", false, 0);
         http2TestServer.addHandler(new Http2EchoHandler(), "/http2/echo");
         int port = http2TestServer.getAddress().getPort();
         http2URI = "http://localhost:" + port + "/http2/echo";
 
-        https2TestServer = new Http2TestServer("localhost", true, 0){
-            @Override
-            protected Http2TestServerConnection createConnection(Http2TestServer http2TestServer,
-                                                                 Socket socket,
-                                                                 Http2TestExchangeSupplier exchangeSupplier)
-                    throws IOException {
-                return new Http2TestServerConnection(http2TestServer, socket, exchangeSupplier, null) {
-                    @Override
-                    protected HttpHeadersImpl createNewResponseHeaders() {
-                        return new OrderedHttpHeaders();
-                    }
-                };
-            }
-        };
+        https2TestServer = new Http2TestServer("localhost", true, 0);
         https2TestServer.addHandler(new Http2EchoHandler(), "/https2/echo");
         port = https2TestServer.getAddress().getPort();
         https2URI = "https://localhost:" + port + "/https2/echo";
@@ -239,16 +251,14 @@
 
     static class Http2EchoHandler implements Http2Handler {
 
-        private final AtomicInteger requestNo = new AtomicInteger();
-
         @Override
         public void handle(Http2TestExchange t) throws IOException {
             try (InputStream is = t.getRequestBody();
                  OutputStream os = t.getResponseBody()) {
                 byte[] bytes = is.readAllBytes();
-                int i = requestNo.incrementAndGet();
-                List<Pair<String, String>> p = BAD_HEADERS.get(i % BAD_HEADERS.size());
-                p.forEach(h -> t.getResponseHeaders().addHeader(h.first, h.second));
+                // Note: strictly ordered response headers will be added within
+                // the custom sendResponseHeaders implementation, based upon the
+                // query parameter
                 t.sendResponseHeaders(200, bytes.length);
                 os.write(bytes);
             }
@@ -259,23 +269,34 @@
     // allow headers to be sent with a number of CONTINUATION frames.
     static class CFTHttp2TestExchange extends Http2TestExchangeImpl {
         static volatile BiFunction<Integer,List<ByteBuffer>,List<Http2Frame>> headerFrameSupplier;
+        volatile int badHeadersIndex = -1;
 
         static void setHeaderFrameSupplier(BiFunction<Integer,List<ByteBuffer>,List<Http2Frame>> hfs) {
             headerFrameSupplier = hfs;
         }
 
-        CFTHttp2TestExchange(int streamid, String method, HttpHeadersImpl reqheaders,
-                             HttpHeadersImpl rspheaders, URI uri, InputStream is,
+        CFTHttp2TestExchange(int streamid, String method, HttpHeaders reqheaders,
+                             HttpHeadersBuilder rspheadersBuilder, URI uri, InputStream is,
                              SSLSession sslSession, BodyOutputStream os,
                              Http2TestServerConnection conn, boolean pushAllowed) {
-            super(streamid, method, reqheaders, rspheaders, uri, is, sslSession,
+            super(streamid, method, reqheaders, rspheadersBuilder, uri, is, sslSession,
                   os, conn, pushAllowed);
-
+            String query = uri.getQuery();
+            badHeadersIndex = Integer.parseInt(query.substring(query.indexOf("=") + 1));
+            assert badHeadersIndex >= 0 && badHeadersIndex < BAD_HEADERS.size() :
+                    "Unexpected badHeadersIndex value: " + badHeadersIndex;
         }
 
         @Override
         public void sendResponseHeaders(int rCode, long responseLength) throws IOException {
-            List<ByteBuffer> encodeHeaders = conn.encodeHeaders(rspheaders);
+            assert rspheadersBuilder.build().map().size() == 0;
+            assert badHeadersIndex >= 0 && badHeadersIndex < BAD_HEADERS.size() :
+                    "Unexpected badHeadersIndex value: " + badHeadersIndex;
+
+            List<Entry<String,String>> headers = BAD_HEADERS.get(badHeadersIndex);
+            System.out.println("Server replying with bad headers: " + headers);
+            List<ByteBuffer> encodeHeaders = conn.encodeHeadersOrdered(headers);
+
             List<Http2Frame> headerFrames = headerFrameSupplier.apply(streamid, encodeHeaders);
             assert headerFrames.size() > 0;  // there must always be at least 1
 
@@ -291,36 +312,4 @@
             System.err.println("Sent response headers " + rCode);
         }
     }
-
-    /*
-     * Use carefully. This class might not be suitable outside this test's
-     * context. Pay attention working with multi Map view returned from map().
-     * The reason is that header names must be lower-cased prior to any
-     * operation that depends on whether or not the map contains a specific
-     * element.
-     */
-    private static class OrderedHttpHeaders extends HttpHeadersImpl {
-
-        private final Map<String, List<String>> map = new LinkedHashMap<>();
-
-        @Override
-        public void addHeader(String name, String value) {
-            super.addHeader(name.toLowerCase(Locale.ROOT), value);
-        }
-
-        @Override
-        public void setHeader(String name, String value) {
-            super.setHeader(name.toLowerCase(Locale.ROOT), value);
-        }
-
-        @Override
-        protected Map<String, List<String>> headersMap() {
-            return map;
-        }
-
-        @Override
-        protected HttpHeadersImpl newDeepCopy() {
-            return new OrderedHttpHeaders();
-        }
-    }
 }