# HG changeset patch # User dfuchs # Date 1520594677 0 # Node ID 481d8c9acc7f87711c00254835622cb280f85315 # Parent fe6f17faa23afb4e8d9960b7d02f1e5e13979d08 http-client-branch: Add a test that throws unexpected exception in PushPromiseHandler diff -r fe6f17faa23a -r 481d8c9acc7f src/java.net.http/share/classes/jdk/internal/net/http/PushGroup.java --- a/src/java.net.http/share/classes/jdk/internal/net/http/PushGroup.java Thu Mar 08 21:24:47 2018 +0000 +++ b/src/java.net.http/share/classes/jdk/internal/net/http/PushGroup.java Fri Mar 09 11:24:37 2018 +0000 @@ -32,6 +32,8 @@ import java.net.http.HttpResponse; import java.net.http.HttpResponse.BodyHandler; import java.net.http.HttpResponse.PushPromiseHandler; +import java.util.concurrent.Executor; + import jdk.internal.net.http.common.MinimalFuture; import jdk.internal.net.http.common.Log; @@ -98,10 +100,17 @@ @Override public boolean accepted() { return cf != null; } } - Acceptor acceptPushRequest(HttpRequest pushRequest) { + Acceptor acceptPushRequest(HttpRequest pushRequest, Executor e) { AcceptorImpl acceptor = new AcceptorImpl<>(); - - pushPromiseHandler.applyPushPromise(initiatingRequest, pushRequest, acceptor::accept); + try { + pushPromiseHandler.applyPushPromise(initiatingRequest, pushRequest, acceptor::accept); + } catch (Throwable t) { + if (acceptor.accepted()) { + CompletableFuture cf = acceptor.cf(); + e.execute(() -> cf.completeExceptionally(t)); + } + throw t; + } synchronized (this) { if (acceptor.accepted()) { diff -r fe6f17faa23a -r 481d8c9acc7f src/java.net.http/share/classes/jdk/internal/net/http/Stream.java --- a/src/java.net.http/share/classes/jdk/internal/net/http/Stream.java Thu Mar 08 21:24:47 2018 +0000 +++ b/src/java.net.http/share/classes/jdk/internal/net/http/Stream.java Fri Mar 09 11:24:37 2018 +0000 @@ -473,9 +473,18 @@ return; } - PushGroup.Acceptor acceptor = pushGroup.acceptPushRequest(pushRequest); - - if (!acceptor.accepted()) { + PushGroup.Acceptor acceptor = null; + boolean accepted = false; + try { + acceptor = pushGroup.acceptPushRequest(pushRequest, + connection.client().theExecutor()); + accepted = acceptor.accepted(); + } catch (Throwable t) { + debug.log(Level.DEBUG, + "PushPromiseHandler::applyPushPromise threw exception %s", + (Object)t); + } + if (!accepted) { // cancel / reject IOException ex = new IOException("Stream " + streamid + " cancelled by users handler"); if (Log.trace()) { @@ -486,6 +495,7 @@ return; } + assert accepted && acceptor != null; CompletableFuture> pushResponseCF = acceptor.cf(); HttpResponse.BodyHandler pushHandler = acceptor.bodyHandler(); assert pushHandler != null; @@ -495,7 +505,7 @@ // setup housekeeping for when the push is received // TODO: deal with ignoring of CF anti-pattern CompletableFuture> cf = pushStream.responseCF(); - cf.whenComplete((HttpResponse resp, Throwable t) -> { + cf.whenCompleteAsync((HttpResponse resp, Throwable t) -> { t = Utils.getCompletionCause(t); if (Log.trace()) { Log.logTrace("Push completed on stream {0} for {1}{2}", @@ -509,7 +519,7 @@ pushResponseCF.complete(resp); } pushGroup.pushCompleted(); - }); + }, connection.client().theExecutor()); } diff -r fe6f17faa23a -r 481d8c9acc7f test/jdk/java/net/httpclient/HttpServerAdapters.java --- a/test/jdk/java/net/httpclient/HttpServerAdapters.java Thu Mar 08 21:24:47 2018 +0000 +++ b/test/jdk/java/net/httpclient/HttpServerAdapters.java Fri Mar 09 11:24:37 2018 +0000 @@ -29,6 +29,7 @@ import com.sun.net.httpserver.HttpServer; import java.net.InetAddress; +import java.io.ByteArrayInputStream; import java.net.http.HttpClient.Version; import jdk.internal.net.http.common.HttpHeadersImpl; import java.io.ByteArrayOutputStream; @@ -182,7 +183,16 @@ public abstract URI getRequestURI(); public abstract String getRequestMethod(); public abstract void close(); - + public void serverPush(URI uri, HttpTestHeaders headers, byte[] body) { + ByteArrayInputStream bais = new ByteArrayInputStream(body); + serverPush(uri, headers, bais); + } + public void serverPush(URI uri, HttpTestHeaders headers, InputStream body) { + throw new UnsupportedOperationException("serverPush with " + getExchangeVersion()); + } + public boolean serverPushAllowed() { + return false; + } public static HttpTestExchange of(HttpExchange exchange) { return new Http1TestExchange(exchange); } @@ -271,6 +281,26 @@ else if (contentLength < 0) contentLength = 0; exchange.sendResponseHeaders(code, contentLength); } + @Override + public boolean serverPushAllowed() { + return exchange.serverPushAllowed(); + } + @Override + public void serverPush(URI uri, HttpTestHeaders headers, InputStream body) { + HttpHeadersImpl headersImpl; + if (headers instanceof HttpTestHeaders.Http2TestHeaders) { + headersImpl = ((HttpTestHeaders.Http2TestHeaders)headers).headers.deepCopy(); + } else { + headersImpl = new HttpHeadersImpl(); + for (Map.Entry> e : headers.entrySet()) { + String name = e.getKey(); + for (String v : e.getValue()) { + headersImpl.addHeader(name, v); + } + } + } + exchange.serverPush(uri, headersImpl, body); + } void doFilter(Filter.Chain filter) throws IOException { throw new IOException("cannot use HTTP/1.1 filter with HTTP/2 server"); } @@ -367,7 +397,7 @@ } catch (Throwable t) { System.out.println("WARNING: exception caught in Http1Chain::doFilter " + t); System.err.println("WARNING: exception caught in Http1Chain::doFilter " + t); - if (PRINTSTACK && !expectException(exchange)) t.printStackTrace(); + if (PRINTSTACK && !expectException(exchange)) t.printStackTrace(System.out); throw t; } } @@ -391,7 +421,7 @@ } catch (Throwable t) { System.out.println("WARNING: exception caught in Http2Chain::doFilter " + t); System.err.println("WARNING: exception caught in Http2Chain::doFilter " + t); - if (PRINTSTACK && !expectException(exchange)) t.printStackTrace(); + if (PRINTSTACK && !expectException(exchange)) t.printStackTrace(System.out); throw t; } } @@ -472,9 +502,15 @@ this.impl = server; } @Override - public void start() { impl.start(); } + public void start() { + System.out.println("Http1TestServer: start"); + impl.start(); + } @Override - public void stop() { impl.stop(0); } + public void stop() { + System.out.println("Http1TestServer: stop"); + impl.stop(0); + } @Override public HttpTestContext addHandler(HttpTestHandler handler, String path) { System.out.println("Http1TestServer[" + getAddress() @@ -499,6 +535,7 @@ } @Override public void addFilter(HttpTestFilter filter) { + System.out.println("Http1TestContext::addFilter " + filter.description()); context.getFilters().add(filter.toFilter()); } @Override diff -r fe6f17faa23a -r 481d8c9acc7f test/jdk/java/net/httpclient/ThrowingPushPromises.java --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/test/jdk/java/net/httpclient/ThrowingPushPromises.java Fri Mar 09 11:24:37 2018 +0000 @@ -0,0 +1,739 @@ +/* + * Copyright (c) 2018, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ + +/* + * @test + * @summary Tests what happens when push promise handlers and their + * response body handlers and subscribers throw unexpected exceptions. + * @library /lib/testlibrary http2/server + * @build jdk.testlibrary.SimpleSSLContext HttpServerAdapters ThrowingPushPromises + * @modules java.base/sun.net.www.http + * java.net.http/jdk.internal.net.http.common + * java.net.http/jdk.internal.net.http.frame + * java.net.http/jdk.internal.net.http.hpack + * @run testng/othervm -Djdk.internal.httpclient.debug=true ThrowingPushPromises + */ + +import jdk.internal.net.http.common.HttpHeadersImpl; +import jdk.testlibrary.SimpleSSLContext; +import org.testng.annotations.AfterTest; +import org.testng.annotations.AfterClass; +import org.testng.annotations.BeforeTest; +import org.testng.annotations.DataProvider; +import org.testng.annotations.Test; + +import javax.net.ssl.SSLContext; +import java.io.BufferedReader; +import java.io.IOException; +import java.io.InputStream; +import java.io.InputStreamReader; +import java.io.OutputStream; +import java.io.UncheckedIOException; +import java.net.URI; +import java.net.URISyntaxException; +import java.net.http.HttpClient; +import java.net.http.HttpHeaders; +import java.net.http.HttpRequest; +import java.net.http.HttpResponse; +import java.net.http.HttpResponse.BodyHandler; +import java.net.http.HttpResponse.BodyHandlers; +import java.net.http.HttpResponse.BodySubscriber; +import java.net.http.HttpResponse.PushPromiseHandler; +import java.nio.ByteBuffer; +import java.nio.charset.StandardCharsets; +import java.util.List; +import java.util.Map; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.CompletionException; +import java.util.concurrent.CompletionStage; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ConcurrentMap; +import java.util.concurrent.Executor; +import java.util.concurrent.Executors; +import java.util.concurrent.Flow; +import java.util.concurrent.atomic.AtomicLong; +import java.util.function.Consumer; +import java.util.function.Function; +import java.util.function.Predicate; +import java.util.function.Supplier; +import java.util.stream.Collectors; +import java.util.stream.Stream; + +import static java.lang.System.out; +import static java.lang.System.err; +import static java.lang.String.format; +import static java.nio.charset.StandardCharsets.UTF_8; +import static org.testng.Assert.assertEquals; +import static org.testng.Assert.assertTrue; + +public class ThrowingPushPromises implements HttpServerAdapters { + + SSLContext sslContext; + HttpTestServer http2TestServer; // HTTP/2 ( h2c ) + HttpTestServer https2TestServer; // HTTP/2 ( h2 ) + String http2URI_fixed; + String http2URI_chunk; + String https2URI_fixed; + String https2URI_chunk; + + static final int ITERATION_COUNT = 1; + // a shared executor helps reduce the amount of threads created by the test + static final Executor executor = new TestExecutor(Executors.newCachedThreadPool()); + static final ConcurrentMap FAILURES = new ConcurrentHashMap<>(); + static volatile boolean tasksFailed; + static final AtomicLong serverCount = new AtomicLong(); + static final AtomicLong clientCount = new AtomicLong(); + static final long start = System.nanoTime(); + public static String now() { + long now = System.nanoTime() - start; + long secs = now / 1000_000_000; + long mill = (now % 1000_000_000) / 1000_000; + long nan = now % 1000_000; + return String.format("[%d s, %d ms, %d ns] ", secs, mill, nan); + } + + private volatile HttpClient sharedClient; + + static class TestExecutor implements Executor { + final AtomicLong tasks = new AtomicLong(); + Executor executor; + TestExecutor(Executor executor) { + this.executor = executor; + } + + @Override + public void execute(Runnable command) { + long id = tasks.incrementAndGet(); + executor.execute(() -> { + try { + command.run(); + } catch (Throwable t) { + tasksFailed = true; + out.printf(now() + "Task %s failed: %s%n", id, t); + err.printf(now() + "Task %s failed: %s%n", id, t); + FAILURES.putIfAbsent("Task " + id, t); + throw t; + } + }); + } + } + + @AfterClass + static final void printFailedTests() { + out.println("\n========================="); + try { + out.printf("%n%sCreated %d servers and %d clients%n", + now(), serverCount.get(), clientCount.get()); + if (FAILURES.isEmpty()) return; + out.println("Failed tests: "); + FAILURES.entrySet().forEach((e) -> { + out.printf("\t%s: %s%n", e.getKey(), e.getValue()); + e.getValue().printStackTrace(out); + e.getValue().printStackTrace(); + }); + if (tasksFailed) { + out.println("WARNING: Some tasks failed"); + } + } finally { + out.println("\n=========================\n"); + } + } + + private String[] uris() { + return new String[] { + http2URI_fixed, + http2URI_chunk, + https2URI_fixed, + https2URI_chunk, + }; + } + + @DataProvider(name = "noThrows") + public Object[][] noThrows() { + String[] uris = uris(); + Object[][] result = new Object[uris.length * 2][]; + + int i = 0; + for (boolean sameClient : List.of(false, true)) { + for (String uri: uris()) { + result[i++] = new Object[] {uri, sameClient}; + } + } + assert i == uris.length * 2; + return result; + } + + @DataProvider(name = "variants") + public Object[][] variants() { + String[] uris = uris(); + Object[][] result = new Object[uris.length * 2 * 2][]; + int i = 0; + for (Thrower thrower : List.of( + new UncheckedIOExceptionThrower(), + new UncheckedCustomExceptionThrower())) { + for (boolean sameClient : List.of(false, true)) { + for (String uri : uris()) { + result[i++] = new Object[]{uri, sameClient, thrower}; + } + } + } + assert i == uris.length * 2 * 2; + return result; + } + + private HttpClient makeNewClient() { + clientCount.incrementAndGet(); + return HttpClient.newBuilder() + .executor(executor) + .sslContext(sslContext) + .build(); + } + + HttpClient newHttpClient(boolean share) { + if (!share) return makeNewClient(); + HttpClient shared = sharedClient; + if (shared != null) return shared; + synchronized (this) { + shared = sharedClient; + if (shared == null) { + shared = sharedClient = makeNewClient(); + } + return shared; + } + } + + @Test(dataProvider = "noThrows") + public void testNoThrows(String uri, boolean sameClient) + throws Exception { + HttpClient client = null; + out.printf("%ntestNoThrows(%s, %b)%n", uri, sameClient); + for (int i=0; i< ITERATION_COUNT; i++) { + if (!sameClient || client == null) + client = newHttpClient(sameClient); + + HttpRequest req = HttpRequest.newBuilder(URI.create(uri)) + .build(); + BodyHandler> handler = + new ThrowingBodyHandler((w) -> {}, + BodyHandlers.ofLines()); + Map>>> pushPromises = + new ConcurrentHashMap<>(); + PushPromiseHandler> pushHandler = new PushPromiseHandler<>() { + @Override + public void applyPushPromise(HttpRequest initiatingRequest, + HttpRequest pushPromiseRequest, + Function>, + CompletableFuture>>> + acceptor) { + pushPromises.putIfAbsent(pushPromiseRequest, acceptor.apply(handler)); + } + }; + HttpResponse> response = + client.sendAsync(req, BodyHandlers.ofLines(), pushHandler).get(); + String body = response.body().collect(Collectors.joining("|")); + assertEquals(URI.create(body).getPath(), URI.create(uri).getPath()); + for (HttpRequest promised : pushPromises.keySet()) { + out.printf("%s Received promise: %s%n\tresponse: %s%n", + now(), promised, pushPromises.get(promised).get()); + String promisedBody = pushPromises.get(promised).get().body() + .collect(Collectors.joining("|")); + assertEquals(promisedBody, promised.uri().toASCIIString()); + } + assertEquals(3, pushPromises.size()); + } + } + + @Test(dataProvider = "variants") + public void testThrowingAsString(String uri, + boolean sameClient, + Thrower thrower) + throws Exception + { + String test = format("testThrowingAsString(%s, %b, %s)", + uri, sameClient, thrower); + testThrowing(test, uri, sameClient, BodyHandlers::ofString, + this::checkAsString, thrower); + } + + @Test(dataProvider = "variants") + public void testThrowingAsLines(String uri, + boolean sameClient, + Thrower thrower) + throws Exception + { + String test = format("testThrowingAsLines(%s, %b, %s)", + uri, sameClient, thrower); + testThrowing(test, uri, sameClient, BodyHandlers::ofLines, + this::checkAsLines, thrower); + } + + @Test(dataProvider = "variants") + public void testThrowingAsInputStream(String uri, + boolean sameClient, + Thrower thrower) + throws Exception + { + String test = format("testThrowingAsInputStream(%s, %b, %s)", + uri, sameClient, thrower); + testThrowing(test, uri, sameClient, BodyHandlers::ofInputStream, + this::checkAsInputStream, thrower); + } + + private void testThrowing(String name, String uri, boolean sameClient, + Supplier> handlers, + Finisher finisher, Thrower thrower) + throws Exception + { + out.printf("%n%s%s%n", now(), name); + try { + testThrowing(uri, sameClient, handlers, finisher, thrower); + } catch (Error | Exception x) { + FAILURES.putIfAbsent(name, x); + throw x; + } + } + + private void testThrowing(String uri, boolean sameClient, + Supplier> handlers, + Finisher finisher, Thrower thrower) + throws Exception + { + HttpClient client = null; + for (Where where : Where.values()) { + if (where == Where.ON_ERROR) continue; + if (!sameClient || client == null) + client = newHttpClient(sameClient); + + HttpRequest req = HttpRequest. + newBuilder(URI.create(uri)) + .build(); + ConcurrentMap>> promiseMap = + new ConcurrentHashMap<>(); + Supplier> throwing = () -> + new ThrowingBodyHandler(where.select(thrower), handlers.get()); + PushPromiseHandler pushHandler = new ThrowingPromiseHandler<>( + where.select(thrower), + PushPromiseHandler.of((r) -> throwing.get(), promiseMap)); + out.println("try throwing in " + where); + HttpResponse response = null; + try { + response = client.sendAsync(req, handlers.get(), pushHandler).join(); + } catch (Error | Exception x) { + throw x; + } + if (response != null) { + finisher.finish(where, req.uri(), response, thrower, promiseMap); + } + } + } + + enum Where { + BODY_HANDLER, ON_SUBSCRIBE, ON_NEXT, ON_COMPLETE, ON_ERROR, GET_BODY, BODY_CF, + BEFORE_ACCEPTING, AFTER_ACCEPTING; + public Consumer select(Consumer consumer) { + return new Consumer() { + @Override + public void accept(Where where) { + if (Where.this == where) { + consumer.accept(where); + } + } + }; + } + } + + interface Thrower extends Consumer, Predicate { + + } + + interface Finisher { + U finish(Where w, URI requestURI, HttpResponse resp, Thrower thrower, + Map>> promises); + } + + final U shouldHaveThrown(Where w, HttpResponse resp, Thrower thrower) { + throw new RuntimeException("Expected exception not thrown in " + w); + } + + final List checkAsString(Where w, URI reqURI, + HttpResponse resp, + Thrower thrower, + Map>> promises) { + Function, List> extractor = + (r) -> List.of(r.body()); + return check(w, reqURI, resp, thrower, promises, extractor); + } + + final List checkAsLines(Where w, URI reqURI, + HttpResponse> resp, + Thrower thrower, + Map>>> promises) { + Function>, List> extractor = + (r) -> r.body().collect(Collectors.toList()); + return check(w, reqURI, resp, thrower, promises, extractor); + } + + final List checkAsInputStream(Where w, URI reqURI, + HttpResponse resp, + Thrower thrower, + Map>> promises) + { + Function, List> extractor = (r) -> { + List result; + try (InputStream is = r.body()) { + result = new BufferedReader(new InputStreamReader(is)) + .lines().collect(Collectors.toList()); + } catch (Throwable t) { + throw new CompletionException(t); + } + return result; + }; + return check(w, reqURI, resp, thrower, promises, extractor); + } + + private final List check(Where w, URI reqURI, + HttpResponse resp, + Thrower thrower, + Map>> promises, + Function, List> extractor) + { + List result = extractor.apply(resp); + for (HttpRequest req : promises.keySet()) { + switch (w) { + case BEFORE_ACCEPTING: + throw new RuntimeException("No push promise should have been received" + + " for " + reqURI + " in " + w + ": got " + promises.keySet()); + default: + break; + } + HttpResponse presp; + try { + presp = promises.get(req).join(); + } catch (Error | Exception x) { + Throwable cause = findCause(x, thrower); + if (cause != null) { + out.println(now() + "Got expected exception in " + + w + ": " + cause); + continue; + } + throw x; + } + switch (w) { + case BEFORE_ACCEPTING: + case AFTER_ACCEPTING: + case BODY_HANDLER: + case ON_SUBSCRIBE: + case GET_BODY: + case BODY_CF: + return shouldHaveThrown(w, presp, thrower); + default: + break; + } + List presult = null; + try { + presult = extractor.apply(presp); + } catch (Error | Exception x) { + Throwable cause = findCause(x, thrower); + if (cause != null) { + out.println(now() + "Got expected exception for " + + req + " in " + w + ": " + cause); + continue; + } + throw x; + } + throw new RuntimeException("Expected exception not thrown for " + + req + " in " + w); + } + final int expectedCount; + switch (w) { + case BEFORE_ACCEPTING: + expectedCount = 0; + break; + default: + expectedCount = 3; + } + assertEquals(promises.size(), expectedCount, + "bad promise count for " + reqURI + " with " + w); + assertEquals(result, List.of(reqURI.toASCIIString())); + return result; + } + + private static Throwable findCause(Throwable x, + Predicate filter) { + while (x != null && !filter.test(x)) x = x.getCause(); + return x; + } + + static final class UncheckedCustomExceptionThrower implements Thrower { + @Override + public void accept(Where where) { + out.println(now() + "Throwing in " + where); + throw new UncheckedCustomException(where.name()); + } + + @Override + public boolean test(Throwable throwable) { + return UncheckedCustomException.class.isInstance(throwable); + } + + @Override + public String toString() { + return "UncheckedCustomExceptionThrower"; + } + } + + static final class UncheckedIOExceptionThrower implements Thrower { + @Override + public void accept(Where where) { + out.println(now() + "Throwing in " + where); + throw new UncheckedIOException(new CustomIOException(where.name())); + } + + @Override + public boolean test(Throwable throwable) { + return UncheckedIOException.class.isInstance(throwable) + && CustomIOException.class.isInstance(throwable.getCause()); + } + + @Override + public String toString() { + return "UncheckedIOExceptionThrower"; + } + } + + static final class UncheckedCustomException extends RuntimeException { + UncheckedCustomException(String message) { + super(message); + } + UncheckedCustomException(String message, Throwable cause) { + super(message, cause); + } + } + + static final class CustomIOException extends IOException { + CustomIOException(String message) { + super(message); + } + CustomIOException(String message, Throwable cause) { + super(message, cause); + } + } + + static final class ThrowingPromiseHandler implements PushPromiseHandler { + final Consumer throwing; + final PushPromiseHandler pushHandler; + ThrowingPromiseHandler(Consumer throwing, PushPromiseHandler pushHandler) { + this.throwing = throwing; + this.pushHandler = pushHandler; + } + + @Override + public void applyPushPromise(HttpRequest initiatingRequest, + HttpRequest pushPromiseRequest, + Function, + CompletableFuture>> acceptor) { + throwing.accept(Where.BEFORE_ACCEPTING); + pushHandler.applyPushPromise(initiatingRequest, pushPromiseRequest, acceptor); + throwing.accept(Where.AFTER_ACCEPTING); + } + } + + static final class ThrowingBodyHandler implements BodyHandler { + final Consumer throwing; + final BodyHandler bodyHandler; + ThrowingBodyHandler(Consumer throwing, BodyHandler bodyHandler) { + this.throwing = throwing; + this.bodyHandler = bodyHandler; + } + @Override + public BodySubscriber apply(int statusCode, HttpHeaders responseHeaders) { + throwing.accept(Where.BODY_HANDLER); + BodySubscriber subscriber = bodyHandler.apply(statusCode, responseHeaders); + return new ThrowingBodySubscriber(throwing, subscriber); + } + } + + static final class ThrowingBodySubscriber implements BodySubscriber { + private final BodySubscriber subscriber; + volatile boolean onSubscribeCalled; + final Consumer throwing; + ThrowingBodySubscriber(Consumer throwing, BodySubscriber subscriber) { + this.throwing = throwing; + this.subscriber = subscriber; + } + + @Override + public void onSubscribe(Flow.Subscription subscription) { + //out.println("onSubscribe "); + onSubscribeCalled = true; + throwing.accept(Where.ON_SUBSCRIBE); + subscriber.onSubscribe(subscription); + } + + @Override + public void onNext(List item) { + // out.println("onNext " + item); + assertTrue(onSubscribeCalled); + throwing.accept(Where.ON_NEXT); + subscriber.onNext(item); + } + + @Override + public void onError(Throwable throwable) { + //out.println("onError"); + assertTrue(onSubscribeCalled); + throwing.accept(Where.ON_ERROR); + subscriber.onError(throwable); + } + + @Override + public void onComplete() { + //out.println("onComplete"); + assertTrue(onSubscribeCalled, "onComplete called before onSubscribe"); + throwing.accept(Where.ON_COMPLETE); + subscriber.onComplete(); + } + + @Override + public CompletionStage getBody() { + throwing.accept(Where.GET_BODY); + try { + throwing.accept(Where.BODY_CF); + } catch (Throwable t) { + return CompletableFuture.failedFuture(t); + } + return subscriber.getBody(); + } + } + + + @BeforeTest + public void setup() throws Exception { + sslContext = new SimpleSSLContext().get(); + if (sslContext == null) + throw new AssertionError("Unexpected null sslContext"); + + // HTTP/2 + HttpTestHandler h2_fixedLengthHandler = new HTTP_FixedLengthHandler(); + HttpTestHandler h2_chunkedHandler = new HTTP_ChunkedHandler(); + + http2TestServer = HttpTestServer.of(new Http2TestServer("localhost", false, 0)); + http2TestServer.addHandler(h2_fixedLengthHandler, "/http2/fixed"); + http2TestServer.addHandler(h2_chunkedHandler, "/http2/chunk"); + http2URI_fixed = "http://" + http2TestServer.serverAuthority() + "/http2/fixed/x"; + http2URI_chunk = "http://" + http2TestServer.serverAuthority() + "/http2/chunk/x"; + + https2TestServer = HttpTestServer.of(new Http2TestServer("localhost", true, 0)); + https2TestServer.addHandler(h2_fixedLengthHandler, "/https2/fixed"); + https2TestServer.addHandler(h2_chunkedHandler, "/https2/chunk"); + https2URI_fixed = "https://" + https2TestServer.serverAuthority() + "/https2/fixed/x"; + https2URI_chunk = "https://" + https2TestServer.serverAuthority() + "/https2/chunk/x"; + + serverCount.addAndGet(2); + http2TestServer.start(); + https2TestServer.start(); + } + + @AfterTest + public void teardown() throws Exception { + sharedClient = null; + http2TestServer.stop(); + https2TestServer.stop(); + } + + private static void pushPromiseFor(HttpTestExchange t, URI requestURI, String pushPath, boolean fixed) + throws IOException + { + try { + URI promise = new URI(requestURI.getScheme(), + requestURI.getAuthority(), + pushPath, null, null); + byte[] promiseBytes = promise.toASCIIString().getBytes(UTF_8); + out.printf("TestServer: %s Pushing promise: %s%n", now(), promise); + err.printf("TestServer: %s Pushing promise: %s%n", now(), promise); + HttpTestHeaders headers = HttpTestHeaders.of(new HttpHeadersImpl()); + if (fixed) { + headers.addHeader("Content-length", String.valueOf(promiseBytes.length)); + } + t.serverPush(promise, headers, promiseBytes); + } catch (URISyntaxException x) { + throw new IOException(x.getMessage(), x); + } + } + + static class HTTP_FixedLengthHandler implements HttpTestHandler { + @Override + public void handle(HttpTestExchange t) throws IOException { + out.println("HTTP_FixedLengthHandler received request to " + t.getRequestURI()); + try (InputStream is = t.getRequestBody()) { + is.readAllBytes(); + } + URI requestURI = t.getRequestURI(); + for (int i = 1; i<2; i++) { + String path = requestURI.getPath() + "/before/promise-" + i; + pushPromiseFor(t, requestURI, path, true); + } + byte[] resp = t.getRequestURI().toString().getBytes(StandardCharsets.UTF_8); + t.sendResponseHeaders(200, resp.length); //fixed content length + try (OutputStream os = t.getResponseBody()) { + int bytes = resp.length/3; + for (int i = 0; i<2; i++) { + String path = requestURI.getPath() + "/after/promise-" + (i + 2); + os.write(resp, i * bytes, bytes); + os.flush(); + pushPromiseFor(t, requestURI, path, true); + } + os.write(resp, 2*bytes, resp.length - 2*bytes); + } + } + + } + + static class HTTP_ChunkedHandler implements HttpTestHandler { + @Override + public void handle(HttpTestExchange t) throws IOException { + out.println("HTTP_ChunkedHandler received request to " + t.getRequestURI()); + byte[] resp = t.getRequestURI().toString().getBytes(StandardCharsets.UTF_8); + try (InputStream is = t.getRequestBody()) { + is.readAllBytes(); + } + URI requestURI = t.getRequestURI(); + for (int i = 1; i<2; i++) { + String path = requestURI.getPath() + "/before/promise-" + i; + pushPromiseFor(t, requestURI, path, false); + } + t.sendResponseHeaders(200, -1); // chunked/variable + try (OutputStream os = t.getResponseBody()) { + int bytes = resp.length/3; + for (int i = 0; i<2; i++) { + String path = requestURI.getPath() + "/after/promise-" + (i + 2); + os.write(resp, i * bytes, bytes); + os.flush(); + pushPromiseFor(t, requestURI, path, false); + } + os.write(resp, 2*bytes, resp.length - 2*bytes); + } + } + } + +}