http-client-branch: (HttpClient) deep(er) validation of pseudo-headers http-client-branch
authorprappo
Tue, 27 Feb 2018 16:08:08 +0000
branchhttp-client-branch
changeset 56205 f4c9c5920141
parent 56204 e5d0c20217a3
child 56206 a0cf7477d139
http-client-branch: (HttpClient) deep(er) validation of pseudo-headers
src/java.net.http/share/classes/jdk/internal/net/http/Http2Connection.java
src/java.net.http/share/classes/jdk/internal/net/http/Stream.java
src/java.net.http/share/classes/jdk/internal/net/http/common/HttpHeadersImpl.java
test/jdk/java/net/httpclient/http2/BadHeadersTest.java
test/jdk/java/net/httpclient/http2/server/Http2TestServer.java
test/jdk/java/net/httpclient/http2/server/Http2TestServerConnection.java
--- a/src/java.net.http/share/classes/jdk/internal/net/http/Http2Connection.java	Tue Feb 27 15:55:24 2018 +0000
+++ b/src/java.net.http/share/classes/jdk/internal/net/http/Http2Connection.java	Tue Feb 27 16:08:08 2018 +0000
@@ -1191,11 +1191,16 @@
         private static final Set<String> PSEUDO_HEADERS =
                 Set.of(":authority", ":method", ":path", ":scheme", ":status");
 
+        /** Used to check that if there are pseudo-headers, they go first */
+        private boolean pseudoHeadersEnded;
+
         /**
          * Called when END_HEADERS was received. This consumer may be invoked
          * again after reset() is called, but for a whole new set of headers.
          */
-        void reset() { }
+        void reset() {
+            pseudoHeadersEnded = false;
+        }
 
         @Override
         public void onDecoded(CharSequence name, CharSequence value)
