http-client-branch: (WebSocket) bug fix & test http-client-branch
authorprappo
Thu, 30 Nov 2017 13:35:26 +0300
branchhttp-client-branch
changeset 55922 77feac3903d9
parent 55912 dfa9489d1cb1
child 55923 67a9df429e0b
http-client-branch: (WebSocket) bug fix & test
src/jdk.incubator.httpclient/share/classes/jdk/incubator/http/internal/websocket/Receiver.java
src/jdk.incubator.httpclient/share/classes/jdk/incubator/http/internal/websocket/WebSocketImpl.java
test/jdk/java/net/httpclient/websocket/jdk.incubator.httpclient/jdk/incubator/http/internal/websocket/MockListener.java
test/jdk/java/net/httpclient/websocket/jdk.incubator.httpclient/jdk/incubator/http/internal/websocket/MockReceiver.java
test/jdk/java/net/httpclient/websocket/jdk.incubator.httpclient/jdk/incubator/http/internal/websocket/ReceivingTest.java
--- a/src/jdk.incubator.httpclient/share/classes/jdk/incubator/http/internal/websocket/Receiver.java	Wed Nov 29 16:59:38 2017 +0000
+++ b/src/jdk.incubator.httpclient/share/classes/jdk/incubator/http/internal/websocket/Receiver.java	Thu Nov 30 13:35:26 2017 +0300
@@ -25,13 +25,13 @@
 
 package jdk.incubator.http.internal.websocket;
 
+import jdk.incubator.http.internal.common.Demand;
+import jdk.incubator.http.internal.common.SequentialScheduler;
+
 import java.io.IOException;
 import java.nio.ByteBuffer;
 import java.nio.channels.SelectionKey;
 
