# HG changeset patch # User prappo # Date 1512038126 -10800 # Node ID 77feac3903d9ebcadcf8c6cd255034edb49c828d # Parent dfa9489d1cb1cc89804d847213af162137d3eb5a http-client-branch: (WebSocket) bug fix & test diff -r dfa9489d1cb1 -r 77feac3903d9 src/jdk.incubator.httpclient/share/classes/jdk/incubator/http/internal/websocket/Receiver.java --- a/src/jdk.incubator.httpclient/share/classes/jdk/incubator/http/internal/websocket/Receiver.java Wed Nov 29 16:59:38 2017 +0000 +++ b/src/jdk.incubator.httpclient/share/classes/jdk/incubator/http/internal/websocket/Receiver.java Thu Nov 30 13:35:26 2017 +0300 @@ -25,13 +25,13 @@ package jdk.incubator.http.internal.websocket; +import jdk.incubator.http.internal.common.Demand; +import jdk.incubator.http.internal.common.SequentialScheduler; + import java.io.IOException; import java.nio.ByteBuffer; import java.nio.channels.SelectionKey; -import jdk.incubator.http.internal.common.Demand; -import jdk.incubator.http.internal.common.SequentialScheduler; - /* * Receives incoming data from the channel on demand and converts it into a * stream of WebSocket messages which are then delivered to the supplied message @@ -101,11 +101,9 @@ } public void request(long n) { - if (n <= 0L) { - throw new IllegalArgumentException("Non-positive request: " + n); + if (demand.increase(n)) { + pushScheduler.runOrSchedule(); } - demand.increase(n); - pushScheduler.runOrSchedule(); } /* diff -r dfa9489d1cb1 -r 77feac3903d9 src/jdk.incubator.httpclient/share/classes/jdk/incubator/http/internal/websocket/WebSocketImpl.java --- a/src/jdk.incubator.httpclient/share/classes/jdk/incubator/http/internal/websocket/WebSocketImpl.java Wed Nov 29 16:59:38 2017 +0000 +++ b/src/jdk.incubator.httpclient/share/classes/jdk/incubator/http/internal/websocket/WebSocketImpl.java Thu Nov 30 13:35:26 2017 +0300 @@ -26,6 +26,7 @@ package jdk.incubator.http.internal.websocket; import jdk.incubator.http.WebSocket; +import jdk.incubator.http.internal.common.Demand; import jdk.incubator.http.internal.common.Log; import jdk.incubator.http.internal.common.MinimalFuture; import jdk.incubator.http.internal.common.Pair; @@ -52,7 +53,6 @@ import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeoutException; import java.util.concurrent.atomic.AtomicBoolean; -import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicReference; import java.util.function.Consumer; import java.util.function.Function; @@ -63,26 +63,37 @@ import static jdk.incubator.http.internal.websocket.StatusCodes.CLOSED_ABNORMALLY; import static jdk.incubator.http.internal.websocket.StatusCodes.NO_STATUS_CODE; import static jdk.incubator.http.internal.websocket.StatusCodes.isLegalToSendFromClient; +import static jdk.incubator.http.internal.websocket.WebSocketImpl.State.BINARY; +import static jdk.incubator.http.internal.websocket.WebSocketImpl.State.CLOSE; +import static jdk.incubator.http.internal.websocket.WebSocketImpl.State.ERROR; +import static jdk.incubator.http.internal.websocket.WebSocketImpl.State.IDLE; +import static jdk.incubator.http.internal.websocket.WebSocketImpl.State.OPEN; +import static jdk.incubator.http.internal.websocket.WebSocketImpl.State.PING; +import static jdk.incubator.http.internal.websocket.WebSocketImpl.State.PONG; +import static jdk.incubator.http.internal.websocket.WebSocketImpl.State.TEXT; +import static jdk.incubator.http.internal.websocket.WebSocketImpl.State.WAITING; /* * A WebSocket client. */ public final class WebSocketImpl implements WebSocket { - private static final int IDLE = 0; - private static final int OPEN = 1; - private static final int TEXT = 2; - private static final int BINARY = 4; - private static final int PING = 8; - private static final int PONG = 16; - private static final int CLOSE = 32; - private static final int ERROR = 64; + enum State { + OPEN, + IDLE, + WAITING, + TEXT, + BINARY, + PING, + PONG, + CLOSE, + ERROR; + } private volatile boolean inputClosed; private volatile boolean outputClosed; - /* Which of the listener's methods to call next? */ - private final AtomicInteger state = new AtomicInteger(OPEN); + private final AtomicReference state = new AtomicReference<>(OPEN); /* Components of calls to Listener's methods */ private MessagePart part; @@ -104,6 +115,7 @@ private final Transmitter transmitter; private final Receiver receiver; private final SequentialScheduler receiveScheduler = new SequentialScheduler(new ReceiveTask()); + private final Demand demand = new Demand(); public static CompletableFuture newInstanceAsync(BuilderImpl b) { Function newWebSocket = r -> { @@ -142,8 +154,7 @@ private WebSocketImpl(URI uri, String subprotocol, Listener listener, - TransportSupplier transport) - { + TransportSupplier transport) { this.uri = requireNonNull(uri); this.subprotocol = requireNonNull(subprotocol); this.listener = requireNonNull(listener); @@ -219,8 +230,7 @@ * completes. This method is used to enforce "one outstanding send * operation" policy. */ - private CompletableFuture enqueueExclusively(OutgoingMessage m) - { + private CompletableFuture enqueueExclusively(OutgoingMessage m) { if (!outstandingSend.compareAndSet(false, true)) { return failedFuture(new IllegalStateException("Send pending")); } @@ -286,9 +296,9 @@ @Override public void request(long n) { - // TODO: delay until state becomes ACTIVE, otherwise messages might be - // requested and consecutively become pending before onOpen is signalled - receiver.request(n); + if (demand.increase(n)) { + receiveScheduler.runOrSchedule(); + } } @Override @@ -338,41 +348,58 @@ */ private class ReceiveTask extends SequentialScheduler.CompleteRestartableTask { + // Receiver only asked here and nowhere else because we must make sure + // onOpen is invoked first and no messages become pending before onOpen + // finishes + @Override public void run() { - final int s = state.getAndSet(IDLE); - try { - switch (s) { - case OPEN: - processOpen(); - break; - case TEXT: - processText(); - break; - case BINARY: - processBinary(); - break; - case PING: - processPing(); - break; - case PONG: - processPong(); - break; - case CLOSE: - processClose(); - break; - case ERROR: - processError(); - break; - case IDLE: - // For debugging spurious signalling: when there was a - // signal, but apparently nothing has changed - break; - default: - throw new InternalError(String.valueOf(s)); + while (true) { + State s = state.get(); + try { + switch (s) { + case OPEN: + processOpen(); + tryChangeState(OPEN, IDLE); + break; + case TEXT: + processText(); + tryChangeState(TEXT, IDLE); + break; + case BINARY: + processBinary(); + tryChangeState(BINARY, IDLE); + break; + case PING: + processPing(); + tryChangeState(PING, IDLE); + break; + case PONG: + processPong(); + tryChangeState(PONG, IDLE); + break; + case CLOSE: + processClose(); + return; + case ERROR: + processError(); + return; + case IDLE: + if (demand.tryDecrement() + && tryChangeState(IDLE, WAITING)) { + receiver.request(1); + } + return; + case WAITING: + // For debugging spurious signalling: when there was a + // signal, but apparently nothing has changed + return; + default: + throw new InternalError(String.valueOf(s)); + } + } catch (Throwable t) { + signalError(t); } - } catch (Throwable t) { - signalError(t); } } @@ -462,7 +489,7 @@ private void signalError(Throwable error) { inputClosed = true; outputClosed = true; - if (!this.error.compareAndSet(null, error) || !tryChangeState(ERROR)) { + if (!this.error.compareAndSet(null, error) || !trySetState(ERROR)) { Log.logError(error); } else { close(); @@ -489,7 +516,7 @@ inputClosed = true; this.statusCode = statusCode; this.reason = reason; - if (!tryChangeState(CLOSE)) { + if (!trySetState(CLOSE)) { Log.logTrace("Close: {0}, ''{1}''", statusCode, reason); } else { try { @@ -507,7 +534,7 @@ receiver.acknowledge(); text = data; WebSocketImpl.this.part = part; - tryChangeState(TEXT); + tryChangeState(WAITING, TEXT); } @Override @@ -515,21 +542,21 @@ receiver.acknowledge(); binaryData = data; WebSocketImpl.this.part = part; - tryChangeState(BINARY); + tryChangeState(WAITING, BINARY); } @Override public void onPing(ByteBuffer data) { receiver.acknowledge(); binaryData = data; - tryChangeState(PING); + tryChangeState(WAITING, PING); } @Override public void onPong(ByteBuffer data) { receiver.acknowledge(); binaryData = data; - tryChangeState(PONG); + tryChangeState(WAITING, PONG); } @Override @@ -540,6 +567,7 @@ @Override public void onComplete() { + receiver.acknowledge(); signalClose(CLOSED_ABNORMALLY, ""); } @@ -549,9 +577,9 @@ } } - private boolean tryChangeState(int newState) { + private boolean trySetState(State newState) { while (true) { - int currentState = state.get(); + State currentState = state.get(); if (currentState == ERROR || currentState == CLOSE) { return false; } else if (state.compareAndSet(currentState, newState)) { @@ -560,4 +588,18 @@ } } } + + private boolean tryChangeState(State expectedState, State newState) { + State witness = state.compareAndExchange(expectedState, newState); + if (witness == expectedState) { + receiveScheduler.runOrSchedule(); + return true; + } + // This should be the only reason for inability to change the state from + // IDLE to WAITING: the state has changed to terminal + if (witness != ERROR && witness != CLOSE) { + throw new InternalError(); + } + return false; + } } diff -r dfa9489d1cb1 -r 77feac3903d9 test/jdk/java/net/httpclient/websocket/jdk.incubator.httpclient/jdk/incubator/http/internal/websocket/MockListener.java --- a/test/jdk/java/net/httpclient/websocket/jdk.incubator.httpclient/jdk/incubator/http/internal/websocket/MockListener.java Wed Nov 29 16:59:38 2017 +0000 +++ b/test/jdk/java/net/httpclient/websocket/jdk.incubator.httpclient/jdk/incubator/http/internal/websocket/MockListener.java Thu Nov 30 13:35:26 2017 +0300 @@ -1,15 +1,46 @@ +/* + * Copyright (c) 2017, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ + package jdk.incubator.http.internal.websocket; import jdk.incubator.http.WebSocket; import jdk.incubator.http.WebSocket.MessagePart; import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.List; +import java.util.Objects; +import java.util.concurrent.CompletableFuture; import java.util.concurrent.CompletionStage; +import static jdk.incubator.http.internal.websocket.TestSupport.fullCopy; + public class MockListener implements WebSocket.Listener { private final long bufferSize; private long count; + private final List invocations = new ArrayList<>(); + private final CompletableFuture lastCall = new CompletableFuture<>(); /* * Typical buffer sizes: 1, n, Long.MAX_VALUE @@ -24,7 +55,12 @@ @Override public void onOpen(WebSocket webSocket) { System.out.printf("onOpen(%s)%n", webSocket); - replenishDemandIfNeeded(webSocket); + invocations.add(new OnOpen(webSocket)); + onOpen0(webSocket); + } + + protected void onOpen0(WebSocket webSocket) { + replenish(webSocket); } @Override @@ -32,7 +68,14 @@ CharSequence message, MessagePart part) { System.out.printf("onText(%s, %s, %s)%n", webSocket, message, part); - replenishDemandIfNeeded(webSocket); + invocations.add(new OnText(webSocket, message.toString(), part)); + return onText0(webSocket, message, part); + } + + protected CompletionStage onText0(WebSocket webSocket, + CharSequence message, + MessagePart part) { + replenish(webSocket); return null; } @@ -41,21 +84,38 @@ ByteBuffer message, MessagePart part) { System.out.printf("onBinary(%s, %s, %s)%n", webSocket, message, part); - replenishDemandIfNeeded(webSocket); + invocations.add(new OnBinary(webSocket, fullCopy(message), part)); + return onBinary0(webSocket, message, part); + } + + protected CompletionStage onBinary0(WebSocket webSocket, + ByteBuffer message, + MessagePart part) { + replenish(webSocket); return null; } @Override public CompletionStage onPing(WebSocket webSocket, ByteBuffer message) { System.out.printf("onPing(%s, %s)%n", webSocket, message); - replenishDemandIfNeeded(webSocket); + invocations.add(new OnPing(webSocket, fullCopy(message))); + return onPing0(webSocket, message); + } + + protected CompletionStage onPing0(WebSocket webSocket, ByteBuffer message) { + replenish(webSocket); return null; } @Override public CompletionStage onPong(WebSocket webSocket, ByteBuffer message) { System.out.printf("onPong(%s, %s)%n", webSocket, message); - replenishDemandIfNeeded(webSocket); + invocations.add(new OnPong(webSocket, fullCopy(message))); + return onPong0(webSocket, message); + } + + protected CompletionStage onPong0(WebSocket webSocket, ByteBuffer message) { + replenish(webSocket); return null; } @@ -64,19 +124,249 @@ int statusCode, String reason) { System.out.printf("onClose(%s, %s, %s)%n", webSocket, statusCode, reason); + invocations.add(new OnClose(webSocket, statusCode, reason)); + lastCall.complete(null); return null; } @Override public void onError(WebSocket webSocket, Throwable error) { System.out.printf("onError(%s, %s)%n", webSocket, error); + invocations.add(new OnError(webSocket, error == null ? null : error.getClass())); + lastCall.complete(null); + } + + public CompletableFuture onCloseOrOnErrorCalled() { + return lastCall.copy(); + } + + protected void replenish(WebSocket webSocket) { + if (--count <= 0) { + count = bufferSize - bufferSize / 2; + } + webSocket.request(count); + } + + public List invocations() { + return new ArrayList<>(invocations); + } + + public abstract static class ListenerInvocation { + + public static OnOpen onOpen(WebSocket webSocket) { + return new OnOpen(webSocket); + } + + public static OnText onText(WebSocket webSocket, + String text, + MessagePart part) { + return new OnText(webSocket, text, part); + } + + public static OnBinary onBinary(WebSocket webSocket, + ByteBuffer data, + MessagePart part) { + return new OnBinary(webSocket, data, part); + } + + public static OnPing onPing(WebSocket webSocket, + ByteBuffer data) { + return new OnPing(webSocket, data); + } + + public static OnPong onPong(WebSocket webSocket, + ByteBuffer data) { + return new OnPong(webSocket, data); + } + + public static OnClose onClose(WebSocket webSocket, + int statusCode, + String reason) { + return new OnClose(webSocket, statusCode, reason); + } + + public static OnError onError(WebSocket webSocket, + Class clazz) { + return new OnError(webSocket, clazz); + } + + final WebSocket webSocket; + + private ListenerInvocation(WebSocket webSocket) { + this.webSocket = webSocket; + } + } + + public static final class OnOpen extends ListenerInvocation { + + public OnOpen(WebSocket webSocket) { + super(webSocket); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + ListenerInvocation that = (ListenerInvocation) o; + return Objects.equals(webSocket, that.webSocket); + } + + @Override + public int hashCode() { + return Objects.hashCode(webSocket); + } + } + + public static final class OnText extends ListenerInvocation { + + final String text; + final MessagePart part; + + public OnText(WebSocket webSocket, String text, MessagePart part) { + super(webSocket); + this.text = text; + this.part = part; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + OnText onText = (OnText) o; + return Objects.equals(text, onText.text) && + part == onText.part && + Objects.equals(webSocket, onText.webSocket); + } + + @Override + public int hashCode() { + return Objects.hash(text, part, webSocket); + } } - private void replenishDemandIfNeeded(WebSocket webSocket) { - if (--count <= 0) { - count = bufferSize - bufferSize / 2; - System.out.printf("request(%s)%n", count); - webSocket.request(count); + public static final class OnBinary extends ListenerInvocation { + + final ByteBuffer data; + final MessagePart part; + + public OnBinary(WebSocket webSocket, ByteBuffer data, MessagePart part) { + super(webSocket); + this.data = data; + this.part = part; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + OnBinary onBinary = (OnBinary) o; + return Objects.equals(data, onBinary.data) && + part == onBinary.part && + Objects.equals(webSocket, onBinary.webSocket); + } + + @Override + public int hashCode() { + return Objects.hash(data, part, webSocket); + } + } + + public static final class OnPing extends ListenerInvocation { + + final ByteBuffer data; + + public OnPing(WebSocket webSocket, ByteBuffer data) { + super(webSocket); + this.data = data; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + OnPing onPing = (OnPing) o; + return Objects.equals(data, onPing.data) && + Objects.equals(webSocket, onPing.webSocket); + } + + @Override + public int hashCode() { + return Objects.hash(data, webSocket); + } + } + + public static final class OnPong extends ListenerInvocation { + + final ByteBuffer data; + + public OnPong(WebSocket webSocket, ByteBuffer data) { + super(webSocket); + this.data = data; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + OnPong onPong = (OnPong) o; + return Objects.equals(data, onPong.data) && + Objects.equals(webSocket, onPong.webSocket); + } + + @Override + public int hashCode() { + return Objects.hash(data, webSocket); + } + } + + public static final class OnClose extends ListenerInvocation { + + final int statusCode; + final String reason; + + public OnClose(WebSocket webSocket, int statusCode, String reason) { + super(webSocket); + this.statusCode = statusCode; + this.reason = reason; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + OnClose onClose = (OnClose) o; + return statusCode == onClose.statusCode && + Objects.equals(reason, onClose.reason) && + Objects.equals(webSocket, onClose.webSocket); + } + + @Override + public int hashCode() { + return Objects.hash(statusCode, reason, webSocket); + } + } + + public static final class OnError extends ListenerInvocation { + + final Class clazz; + + public OnError(WebSocket webSocket, Class clazz) { + super(webSocket); + this.clazz = clazz; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + OnError onError = (OnError) o; + return Objects.equals(clazz, onError.clazz) && + Objects.equals(webSocket, onError.webSocket); + } + + @Override + public int hashCode() { + return Objects.hash(clazz, webSocket); } } } diff -r dfa9489d1cb1 -r 77feac3903d9 test/jdk/java/net/httpclient/websocket/jdk.incubator.httpclient/jdk/incubator/http/internal/websocket/MockReceiver.java --- a/test/jdk/java/net/httpclient/websocket/jdk.incubator.httpclient/jdk/incubator/http/internal/websocket/MockReceiver.java Wed Nov 29 16:59:38 2017 +0000 +++ b/test/jdk/java/net/httpclient/websocket/jdk.incubator.httpclient/jdk/incubator/http/internal/websocket/MockReceiver.java Thu Nov 30 13:35:26 2017 +0300 @@ -70,6 +70,8 @@ repeat(taskCompleter); }); } + } else { + taskCompleter.complete(); } } diff -r dfa9489d1cb1 -r 77feac3903d9 test/jdk/java/net/httpclient/websocket/jdk.incubator.httpclient/jdk/incubator/http/internal/websocket/ReceivingTest.java --- a/test/jdk/java/net/httpclient/websocket/jdk.incubator.httpclient/jdk/incubator/http/internal/websocket/ReceivingTest.java Wed Nov 29 16:59:38 2017 +0000 +++ b/test/jdk/java/net/httpclient/websocket/jdk.incubator.httpclient/jdk/incubator/http/internal/websocket/ReceivingTest.java Thu Nov 30 13:35:26 2017 +0300 @@ -24,15 +24,13 @@ package jdk.incubator.http.internal.websocket; import jdk.incubator.http.WebSocket; -import jdk.incubator.http.WebSocket.MessagePart; import org.testng.annotations.Test; import java.net.URI; -import java.nio.ByteBuffer; +import java.util.List; import java.util.concurrent.CompletableFuture; import java.util.concurrent.CompletionStage; import java.util.concurrent.TimeUnit; -import java.util.concurrent.atomic.AtomicInteger; import static java.util.concurrent.CompletableFuture.completedStage; import static jdk.incubator.http.WebSocket.MessagePart.FIRST; @@ -41,115 +39,75 @@ import static jdk.incubator.http.WebSocket.MessagePart.WHOLE; import static jdk.incubator.http.WebSocket.NORMAL_CLOSURE; import static jdk.incubator.http.internal.common.Pair.pair; -import static jdk.incubator.http.internal.websocket.WebSocketImpl.newInstance; +import static jdk.incubator.http.internal.websocket.MockListener.ListenerInvocation.onClose; +import static jdk.incubator.http.internal.websocket.MockListener.ListenerInvocation.onError; +import static jdk.incubator.http.internal.websocket.MockListener.ListenerInvocation.onOpen; +import static jdk.incubator.http.internal.websocket.MockListener.ListenerInvocation.onText; +import static org.testng.Assert.assertEquals; public class ReceivingTest { // TODO: request in onClose/onError // TODO: throw exception in onClose/onError + // TODO: exception is thrown from request() @Test - public void testNonPositiveRequest() { - URI uri = URI.create("ws://localhost"); - String subprotocol = ""; - CompletableFuture result = new CompletableFuture<>(); - newInstance(uri, subprotocol, new MockListener(Long.MAX_VALUE) { - - final AtomicInteger onOpenCount = new AtomicInteger(); - volatile WebSocket webSocket; - + public void testNonPositiveRequest() throws Exception { + MockListener listener = new MockListener(Long.MAX_VALUE) { @Override - public void onOpen(WebSocket webSocket) { - int i = onOpenCount.incrementAndGet(); - if (i > 1) { - result.completeExceptionally(new IllegalStateException()); - } else { - this.webSocket = webSocket; - webSocket.request(0); - } - } - - @Override - public CompletionStage onBinary(WebSocket webSocket, - ByteBuffer message, - MessagePart part) { - result.completeExceptionally(new IllegalStateException()); - return null; - } - - @Override - public CompletionStage onText(WebSocket webSocket, - CharSequence message, - MessagePart part) { - result.completeExceptionally(new IllegalStateException()); - return null; + protected void onOpen0(WebSocket webSocket) { + webSocket.request(0); } - - @Override - public CompletionStage onPing(WebSocket webSocket, - ByteBuffer message) { - result.completeExceptionally(new IllegalStateException()); - return null; - } - - @Override - public CompletionStage onPong(WebSocket webSocket, - ByteBuffer message) { - result.completeExceptionally(new IllegalStateException()); - return null; - } - - @Override - public CompletionStage onClose(WebSocket webSocket, - int statusCode, - String reason) { - result.completeExceptionally(new IllegalStateException()); - return null; - } - - @Override - public void onError(WebSocket webSocket, Throwable error) { - if (!this.webSocket.equals(webSocket)) { - result.completeExceptionally(new IllegalArgumentException()); - } else if (error == null || error.getClass() != IllegalArgumentException.class) { - result.completeExceptionally(new IllegalArgumentException()); - } else { - result.complete(null); - } - } - }, new MockTransport() { + }; + MockTransport transport = new MockTransport() { @Override protected Receiver newReceiver(MessageStreamConsumer consumer) { - return new MockReceiver(consumer, channel, pair(now(), m -> m.onText("1", WHOLE) )); + return new MockReceiver(consumer, channel, pair(now(), m -> m.onText("1", WHOLE))); } - }); - result.join(); + }; + WebSocket ws = newInstance(listener, transport); + listener.onCloseOrOnErrorCalled().get(10, TimeUnit.SECONDS); + List invocations = listener.invocations(); + assertEquals(invocations, List.of(onOpen(ws), onError(ws, IllegalArgumentException.class))); } @Test - public void testText1() throws InterruptedException { - URI uri = URI.create("ws://localhost"); - String subprotocol = ""; - newInstance(uri, subprotocol, new MockListener(Long.MAX_VALUE), - new MockTransport() { - @Override - protected Receiver newReceiver(MessageStreamConsumer consumer) { - return new MockReceiver(consumer, channel, - pair(now(), m -> m.onText("1", FIRST)), - pair(now(), m -> m.onText("2", PART)), - pair(now(), m -> m.onText("3", PART)), - pair(now(), m -> m.onText("4", LAST)), - pair(now(), m -> m.onClose(NORMAL_CLOSURE, "no reason"))); - } - }); - Thread.sleep(2000); + public void testText1() throws Exception { + MockListener listener = new MockListener(Long.MAX_VALUE); + MockTransport transport = new MockTransport() { + @Override + protected Receiver newReceiver(MessageStreamConsumer consumer) { + return new MockReceiver(consumer, channel, + pair(now(), m -> m.onText("1", FIRST)), + pair(now(), m -> m.onText("2", PART)), + pair(now(), m -> m.onText("3", PART)), + pair(now(), m -> m.onText("4", LAST)), + pair(now(), m -> m.onClose(NORMAL_CLOSURE, "no reason"))); + } + }; + WebSocket ws = newInstance(listener, transport); + listener.onCloseOrOnErrorCalled().get(10, TimeUnit.SECONDS); + List invocations = listener.invocations(); + assertEquals(invocations, List.of(onOpen(ws), + onText(ws, "1", FIRST), + onText(ws, "2", PART), + onText(ws, "3", PART), + onText(ws, "4", LAST), + onClose(ws, NORMAL_CLOSURE, "no reason"))); } - private CompletionStage inSeconds(long s) { + private static CompletionStage seconds(long s) { return new CompletableFuture<>().completeOnTimeout(null, s, TimeUnit.SECONDS); } - private CompletionStage now() { + private static CompletionStage now() { return completedStage(null); } + + private static WebSocket newInstance(WebSocket.Listener listener, + TransportSupplier transport) { + URI uri = URI.create("ws://localhost"); + String subprotocol = ""; + return WebSocketImpl.newInstance(uri, subprotocol, listener, transport); + } }