http-client-branch: (WebSocket) swapping automatic pong replies http-client-branch
authorprappo
Wed, 14 Mar 2018 13:03:11 +0000
branchhttp-client-branch
changeset 56303 a82058c084ef
parent 56300 13a2ec671e62
child 56304 065641767a75
http-client-branch: (WebSocket) swapping automatic pong replies
src/java.net.http/share/classes/jdk/internal/net/http/websocket/MessageQueue.java
src/java.net.http/share/classes/jdk/internal/net/http/websocket/Transport.java
src/java.net.http/share/classes/jdk/internal/net/http/websocket/TransportImpl.java
src/java.net.http/share/classes/jdk/internal/net/http/websocket/WebSocketImpl.java
test/jdk/java/net/httpclient/websocket/MockListener.java
test/jdk/java/net/httpclient/websocket/WebSocketTest.java
test/jdk/java/net/httpclient/websocket/java.net.http/jdk/internal/net/http/websocket/MessageQueueTest.java
--- a/src/java.net.http/share/classes/jdk/internal/net/http/websocket/MessageQueue.java	Wed Mar 14 09:01:15 2018 +0000
+++ b/src/java.net.http/share/classes/jdk/internal/net/http/websocket/MessageQueue.java	Wed Mar 14 13:03:11 2018 +0000
@@ -34,6 +34,7 @@
 import java.util.concurrent.CompletableFuture;
 import java.util.concurrent.atomic.AtomicInteger;
 import java.util.function.BiConsumer;
