src/jdk.incubator.httpclient/share/classes/jdk/incubator/http/internal/websocket/WebSocketImpl.java
branchhttp-client-branch
changeset 55988 7f1e0cf933a6
parent 55973 4d9b002587db
child 55989 76ac25076fdc
--- a/src/jdk.incubator.httpclient/share/classes/jdk/incubator/http/internal/websocket/WebSocketImpl.java	Thu Dec 14 18:41:57 2017 +0000
+++ b/src/jdk.incubator.httpclient/share/classes/jdk/incubator/http/internal/websocket/WebSocketImpl.java	Fri Dec 15 00:47:16 2017 +0300
@@ -29,37 +29,25 @@
 import jdk.incubator.http.internal.common.Demand;
 import jdk.incubator.http.internal.common.Log;
 import jdk.incubator.http.internal.common.MinimalFuture;
-import jdk.incubator.http.internal.common.Pair;
 import jdk.incubator.http.internal.common.SequentialScheduler;
-import jdk.incubator.http.internal.common.SequentialScheduler.DeferredCompleter;
 import jdk.incubator.http.internal.common.Utils;
 import jdk.incubator.http.internal.websocket.OpeningHandshake.Result;
-import jdk.incubator.http.internal.websocket.OutgoingMessage.Binary;
-import jdk.incubator.http.internal.websocket.OutgoingMessage.Close;
-import jdk.incubator.http.internal.websocket.OutgoingMessage.Context;
-import jdk.incubator.http.internal.websocket.OutgoingMessage.Ping;
-import jdk.incubator.http.internal.websocket.OutgoingMessage.Pong;
-import jdk.incubator.http.internal.websocket.OutgoingMessage.Text;
 
 import java.io.IOException;
 import java.lang.ref.Reference;
 import java.net.ProtocolException;
 import java.net.URI;
 import java.nio.ByteBuffer;
-import java.util.Queue;
 import java.util.concurrent.CompletableFuture;
 import java.util.concurrent.CompletionStage;
-import java.util.concurrent.ConcurrentLinkedQueue;
 import java.util.concurrent.TimeUnit;
 import java.util.concurrent.TimeoutException;
 import java.util.concurrent.atomic.AtomicBoolean;
 import java.util.concurrent.atomic.AtomicReference;
-import java.util.function.Consumer;
 import java.util.function.Function;
 
 import static java.util.Objects.requireNonNull;
 import static jdk.incubator.http.internal.common.MinimalFuture.failedFuture;
-import static jdk.incubator.http.internal.common.Pair.pair;
 import static jdk.incubator.http.internal.websocket.StatusCodes.CLOSED_ABNORMALLY;
 import static jdk.incubator.http.internal.websocket.StatusCodes.NO_STATUS_CODE;
 import static jdk.incubator.http.internal.websocket.StatusCodes.isLegalToSendFromClient;
@@ -108,12 +96,7 @@
     private final Listener listener;
 
     private final AtomicBoolean outstandingSend = new AtomicBoolean();
-    private final SequentialScheduler sendScheduler = new SequentialScheduler(new SendTask());
-    private final Queue<Pair<OutgoingMessage, CompletableFuture<WebSocket>>>
-            queue = new ConcurrentLinkedQueue<>();
-    private final Context context = new OutgoingMessage.Context();
-    private final Transmitter transmitter;
-    private final Receiver receiver;
+    private final Transport<WebSocket> transport;
     private final SequentialScheduler receiveScheduler = new SequentialScheduler(new ReceiveTask());
     private final Demand demand = new Demand();
 
@@ -140,10 +123,10 @@
     }
 
     /* Exposed for testing purposes */