-import jdk.incubator.http.internal.common.Demand;
-import jdk.incubator.http.internal.common.SequentialScheduler;
-
 /*
  * Receives incoming data from the channel on demand and converts it into a
  * stream of WebSocket messages which are then delivered to the supplied message
@@ -101,11 +101,9 @@
     }
 
     public void request(long n) {
-        if (n <= 0L) {
-            throw new IllegalArgumentException("Non-positive request: " + n);
+        if (demand.increase(n)) {
+            pushScheduler.runOrSchedule();
         }
-        demand.increase(n);
-        pushScheduler.runOrSchedule();
     }
 
     /*
--- a/src/jdk.incubator.httpclient/share/classes/jdk/incubator/http/internal/websocket/WebSocketImpl.java	Wed Nov 29 16:59:38 2017 +0000
+++ b/src/jdk.incubator.httpclient/share/classes/jdk/incubator/http/internal/websocket/WebSocketImpl.java	Thu Nov 30 13:35:26 2017 +0300
@@ -26,6 +26,7 @@
 package jdk.incubator.http.internal.websocket;
 
 import jdk.incubator.http.WebSocket;
+import jdk.incubator.http.internal.common.Demand;
 import jdk.incubator.http.internal.common.Log;
 import jdk.incubator.http.internal.common.MinimalFuture;
 import jdk.incubator.http.internal.common.Pair;
@@ -52,7 +53,6 @@
 import java.util.concurrent.TimeUnit;
 import java.util.concurrent.TimeoutException;
 import java.util.concurrent.atomic.AtomicBoolean;
-import java.util.concurrent.atomic.AtomicInteger;
 import java.util.concurrent.atomic.AtomicReference;
 import java.util.function.Consumer;
 import java.util.function.Function;
@@ -63,26 +63,37 @@
 import static jdk.incubator.http.internal.websocket.StatusCodes.CLOSED_ABNORMALLY;
 import static jdk.incubator.http.internal.websocket.StatusCodes.NO_STATUS_CODE;
 import static jdk.incubator.http.internal.websocket.StatusCodes.isLegalToSendFromClient;
+import static jdk.incubator.http.internal.websocket.WebSocketImpl.State.BINARY;
+import static jdk.incubator.http.internal.websocket.WebSocketImpl.State.CLOSE;
+import static jdk.incubator.http.internal.websocket.WebSocketImpl.State.ERROR;
+import static jdk.incubator.http.internal.websocket.WebSocketImpl.State.IDLE;
+import static jdk.incubator.http.internal.websocket.WebSocketImpl.State.OPEN;
+import static jdk.incubator.http.internal.websocket.WebSocketImpl.State.PING;
+import static jdk.incubator.http.internal.websocket.WebSocketImpl.State.PONG;
+import static jdk.incubator.http.internal.websocket.WebSocketImpl.State.TEXT;
+import static jdk.incubator.http.internal.websocket.WebSocketImpl.State.WAITING;
 
 /*
  * A WebSocket client.
  */
 public final class WebSocketImpl implements WebSocket {
 
-    private static final int IDLE   =  0;
-    private static final int OPEN   =  1;
-    private static final int TEXT   =  2;
-    private static final int BINARY =  4;
-    private static final int PING   =  8;
-    private static final int PONG   = 16;
-    private static final int CLOSE  = 32;
-    private static final int ERROR  = 64;
+    enum State {
+        OPEN,
+        IDLE,
+        WAITING,
+        TEXT,
+        BINARY,
+        PING,
+        PONG,
+        CLOSE,
+        ERROR;
+    }
 
     private volatile boolean inputClosed;
     private volatile boolean outputClosed;
 
-    /* Which of the listener's methods to call next? */
-    private final AtomicInteger state = new AtomicInteger(OPEN);
+    private final AtomicReference<State> state = new AtomicReference<>(OPEN);
 
     /* Components of calls to Listener's methods */
     private MessagePart part;
@@ -104,6 +115,7 @@
     private final Transmitter transmitter;
     private final Receiver receiver;
     private final SequentialScheduler receiveScheduler = new SequentialScheduler(new ReceiveTask());
+    private final Demand demand = new Demand();
 
     public static CompletableFuture<WebSocket> newInstanceAsync(BuilderImpl b) {
         Function<Result, WebSocket> newWebSocket = r -> {
@@ -142,8 +154,7 @@
     private WebSocketImpl(URI uri,
                           String subprotocol,
                           Listener listener,
-                          TransportSupplier transport)
-    {
+                          TransportSupplier transport) {
         this.uri = requireNonNull(uri);
         this.subprotocol = requireNonNull(subprotocol);
         this.listener = requireNonNull(listener);
@@ -219,8 +230,7 @@
      * completes. This method is used to enforce "one outstanding send
      * operation" policy.
      */
-    private CompletableFuture<WebSocket> enqueueExclusively(OutgoingMessage m)
-    {
+    private CompletableFuture<WebSocket> enqueueExclusively(OutgoingMessage m) {
         if (!outstandingSend.compareAndSet(false, true)) {
             return failedFuture(new IllegalStateException("Send pending"));
         }
@@ -286,9 +296,9 @@
 
     @Override
     public void request(long n) {
-        // TODO: delay until state becomes ACTIVE, otherwise messages might be
-        // requested and consecutively become pending before onOpen is signalled
-        receiver.request(n);
+        if (demand.increase(n)) {
+            receiveScheduler.runOrSchedule();
+        }
     }
 
     @Override
@@ -338,41 +348,58 @@
      */
     private class ReceiveTask extends SequentialScheduler.CompleteRestartableTask {
 
+        // Receiver only asked here and nowhere else because we must make sure
+        // onOpen is invoked first and no messages become pending before onOpen
+        // finishes
+
         @Override
         public void run() {
-            final int s = state.getAndSet(IDLE);
-            try {
-                switch (s) {
-                    case OPEN:
-                        processOpen();
-                        break;
-                    case TEXT:
-                        processText();
-                        break;
-                    case BINARY:
-                        processBinary();
-                        break;
-                    case PING:
-                        processPing();
-                        break;
-                    case PONG:
-                        processPong();
-                        break;
-                    case CLOSE:
-                        processClose();
-                        break;
-                    case ERROR:
-                        processError();
-                        break;
-                    case IDLE:
-                        // For debugging spurious signalling: when there was a
-                        // signal, but apparently nothing has changed
-                        break;
-                    default:
-                        throw new InternalError(String.valueOf(s));
+            while (true) {
+                State s = state.get();
+                try {
+                    switch (s) {
+                        case OPEN:
+                            processOpen();
+                            tryChangeState(OPEN, IDLE);
+                            break;
+                        case TEXT:
+                            processText();
+                            tryChangeState(TEXT, IDLE);
+                            break;
+                        case BINARY:
+                            processBinary();
+                            tryChangeState(BINARY, IDLE);
+                            break;
+                        case PING:
+                            processPing();
+                            tryChangeState(PING, IDLE);
+                            break;
+                        case PONG:
+                            processPong();
+                            tryChangeState(PONG, IDLE);
+                            break;
+                        case CLOSE:
+                            processClose();
+                            return;
+                        case ERROR:
+                            processError();
+                            return;
+                        case IDLE:
+                            if (demand.tryDecrement()
+                                    && tryChangeState(IDLE, WAITING)) {
+                                receiver.request(1);
+                            }
+                            return;
+                        case WAITING:
+                            // For debugging spurious signalling: when there was a
+                            // signal, but apparently nothing has changed
+                            return;
+                        default:
+                            throw new InternalError(String.valueOf(s));
+                    }
+                } catch (Throwable t) {
+                    signalError(t);
                 }
-            } catch (Throwable t) {
-                signalError(t);
             }
         }
 
@@ -462,7 +489,7 @@
     private void signalError(Throwable error) {
         inputClosed = true;
         outputClosed = true;
-        if (!this.error.compareAndSet(null, error) || !tryChangeState(ERROR)) {
+        if (!this.error.compareAndSet(null, error) || !trySetState(ERROR)) {
             Log.logError(error);
         } else {
             close();
@@ -489,7 +516,7 @@
         inputClosed = true;
         this.statusCode = statusCode;
         this.reason = reason;
-        if (!tryChangeState(CLOSE)) {
+        if (!trySetState(CLOSE)) {
             Log.logTrace("Close: {0}, ''{1}''", statusCode, reason);
         } else {
             try {
@@ -507,7 +534,7 @@
             receiver.acknowledge();
             text = data;
             WebSocketImpl.this.part = part;
-            tryChangeState(TEXT);
+            tryChangeState(WAITING, TEXT);
         }
 
         @Override
@@ -515,21 +542,21 @@
             receiver.acknowledge();
             binaryData = data;
             WebSocketImpl.this.part = part;
-            tryChangeState(BINARY);
+            tryChangeState(WAITING, BINARY);
         }
 
         @Override
         public void onPing(ByteBuffer data) {
             receiver.acknowledge();
             binaryData = data;
-            tryChangeState(PING);
+            tryChangeState(WAITING, PING);
         }
 
         @Override
         public void onPong(ByteBuffer data) {
             receiver.acknowledge();
             binaryData = data;
-            tryChangeState(PONG);
+            tryChangeState(WAITING, PONG);
         }
 
         @Override
@@ -540,6 +567,7 @@
 
         @Override
         public void onComplete() {
+            receiver.acknowledge();
             signalClose(CLOSED_ABNORMALLY, "");
         }
 
@@ -549,9 +577,9 @@
         }
     }
 
-    private boolean tryChangeState(int newState) {
+    private boolean trySetState(State newState) {
         while (true) {
-            int currentState = state.get();
+            State currentState = state.get();
             if (currentState == ERROR || currentState == CLOSE) {
                 return false;
             } else if (state.compareAndSet(currentState, newState)) {
@@ -560,4 +588,18 @@
             }
         }
     }
+
+    private boolean tryChangeState(State expectedState, State newState) {
+        State witness = state.compareAndExchange(expectedState, newState);
+        if (witness == expectedState) {
+            receiveScheduler.runOrSchedule();
+            return true;
+        }
+        // This should be the only reason for inability to change the state from
+        // IDLE to WAITING: the state has changed to terminal
+        if (witness != ERROR && witness != CLOSE) {
+            throw new InternalError();
+        }
+        return false;
+    }
 }
--- a/test/jdk/java/net/httpclient/websocket/jdk.incubator.httpclient/jdk/incubator/http/internal/websocket/MockListener.java	Wed Nov 29 16:59:38 2017 +0000
+++ b/test/jdk/java/net/httpclient/websocket/jdk.incubator.httpclient/jdk/incubator/http/internal/websocket/MockListener.java	Thu Nov 30 13:35:26 2017 +0300
@@ -1,15 +1,46 @@
+/*
+ * Copyright (c) 2017, Oracle and/or its affiliates. All rights reserved.
+ * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
+ *
+ * This code is free software; you can redistribute it and/or modify it
+ * under the terms of the GNU General Public License version 2 only, as
+ * published by the Free Software Foundation.
+ *
+ * This code is distributed in the hope that it will be useful, but WITHOUT
+ * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
+ * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
+ * version 2 for more details (a copy is included in the LICENSE file that
+ * accompanied this code).
+ *
+ * You should have received a copy of the GNU General Public License version
+ * 2 along with this work; if not, write to the Free Software Foundation,
+ * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
+ *
+ * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
+ * or visit www.oracle.com if you need additional information or have any
+ * questions.
+ */
+
 package jdk.incubator.http.internal.websocket;
 
 import jdk.incubator.http.WebSocket;
 import jdk.incubator.http.WebSocket.MessagePart;
 
 import java.nio.ByteBuffer;
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Objects;
+import java.util.concurrent.CompletableFuture;
 import java.util.concurrent.CompletionStage;
 
+import static jdk.incubator.http.internal.websocket.TestSupport.fullCopy;
+
 public class MockListener implements WebSocket.Listener {
 
     private final long bufferSize;
     private long count;
+    private final List<ListenerInvocation> invocations = new ArrayList<>();
+    private final CompletableFuture<?> lastCall = new CompletableFuture<>();
 
     /*
      * Typical buffer sizes: 1, n, Long.MAX_VALUE
@@ -24,7 +55,12 @@
     @Override
     public void onOpen(WebSocket webSocket) {
         System.out.printf("onOpen(%s)%n", webSocket);
-        replenishDemandIfNeeded(webSocket);
+        invocations.add(new OnOpen(webSocket));
+        onOpen0(webSocket);
+    }
+
+    protected void onOpen0(WebSocket webSocket) {
+        replenish(webSocket);
     }
 
     @Override
@@ -32,7 +68,14 @@
                                      CharSequence message,
                                      MessagePart part) {
         System.out.printf("onText(%s, %s, %s)%n", webSocket, message, part);
-        replenishDemandIfNeeded(webSocket);
+        invocations.add(new OnText(webSocket, message.toString(), part));
+        return onText0(webSocket, message, part);
+    }
+
+    protected CompletionStage<?> onText0(WebSocket webSocket,
+                                         CharSequence message,
+                                         MessagePart part) {
+        replenish(webSocket);
         return null;
     }
 
@@ -41,21 +84,38 @@
                                        ByteBuffer message,
                                        MessagePart part) {
         System.out.printf("onBinary(%s, %s, %s)%n", webSocket, message, part);
-        replenishDemandIfNeeded(webSocket);
+        invocations.add(new OnBinary(webSocket, fullCopy(message), part));
+        return onBinary0(webSocket, message, part);
+    }
+
+    protected CompletionStage<?> onBinary0(WebSocket webSocket,
+                                           ByteBuffer message,
+                                           MessagePart part) {
+        replenish(webSocket);
         return null;
     }
 
     @Override
     public CompletionStage<?> onPing(WebSocket webSocket, ByteBuffer message) {
         System.out.printf("onPing(%s, %s)%n", webSocket, message);
-        replenishDemandIfNeeded(webSocket);
+        invocations.add(new OnPing(webSocket, fullCopy(message)));
+        return onPing0(webSocket, message);
+    }
+
+    protected CompletionStage<?> onPing0(WebSocket webSocket, ByteBuffer message) {
+        replenish(webSocket);
         return null;
     }
 
     @Override
     public CompletionStage<?> onPong(WebSocket webSocket, ByteBuffer message) {
         System.out.printf("onPong(%s, %s)%n", webSocket, message);
-        replenishDemandIfNeeded(webSocket);
+        invocations.add(new OnPong(webSocket, fullCopy(message)));
+        return onPong0(webSocket, message);
+    }
+
+    protected CompletionStage<?> onPong0(WebSocket webSocket, ByteBuffer message) {
+        replenish(webSocket);
         return null;
     }
 
@@ -64,19 +124,249 @@
                                       int statusCode,
                                       String reason) {
         System.out.printf("onClose(%s, %s, %s)%n", webSocket, statusCode, reason);
+        invocations.add(new OnClose(webSocket, statusCode, reason));
+        lastCall.complete(null);
         return null;
     }
 
     @Override
     public void onError(WebSocket webSocket, Throwable error) {
         System.out.printf("onError(%s, %s)%n", webSocket, error);
+        invocations.add(new OnError(webSocket, error == null ? null : error.getClass()));
+        lastCall.complete(null);
+    }
+
+    public CompletableFuture<?> onCloseOrOnErrorCalled() {
+        return lastCall.copy();
+    }
+
+    protected void replenish(WebSocket webSocket) {
+        if (--count <= 0) {
+            count = bufferSize - bufferSize / 2;
+        }
+        webSocket.request(count);
+    }
+
+    public List<ListenerInvocation> invocations() {
+        return new ArrayList<>(invocations);
+    }
+
+    public abstract static class ListenerInvocation {
+
+        public static OnOpen onOpen(WebSocket webSocket) {
+            return new OnOpen(webSocket);
+        }
+
+        public static OnText onText(WebSocket webSocket,
+                                    String text,
+                                    MessagePart part) {
+            return new OnText(webSocket, text, part);
+        }
+
+        public static OnBinary onBinary(WebSocket webSocket,
+                                        ByteBuffer data,
+                                        MessagePart part) {
+            return new OnBinary(webSocket, data, part);
+        }
+
+        public static OnPing onPing(WebSocket webSocket,
+                                    ByteBuffer data) {
+            return new OnPing(webSocket, data);
+        }
+
+        public static OnPong onPong(WebSocket webSocket,
+                                    ByteBuffer data) {
+            return new OnPong(webSocket, data);
+        }
+
+        public static OnClose onClose(WebSocket webSocket,
+                                      int statusCode,
+                                      String reason) {
+            return new OnClose(webSocket, statusCode, reason);
+        }
+
+        public static OnError onError(WebSocket webSocket,
+                                      Class<? extends Throwable> clazz) {
+            return new OnError(webSocket, clazz);
+        }
+
+        final WebSocket webSocket;
+
+        private ListenerInvocation(WebSocket webSocket) {
+            this.webSocket = webSocket;
+        }
+    }
+
+    public static final class OnOpen extends ListenerInvocation {
+
+        public OnOpen(WebSocket webSocket) {
+            super(webSocket);
+        }
+
+        @Override
+        public boolean equals(Object o) {
+            if (this == o) return true;
+            if (o == null || getClass() != o.getClass()) return false;
+            ListenerInvocation that = (ListenerInvocation) o;
+            return Objects.equals(webSocket, that.webSocket);
+        }
+
+        @Override
+        public int hashCode() {
+            return Objects.hashCode(webSocket);
+        }
+    }
+
+    public static final class OnText extends ListenerInvocation {
+
+        final String text;
+        final MessagePart part;
+
+        public OnText(WebSocket webSocket, String text, MessagePart part) {
+            super(webSocket);
+            this.text = text;
+            this.part = part;
+        }
+
+        @Override
+        public boolean equals(Object o) {
+            if (this == o) return true;
+            if (o == null || getClass() != o.getClass()) return false;
+            OnText onText = (OnText) o;
+            return Objects.equals(text, onText.text) &&
+                    part == onText.part &&
+                    Objects.equals(webSocket, onText.webSocket);
+        }
+
+        @Override
+        public int hashCode() {
+            return Objects.hash(text, part, webSocket);
+        }
     }
 
-    private void replenishDemandIfNeeded(WebSocket webSocket) {
-        if (--count <= 0) {
-            count = bufferSize - bufferSize / 2;
-            System.out.printf("request(%s)%n", count);
-            webSocket.request(count);
+    public static final class OnBinary extends ListenerInvocation {
+
+        final ByteBuffer data;
+        final MessagePart part;
+
+        public OnBinary(WebSocket webSocket, ByteBuffer data, MessagePart part) {
+            super(webSocket);
+            this.data = data;
+            this.part = part;
+        }
+
+        @Override
+        public boolean equals(Object o) {
+            if (this == o) return true;
+            if (o == null || getClass() != o.getClass()) return false;
+            OnBinary onBinary = (OnBinary) o;
+            return Objects.equals(data, onBinary.data) &&
+                    part == onBinary.part &&
+                    Objects.equals(webSocket, onBinary.webSocket);
+        }
+
+        @Override
+        public int hashCode() {
+            return Objects.hash(data, part, webSocket);
+        }
+    }
+
+    public static final class OnPing extends ListenerInvocation {
+
+        final ByteBuffer data;
+
+        public OnPing(WebSocket webSocket, ByteBuffer data) {
+            super(webSocket);
+            this.data = data;
+        }
+
+        @Override
+        public boolean equals(Object o) {
+            if (this == o) return true;
+            if (o == null || getClass() != o.getClass()) return false;
+            OnPing onPing = (OnPing) o;
+            return Objects.equals(data, onPing.data) &&
+                    Objects.equals(webSocket, onPing.webSocket);
+        }
+
+        @Override
+        public int hashCode() {
+            return Objects.hash(data, webSocket);
+        }
+    }
+
+    public static final class OnPong extends ListenerInvocation {
+
+        final ByteBuffer data;
+
+        public OnPong(WebSocket webSocket, ByteBuffer data) {
+            super(webSocket);
+            this.data = data;
+        }
+
+        @Override
+        public boolean equals(Object o) {
+            if (this == o) return true;
+            if (o == null || getClass() != o.getClass()) return false;
+            OnPong onPong = (OnPong) o;
+            return Objects.equals(data, onPong.data) &&
+                    Objects.equals(webSocket, onPong.webSocket);
+        }
+
+        @Override
+        public int hashCode() {
+            return Objects.hash(data, webSocket);
+        }
+    }
+
+    public static final class OnClose extends ListenerInvocation {
+
+        final int statusCode;
+        final String reason;
+
+        public OnClose(WebSocket webSocket, int statusCode, String reason) {
+            super(webSocket);
+            this.statusCode = statusCode;
+            this.reason = reason;
+        }
+
+        @Override
+        public boolean equals(Object o) {
+            if (this == o) return true;
+            if (o == null || getClass() != o.getClass()) return false;
+            OnClose onClose = (OnClose) o;
+            return statusCode == onClose.statusCode &&
+                    Objects.equals(reason, onClose.reason) &&
+                    Objects.equals(webSocket, onClose.webSocket);
+        }
+
+        @Override
+        public int hashCode() {
+            return Objects.hash(statusCode, reason, webSocket);
+        }
+    }
+
+    public static final class OnError extends ListenerInvocation {
+
+        final Class<? extends Throwable> clazz;
+
+        public OnError(WebSocket webSocket, Class<? extends Throwable> clazz) {
+            super(webSocket);
+            this.clazz = clazz;
+        }
+
+        @Override
+        public boolean equals(Object o) {
+            if (this == o) return true;
+            if (o == null || getClass() != o.getClass()) return false;
+            OnError onError = (OnError) o;
+            return Objects.equals(clazz, onError.clazz) &&
+                    Objects.equals(webSocket, onError.webSocket);
+        }
+
+        @Override
+        public int hashCode() {
+            return Objects.hash(clazz, webSocket);
         }
     }
 }
--- a/test/jdk/java/net/httpclient/websocket/jdk.incubator.httpclient/jdk/incubator/http/internal/websocket/MockReceiver.java	Wed Nov 29 16:59:38 2017 +0000
+++ b/test/jdk/java/net/httpclient/websocket/jdk.incubator.httpclient/jdk/incubator/http/internal/websocket/MockReceiver.java	Thu Nov 30 13:35:26 2017 +0300
@@ -70,6 +70,8 @@
                                     repeat(taskCompleter);
                                 });
                             }
+                        } else {
+                            taskCompleter.complete();
                         }
                     }
 
--- a/test/jdk/java/net/httpclient/websocket/jdk.incubator.httpclient/jdk/incubator/http/internal/websocket/ReceivingTest.java	Wed Nov 29 16:59:38 2017 +0000
+++ b/test/jdk/java/net/httpclient/websocket/jdk.incubator.httpclient/jdk/incubator/http/internal/websocket/ReceivingTest.java	Thu Nov 30 13:35:26 2017 +0300
@@ -24,15 +24,13 @@
 package jdk.incubator.http.internal.websocket;
 
 import jdk.incubator.http.WebSocket;
-import jdk.incubator.http.WebSocket.MessagePart;
 import org.testng.annotations.Test;
 
 import java.net.URI;
-import java.nio.ByteBuffer;
+import java.util.List;
 import java.util.concurrent.CompletableFuture;
 import java.util.concurrent.CompletionStage;
 import java.util.concurrent.TimeUnit;
-import java.util.concurrent.atomic.AtomicInteger;
 
 import static java.util.concurrent.CompletableFuture.completedStage;
 import static jdk.incubator.http.WebSocket.MessagePart.FIRST;
@@ -41,115 +39,75 @@
 import static jdk.incubator.http.WebSocket.MessagePart.WHOLE;
 import static jdk.incubator.http.WebSocket.NORMAL_CLOSURE;
 import static jdk.incubator.http.internal.common.Pair.pair;
-import static jdk.incubator.http.internal.websocket.WebSocketImpl.newInstance;
+import static jdk.incubator.http.internal.websocket.MockListener.ListenerInvocation.onClose;
+import static jdk.incubator.http.internal.websocket.MockListener.ListenerInvocation.onError;
+import static jdk.incubator.http.internal.websocket.MockListener.ListenerInvocation.onOpen;
+import static jdk.incubator.http.internal.websocket.MockListener.ListenerInvocation.onText;
+import static org.testng.Assert.assertEquals;
 
 public class ReceivingTest {
 
     // TODO: request in onClose/onError
     // TODO: throw exception in onClose/onError
+    // TODO: exception is thrown from request()
 
     @Test
-    public void testNonPositiveRequest() {
-        URI uri = URI.create("ws://localhost");
-        String subprotocol = "";
-        CompletableFuture<Throwable> result = new CompletableFuture<>();
-        newInstance(uri, subprotocol, new MockListener(Long.MAX_VALUE) {
-
-            final AtomicInteger onOpenCount = new AtomicInteger();
-            volatile WebSocket webSocket;
-
+    public void testNonPositiveRequest() throws Exception {
+        MockListener listener = new MockListener(Long.MAX_VALUE) {
             @Override
-            public void onOpen(WebSocket webSocket) {
-                int i = onOpenCount.incrementAndGet();
-                if (i > 1) {
-                    result.completeExceptionally(new IllegalStateException());
-                } else {
-                    this.webSocket = webSocket;
-                    webSocket.request(0);
-                }
-            }
-
-            @Override
-            public CompletionStage<?> onBinary(WebSocket webSocket,
-                                               ByteBuffer message,
-                                               MessagePart part) {
-                result.completeExceptionally(new IllegalStateException());
-                return null;
-            }
-
-            @Override
-            public CompletionStage<?> onText(WebSocket webSocket,
-                                             CharSequence message,
-                                             MessagePart part) {
-                result.completeExceptionally(new IllegalStateException());
-                return null;
+            protected void onOpen0(WebSocket webSocket) {
+                webSocket.request(0);
             }
-
-            @Override
-            public CompletionStage<?> onPing(WebSocket webSocket,
-                                             ByteBuffer message) {
-                result.completeExceptionally(new IllegalStateException());
-                return null;
-            }
-
-            @Override
-            public CompletionStage<?> onPong(WebSocket webSocket,
-                                             ByteBuffer message) {
-                result.completeExceptionally(new IllegalStateException());
-                return null;
-            }
-
-            @Override
-            public CompletionStage<?> onClose(WebSocket webSocket,
-                                              int statusCode,
-                                              String reason) {
-                result.completeExceptionally(new IllegalStateException());
-                return null;
-            }
-
-            @Override
-            public void onError(WebSocket webSocket, Throwable error) {
-                if (!this.webSocket.equals(webSocket)) {
-                    result.completeExceptionally(new IllegalArgumentException());
-                } else if (error == null || error.getClass() != IllegalArgumentException.class) {
-                    result.completeExceptionally(new IllegalArgumentException());
-                } else {
-                    result.complete(null);
-                }
-            }
-        }, new MockTransport() {
+        };
+        MockTransport transport = new MockTransport() {
             @Override
             protected Receiver newReceiver(MessageStreamConsumer consumer) {
-                return new MockReceiver(consumer, channel, pair(now(), m -> m.onText("1", WHOLE) ));
+                return new MockReceiver(consumer, channel, pair(now(), m -> m.onText("1", WHOLE)));
             }
-        });
-        result.join();
+        };
+        WebSocket ws = newInstance(listener, transport);
+        listener.onCloseOrOnErrorCalled().get(10, TimeUnit.SECONDS);
+        List<MockListener.ListenerInvocation> invocations = listener.invocations();
+        assertEquals(invocations, List.of(onOpen(ws), onError(ws, IllegalArgumentException.class)));
     }
 
     @Test
-    public void testText1() throws InterruptedException {
-        URI uri = URI.create("ws://localhost");
-        String subprotocol = "";
-        newInstance(uri, subprotocol, new MockListener(Long.MAX_VALUE),
-                    new MockTransport() {
-                        @Override
-                        protected Receiver newReceiver(MessageStreamConsumer consumer) {
-                            return new MockReceiver(consumer, channel,
-                                                    pair(now(), m -> m.onText("1", FIRST)),
-                                                    pair(now(), m -> m.onText("2", PART)),
-                                                    pair(now(), m -> m.onText("3", PART)),
-                                                    pair(now(), m -> m.onText("4", LAST)),
-                                                    pair(now(), m -> m.onClose(NORMAL_CLOSURE, "no reason")));
-                        }
-                    });
-        Thread.sleep(2000);
+    public void testText1() throws Exception {
+        MockListener listener = new MockListener(Long.MAX_VALUE);
+        MockTransport transport = new MockTransport() {
+            @Override
+            protected Receiver newReceiver(MessageStreamConsumer consumer) {
+                return new MockReceiver(consumer, channel,
+                                        pair(now(), m -> m.onText("1", FIRST)),
+                                        pair(now(), m -> m.onText("2", PART)),
+                                        pair(now(), m -> m.onText("3", PART)),
+                                        pair(now(), m -> m.onText("4", LAST)),
+                                        pair(now(), m -> m.onClose(NORMAL_CLOSURE, "no reason")));
+            }
+        };
+        WebSocket ws = newInstance(listener, transport);
+        listener.onCloseOrOnErrorCalled().get(10, TimeUnit.SECONDS);
+        List<MockListener.ListenerInvocation> invocations = listener.invocations();
+        assertEquals(invocations, List.of(onOpen(ws),
+                                          onText(ws, "1", FIRST),
+                                          onText(ws, "2", PART),
+                                          onText(ws, "3", PART),
+                                          onText(ws, "4", LAST),
+                                          onClose(ws, NORMAL_CLOSURE, "no reason")));
     }
 
-    private CompletionStage<?> inSeconds(long s) {
+    private static CompletionStage<?> seconds(long s) {
         return new CompletableFuture<>().completeOnTimeout(null, s, TimeUnit.SECONDS);
     }
 
-    private CompletionStage<?> now() {
+    private static CompletionStage<?> now() {
         return completedStage(null);
     }
+
+    private static WebSocket newInstance(WebSocket.Listener listener,
+                                         TransportSupplier transport) {
+        URI uri = URI.create("ws://localhost");
+        String subprotocol = "";
+        return WebSocketImpl.newInstance(uri, subprotocol, listener, transport);
+    }
 }