+import java.util.function.Supplier;
 
 /*
  * A FIFO message storage facility.
@@ -135,11 +136,12 @@
                             CompletableFuture<T> future)
             throws IOException
     {
-        add(MessageQueue.Type.TEXT, null, message, isLast, -1, attachment,
+        add(MessageQueue.Type.TEXT, null, null, message, isLast, -1, attachment,
             action, future);
     }
 
     private <T> void add(Type type,
+                         Supplier<? extends ByteBuffer> binarySupplier,
                          ByteBuffer binary,
                          CharBuffer text,
                          boolean isLast,
@@ -149,6 +151,9 @@
                          CompletableFuture<? super T> future)
             throws IOException
     {
+        // Pong "subtype" is determined by whichever field (data carrier)
+        // is not null. Both fields cannot be null or non-null simultaneously.
+        assert type != Type.PONG || (binary == null ^ binarySupplier == null);
         int h, currentTail, newTail;
         do {
             h = head;
@@ -163,6 +168,7 @@
             throw new InternalError();
         }
         t.type = type;
+        t.binarySupplier = binarySupplier;
         t.binary = binary;
         t.text = text;
         t.isLast = isLast;
@@ -180,7 +186,7 @@
                               CompletableFuture<? super T> future)
             throws IOException
     {
-        add(MessageQueue.Type.BINARY, message, null, isLast, -1, attachment,
+        add(MessageQueue.Type.BINARY, null, message, null, isLast, -1, attachment,
             action, future);
     }
 
@@ -190,7 +196,7 @@
                             CompletableFuture<? super T> future)
             throws IOException
     {
-        add(MessageQueue.Type.PING, message, null, false, -1, attachment,
+        add(MessageQueue.Type.PING, null, message, null, false, -1, attachment,
             action, future);
     }
 
@@ -200,7 +206,17 @@
                             CompletableFuture<? super T> future)
             throws IOException
     {
-        add(MessageQueue.Type.PONG, message, null, false, -1, attachment,
+        add(MessageQueue.Type.PONG, null, message, null, false, -1, attachment,
+            action, future);
+    }
+
+    public <T> void addPong(Supplier<? extends ByteBuffer> message,
+                            T attachment,
+                            BiConsumer<? super T, ? super Throwable> action,
+                            CompletableFuture<? super T> future)
+            throws IOException
+    {
+        add(MessageQueue.Type.PONG, message, null, null, false, -1, attachment,
             action, future);
     }
 
@@ -211,7 +227,7 @@
                              CompletableFuture<? super T> future)
             throws IOException
     {
-        add(MessageQueue.Type.CLOSE, null, reason, false, statusCode,
+        add(MessageQueue.Type.CLOSE, null, null, reason, false, statusCode,
             attachment, action, future);
     }
 
@@ -258,8 +274,13 @@
                 }
             case PONG:
                 try {
-                    return (R) callback.onPong(h.binary, h.attachment, h.action,
-                                               h.future);
+                    if (h.binarySupplier != null) {
+                        return (R) callback.onPong(h.binarySupplier, h.attachment,
+                                                   h.action, h.future);
+                    } else {
+                        return (R) callback.onPong(h.binary, h.attachment, h.action,
+                                                   h.future);
+                    }
                 } catch (Throwable t) {
                     throw (E) t;
                 }
@@ -286,6 +307,7 @@
             throw new InternalError("Queue empty");
         }
         h.type = null;
+        h.binarySupplier = null;
         h.binary = null;
         h.text = null;
         h.attachment = null;
@@ -334,6 +356,11 @@
                      BiConsumer<? super T, ? super Throwable> action,
                      CompletableFuture<? super T> future) throws E;
 
+        <T> R onPong(Supplier<? extends ByteBuffer> message,
+                     T attachment,
+                     BiConsumer<? super T, ? super Throwable> action,
+                     CompletableFuture<? super T> future) throws E;
+
         <T> R onClose(int statusCode,
                       CharBuffer reason,
                       T attachment,
@@ -358,6 +385,7 @@
         // -- The source message fields --
 
         private Type type;
+        private Supplier<? extends ByteBuffer> binarySupplier;
         private ByteBuffer binary;
         private CharBuffer text;
         private boolean isLast;
--- a/src/java.net.http/share/classes/jdk/internal/net/http/websocket/Transport.java	Wed Mar 14 09:01:15 2018 +0000
+++ b/src/java.net.http/share/classes/jdk/internal/net/http/websocket/Transport.java	Wed Mar 14 13:03:11 2018 +0000
@@ -29,6 +29,7 @@
 import java.nio.ByteBuffer;
 import java.util.concurrent.CompletableFuture;
 import java.util.function.BiConsumer;
+import java.util.function.Supplier;
 
 /*
  * A WebSocket view of the underlying communication channel. This view provides
@@ -77,6 +78,14 @@
                                       T attachment,
                                       BiConsumer<? super T, ? super Throwable> action);
 
+    /*
+     * Sends a Pong message with initially unknown data. Used for sending the
+     * most recent automatic Pong reply.
+     */
+    <T> CompletableFuture<T> sendPong(Supplier<? extends ByteBuffer> message,
+                                      T attachment,
+                                      BiConsumer<? super T, ? super Throwable> action);
+
     <T> CompletableFuture<T> sendClose(int statusCode,
                                        String reason,
                                        T attachment,
--- a/src/java.net.http/share/classes/jdk/internal/net/http/websocket/TransportImpl.java	Wed Mar 14 09:01:15 2018 +0000
+++ b/src/java.net.http/share/classes/jdk/internal/net/http/websocket/TransportImpl.java	Wed Mar 14 13:03:11 2018 +0000
@@ -39,6 +39,7 @@
 import java.util.concurrent.atomic.AtomicLong;
 import java.util.concurrent.atomic.AtomicReference;
 import java.util.function.BiConsumer;
+import java.util.function.Supplier;
 
 import static jdk.internal.net.http.websocket.TransportImpl.ChannelState.AVAILABLE;
 import static jdk.internal.net.http.websocket.TransportImpl.ChannelState.CLOSED;
@@ -218,6 +219,29 @@
     }
 
     @Override
+    public <T> CompletableFuture<T> sendPong(Supplier<? extends ByteBuffer> message,
+                                             T attachment,
+                                             BiConsumer<? super T, ? super Throwable> action) {
+        long id;
+        if (DEBUG) {
+            id = counter.incrementAndGet();
+            System.out.printf("[Transport] enter send pong %s supplier=%s%n",
+                              id, message);
+        }
+        MinimalFuture<T> f = new MinimalFuture<>();
+        try {
+            queue.addPong(message, attachment, action, f);
+            sendScheduler.runOrSchedule();
+        } catch (IOException e) {
+            f.completeExceptionally(e);
+        }
+        if (DEBUG) {
+            System.out.printf("[Transport] exit send pong %s returned %s%n", id, f);
+        }
+        return f;
+    }
+
+    @Override
     public <T> CompletableFuture<T> sendClose(int statusCode,
                                               String reason,
                                               T attachment,
@@ -357,6 +381,14 @@
             }
 
             @Override