-    static WebSocket newInstance(URI uri,
-                                 String subprotocol,
-                                 Listener listener,
-                                 TransportSupplier transport) {
+    static WebSocketImpl newInstance(URI uri,
+                                     String subprotocol,
+                                     Listener listener,
+                                     TransportFactory transport) {
         WebSocketImpl ws = new WebSocketImpl(uri, subprotocol, listener, transport);
         // This initialisation is outside of the constructor for the sake of
         // safe publication of WebSocketImpl.this
@@ -154,68 +137,82 @@
     private WebSocketImpl(URI uri,
                           String subprotocol,
                           Listener listener,
-                          TransportSupplier transport) {
+                          TransportFactory transportFactory) {
         this.uri = requireNonNull(uri);
         this.subprotocol = requireNonNull(subprotocol);
         this.listener = requireNonNull(listener);
-        this.transmitter = transport.transmitter();
-        this.receiver = transport.receiver(new SignallingMessageConsumer());
+        this.transport = transportFactory.createTransport(
+                () -> WebSocketImpl.this, // What about escape of WebSocketImpl.this?
+                new SignallingMessageConsumer());
     }
 
     @Override
-    public CompletableFuture<WebSocket> sendText(CharSequence message, boolean isLast) {
-        return enqueueExclusively(new Text(message, isLast));
+    public CompletableFuture<WebSocket> sendText(CharSequence message,
+                                                 boolean isLast) {
+        if (!outstandingSend.compareAndSet(false, true)) {
+            return failedFuture(new IllegalStateException("Send pending"));
+        }
+        CompletableFuture<WebSocket> cf = transport.sendText(message, isLast);
+        cf.whenComplete((r, e) -> outstandingSend.set(false));
+        return cf;
     }
 
     @Override
-    public CompletableFuture<WebSocket> sendBinary(ByteBuffer message, boolean isLast) {
-        return enqueueExclusively(new Binary(message, isLast));
+    public CompletableFuture<WebSocket> sendBinary(ByteBuffer message,
+                                                   boolean isLast) {
+        if (!outstandingSend.compareAndSet(false, true)) {
+            return failedFuture(new IllegalStateException("Send pending"));
+        }
+        CompletableFuture<WebSocket> cf = transport.sendBinary(message, isLast);
+        // Optimize?
+        //        if (cf.isDone()) {
+        //            outstandingSend.set(false);
+        //        } else {
+        //            cf.whenComplete((r, e) -> outstandingSend.set(false));
+        //        }
+        cf.whenComplete((r, e) -> outstandingSend.set(false));
+        return cf;
     }
 
     @Override
     public CompletableFuture<WebSocket> sendPing(ByteBuffer message) {
-        return enqueue(new Ping(message));
+        return transport.sendPing(message);
     }
 
     @Override
     public CompletableFuture<WebSocket> sendPong(ByteBuffer message) {
-        return enqueue(new Pong(message));
+        return transport.sendPong(message);
     }
 
     @Override
     public CompletableFuture<WebSocket> sendClose(int statusCode, String reason) {
         if (!isLegalToSendFromClient(statusCode)) {
-            return failedFuture(
-                    new IllegalArgumentException("statusCode: " + statusCode));
+            return failedFuture(new IllegalArgumentException("statusCode"));
         }
-        Close msg;
-        try {
-            msg = new Close(statusCode, reason);
-        } catch (IllegalArgumentException e) {
-            return failedFuture(e);
-        }
-        outputClosed = true;
-        return enqueueClose(msg);
+        return sendClose0(statusCode, reason);
     }
 
     /*
-     * Sends a Close message, then shuts down the transmitter since no more
+     * Sends a Close message, then shuts down the output since no more
      * messages are expected to be sent after this.
+     *
+     * TODO: Even if arguments are illegal the default message will be sent.
      */
-    private CompletableFuture<WebSocket> enqueueClose(Close m) {
+    private CompletableFuture<WebSocket> sendClose0(int statusCode, String reason ) {
         // TODO: MUST be a CF created once and shared across sendClose, otherwise
         // a second sendClose may prematurely close the channel
-        return enqueue(m)
+        outputClosed = true;
+        return transport.sendClose(statusCode, reason)
                 .orTimeout(60, TimeUnit.SECONDS)
                 .whenComplete((r, error) -> {
                     try {
-                        transmitter.close();
+                        transport.closeOutput();
                     } catch (IOException e) {
                         Log.logError(e);
                     }
                     if (error instanceof TimeoutException) {
                         try {
-                            receiver.close();
+                            transport.closeInput();
                         } catch (IOException e) {
                             Log.logError(e);
                         }
@@ -223,77 +220,6 @@
                 });
     }
 
-    /*
-     * Accepts the given message into the outgoing queue in a mutually-exclusive
-     * fashion in respect to other messages accepted through this method. No
-     * further messages will be accepted until the returned CompletableFuture
-     * completes. This method is used to enforce "one outstanding send
-     * operation" policy.
-     */
-    private CompletableFuture<WebSocket> enqueueExclusively(OutgoingMessage m) {
-        if (!outstandingSend.compareAndSet(false, true)) {
-            return failedFuture(new IllegalStateException("Send pending"));
-        }
-        return enqueue(m).whenComplete((r, e) -> outstandingSend.set(false));
-    }
-
-    private CompletableFuture<WebSocket> enqueue(OutgoingMessage m) {
-        CompletableFuture<WebSocket> cf = new MinimalFuture<>();
-        boolean added = queue.add(pair(m, cf));
-        if (!added) {
-            // The queue is supposed to be unbounded
-            throw new InternalError();
-        }
-        sendScheduler.runOrSchedule();
-        return cf;
-    }
-
-    /*
-     * This is a message sending task. It pulls messages from the queue one by
-     * one and sends them. It may be run in different threads, but never
-     * concurrently.
-     */
-    private class SendTask implements SequentialScheduler.RestartableTask {
-
-        @Override
-        public void run(DeferredCompleter taskCompleter) {
-            Pair<OutgoingMessage, CompletableFuture<WebSocket>> p = queue.poll();
-            if (p == null) {
-                taskCompleter.complete();
-                return;
-            }
-            OutgoingMessage message = p.first;
-            CompletableFuture<WebSocket> cf = p.second;
-            try {
-                if (!message.contextualize(context)) { // Do not send the message
-                    cf.complete(null);
-                    repeat(taskCompleter);
-                    return;
-                }
-                Consumer<Exception> h = e -> {
-                    if (e == null) {
-                        cf.complete(WebSocketImpl.this);
-                    } else {
-                        cf.completeExceptionally(e);
-                    }
-                    repeat(taskCompleter);
-                };
-                transmitter.send(message, h);
-            } catch (Throwable t) {
-                cf.completeExceptionally(t);
-                repeat(taskCompleter);
-            }
-        }
-
-        private void repeat(DeferredCompleter taskCompleter) {
-            taskCompleter.complete();
-            // More than a single message may have been enqueued while
-            // the task has been busy with the current message, but
-            // there is only a single signal recorded
-            sendScheduler.runOrSchedule();
-        }
-    }
-
     @Override
     public void request(long n) {
         if (demand.increase(n)) {
@@ -348,7 +274,7 @@
      */
     private class ReceiveTask extends SequentialScheduler.CompleteRestartableTask {
 
-        // Receiver only asked here and nowhere else because we must make sure
+        // Transport only asked here and nowhere else because we must make sure
         // onOpen is invoked first and no messages become pending before onOpen
         // finishes
 
@@ -387,7 +313,7 @@
                         case IDLE:
                             if (demand.tryDecrement()
                                     && tryChangeState(IDLE, WAITING)) {
-                                receiver.request(1);
+                                transport.request(1);
                             }
                             return;
                         case WAITING:
@@ -404,13 +330,13 @@
         }
 
         private void processError() throws IOException {
-            receiver.close();
+            transport.closeInput();
             receiveScheduler.stop();
             Throwable err = error.get();
             if (err instanceof FailWebSocketException) {
                 int code1 = ((FailWebSocketException) err).getStatusCode();
                 err = new ProtocolException().initCause(err);
-                enqueueClose(new Close(code1, ""))
+                sendClose0(code1, "")
                         .whenComplete(
                                 (r, e) -> {
                                     if (e != null) {
@@ -422,7 +348,7 @@
         }
 
         private void processClose() throws IOException {
-            receiver.close();
+            transport.closeInput();
             receiveScheduler.stop();
             CompletionStage<?> readyToClose;
             readyToClose = listener.onClose(WebSocketImpl.this, statusCode, reason);
@@ -436,7 +362,7 @@
                 code = statusCode;
             }
             readyToClose.whenComplete((r, e) -> {
-                enqueueClose(new Close(code, ""))
+                sendClose0(code, "")
                         .whenComplete((r1, e1) -> {
                             if (e1 != null) {
                                 Log.logError(e1);
@@ -458,7 +384,7 @@
                     .put(binaryData)
                     .flip();
             // Non-exclusive send;
-            CompletableFuture<WebSocket> pongSent = enqueue(new Pong(copy));
+            CompletableFuture<WebSocket> pongSent = transport.sendPong(copy);
             pongSent.whenComplete(
                     (r, e) -> {
                         if (e != null) {
@@ -499,9 +425,9 @@
     private void close() {
         try {
             try {
-                receiver.close();
+                transport.closeInput();
             } finally {
-                transmitter.close();
+                transport.closeOutput();
             }
         } catch (Throwable t) {
             Log.logError(t);
@@ -520,7 +446,7 @@
             Log.logTrace("Close: {0}, ''{1}''", statusCode, reason);
         } else {
             try {
-                receiver.close();
+                transport.closeInput();
             } catch (Throwable t) {
                 Log.logError(t);
             }
@@ -531,7 +457,7 @@
 
         @Override
         public void onText(CharSequence data, MessagePart part) {
-            receiver.acknowledge();
+            transport.acknowledgeReception();
             text = data;
             WebSocketImpl.this.part = part;
             tryChangeState(WAITING, TEXT);
@@ -539,7 +465,7 @@
 
         @Override
         public void onBinary(ByteBuffer data, MessagePart part) {
-            receiver.acknowledge();
+            transport.acknowledgeReception();
             binaryData = data;
             WebSocketImpl.this.part = part;
             tryChangeState(WAITING, BINARY);
@@ -547,27 +473,27 @@
 
         @Override
         public void onPing(ByteBuffer data) {
-            receiver.acknowledge();
+            transport.acknowledgeReception();
             binaryData = data;
             tryChangeState(WAITING, PING);
         }
 
         @Override
         public void onPong(ByteBuffer data) {
-            receiver.acknowledge();
+            transport.acknowledgeReception();
             binaryData = data;
             tryChangeState(WAITING, PONG);
         }
 
         @Override
         public void onClose(int statusCode, CharSequence reason) {
-            receiver.acknowledge();
+            transport.acknowledgeReception();
             signalClose(statusCode, reason.toString());
         }
 
         @Override
         public void onComplete() {
-            receiver.acknowledge();
+            transport.acknowledgeReception();
             signalClose(CLOSED_ABNORMALLY, "");
         }
 
@@ -602,4 +528,9 @@
         }
         return false;
     }
+
+    /* Exposed for testing purposes */
+    protected final Transport<WebSocket> transport() {
+        return transport;
+    }
 }