8218554: HttpServer: allow custom handlers to request that the connection be closed after the exchange.
authordfuchs
Mon, 11 Feb 2019 18:41:24 +0100
changeset 53720 3e451bff6f7f
parent 53719 3a56e823d843
child 53721 f35a8aaabcb9
8218554: HttpServer: allow custom handlers to request that the connection be closed after the exchange. Summary: custom handler code can supply `Connection: close` to response headers in order to force connection close after the exchange terminates. Reviewed-by: chegar
src/jdk.httpserver/share/classes/com/sun/net/httpserver/HttpExchange.java
src/jdk.httpserver/share/classes/sun/net/httpserver/ExchangeImpl.java
test/jdk/com/sun/net/httpserver/bugs/HandlerConnectionClose.java
--- a/src/jdk.httpserver/share/classes/com/sun/net/httpserver/HttpExchange.java	Mon Feb 11 16:49:08 2019 +0100
+++ b/src/jdk.httpserver/share/classes/com/sun/net/httpserver/HttpExchange.java	Mon Feb 11 18:41:24 2019 +0100
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2005, 2013, Oracle and/or its affiliates. All rights reserved.
+ * Copyright (c) 2005, 2019, Oracle and/or its affiliates. All rights reserved.
  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
  *
  * This code is free software; you can redistribute it and/or modify it