+            public <T> Boolean onPong(Supplier<? extends ByteBuffer> message,
+                                      T attachment,
+                                      BiConsumer<? super T, ? super Throwable> action,
+                                      CompletableFuture<? super T> future) throws IOException {
+                return encoder.encodePong(message.get(), dst);
+            }
+
+            @Override
             public <T> Boolean onClose(int statusCode,
                                        CharBuffer reason,
                                        T attachment,
@@ -437,6 +469,18 @@
             }
 
             @Override
+            public <T> Boolean onPong(Supplier<? extends ByteBuffer> message,
+                                      T attachment,
+                                      BiConsumer<? super T, ? super Throwable> action,
+                                      CompletableFuture<? super T> future)
+            {
+                SendTask.this.attachment = attachment;
+                SendTask.this.action = action;
+                SendTask.this.future = future;
+                return true;
+            }
+
+            @Override
             public <T> Boolean onClose(int statusCode,
                                        CharBuffer reason,
                                        T attachment,
--- a/src/java.net.http/share/classes/jdk/internal/net/http/websocket/WebSocketImpl.java	Wed Mar 14 09:01:15 2018 +0000
+++ b/src/java.net.http/share/classes/jdk/internal/net/http/websocket/WebSocketImpl.java	Wed Mar 14 13:03:11 2018 +0000
@@ -87,6 +87,7 @@
         ERROR;
     }
 
+    private final AtomicReference<ByteBuffer> lastAutomaticPong = new AtomicReference<>();
     private final MinimalFuture<WebSocket> DONE = MinimalFuture.completedFuture(this);
     private final long closeTimeout;
     private volatile boolean inputClosed;
@@ -582,13 +583,17 @@
             ByteBuffer copy = ByteBuffer.allocate(binaryData.remaining())
                     .put(binaryData)
                     .flip();
-            // Non-exclusive send;
-            BiConsumer<WebSocketImpl, Throwable> reporter = (r, e) -> {
-                if (e != null) { // TODO: better error handing. What if already closed?
-                    signalError(Utils.getCompletionCause(e));
-                }
-            };
-            transport.sendPong(copy, WebSocketImpl.this, reporter);
+            if (!trySwapAutomaticPong(copy)) {
+                // Non-exclusive send;
+                BiConsumer<WebSocketImpl, Throwable> reporter = (r, e) -> {
+                    if (e != null) { // TODO: better error handing. What if already closed?
+                        signalError(Utils.getCompletionCause(e));
+                    }
+                };
+                transport.sendPong(WebSocketImpl.this::clearAutomaticPong,
+                                   WebSocketImpl.this,
+                                   reporter);
+            }
             long id;
             if (DEBUG) {
                 id = receiveCounter.incrementAndGet();
@@ -658,6 +663,46 @@
         }
     }
 