@@ -1203,11 +1208,16 @@
         {
             String n = name.toString();
             if (n.startsWith(":")) {
-                if (!PSEUDO_HEADERS.contains(n)) {
+                if (pseudoHeadersEnded) {
                     throw newException("Unexpected pseudo-header '%s'", n);
+                } else if (!PSEUDO_HEADERS.contains(n)) {
+                    throw newException("Unknown pseudo-header '%s'", n);
                 }
-            } else if (!Utils.isValidName(n)) {
-                throw newException("Bad header name '%s'", n);
+            } else {
+                pseudoHeadersEnded = true;
+                if (!Utils.isValidName(n)) {
+                    throw newException("Bad header name '%s'", n);
+                }
             }
             String v = value.toString();
             if (!Utils.isValidValue(v)) {
--- a/src/java.net.http/share/classes/jdk/internal/net/http/Stream.java	Tue Feb 27 15:55:24 2018 +0000
+++ b/src/java.net.http/share/classes/jdk/internal/net/http/Stream.java	Tue Feb 27 16:08:08 2018 +0000
@@ -1191,6 +1191,7 @@
     private class HeadersConsumer extends Http2Connection.ValidatingHeadersConsumer {
 
         void reset() {
+            super.reset();
             responseHeaders.clear();
         }
 
--- a/src/java.net.http/share/classes/jdk/internal/net/http/common/HttpHeadersImpl.java	Tue Feb 27 15:55:24 2018 +0000
+++ b/src/java.net.http/share/classes/jdk/internal/net/http/common/HttpHeadersImpl.java	Tue Feb 27 16:08:08 2018 +0000
@@ -41,7 +41,7 @@
  */
 public class HttpHeadersImpl extends HttpHeaders {
 
-    private final TreeMap<String,List<String>> headers;
+    private final TreeMap<String, List<String>> headers;
 
     public HttpHeadersImpl() {
         headers = new TreeMap<>(String.CASE_INSENSITIVE_ORDER);
@@ -49,33 +49,41 @@
 
     @Override
     public Map<String, List<String>> map() {
-        return Collections.unmodifiableMap(headers);
+        return Collections.unmodifiableMap(headersMap());
     }
 
     // non-HttpHeaders private mutators
 
     public HttpHeadersImpl deepCopy() {
-        HttpHeadersImpl h1 = new HttpHeadersImpl();
-        for (Map.Entry<String,List<String>> entry : headers.entrySet()) {
+        HttpHeadersImpl h1 = newDeepCopy();
+        for (Map.Entry<String, List<String>> entry : headersMap().entrySet()) {
             List<String> valuesCopy = new ArrayList<>(entry.getValue());
-            h1.headers.put(entry.getKey(), valuesCopy);
+            h1.headersMap().put(entry.getKey(), valuesCopy);
         }
         return h1;
     }
 
     public void addHeader(String name, String value) {
-        headers.computeIfAbsent(name, k -> new ArrayList<>(1))
-               .add(value);
+        headersMap().computeIfAbsent(name, k -> new ArrayList<>(1))
+                    .add(value);
     }
 
     public void setHeader(String name, String value) {
         // headers typically have one value
         List<String> values = new ArrayList<>(1);
         values.add(value);
-        headers.put(name, values);
+        headersMap().put(name, values);
     }
 
     public void clear() {
-        headers.clear();
+        headersMap().clear();
+    }
+
+    protected HttpHeadersImpl newDeepCopy() {
+        return new HttpHeadersImpl();
+    }
+
+    protected Map<String, List<String>> headersMap() {
+        return headers;
     }
 }
--- a/test/jdk/java/net/httpclient/http2/BadHeadersTest.java	Tue Feb 27 15:55:24 2018 +0000
+++ b/test/jdk/java/net/httpclient/http2/BadHeadersTest.java	Tue Feb 27 16:08:08 2018 +0000
@@ -50,6 +50,7 @@
 import java.io.IOException;
 import java.io.InputStream;
 import java.io.OutputStream;
+import java.net.Socket;
 import java.net.URI;
 import java.net.http.HttpClient;
 import java.net.http.HttpRequest;
@@ -57,22 +58,27 @@
 import java.net.http.HttpResponse.BodyHandlers;
 import java.nio.ByteBuffer;
 import java.util.ArrayList;
+import java.util.LinkedHashMap;
 import java.util.List;
+import java.util.Locale;
+import java.util.Map;
 import java.util.concurrent.CompletionException;
 import java.util.concurrent.atomic.AtomicInteger;
 import java.util.function.BiFunction;
 
+import static java.util.List.of;
 import static jdk.internal.net.http.common.Pair.pair;
 import static org.testng.Assert.assertThrows;
 
 // Code copied from ContinuationFrameTest
 public class BadHeadersTest {
 
-    private static final List<Pair<String, String>> BAD_HEADERS = List.of(
-            pair(":hello", "GET"),                    // Unknown pseudo-header
-            pair("hell o", "value"),                  // Space in the name
-            pair("hello", "line1\r\n  line2\r\n"),    // Multiline value
-            pair("hello", "DE" + ((char) 0x7F) + "L") // Bad byte in value
+    private static final List<List<Pair<String, String>>> BAD_HEADERS = of(
+            of(pair(":status", "200"),  pair(":hello", "GET")),                      // Unknown pseudo-header
+            of(pair(":status", "200"),  pair("hell o", "value")),                    // Space in the name
+            of(pair(":status", "200"),  pair("hello", "line1\r\n  line2\r\n")),      // Multiline value
+            of(pair(":status", "200"),  pair("hello", "DE" + ((char) 0x7F) + "L")),  // Bad byte in value
+            of(pair("hello", "world!"), pair(":status", "200"))                      // Pseudo header is not the first one
     );
 
     SSLContext sslContext;
@@ -87,12 +93,12 @@
      */
     static BiFunction<Integer,List<ByteBuffer>,List<Http2Frame>> oneContinuation =
             (Integer streamid, List<ByteBuffer> encodedHeaders) -> {
-                List<ByteBuffer> empty =  List.of(ByteBuffer.wrap(new byte[0]));
+                List<ByteBuffer> empty =  of(ByteBuffer.wrap(new byte[0]));
                 HeadersFrame hf = new HeadersFrame(streamid, 0, empty);
                 ContinuationFrame cf = new ContinuationFrame(streamid,
                                                              HeaderFrame.END_HEADERS,
                                                              encodedHeaders);
-                return List.of(hf, cf);
+                return of(hf, cf);
             };
 
     /**
@@ -108,7 +114,7 @@
                 frames.add(hf);
                 for (ByteBuffer bb : encodedHeaders) {
                     while (bb.hasRemaining()) {
-                        List<ByteBuffer> data = List.of(ByteBuffer.wrap(new byte[] {bb.get()}));
+                        List<ByteBuffer> data = of(ByteBuffer.wrap(new byte[] {bb.get()}));
                         ContinuationFrame cf = new ContinuationFrame(streamid, 0, data);
                         frames.add(cf);
                     }
@@ -180,12 +186,38 @@
         if (sslContext == null)
             throw new AssertionError("Unexpected null sslContext");
 
-        http2TestServer = new Http2TestServer("127.0.0.1", false, 0);
+        http2TestServer = new Http2TestServer("127.0.0.1", false, 0) {
+            @Override
+            protected Http2TestServerConnection createConnection(Http2TestServer http2TestServer,
+                                                                 Socket socket,
+                                                                 Http2TestExchangeSupplier exchangeSupplier)
+                    throws IOException {
+                return new Http2TestServerConnection(http2TestServer, socket, exchangeSupplier) {
+                    @Override
+                    protected HttpHeadersImpl createNewResponseHeaders() {
+                        return new OrderedHttpHeaders();
+                    }
+                };
+            }
+        };
         http2TestServer.addHandler(new Http2EchoHandler(), "/http2/echo");
         int port = http2TestServer.getAddress().getPort();
         http2URI = "http://127.0.0.1:" + port + "/http2/echo";
 
-        https2TestServer = new Http2TestServer("127.0.0.1", true, 0);
+        https2TestServer = new Http2TestServer("127.0.0.1", true, 0){
+            @Override
+            protected Http2TestServerConnection createConnection(Http2TestServer http2TestServer,
+                                                                 Socket socket,
+                                                                 Http2TestExchangeSupplier exchangeSupplier)
+                    throws IOException {
+                return new Http2TestServerConnection(http2TestServer, socket, exchangeSupplier) {
+                    @Override
+                    protected HttpHeadersImpl createNewResponseHeaders() {
+                        return new OrderedHttpHeaders();
+                    }
+                };
+            }
+        };
         https2TestServer.addHandler(new Http2EchoHandler(), "/https2/echo");
         port = https2TestServer.getAddress().getPort();
         https2URI = "https://127.0.0.1:" + port + "/https2/echo";
@@ -215,8 +247,8 @@
                  OutputStream os = t.getResponseBody()) {
                 byte[] bytes = is.readAllBytes();
                 int i = requestNo.incrementAndGet();
-                Pair<String, String> p = BAD_HEADERS.get(i % BAD_HEADERS.size());
-                t.getResponseHeaders().addHeader(p.first, p.second);
+                List<Pair<String, String>> p = BAD_HEADERS.get(i % BAD_HEADERS.size());
+                p.forEach(h -> t.getResponseHeaders().addHeader(h.first, h.second));
                 t.sendResponseHeaders(200, bytes.length);
                 os.write(bytes);
             }
@@ -243,13 +275,6 @@
 
         @Override
         public void sendResponseHeaders(int rCode, long responseLength) throws IOException {
-            this.responseLength = responseLength;
-            if (responseLength > 0 || responseLength < 0) {
-                long clen = responseLength > 0 ? responseLength : 0;
-                rspheaders.setHeader("Content-length", Long.toString(clen));
-            }
-            rspheaders.setHeader(":status", Integer.toString(rCode));
-
             List<ByteBuffer> encodeHeaders = conn.encodeHeaders(rspheaders);
             List<Http2Frame> headerFrames = headerFrameSupplier.apply(streamid, encodeHeaders);
             assert headerFrames.size() > 0;  // there must always be at least 1
@@ -266,4 +291,29 @@
             System.err.println("Sent response headers " + rCode);
         }
     }
+
+    private static class OrderedHttpHeaders extends HttpHeadersImpl {
+
+        private final Map<String, List<String>> map = new LinkedHashMap<>();
+
+        @Override
+        public void addHeader(String name, String value) {
+            super.addHeader(name.toLowerCase(Locale.ROOT), value);
+        }
+
+        @Override
+        public void setHeader(String name, String value) {
+            super.setHeader(name.toLowerCase(Locale.ROOT), value);
+        }
+
+        @Override
+        protected Map<String, List<String>> headersMap() {
+            return map;
+        }
+
+        @Override
+        protected HttpHeadersImpl newDeepCopy() {
+            return new OrderedHttpHeaders();
+        }
+    }
 }
--- a/test/jdk/java/net/httpclient/http2/server/Http2TestServer.java	Tue Feb 27 15:55:24 2018 +0000
+++ b/test/jdk/java/net/httpclient/http2/server/Http2TestServer.java	Tue Feb 27 16:08:08 2018 +0000
@@ -227,7 +227,7 @@
                     InetSocketAddress addr = null;
                     try {
                         addr = (InetSocketAddress) socket.getRemoteSocketAddress();
-                        c = new Http2TestServerConnection(this, socket, exchangeSupplier);
+                        c = createConnection(this, socket, exchangeSupplier);
                         putConnection(addr, c);
                         c.run();
                     } catch (Throwable e) {
@@ -254,6 +254,13 @@
         });
     }
 
+    protected Http2TestServerConnection createConnection(Http2TestServer http2TestServer,
+                                                         Socket socket,
+                                                         Http2TestExchangeSupplier exchangeSupplier)
+            throws IOException {
+        return new Http2TestServerConnection(http2TestServer, socket, exchangeSupplier);
+    }
+
     @Override
     public void close() throws Exception {
         stop();
--- a/test/jdk/java/net/httpclient/http2/server/Http2TestServerConnection.java	Tue Feb 27 15:55:24 2018 +0000
+++ b/test/jdk/java/net/httpclient/http2/server/Http2TestServerConnection.java	Tue Feb 27 16:08:08 2018 +0000
@@ -421,7 +421,7 @@
     }
 
     HttpHeadersImpl decodeHeaders(List<HeaderFrame> frames) throws IOException {
-        HttpHeadersImpl headers = new HttpHeadersImpl();
+        HttpHeadersImpl headers = createNewResponseHeaders();
 
         DecodingCallback cb = (name, value) -> {
             headers.addHeader(name.toString(), value.toString());
@@ -468,7 +468,7 @@
     // First stream (1) comes from a plaintext HTTP/1.1 request
     @SuppressWarnings({"rawtypes","unchecked"})
     void createPrimordialStream(Http1InitialRequest request) throws IOException {
-        HttpHeadersImpl headers = new HttpHeadersImpl();
+        HttpHeadersImpl headers = createNewResponseHeaders();
         String requestLine = getRequestLine(request.headers);
         String[] tokens = requestLine.split(" ");
         if (!tokens[2].equals("HTTP/1.1")) {
@@ -572,7 +572,7 @@
         String authority = headers.firstValue(":authority").orElse("");
         //System.out.println("authority = " + authority);
         System.err.printf("TestServer: %s %s\n", method, path);
-        HttpHeadersImpl rspheaders = new HttpHeadersImpl();
+        HttpHeadersImpl rspheaders = createNewResponseHeaders();
         int winsize = clientSettings.getParameter(
                 SettingsFrame.INITIAL_WINDOW_SIZE);
         //System.err.println ("Stream window size = " + winsize);
@@ -609,6 +609,10 @@
         }
     }
 
+    protected HttpHeadersImpl createNewResponseHeaders() {
+        return new HttpHeadersImpl();
+    }
+
     private SSLSession getSSLSession() {
         if (! (socket instanceof SSLSocket))
             return null;
@@ -797,7 +801,7 @@
     // returns a minimal response with status 200
     // that is the response to the push promise just sent
     private ResponseHeaders getPushResponse(int streamid) {
-        HttpHeadersImpl h = new HttpHeadersImpl();
+        HttpHeadersImpl h = createNewResponseHeaders();
         h.addHeader(":status", "200");
         ResponseHeaders oh = new ResponseHeaders(h);
         oh.streamid(streamid);