@@ -171,6 +171,13 @@
      * this is set to the appropriate value depending on the response length parameter.
      * <p>
      * This method must be called prior to calling {@link #getResponseBody()}.
+     *
+     * @implNote This implementation allows the caller to instruct the
+     * server to force a connection close after the exchange terminates, by
+     * supplying a {@code Connection: close} header to the {@linkplain
+     * #getResponseHeaders() response headers} before {@code sendResponseHeaders}
+     * is called.
+     *
      * @param rCode the response code to send
      * @param responseLength if {@literal > 0}, specifies a fixed response
      *        body length and that exact number of bytes must be written
--- a/src/jdk.httpserver/share/classes/sun/net/httpserver/ExchangeImpl.java	Mon Feb 11 16:49:08 2019 +0100
+++ b/src/jdk.httpserver/share/classes/sun/net/httpserver/ExchangeImpl.java	Mon Feb 11 18:41:24 2019 +0100
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2005, 2010, Oracle and/or its affiliates. All rights reserved.
+ * Copyright (c) 2005, 2019, Oracle and/or its affiliates. All rights reserved.
  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
  *
  * This code is free software; you can redistribute it and/or modify it
@@ -32,6 +32,7 @@
 import java.lang.System.Logger;
 import java.lang.System.Logger.Level;
 import java.text.*;
+import java.util.stream.Stream;
 import com.sun.net.httpserver.*;
 
 class ExchangeImpl {
@@ -261,6 +262,22 @@
                 o.setWrappedStream (new FixedLengthOutputStream (this, ros, contentLen));
             }
         }
+
+        // A custom handler can request that the connection be
+        // closed after the exchange by supplying Connection: close
+        // to the response header. Nothing to do if the exchange is
+        // already set up to be closed.
+        if (!close) {
+            Stream<String> conheader =
+                    Optional.ofNullable(rspHdrs.get("Connection"))
+                    .map(List::stream).orElse(Stream.empty());
+            if (conheader.anyMatch("close"::equalsIgnoreCase)) {
+                Logger logger = server.getLogger();
+                logger.log (Level.DEBUG, "Connection: close requested by handler");
+                close = true;
+            }
+        }
+
         write (rspHdrs, tmpout);
         this.rspContentLen = contentLen;
         tmpout.flush() ;
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/test/jdk/com/sun/net/httpserver/bugs/HandlerConnectionClose.java	Mon Feb 11 18:41:24 2019 +0100
@@ -0,0 +1,437 @@
+/*
+ * Copyright (c) 2019, Oracle and/or its affiliates. All rights reserved.
+ * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
+ *
+ * This code is free software; you can redistribute it and/or modify it
+ * under the terms of the GNU General Public License version 2 only, as
+ * published by the Free Software Foundation.
+ *
+ * This code is distributed in the hope that it will be useful, but WITHOUT
+ * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
+ * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
+ * version 2 for more details (a copy is included in the LICENSE file that
+ * accompanied this code).
+ *
+ * You should have received a copy of the GNU General Public License version
+ * 2 along with this work; if not, write to the Free Software Foundation,
+ * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
+ *
+ * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
+ * or visit www.oracle.com if you need additional information or have any
+ * questions.
+ */
+
+/**
+ * @test
+ * @bug 8218554
+ * @summary  test that the handler can request a connection close.
+ * @library /test/lib
+ * @build jdk.test.lib.net.SimpleSSLContext
+ * @run main/othervm HandlerConnectionClose
+ */
+
+import com.sun.net.httpserver.HttpExchange;
+import com.sun.net.httpserver.HttpHandler;
+import com.sun.net.httpserver.HttpServer;
+import com.sun.net.httpserver.HttpsConfigurator;
+import com.sun.net.httpserver.HttpsServer;
+
+import java.io.ByteArrayOutputStream;
+import java.io.IOException;
+import java.io.InputStream;
+import java.io.OutputStream;
+import java.net.HttpURLConnection;
+import java.net.InetAddress;
+import java.net.InetSocketAddress;
+import java.net.Socket;
+import java.net.URI;
+import java.net.URL;
+import java.nio.charset.StandardCharsets;
+import java.util.List;
+import java.util.Locale;
+import java.util.logging.Handler;
+import java.util.logging.Level;
+import java.util.logging.Logger;
+import java.util.logging.SimpleFormatter;
+import java.util.logging.StreamHandler;
+import jdk.test.lib.net.SimpleSSLContext;
+import javax.net.ssl.HttpsURLConnection;
+import javax.net.ssl.SSLContext;
+import javax.net.ssl.SSLSession;
+
+public class HandlerConnectionClose
+{
+    static final int ONEK = 1024;
+    static final long POST_SIZE = ONEK * 1L;
+    SSLContext sslContext;
+    Logger logger;
+
+    void test(String[] args) throws Exception {
+
+        HttpServer httpServer = startHttpServer("http");
+        try {
+            testHttpURLConnection(httpServer, "http","/close/legacy/http/chunked");
+            testHttpURLConnection(httpServer, "http","/close/legacy/http/fixed");
+            testPlainSocket(httpServer, "http","/close/plain/http/chunked");
+            testPlainSocket(httpServer, "http","/close/plain/http/fixed");
+        } finally {
+            httpServer.stop(0);
+        }
+        sslContext = new SimpleSSLContext().get();
+        HttpServer httpsServer = startHttpServer("https");
+        try {
+            testHttpURLConnection(httpsServer, "https","/close/legacy/https/chunked");
+            testHttpURLConnection(httpsServer, "https","/close/legacy/https/fixed");
+            testPlainSocket(httpsServer, "https","/close/plain/https/chunked");
+            testPlainSocket(httpsServer, "https","/close/plain/https/fixed");
+        } finally{
+            httpsServer.stop(0);
+        }
+    }
+
+    void testHttpURLConnection(HttpServer httpServer, String protocol, String path) throws Exception {
+        int port = httpServer.getAddress().getPort();
+        String host = httpServer.getAddress().getHostString();
+        URL url = new URI(protocol, null, host, port, path, null, null).toURL();
+        HttpURLConnection uc = (HttpURLConnection) url.openConnection();
+        if ("https".equalsIgnoreCase(protocol)) {
+            ((HttpsURLConnection)uc).setSSLSocketFactory(sslContext.getSocketFactory());
+            ((HttpsURLConnection)uc).setHostnameVerifier((String hostname, SSLSession session) -> true);
+        }
+        uc.setDoOutput(true);
+        uc.setRequestMethod("POST");
+        uc.setFixedLengthStreamingMode(POST_SIZE);
+        OutputStream os = uc.getOutputStream();
+
+        /* create a 1K byte array with data to POST */
+        byte[] ba = new byte[ONEK];
+        for (int i = 0; i < ONEK; i++)
+            ba[i] = (byte) i;
+
+        System.out.println("\n" + uc.getClass().getSimpleName() +": POST " + url + " HTTP/1.1");
+        long times = POST_SIZE / ONEK;
+        for (int i = 0; i < times; i++) {
+            os.write(ba);
+        }
+
+        os.close();
+        InputStream is = uc.getInputStream();
+        int read;
+        long count = 0;
+        while ((read = is.read(ba)) != -1) {
+            for (int i = 0; i < read; i++) {
+                byte expected = (byte) count++;
+                if (ba[i] != expected) {
+                    throw new IOException("byte mismatch at "
+                            + (count - 1) + ": expected " + expected + " got " + ba[i]);
+                }
+            }
+        }
+        if (count != POST_SIZE) {
+            throw new IOException("Unexpected length: " + count + " expected " + POST_SIZE);
+        }
+        is.close();
+
+        pass();
+    }
+
+    void testPlainSocket(HttpServer httpServer, String protocol, String path) throws Exception {
+        int port = httpServer.getAddress().getPort();
+        String host = httpServer.getAddress().getHostString();
+        URL url = new URI(protocol, null, host, port, path, null, null).toURL();
+        Socket socket;
+        if ("https".equalsIgnoreCase(protocol)) {
+            socket = sslContext.getSocketFactory().createSocket(host, port);
+        } else {
+            socket = new Socket(host, port);
+        }
+        try (Socket sock = socket) {
+            OutputStream os = socket.getOutputStream();
+
+            // send request headers
+            String request = new StringBuilder()
+                    .append("POST ").append(path).append(" HTTP/1.1").append("\r\n")
+                    .append("host: ").append(host).append(':').append(port).append("\r\n")
+                    .append("Content-Length: ").append(POST_SIZE).append("\r\n")
+                    .append("\r\n")
+                    .toString();
+            os.write(request.getBytes(StandardCharsets.US_ASCII));
+
+            /* create a 1K byte array with data to POST */
+            byte[] ba = new byte[ONEK];
+            for (int i = 0; i < ONEK; i++)
+                ba[i] = (byte) i;
+
+            // send request data
+            long times = POST_SIZE / ONEK;
+            for (int i = 0; i < times; i++) {
+                os.write(ba);
+            }
+            os.flush();
+
+            InputStream is = socket.getInputStream();
+            ByteArrayOutputStream bos = new ByteArrayOutputStream();
+
+            // read all response headers
+            int c;
+            int crlf = 0;
+            while ((c = is.read()) != -1) {
+                if (c == '\r') continue;
+                if (c == '\n') crlf++;
+                else crlf = 0;
+                bos.write(c);
+                if (crlf == 2) break;
+            }
+            String responseHeadersStr = bos.toString(StandardCharsets.US_ASCII);
+            List<String> responseHeaders = List.of(responseHeadersStr.split("\n"));
+            System.out.println("\nPOST " + url + " HTTP/1.1");
+            responseHeaders.stream().forEach(s -> System.out.println("[reply]\t" + s));
+            String statusLine = responseHeaders.get(0);
+            if (!statusLine.startsWith("HTTP/1.1 200 "))
+                throw new IOException("Unexpected status: " + statusLine);
+            String cl = responseHeaders.stream()
+                    .map(s -> s.toLowerCase(Locale.ROOT))
+                    .filter(s -> s.startsWith("content-length: "))
+                    .findFirst()
+                    .orElse(null);
+            String te = responseHeaders.stream()
+                    .map(s -> s.toLowerCase(Locale.ROOT))
+                    .filter(s -> s.startsWith("transfer-encoding: "))
+                    .findFirst()
+                    .orElse(null);
+
+            // check content-length and transfer-encoding are as expected
+            int read = 0;
+            long count = 0;
+            if (path.endsWith("/fixed")) {
+                if (!("content-length: " + POST_SIZE).equalsIgnoreCase(cl)) {
+                    throw new IOException("Unexpected Content-Length: [" + cl + "]");
+                }
+                if (te != null) {
+                    throw new IOException("Unexpected Transfer-Encoding: [" + te + "]");
+                }
+                // Got expected Content-Length: 1024 - read response data
+                while ((read = is.read()) != -1) {
+                    int expected = (int) (count & 0xFF);
+                    if ((read & 0xFF) != expected) {
+                        throw new IOException("byte mismatch at "
+                                + (count - 1) + ": expected " + expected + " got " + read);
+                    }
+                    if (++count == POST_SIZE) break;
+                }
+            } else if (cl != null) {
+                throw new IOException("Unexpected Content-Length: [" + cl + "]");
+            } else {
+                if (!("transfer-encoding: chunked").equalsIgnoreCase(te)) {
+                    throw new IOException("Unexpected Transfer-Encoding: [" + te + "]");
+                }
+                // This is a quick & dirty implementation of
+                // chunk decoding - no trailers - no extensions
+                StringBuilder chunks = new StringBuilder();
+                int cs = -1;
+                while (cs != 0) {
+                    cs = 0;
+                    chunks.setLength(0);
+
+                    // read next chunk length
+                    while ((read = is.read()) != -1) {
+                        if (read == '\r') continue;
+                        if (read == '\n') break;
+                        chunks.append((char) read);
+                    }
+                    cs = Integer.parseInt(chunks.toString().trim(), 16);
+                    System.out.println("Got chunk length: " + cs);
+
+                    // If chunk size is 0, then we have read the last chunk.
+                    if (cs == 0) break;
+
+                    // Read the chunk data
+                    while (--cs >= 0) {
+                        read = is.read();
+                        if (read == -1) break; // EOF
+                        int expected = (int) (count & 0xFF);
+                        if ((read & 0xFF) != expected) {
+                            throw new IOException("byte mismatch at "
+                                    + (count - 1) + ": expected " + expected + " got " + read);
+                        }
+                        // This is cheating: we know the size :-)
+                        if (++count == POST_SIZE) break;
+                    }
+
+                    if (read == -1) {
+                        throw new IOException("Unexpected EOF after " + count + " data bytes");
+                    }
+
+                    // read CRLF
+                    if ((read = is.read()) != '\r') {
+                        throw new IOException("Expected CR at " + count + "after chunk data - got " + read);
+                    }
+                    if ((read = is.read()) != '\n') {
+                        throw new IOException("Expected LF at " + count + "after chunk data - got " + read);
+                    }
+
+                    if (cs == 0 && count == POST_SIZE) {
+                        cs = -1;
+                    }
+
+                    if (cs != -1) {
+                        // count == POST_SIZE, but some chunk data still to be read?
+                        throw new IOException("Unexpected chunk size, "
+                                + cs + " bytes still to read after " + count +
+                                " data bytes received.");
+                    }
+                }
+                // Last CRLF?
+                for (int i = 0; i < 2; i++) {
+                    if ((read = is.read()) == -1) break;
+                }
+            }
+
+            if (count != POST_SIZE) {
+                throw new IOException("Unexpected length: " + count + " expected " + POST_SIZE);
+            }
+
+            if (!sock.isClosed()) {
+                try {
+                    // We send an end request to the server to verify that the
+                    // connection is closed. If the server has not closed the
+                    // connection, it will reply. If we receive a response,
+                    // we should fail...
+                    String endrequest = new StringBuilder()
+                            .append("GET ").append("/close/end").append(" HTTP/1.1").append("\r\n")
+                            .append("host: ").append(host).append(':').append(port).append("\r\n")
+                            .append("Content-Length: ").append(0).append("\r\n")
+                            .append("\r\n")
+                            .toString();
+                    os.write(endrequest.getBytes(StandardCharsets.US_ASCII));
+                    os.flush();
+                    StringBuilder resp = new StringBuilder();
+                    crlf = 0;
+
+                    // read all headers.
+                    // If the server closed the connection as expected
+                    // we should immediately read EOF
+                    while ((read = is.read()) != -1) {
+                        if (read == '\r') continue;
+                        if (read == '\n') crlf++;
+                        else crlf = 0;
+                        if (crlf == 2) break;
+                        resp.append((char) read);
+                    }
+
+                    List<String> lines = List.of(resp.toString().split("\n"));
+                    if (read != -1 || resp.length() != 0) {
+                        System.err.println("Connection not closed!");
+                        System.err.println("Got: ");
+                        lines.stream().forEach(s -> System.err.println("[end]\t" + s));
+                        throw new AssertionError("EOF not received after " + count + " data bytes");
+                    }
+                    if (read != -1) {
+                        throw new AssertionError("EOF was expected after " + count + " bytes, but got: " + read);
+                    } else {
+                        System.out.println("Got expected EOF (" + read + ")");
+                    }
+                } catch (IOException x) {
+                    // expected! all is well
+                    System.out.println("Socket closed as expected, got exception writing to it.");
+                }
+            } else {
+                System.out.println("Socket closed as expected");
+            }
+            pass();
+        }
+    }
+
+    /**
+     * Http Server
+     */
+    HttpServer startHttpServer(String protocol) throws IOException {
+        if (debug) {
+            logger = Logger.getLogger("com.sun.net.httpserver");
+            Handler outHandler = new StreamHandler(System.out,
+                                     new SimpleFormatter());
+            outHandler.setLevel(Level.FINEST);
+            logger.setLevel(Level.FINEST);
+            logger.addHandler(outHandler);
+        }
+
+        InetSocketAddress serverAddress = new InetSocketAddress(InetAddress.getLoopbackAddress(), 0);
+        HttpServer httpServer = null;
+        if ("http".equalsIgnoreCase(protocol)) {
+            httpServer = HttpServer.create(serverAddress, 0);
+        }
+        if ("https".equalsIgnoreCase(protocol)) {
+            HttpsServer httpsServer = HttpsServer.create(serverAddress, 0);
+            httpsServer.setHttpsConfigurator(new HttpsConfigurator(sslContext));
+            httpServer = httpsServer;
+        }
+        httpServer.createContext("/close/", new MyHandler(POST_SIZE));
+        System.out.println("Server created at: " + httpServer.getAddress());
+        httpServer.start();
+        return httpServer;
+    }
+
+    class MyHandler implements HttpHandler {
+        static final int BUFFER_SIZE = 512;
+        final long expected;
+
+        MyHandler(long expected){
+            this.expected = expected;
+        }
+
+        @Override
+        public void handle(HttpExchange t) throws IOException {
+            System.out.println("Server: serving " + t.getRequestURI());
+            boolean chunked = t.getRequestURI().getPath().endsWith("/chunked");
+            boolean fixed = t.getRequestURI().getPath().endsWith("/fixed");
+            boolean end = t.getRequestURI().getPath().endsWith("/end");
+            long responseLength = fixed ? POST_SIZE : 0;
+            responseLength = end ? -1 : responseLength;
+            responseLength = chunked ? 0 : responseLength;
+
+            if (!end) t.getResponseHeaders().add("connection", "CLose");
+            t.sendResponseHeaders(200, responseLength);
+
+            if (!end) {
+                OutputStream os = t.getResponseBody();
+                InputStream is = t.getRequestBody();
+                byte[] ba = new byte[BUFFER_SIZE];
+                int read;
+                long count = 0L;
+                while ((read = is.read(ba)) != -1) {
+                    count += read;
+                    os.write(ba, 0, read);
+                }
+                is.close();
+
+                check(count == expected, "Expected: " + expected + ", received "
+                        + count);
+                debug("Received " + count + " bytes");
+                os.close();
+            }
+
+            t.close();
+        }
+    }
+
+         //--------------------- Infrastructure ---------------------------
+    boolean debug = true;
+    volatile int passed = 0, failed = 0;
+    void pass() {passed++;}
+    void fail() {failed++; Thread.dumpStack();}
+    void fail(String msg) {System.err.println(msg); fail();}
+    void unexpected(Throwable t) {failed++; t.printStackTrace();}
+    void check(boolean cond) {if (cond) pass(); else fail();}
+    void check(boolean cond, String failMessage) {if (cond) pass(); else fail(failMessage);}
+    void debug(String message) {if(debug) { System.out.println(message); }  }
+    public static void main(String[] args) throws Throwable {
+        Class<?> k = new Object(){}.getClass().getEnclosingClass();
+        try {k.getMethod("instanceMain",String[].class)
+                .invoke( k.newInstance(), (Object) args);}
+        catch (Throwable e) {throw e.getCause();}}
+    public void instanceMain(String[] args) throws Throwable {
+        try {test(args);} catch (Throwable t) {unexpected(t);}
+        System.out.printf("%nPassed = %d, failed = %d%n%n", passed, failed);
+        if (failed > 0) throw new AssertionError("Some tests failed");}
+
+}