+    private ByteBuffer clearAutomaticPong() {
+        ByteBuffer data;
+        do {
+            data = lastAutomaticPong.get();
+            if (data == null) {
+                // This method must never be called unless a message that is
+                // using it has been added previously
+                throw new InternalError();
+            }
+        } while (!lastAutomaticPong.compareAndSet(data, null));
+        return data;
+    }
+
+    private boolean trySwapAutomaticPong(ByteBuffer copy) {
+        ByteBuffer message;
+        boolean swapped;
+        while (true) {
+            message = lastAutomaticPong.get();
+            if (message == null) {
+                if (!lastAutomaticPong.compareAndSet(null, copy)) {
+                    // It's only this method that can change null to ByteBuffer,
+                    // and this method is invoked at most by one thread at a
+                    // time. Thus no failure in the atomic operation above is
+                    // expected.
+                    throw new InternalError();
+                }
+                swapped = false;
+                break;
+            } else if (lastAutomaticPong.compareAndSet(message, copy)) {
+                swapped = true;
+                break;
+            }
+        }
+        if (DEBUG) {
+            System.out.printf("[WebSocket] swapped automatic pong from %s to %s%n",
+                              message, copy);
+        }
+        return swapped;
+    }
+
     private void signalOpen() {
         if (DEBUG) {
             System.out.printf("[WebSocket] signalOpen%n");
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/test/jdk/java/net/httpclient/websocket/MockListener.java	Wed Mar 14 13:03:11 2018 +0000
@@ -0,0 +1,448 @@
+/*
+ * Copyright (c) 2017, 2018, Oracle and/or its affiliates. All rights reserved.
+ * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
+ *
+ * This code is free software; you can redistribute it and/or modify it
+ * under the terms of the GNU General Public License version 2 only, as
+ * published by the Free Software Foundation.
+ *
+ * This code is distributed in the hope that it will be useful, but WITHOUT
+ * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
+ * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
+ * version 2 for more details (a copy is included in the LICENSE file that
+ * accompanied this code).
+ *
+ * You should have received a copy of the GNU General Public License version
+ * 2 along with this work; if not, write to the Free Software Foundation,
+ * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
+ *
+ * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
+ * or visit www.oracle.com if you need additional information or have any
+ * questions.
+ */
+
+import java.net.http.WebSocket;
+import java.net.http.WebSocket.MessagePart;
+import java.nio.ByteBuffer;
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Objects;
+import java.util.concurrent.CompletableFuture;
+import java.util.concurrent.CompletionStage;
+import java.util.function.Predicate;
+
+public class MockListener implements WebSocket.Listener {
+
+    private final long bufferSize;
+    private long count;
+    private final List<Invocation> invocations = new ArrayList<>();
+    private final CompletableFuture<?> lastCall = new CompletableFuture<>();
+    private final Predicate<? super Invocation> collectUntil;
+
+    public MockListener() {
+        this(i -> i instanceof OnClose || i instanceof OnError);
+    }
+
+    public MockListener(Predicate<? super Invocation> collectUntil) {
+        this(2, collectUntil);
+    }
+
+    /*
+     * Typical buffer sizes: 1, n, Long.MAX_VALUE
+     */
+    public MockListener(long bufferSize,
+                        Predicate<? super Invocation> collectUntil) {
+        if (bufferSize < 1) {
+            throw new IllegalArgumentException();
+        }
+        Objects.requireNonNull(collectUntil);
+        this.bufferSize = bufferSize;
+        this.collectUntil = collectUntil;
+    }
+
+    @Override
+    public void onOpen(WebSocket webSocket) {
+        System.out.printf("onOpen(%s)%n", webSocket);
+        OnOpen inv = new OnOpen(webSocket);
+        invocations.add(inv);
+        if (collectUntil.test(inv)) {
+            lastCall.complete(null);
+        }
+        onOpen0(webSocket);
+    }
+
+    protected void onOpen0(WebSocket webSocket) {
+        replenish(webSocket);
+    }
+
+    @Override
+    public CompletionStage<?> onText(WebSocket webSocket,
+                                     CharSequence message,
+                                     MessagePart part) {
+        System.out.printf("onText(%s, %s, %s)%n", webSocket, message, part);
+        OnText inv = new OnText(webSocket, message.toString(), part);
+        invocations.add(inv);
+        if (collectUntil.test(inv)) {
+            lastCall.complete(null);
+        }
+        return onText0(webSocket, message, part);
+    }
+
+    protected CompletionStage<?> onText0(WebSocket webSocket,
+                                         CharSequence message,
+                                         MessagePart part) {
+        replenish(webSocket);
+        return null;
+    }
+
+    @Override
+    public CompletionStage<?> onBinary(WebSocket webSocket,
+                                       ByteBuffer message,
+                                       MessagePart part) {
+        System.out.printf("onBinary(%s, %s, %s)%n", webSocket, message, part);
+        OnBinary inv = new OnBinary(webSocket, fullCopy(message), part);
+        invocations.add(inv);
+        if (collectUntil.test(inv)) {
+            lastCall.complete(null);
+        }
+        return onBinary0(webSocket, message, part);
+    }
+
+    protected CompletionStage<?> onBinary0(WebSocket webSocket,
+                                           ByteBuffer message,
+                                           MessagePart part) {
+        replenish(webSocket);
+        return null;
+    }
+
+    @Override
+    public CompletionStage<?> onPing(WebSocket webSocket, ByteBuffer message) {
+        System.out.printf("onPing(%s, %s)%n", webSocket, message);
+        OnPing inv = new OnPing(webSocket, fullCopy(message));
+        invocations.add(inv);
+        if (collectUntil.test(inv)) {
+            lastCall.complete(null);
+        }
+        return onPing0(webSocket, message);
+    }
+
+    protected CompletionStage<?> onPing0(WebSocket webSocket, ByteBuffer message) {
+        replenish(webSocket);
+        return null;
+    }
+
+    @Override
+    public CompletionStage<?> onPong(WebSocket webSocket, ByteBuffer message) {
+        System.out.printf("onPong(%s, %s)%n", webSocket, message);
+        OnPong inv = new OnPong(webSocket, fullCopy(message));
+        invocations.add(inv);
+        if (collectUntil.test(inv)) {
+            lastCall.complete(null);
+        }
+        return onPong0(webSocket, message);
+    }
+
+    protected CompletionStage<?> onPong0(WebSocket webSocket, ByteBuffer message) {
+        replenish(webSocket);
+        return null;
+    }
+
+    @Override
+    public CompletionStage<?> onClose(WebSocket webSocket,
+                                      int statusCode,
+                                      String reason) {
+        System.out.printf("onClose(%s, %s, %s)%n", webSocket, statusCode, reason);
+        OnClose inv = new OnClose(webSocket, statusCode, reason);
+        invocations.add(inv);
+        if (collectUntil.test(inv)) {
+            lastCall.complete(null);
+        }
+        return null;
+    }
+
+    @Override
+    public void onError(WebSocket webSocket, Throwable error) {
+        System.out.printf("onError(%s, %s)%n", webSocket, error);
+        OnError inv = new OnError(webSocket, error == null ? null : error.getClass());
+        invocations.add(inv);
+        if (collectUntil.test(inv)) {
+            lastCall.complete(null);
+        }
+    }
+
+    public List<Invocation> invocations() {
+        lastCall.join();
+        return new ArrayList<>(invocations);
+    }
+
+    protected void replenish(WebSocket webSocket) {
+        if (--count <= 0) {
+            count = bufferSize - bufferSize / 2;
+        }
+        webSocket.request(count);
+    }
+
+    public abstract static class Invocation {
+
+        public static OnOpen onOpen(WebSocket webSocket) {
+            return new OnOpen(webSocket);
+        }
+
+        public static OnText onText(WebSocket webSocket,
+                                    String text,
+                                    MessagePart part) {
+            return new OnText(webSocket, text, part);
+        }
+
+        public static OnBinary onBinary(WebSocket webSocket,
+                                        ByteBuffer data,
+                                        MessagePart part) {
+            return new OnBinary(webSocket, data, part);
+        }
+
+        public static OnPing onPing(WebSocket webSocket,
+                                    ByteBuffer data) {
+            return new OnPing(webSocket, data);
+        }
+
+        public static OnPong onPong(WebSocket webSocket,
+                                    ByteBuffer data) {
+            return new OnPong(webSocket, data);
+        }
+
+        public static OnClose onClose(WebSocket webSocket,
+                                      int statusCode,
+                                      String reason) {
+            return new OnClose(webSocket, statusCode, reason);
+        }
+
+        public static OnError onError(WebSocket webSocket,
+                                      Class<? extends Throwable> clazz) {
+            return new OnError(webSocket, clazz);
+        }
+
+        final WebSocket webSocket;
+
+        private Invocation(WebSocket webSocket) {
+            this.webSocket = webSocket;
+        }
+    }
+
+    public static final class OnOpen extends Invocation {
+
+        public OnOpen(WebSocket webSocket) {
+            super(webSocket);
+        }
+
+        @Override
+        public boolean equals(Object o) {
+            if (this == o) return true;
+            if (o == null || getClass() != o.getClass()) return false;
+            Invocation that = (Invocation) o;
+            return Objects.equals(webSocket, that.webSocket);
+        }
+
+        @Override
+        public int hashCode() {
+            return Objects.hashCode(webSocket);
+        }
+
+        @Override
+        public String toString() {
+            return String.format("onOpen(%s)", webSocket);
+        }
+    }
+
+    public static final class OnText extends Invocation {
+
+        final String text;
+        final MessagePart part;
+
+        public OnText(WebSocket webSocket, String text, MessagePart part) {
+            super(webSocket);
+            this.text = text;
+            this.part = part;
+        }
+
+        @Override
+        public boolean equals(Object o) {
+            if (this == o) return true;
+            if (o == null || getClass() != o.getClass()) return false;
+            OnText onText = (OnText) o;
+            return Objects.equals(text, onText.text) &&
+                    part == onText.part &&
+                    Objects.equals(webSocket, onText.webSocket);
+        }
+
+        @Override
+        public int hashCode() {
+            return Objects.hash(text, part, webSocket);
+        }
+
+        @Override
+        public String toString() {
+            return String.format("onText(%s, %s, %s)", webSocket, text, part);
+        }
+    }
+
+    public static final class OnBinary extends Invocation {
+
+        final ByteBuffer data;
+        final MessagePart part;
+
+        public OnBinary(WebSocket webSocket, ByteBuffer data, MessagePart part) {
+            super(webSocket);
+            this.data = data;
+            this.part = part;
+        }
+
+        @Override
+        public boolean equals(Object o) {
+            if (this == o) return true;
+            if (o == null || getClass() != o.getClass()) return false;
+            OnBinary onBinary = (OnBinary) o;
+            return Objects.equals(data, onBinary.data) &&
+                    part == onBinary.part &&
+                    Objects.equals(webSocket, onBinary.webSocket);
+        }
+
+        @Override
+        public int hashCode() {
+            return Objects.hash(data, part, webSocket);
+        }
+
+        @Override
+        public String toString() {
+            return String.format("onBinary(%s, %s, %s)", webSocket, data, part);
+        }
+    }
+
+    public static final class OnPing extends Invocation {
+
+        final ByteBuffer data;
+
+        public OnPing(WebSocket webSocket, ByteBuffer data) {
+            super(webSocket);
+            this.data = data;
+        }
+
+        @Override
+        public boolean equals(Object o) {
+            if (this == o) return true;
+            if (o == null || getClass() != o.getClass()) return false;
+            OnPing onPing = (OnPing) o;
+            return Objects.equals(data, onPing.data) &&
+                    Objects.equals(webSocket, onPing.webSocket);
+        }
+
+        @Override
+        public int hashCode() {
+            return Objects.hash(data, webSocket);
+        }
+
+        @Override
+        public String toString() {
+            return String.format("onPing(%s, %s)", webSocket, data);
+        }
+    }
+
+    public static final class OnPong extends Invocation {
+
+        final ByteBuffer data;
+
+        public OnPong(WebSocket webSocket, ByteBuffer data) {
+            super(webSocket);
+            this.data = data;
+        }
+
+        @Override
+        public boolean equals(Object o) {
+            if (this == o) return true;
+            if (o == null || getClass() != o.getClass()) return false;
+            OnPong onPong = (OnPong) o;
+            return Objects.equals(data, onPong.data) &&
+                    Objects.equals(webSocket, onPong.webSocket);
+        }
+
+        @Override
+        public int hashCode() {
+            return Objects.hash(data, webSocket);
+        }
+
+        @Override
+        public String toString() {
+            return String.format("onPong(%s, %s)", webSocket, data);
+        }
+    }
+
+    public static final class OnClose extends Invocation {
+
+        final int statusCode;
+        final String reason;
+
+        public OnClose(WebSocket webSocket, int statusCode, String reason) {
+            super(webSocket);
+            this.statusCode = statusCode;
+            this.reason = reason;
+        }
+
+        @Override
+        public boolean equals(Object o) {
+            if (this == o) return true;
+            if (o == null || getClass() != o.getClass()) return false;
+            OnClose onClose = (OnClose) o;
+            return statusCode == onClose.statusCode &&
+                    Objects.equals(reason, onClose.reason) &&
+                    Objects.equals(webSocket, onClose.webSocket);
+        }
+
+        @Override
+        public int hashCode() {
+            return Objects.hash(statusCode, reason, webSocket);
+        }
+
+        @Override
+        public String toString() {
+            return String.format("onClose(%s, %s, %s)", webSocket, statusCode, reason);
+        }
+    }
+
+    public static final class OnError extends Invocation {
+
+        final Class<? extends Throwable> clazz;
+
+        public OnError(WebSocket webSocket, Class<? extends Throwable> clazz) {
+            super(webSocket);
+            this.clazz = clazz;
+        }
+
+        @Override
+        public boolean equals(Object o) {
+            if (this == o) return true;
+            if (o == null || getClass() != o.getClass()) return false;
+            OnError onError = (OnError) o;
+            return Objects.equals(clazz, onError.clazz) &&
+                    Objects.equals(webSocket, onError.webSocket);
+        }
+
+        @Override
+        public int hashCode() {
+            return Objects.hash(clazz, webSocket);
+        }
+
+        @Override
+        public String toString() {
+            return String.format("onError(%s, %s)", webSocket, clazz);
+        }
+    }
+
+    private static ByteBuffer fullCopy(ByteBuffer src) {
+        ByteBuffer copy = ByteBuffer.allocate(src.capacity());
+        int p = src.position();
+        int l = src.limit();
+        src.clear();
+        copy.put(src).position(p).limit(l);
+        src.position(p).limit(l);
+        return copy;
+    }
+}
--- a/test/jdk/java/net/httpclient/websocket/WebSocketTest.java	Wed Mar 14 09:01:15 2018 +0000
+++ b/test/jdk/java/net/httpclient/websocket/WebSocketTest.java	Wed Mar 14 13:03:11 2018 +0000
@@ -26,6 +26,7 @@
  * @build DummyWebSocketServer
  * @run testng/othervm WebSocketTest
  */
+import org.testng.annotations.DataProvider;
 import org.testng.annotations.Test;
 
 import java.io.IOException;
@@ -156,6 +157,11 @@
         for (int i = 0; i < data.length; i++) {
             copy[i] = (byte) data[i];
         }
+        return serverWithCannedData(copy);
+    }
+
+    private static DummyWebSocketServer serverWithCannedData(byte... data) {
+        byte[] copy = Arrays.copyOf(data, data.length);
         return new DummyWebSocketServer() {
             @Override
             protected void serve(SocketChannel channel) throws IOException {
@@ -953,4 +959,41 @@
             assertEquals(a, expected);
         }
     }
+
+    @Test(dataProvider = "nPings")
+    public void swappingPongs(int nPings) throws Exception {
+        // big enough to not bother with resize
+        ByteBuffer buffer = ByteBuffer.allocate(16384);
+        Frame.HeaderWriter w = new Frame.HeaderWriter();
+        for (int i = 0; i < nPings; i++) {
+            w.fin(true)
+             .opcode(Frame.Opcode.PING)
+             .noMask()
+             .payloadLen(4)
+             .write(buffer);
+            buffer.putInt(i);
+        }
+        w.fin(true)
+         .opcode(Frame.Opcode.CLOSE)
+         .noMask()
+         .payloadLen(2)
+         .write(buffer);
+        buffer.putChar((char) 1000);
+        buffer.flip();
+        try (DummyWebSocketServer server = serverWithCannedData(buffer.array())) {
+            MockListener listener = new MockListener();
+            server.open();
+            WebSocket ws = newHttpClient()
+                    .newWebSocketBuilder()
+                    .buildAsync(server.getURI(), listener)
+                    .join();
+            List<MockListener.Invocation> inv = listener.invocations();
+            assertEquals(inv.size(), nPings + 2); // onOpen + onClose + n*onPing
+        }
+    }
+
+    @DataProvider(name = "nPings")
+    public Object[][] nPings() {
+        return new Object[][]{{1}, {2}, {4}, {8}, {9}, {1023}};
+    }
 }
--- a/test/jdk/java/net/httpclient/websocket/java.net.http/jdk/internal/net/http/websocket/MessageQueueTest.java	Wed Mar 14 09:01:15 2018 +0000
+++ b/test/jdk/java/net/httpclient/websocket/java.net.http/jdk/internal/net/http/websocket/MessageQueueTest.java	Wed Mar 14 13:03:11 2018 +0000
@@ -43,6 +43,7 @@
 import java.util.concurrent.Executors;
 import java.util.concurrent.Future;
 import java.util.function.BiConsumer;
+import java.util.function.Supplier;
 
 import static org.testng.Assert.assertEquals;
 import static org.testng.Assert.assertFalse;
@@ -97,6 +98,7 @@
     private Message createRandomMessage() {
         Message.Type[] values = Message.Type.values();
         Message.Type type = values[r.nextInt(values.length)];
+        Supplier<? extends ByteBuffer> binarySupplier = null;
         ByteBuffer binary = null;
         CharBuffer text = null;
         boolean isLast = false;
@@ -114,7 +116,11 @@
                 binary = ByteBuffer.allocate(r.nextInt(19));
                 break;
             case PONG:
-                binary = ByteBuffer.allocate(r.nextInt(19));
+                if (r.nextBoolean()) {
+                    binary = ByteBuffer.allocate(r.nextInt(19));
+                } else {
+                    binarySupplier = () -> ByteBuffer.allocate(r.nextInt(19));
+                }
                 break;
             case CLOSE:
                 text = CharBuffer.allocate(r.nextInt(17));
@@ -128,7 +134,7 @@
             public void accept(Integer o, Throwable throwable) { }
         };
         CompletableFuture<Integer> future = new CompletableFuture<>();
-        return new Message(type, binary, text, isLast, statusCode, r.nextInt(),
+        return new Message(type, binarySupplier, binary, text, isLast, statusCode, r.nextInt(),
                            action, future);
     }
 
@@ -299,7 +305,11 @@
                     q.addPing(m.binary, m.attachment, m.action, m.future);
                     break;
                 case PONG:
-                    q.addPong(m.binary, m.attachment, m.action, m.future);
+                    if (m.binarySupplier != null) {
+                        q.addPong(m.binarySupplier, m.attachment, m.action, m.future);
+                    } else {
+                        q.addPong(m.binary, m.attachment, m.action, m.future);
+                    }
                     break;
                 case CLOSE:
                     q.addClose(m.statusCode, m.text, m.attachment, m.action, m.future);
@@ -325,7 +335,7 @@
                                           CompletableFuture<? super T> future) {
                     assertFalse(called);
                     called = true;
-                    return new Message(Message.Type.TEXT, null, message, isLast,
+                    return new Message(Message.Type.TEXT, null, null, message, isLast,
                                        -1, attachment, action, future);
                 }
 
@@ -337,7 +347,7 @@
                                             CompletableFuture<? super T> future) {
                     assertFalse(called);
                     called = true;
-                    return new Message(Message.Type.BINARY, message, null, isLast,
+                    return new Message(Message.Type.BINARY, null, message, null, isLast,
                                        -1, attachment, action, future);
                 }
 
@@ -348,7 +358,7 @@
                                           CompletableFuture<? super T> future) {
                     assertFalse(called);
                     called = true;
-                    return new Message(Message.Type.PING, message, null, false,
+                    return new Message(Message.Type.PING, null, message, null, false,
                                        -1, attachment, action, future);
                 }
 
@@ -359,7 +369,18 @@
                                           CompletableFuture<? super T> future) {
                     assertFalse(called);
                     called = true;
-                    return new Message(Message.Type.PONG, message, null, false,
+                    return new Message(Message.Type.PONG, null, message, null, false,
+                                       -1, attachment, action, future);
+                }
+
+                @Override
+                public <T> Message onPong(Supplier<? extends ByteBuffer> message,
+                                          T attachment,
+                                          BiConsumer<? super T, ? super Throwable> action,
+                                          CompletableFuture<? super T> future) {
+                    assertFalse(called);
+                    called = true;
+                    return new Message(Message.Type.PONG, message, null, null, false,
                                        -1, attachment, action, future);
                 }
 
@@ -371,7 +392,7 @@
                                            CompletableFuture<? super T> future) {
                     assertFalse(called);
                     called = true;
-                    return new Message(Message.Type.CLOSE, null, reason, false,
+                    return new Message(Message.Type.CLOSE, null, null, reason, false,
                                        statusCode, attachment, action, future);
                 }
 
@@ -390,6 +411,7 @@
     static class Message {
 
         private final Type type;
+        private final Supplier<? extends ByteBuffer> binarySupplier;
         private final ByteBuffer binary;
         private final CharBuffer text;
         private final boolean isLast;
@@ -401,6 +423,7 @@
         private final CompletableFuture future;
 
         <T> Message(Type type,
+                    Supplier<? extends ByteBuffer> binarySupplier,
                     ByteBuffer binary,
                     CharBuffer text,
                     boolean isLast,
@@ -409,6 +432,7 @@
                     BiConsumer<? super T, ? super Throwable> action,
                     CompletableFuture<? super T> future) {
             this.type = type;
+            this.binarySupplier = binarySupplier;
             this.binary = binary;
             this.text = text;
             this.isLast = isLast;
@@ -420,7 +444,7 @@
 
         @Override
         public int hashCode() {
-            return Objects.hash(type, binary, text, isLast, statusCode, attachment, action, future);
+            return Objects.hash(type, binarySupplier, binary, text, isLast, statusCode, attachment, action, future);
         }
 
         @Override
@@ -431,6 +455,7 @@
             return isLast == message.isLast &&
                     statusCode == message.statusCode &&
                     type == message.type &&
+                    Objects.equals(binarySupplier, message.binarySupplier) &&
                     Objects.equals(binary, message.binary) &&
                     Objects.equals(text, message.text) &&
                     Objects.equals(attachment, message.attachment) &&