# HG changeset patch # User prappo # Date 1520442988 0 # Node ID 4933a477d628659c5a616489238baa684b261187 # Parent d818a6a8295a8ff1b5cae18df54f8ee6315f533b http-client-branch: (WebSocket) impl change diff -r d818a6a8295a -r 4933a477d628 src/java.net.http/share/classes/java/net/http/WebSocket.java --- a/src/java.net.http/share/classes/java/net/http/WebSocket.java Wed Mar 07 15:39:25 2018 +0000 +++ b/src/java.net.http/share/classes/java/net/http/WebSocket.java Wed Mar 07 17:16:28 2018 +0000 @@ -508,8 +508,6 @@ *

A {@code CompletableFuture} returned from this method can * complete exceptionally with: *

* - * @implNote If a partial UTF-16 sequence is passed to this method, a - * {@code CompletableFuture} returned will complete exceptionally with - * {@code IOException}. + * @implNote If a partial or malformed UTF-16 sequence is passed to this + * method, a {@code CompletableFuture} returned will complete exceptionally + * with {@code IOException}. * * @param message * the message @@ -632,7 +630,7 @@ * complete exceptionally with: * diff -r d818a6a8295a -r 4933a477d628 src/java.net.http/share/classes/jdk/internal/net/http/common/Utils.java --- a/src/java.net.http/share/classes/jdk/internal/net/http/common/Utils.java Wed Mar 07 15:39:25 2018 +0000 +++ b/src/java.net.http/share/classes/jdk/internal/net/http/common/Utils.java Wed Mar 07 17:16:28 2018 +0000 @@ -373,7 +373,7 @@ NetProperties.getInteger(name, defaultValue)); } - static String getNetProperty(String name) { + public static String getNetProperty(String name) { return AccessController.doPrivileged((PrivilegedAction) () -> NetProperties.get(name)); } diff -r d818a6a8295a -r 4933a477d628 src/java.net.http/share/classes/jdk/internal/net/http/websocket/Frame.java --- a/src/java.net.http/share/classes/jdk/internal/net/http/websocket/Frame.java Wed Mar 07 15:39:25 2018 +0000 +++ b/src/java.net.http/share/classes/jdk/internal/net/http/websocket/Frame.java Wed Mar 07 17:16:28 2018 +0000 @@ -40,6 +40,7 @@ private Frame() { } static final int MAX_HEADER_SIZE_BYTES = 2 + 8 + 4; + static final int MAX_CONTROL_FRAME_PAYLOAD_LENGTH = 125; enum Opcode { diff -r d818a6a8295a -r 4933a477d628 src/java.net.http/share/classes/jdk/internal/net/http/websocket/FrameConsumer.java --- a/src/java.net.http/share/classes/jdk/internal/net/http/websocket/FrameConsumer.java Wed Mar 07 15:39:25 2018 +0000 +++ /dev/null Thu Jan 01 00:00:00 1970 +0000 @@ -1,288 +0,0 @@ -/* - * Copyright (c) 2015, 2018, Oracle and/or its affiliates. All rights reserved. - * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. - * - * This code is free software; you can redistribute it and/or modify it - * under the terms of the GNU General Public License version 2 only, as - * published by the Free Software Foundation. Oracle designates this - * particular file as subject to the "Classpath" exception as provided - * by Oracle in the LICENSE file that accompanied this code. - * - * This code is distributed in the hope that it will be useful, but WITHOUT - * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or - * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License - * version 2 for more details (a copy is included in the LICENSE file that - * accompanied this code). - * - * You should have received a copy of the GNU General Public License version - * 2 along with this work; if not, write to the Free Software Foundation, - * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. - * - * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA - * or visit www.oracle.com if you need additional information or have any - * questions. - */ - -package jdk.internal.net.http.websocket; - -import java.net.http.WebSocket.MessagePart; -import jdk.internal.net.http.websocket.Frame.Opcode; - -import java.nio.ByteBuffer; -import java.nio.CharBuffer; -import java.nio.charset.CharacterCodingException; - -import static java.lang.String.format; -import static java.nio.charset.StandardCharsets.UTF_8; -import static java.util.Objects.requireNonNull; -import static jdk.internal.net.http.common.Utils.dump; -import static jdk.internal.net.http.websocket.StatusCodes.NO_STATUS_CODE; -import static jdk.internal.net.http.websocket.StatusCodes.isLegalToReceiveFromServer; - -/* - * Consumes frame parts and notifies a message consumer, when there is - * sufficient data to produce a message, or part thereof. - * - * Data consumed but not yet translated is accumulated until it's sufficient to - * form a message. - */ -/* Non-final for testing purposes only */ -class FrameConsumer implements Frame.Consumer { - - private final static boolean DEBUG = false; - private final MessageStreamConsumer output; - private final UTF8AccumulatingDecoder decoder = new UTF8AccumulatingDecoder(); - private boolean fin; - private Opcode opcode, originatingOpcode; - private MessagePart part = MessagePart.WHOLE; - private long payloadLen; - private long unconsumedPayloadLen; - private ByteBuffer binaryData; - - FrameConsumer(MessageStreamConsumer output) { - this.output = requireNonNull(output); - } - - /* Exposed for testing purposes only */ - MessageStreamConsumer getOutput() { - return output; - } - - @Override - public void fin(boolean value) { - if (DEBUG) { - System.out.printf("Reading fin: %s%n", value); - } - fin = value; - } - - @Override - public void rsv1(boolean value) { - if (DEBUG) { - System.out.printf("Reading rsv1: %s%n", value); - } - if (value) { - throw new FailWebSocketException("Unexpected rsv1 bit"); - } - } - - @Override - public void rsv2(boolean value) { - if (DEBUG) { - System.out.printf("Reading rsv2: %s%n", value); - } - if (value) { - throw new FailWebSocketException("Unexpected rsv2 bit"); - } - } - - @Override - public void rsv3(boolean value) { - if (DEBUG) { - System.out.printf("Reading rsv3: %s%n", value); - } - if (value) { - throw new FailWebSocketException("Unexpected rsv3 bit"); - } - } - - @Override - public void opcode(Opcode v) { - if (DEBUG) { - System.out.printf("Reading opcode: %s%n", v); - } - if (v == Opcode.PING || v == Opcode.PONG || v == Opcode.CLOSE) { - if (!fin) { - throw new FailWebSocketException("Fragmented control frame " + v); - } - opcode = v; - } else if (v == Opcode.TEXT || v == Opcode.BINARY) { - if (originatingOpcode != null) { - throw new FailWebSocketException( - format("Unexpected frame %s (fin=%s)", v, fin)); - } - opcode = v; - if (!fin) { - originatingOpcode = v; - } - } else if (v == Opcode.CONTINUATION) { - if (originatingOpcode == null) { - throw new FailWebSocketException( - format("Unexpected frame %s (fin=%s)", v, fin)); - } - opcode = v; - } else { - throw new FailWebSocketException("Unexpected opcode " + v); - } - } - - @Override - public void mask(boolean value) { - if (DEBUG) { - System.out.printf("Reading mask: %s%n", value); - } - if (value) { - throw new FailWebSocketException("Masked frame received"); - } - } - - @Override - public void payloadLen(long value) { - if (DEBUG) { - System.out.printf("Reading payloadLen: %s%n", value); - } - if (opcode.isControl()) { - if (value > 125) { - throw new FailWebSocketException( - format("%s's payload length %s", opcode, value)); - } - assert Opcode.CLOSE.isControl(); - if (opcode == Opcode.CLOSE && value == 1) { - throw new FailWebSocketException("Incomplete status code"); - } - } - payloadLen = value; - unconsumedPayloadLen = value; - } - - @Override - public void maskingKey(int value) { - // `FrameConsumer.mask(boolean)` is where a masked frame is detected and - // reported on; `FrameConsumer.mask(boolean)` MUST be invoked before - // this method; - // So this method (`maskingKey`) is not supposed to be invoked while - // reading a frame that has came from the server. If this method is - // invoked, then it's an error in implementation, thus InternalError - throw new InternalError(); - } - - @Override - public void payloadData(ByteBuffer data) { - if (DEBUG) { - System.out.printf("Reading payloadData: %s%n", data); - } - unconsumedPayloadLen -= data.remaining(); - boolean isLast = unconsumedPayloadLen == 0; - if (opcode.isControl()) { - if (binaryData != null) { // An intermediate or the last chunk - binaryData.put(data); - } else if (!isLast) { // The first chunk - int remaining = data.remaining(); - // It shouldn't be 125, otherwise the next chunk will be of size - // 0, which is not what Reader promises to deliver (eager - // reading) - assert remaining < 125 : dump(remaining); - binaryData = ByteBuffer.allocate(125).put(data); - } else { // The only chunk - binaryData = ByteBuffer.allocate(data.remaining()).put(data); - } - } else { - part = determinePart(isLast); - boolean text = opcode == Opcode.TEXT || originatingOpcode == Opcode.TEXT; - if (!text) { - output.onBinary(data.slice(), part); - data.position(data.limit()); // Consume - } else { - boolean binaryNonEmpty = data.hasRemaining(); - CharBuffer textData; - try { - textData = decoder.decode(data, part == MessagePart.WHOLE || part == MessagePart.LAST); - } catch (CharacterCodingException e) { - throw new FailWebSocketException( - "Invalid UTF-8 in frame " + opcode, StatusCodes.NOT_CONSISTENT) - .initCause(e); - } - if (!(binaryNonEmpty && !textData.hasRemaining())) { - // If there's a binary data, that result in no text, then we - // don't deliver anything - output.onText(textData, part); - } - } - } - } - - @Override - public void endFrame() { - if (DEBUG) { - System.out.println("End frame"); - } - if (opcode.isControl()) { - binaryData.flip(); - } - switch (opcode) { - case CLOSE: - char statusCode = NO_STATUS_CODE; - String reason = ""; - if (payloadLen != 0) { - int len = binaryData.remaining(); - assert 2 <= len && len <= 125 : dump(len, payloadLen); - statusCode = binaryData.getChar(); - if (!isLegalToReceiveFromServer(statusCode)) { - throw new FailWebSocketException( - "Illegal status code: " + statusCode); - } - try { - reason = UTF_8.newDecoder().decode(binaryData).toString(); - } catch (CharacterCodingException e) { - throw new FailWebSocketException("Illegal close reason") - .initCause(e); - } - } - output.onClose(statusCode, reason); - break; - case PING: - output.onPing(binaryData); - binaryData = null; - break; - case PONG: - output.onPong(binaryData); - binaryData = null; - break; - default: - assert opcode == Opcode.TEXT || opcode == Opcode.BINARY - || opcode == Opcode.CONTINUATION : dump(opcode); - if (fin) { - // It is always the last chunk: - // either TEXT(FIN=TRUE)/BINARY(FIN=TRUE) or CONT(FIN=TRUE) - originatingOpcode = null; - } - break; - } - payloadLen = 0; - opcode = null; - } - - private MessagePart determinePart(boolean isLast) { - boolean lastChunk = fin && isLast; - switch (part) { - case LAST: - case WHOLE: - return lastChunk ? MessagePart.WHOLE : MessagePart.FIRST; - case FIRST: - case PART: - return lastChunk ? MessagePart.LAST : MessagePart.PART; - default: - throw new InternalError(String.valueOf(part)); - } - } -} diff -r d818a6a8295a -r 4933a477d628 src/java.net.http/share/classes/jdk/internal/net/http/websocket/MessageDecoder.java --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/src/java.net.http/share/classes/jdk/internal/net/http/websocket/MessageDecoder.java Wed Mar 07 17:16:28 2018 +0000 @@ -0,0 +1,294 @@ +/* + * Copyright (c) 2015, 2018, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. Oracle designates this + * particular file as subject to the "Classpath" exception as provided + * by Oracle in the LICENSE file that accompanied this code. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ + +package jdk.internal.net.http.websocket; + +import jdk.internal.net.http.websocket.Frame.Opcode; + +import java.net.http.WebSocket.MessagePart; +import java.nio.ByteBuffer; +import java.nio.CharBuffer; +import java.nio.charset.CharacterCodingException; + +import static java.lang.String.format; +import static java.nio.charset.StandardCharsets.UTF_8; +import static java.util.Objects.requireNonNull; +import static jdk.internal.net.http.common.Utils.dump; +import static jdk.internal.net.http.websocket.StatusCodes.NO_STATUS_CODE; +import static jdk.internal.net.http.websocket.StatusCodes.isLegalToReceiveFromServer; + +/* + * Consumes frame parts and notifies a message consumer, when there is + * sufficient data to produce a message, or part thereof. + * + * Data consumed but not yet translated is accumulated until it's sufficient to + * form a message. + */ +/* Non-final for testing purposes only */ +class MessageDecoder implements Frame.Consumer { + + private final static boolean DEBUG = false; + private final MessageStreamConsumer output; + private final UTF8AccumulatingDecoder decoder = new UTF8AccumulatingDecoder(); + private boolean fin; + private Opcode opcode, originatingOpcode; + private MessagePart part = MessagePart.WHOLE; + private long payloadLen; + private long unconsumedPayloadLen; + private ByteBuffer binaryData; + + MessageDecoder(MessageStreamConsumer output) { + this.output = requireNonNull(output); + } + + /* Exposed for testing purposes only */ + MessageStreamConsumer getOutput() { + return output; + } + + @Override + public void fin(boolean value) { + if (DEBUG) { + System.out.printf("[Input] fin %s%n", value); + } + fin = value; + } + + @Override + public void rsv1(boolean value) { + if (DEBUG) { + System.out.printf("[Input] rsv1 %s%n", value); + } + if (value) { + throw new FailWebSocketException("Unexpected rsv1 bit"); + } + } + + @Override + public void rsv2(boolean value) { + if (DEBUG) { + System.out.printf("[Input] rsv2 %s%n", value); + } + if (value) { + throw new FailWebSocketException("Unexpected rsv2 bit"); + } + } + + @Override + public void rsv3(boolean value) { + if (DEBUG) { + System.out.printf("[Input] rsv3 %s%n", value); + } + if (value) { + throw new FailWebSocketException("Unexpected rsv3 bit"); + } + } + + @Override + public void opcode(Opcode v) { + if (DEBUG) { + System.out.printf("[Input] opcode %s%n", v); + } + if (v == Opcode.PING || v == Opcode.PONG || v == Opcode.CLOSE) { + if (!fin) { + throw new FailWebSocketException("Fragmented control frame " + v); + } + opcode = v; + } else if (v == Opcode.TEXT || v == Opcode.BINARY) { + if (originatingOpcode != null) { + throw new FailWebSocketException( + format("Unexpected frame %s (fin=%s)", v, fin)); + } + opcode = v; + if (!fin) { + originatingOpcode = v; + } + } else if (v == Opcode.CONTINUATION) { + if (originatingOpcode == null) { + throw new FailWebSocketException( + format("Unexpected frame %s (fin=%s)", v, fin)); + } + opcode = v; + } else { + throw new FailWebSocketException("Unexpected opcode " + v); + } + } + + @Override + public void mask(boolean value) { + if (DEBUG) { + System.out.printf("[Input] mask %s%n", value); + } + if (value) { + throw new FailWebSocketException("Masked frame received"); + } + } + + @Override + public void payloadLen(long value) { + if (DEBUG) { + System.out.printf("[Input] payloadLen %s%n", value); + } + if (opcode.isControl()) { + if (value > Frame.MAX_CONTROL_FRAME_PAYLOAD_LENGTH) { + throw new FailWebSocketException( + format("%s's payload length %s", opcode, value)); + } + assert Opcode.CLOSE.isControl(); + if (opcode == Opcode.CLOSE && value == 1) { + throw new FailWebSocketException("Incomplete status code"); + } + } + payloadLen = value; + unconsumedPayloadLen = value; + } + + @Override + public void maskingKey(int value) { + // `MessageDecoder.mask(boolean)` is where a masked frame is detected and + // reported on; `MessageDecoder.mask(boolean)` MUST be invoked before + // this method; + // So this method (`maskingKey`) is not supposed to be invoked while + // reading a frame that has came from the server. If this method is + // invoked, then it's an error in implementation, thus InternalError + throw new InternalError(); + } + + @Override + public void payloadData(ByteBuffer data) { + if (DEBUG) { + System.out.printf("[Input] payload %s%n", data); + } + unconsumedPayloadLen -= data.remaining(); + boolean isLast = unconsumedPayloadLen == 0; + if (opcode.isControl()) { + if (binaryData != null) { // An intermediate or the last chunk + binaryData.put(data); + } else if (!isLast) { // The first chunk + int remaining = data.remaining(); + // It shouldn't be 125, otherwise the next chunk will be of size + // 0, which is not what Reader promises to deliver (eager + // reading) + assert remaining < Frame.MAX_CONTROL_FRAME_PAYLOAD_LENGTH + : dump(remaining); + binaryData = ByteBuffer.allocate( + Frame.MAX_CONTROL_FRAME_PAYLOAD_LENGTH).put(data); + } else { // The only chunk + binaryData = ByteBuffer.allocate(data.remaining()).put(data); + } + } else { + part = determinePart(isLast); + boolean text = opcode == Opcode.TEXT || originatingOpcode == Opcode.TEXT; + if (!text) { + output.onBinary(data.slice(), part); + data.position(data.limit()); // Consume + } else { + boolean binaryNonEmpty = data.hasRemaining(); + CharBuffer textData; + try { + boolean eof = part == MessagePart.WHOLE + || part == MessagePart.LAST; + textData = decoder.decode(data, eof); + } catch (CharacterCodingException e) { + throw new FailWebSocketException( + "Invalid UTF-8 in frame " + opcode, + StatusCodes.NOT_CONSISTENT).initCause(e); + } + if (!(binaryNonEmpty && !textData.hasRemaining())) { + // If there's a binary data, that result in no text, then we + // don't deliver anything + output.onText(textData, part); + } + } + } + } + + @Override + public void endFrame() { + if (DEBUG) { + System.out.println("[Input] end frame"); + } + if (opcode.isControl()) { + binaryData.flip(); + } + switch (opcode) { + case CLOSE: + char statusCode = NO_STATUS_CODE; + String reason = ""; + if (payloadLen != 0) { + int len = binaryData.remaining(); + assert 2 <= len + && len <= Frame.MAX_CONTROL_FRAME_PAYLOAD_LENGTH + : dump(len, payloadLen); + statusCode = binaryData.getChar(); + if (!isLegalToReceiveFromServer(statusCode)) { + throw new FailWebSocketException( + "Illegal status code: " + statusCode); + } + try { + reason = UTF_8.newDecoder().decode(binaryData).toString(); + } catch (CharacterCodingException e) { + throw new FailWebSocketException("Illegal close reason") + .initCause(e); + } + } + output.onClose(statusCode, reason); + break; + case PING: + output.onPing(binaryData); + binaryData = null; + break; + case PONG: + output.onPong(binaryData); + binaryData = null; + break; + default: + assert opcode == Opcode.TEXT || opcode == Opcode.BINARY + || opcode == Opcode.CONTINUATION : dump(opcode); + if (fin) { + // It is always the last chunk: + // either TEXT(FIN=TRUE)/BINARY(FIN=TRUE) or CONT(FIN=TRUE) + originatingOpcode = null; + } + break; + } + payloadLen = 0; + opcode = null; + } + + private MessagePart determinePart(boolean isLast) { + boolean lastChunk = fin && isLast; + switch (part) { + case LAST: + case WHOLE: + return lastChunk ? MessagePart.WHOLE : MessagePart.FIRST; + case FIRST: + case PART: + return lastChunk ? MessagePart.LAST : MessagePart.PART; + default: + throw new InternalError(String.valueOf(part)); + } + } +} diff -r d818a6a8295a -r 4933a477d628 src/java.net.http/share/classes/jdk/internal/net/http/websocket/MessageEncoder.java --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/src/java.net.http/share/classes/jdk/internal/net/http/websocket/MessageEncoder.java Wed Mar 07 17:16:28 2018 +0000 @@ -0,0 +1,416 @@ +/* + * Copyright (c) 2018, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. Oracle designates this + * particular file as subject to the "Classpath" exception as provided + * by Oracle in the LICENSE file that accompanied this code. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ + +package jdk.internal.net.http.websocket; + +import jdk.internal.net.http.common.Utils; +import jdk.internal.net.http.websocket.Frame.Opcode; + +import java.io.IOException; +import java.nio.ByteBuffer; +import java.nio.CharBuffer; +import java.nio.charset.CharacterCodingException; +import java.nio.charset.CharsetEncoder; +import java.nio.charset.CoderResult; +import java.nio.charset.CodingErrorAction; +import java.nio.charset.StandardCharsets; +import java.security.SecureRandom; + +/* + * A stateful producer of binary representations of WebSocket messages being + * sent from the client to the server. + * + * An encoding methods are given original messages and byte buffers to put the + * resulting bytes to. + * + * The method is called + * repeatedly with a non-empty target buffer. Once the caller finds the buffer + * unmodified after the call returns, the message has been completely encoded. + */ + +/* + * The state of encoding.An instance of this class is passed sequentially between messages, so + * every message in a sequence can check the context it is in and update it + * if necessary. + */ + +public class MessageEncoder { + + // FIXME: write frame method + + private final static boolean DEBUG = false; + + private final SecureRandom maskingKeySource = new SecureRandom(); + private final Frame.HeaderWriter headerWriter = new Frame.HeaderWriter(); + private final Frame.Masker payloadMasker = new Frame.Masker(); + private final CharsetEncoder charsetEncoder + = StandardCharsets.UTF_8.newEncoder() + .onMalformedInput(CodingErrorAction.REPORT) + .onUnmappableCharacter(CodingErrorAction.REPORT); + /* + * This buffer is used both to encode characters to UTF-8 and to calculate + * the length of the resulting frame's payload. The length of the payload + * must be known before the frame's header can be written. + * For implementation reasons, this buffer must have a capacity of at least + * the maximum size of a Close frame payload, which is 125 bytes + * (or Frame.MAX_CONTROL_FRAME_PAYLOAD_LENGTH). + */ + private final ByteBuffer intermediateBuffer = createIntermediateBuffer( + Frame.MAX_CONTROL_FRAME_PAYLOAD_LENGTH); + private final ByteBuffer headerBuffer = ByteBuffer.allocate( + Frame.MAX_HEADER_SIZE_BYTES); + + private boolean started; + private boolean flushing; + private boolean moreText = true; + private long headerCount; + private boolean previousLast = true; + private boolean previousText; + private boolean closed; + + /* + * How many bytes of the current message have been already encoded. + * + * Even though the user hands their buffers over to us, they still can + * manipulate these buffers while we are getting data out of them. + * The number of produced bytes guards us from such behaviour in the + * case of messages that must be restricted in size (Ping, Pong and Close). + * For other messages this measure provides a best-effort attempt to detect + * concurrent changes to buffer. + * + * Making a shallow copy (duplicate/slice) and then checking the size + * precondition on it would also solve the problem, but at the cost of this + * extra copy. + */ + private int actualLen; + + /* + * How many bytes were originally there in the message, before the encoding + * started. + */ + private int expectedLen; + + /* Exposed for testing purposes */ + protected ByteBuffer createIntermediateBuffer(int minSize) { + int capacity = Utils.getIntegerNetProperty( + "jdk.httpclient.websocket.intermediateBufferSize", 16384); + return ByteBuffer.allocate(Math.max(minSize, capacity)); + } + + public void reset() { + // Do not reset the message stream state fields, e.g. previousLast, + // previousText. Just an individual message state: + started = false; + flushing = false; + moreText = true; + headerCount = 0; + actualLen = 0; + } + + /* + * Encodes text messages by cutting them into fragments of maximum size of + * intermediateBuffer.capacity() + */ + public boolean encodeText(CharBuffer src, boolean last, ByteBuffer dst) + throws IOException + { + if (DEBUG) { + System.out.printf("[Output] encodeText src.remaining()=%s, %s, %s%n", + src.remaining(), last, dst); + } + if (closed) { + throw new IOException("Output closed"); + } + if (!started) { + if (!previousText && !previousLast) { + // Previous data message was a partial binary message + throw new IllegalStateException("Unexpected text message"); + } + started = true; + headerBuffer.position(0).limit(0); + intermediateBuffer.position(0).limit(0); + charsetEncoder.reset(); + } + while (true) { + if (DEBUG) { + System.out.printf("[Output] put%n"); + } + if (!putAvailable(headerBuffer, dst)) { + return false; + } + if (DEBUG) { + System.out.printf("[Output] mask%n"); + } + if (maskAvailable(intermediateBuffer, dst) < 0) { + return false; + } + if (DEBUG) { + System.out.printf("[Output] moreText%n"); + } + if (!moreText) { + return true; + } + intermediateBuffer.clear(); + CoderResult r = null; + if (!flushing) { + r = charsetEncoder.encode(src, intermediateBuffer, true); + if (r.isUnderflow()) { + flushing = true; + } + } + if (flushing) { + r = charsetEncoder.flush(intermediateBuffer); + if (r.isUnderflow()) { + moreText = false; + } + } + if (r.isError()) { + try { + r.throwException(); + } catch (CharacterCodingException e) { + throw new IOException("Malformed text message", e); + } + } + if (DEBUG) { + System.out.printf("[Output] header #%s%n", headerCount); + } + if (headerCount == 0) { // set once + previousLast = last; + previousText = true; + } + intermediateBuffer.flip(); + headerBuffer.clear(); + int mask = maskingKeySource.nextInt(); + Opcode opcode = previousLast && headerCount == 0 + ? Opcode.TEXT : Opcode.CONTINUATION; + if (DEBUG) { + System.out.printf("[Output] opcode %s%n", opcode); + } + headerWriter.fin(last && !moreText) + .opcode(opcode) + .payloadLen(intermediateBuffer.remaining()) + .mask(mask) + .write(headerBuffer); + headerBuffer.flip(); + headerCount++; + payloadMasker.mask(mask); + } + } + + private boolean putAvailable(ByteBuffer src, ByteBuffer dst) { + int available = dst.remaining(); + if (available >= src.remaining()) { + dst.put(src); + return true; + } else { + int lim = src.limit(); // save the limit + src.limit(src.position() + available); + dst.put(src); + src.limit(lim); // restore the limit + return false; + } + } + + public boolean encodeBinary(ByteBuffer src, boolean last, ByteBuffer dst) + throws IOException + { + if (DEBUG) { + System.out.printf("[Output] encodeBinary %s, %s, %s%n", + src, last, dst); + } + if (closed) { + throw new IOException("Output closed"); + } + if (!started) { + if (previousText && !previousLast) { + // Previous data message was a partial text message + throw new IllegalStateException("Unexpected binary message"); + } + expectedLen = src.remaining(); + int mask = maskingKeySource.nextInt(); + headerBuffer.clear(); + headerWriter.fin(last) + .opcode(previousLast ? Opcode.BINARY : Opcode.CONTINUATION) + .payloadLen(expectedLen) + .mask(mask) + .write(headerBuffer); + headerBuffer.flip(); + payloadMasker.mask(mask); + previousLast = last; + previousText = false; + started = true; + } + if (!putAvailable(headerBuffer, dst)) { + return false; + } + int count = maskAvailable(src, dst); + actualLen += Math.abs(count); + if (count >= 0 && actualLen != expectedLen) { + throw new IOException("Concurrent message modification"); + } + return count >= 0; + } + + private int maskAvailable(ByteBuffer src, ByteBuffer dst) { + int r0 = dst.remaining(); + payloadMasker.transferMasking(src, dst); + int masked = r0 - dst.remaining(); + return src.hasRemaining() ? -masked : masked; + } + + public boolean encodePing(ByteBuffer src, ByteBuffer dst) + throws IOException + { + if (closed) { + throw new IOException("Output closed"); + } + if (DEBUG) System.out.printf("[Output] encodePing %s, %s%n", src, dst); + if (!started) { + expectedLen = src.remaining(); + if (expectedLen > Frame.MAX_CONTROL_FRAME_PAYLOAD_LENGTH) { + throw new IllegalArgumentException("Long message: " + expectedLen); + } + int mask = maskingKeySource.nextInt(); + headerBuffer.clear(); + headerWriter.fin(true) + .opcode(Opcode.PING) + .payloadLen(expectedLen) + .mask(mask) + .write(headerBuffer); + headerBuffer.flip(); + payloadMasker.mask(mask); + started = true; + } + if (!putAvailable(headerBuffer, dst)) { + return false; + } + int count = maskAvailable(src, dst); + actualLen += Math.abs(count); + if (count >= 0 && actualLen != expectedLen) { + throw new IOException("Concurrent message modification"); + } + return count >= 0; + } + + public boolean encodePong(ByteBuffer src, ByteBuffer dst) + throws IOException + { + if (closed) { + throw new IOException("Output closed"); + } + if (DEBUG) { + System.out.printf("[Output] encodePong %s, %s%n", + src, dst); + } + if (!started) { + expectedLen = src.remaining(); + if (expectedLen > Frame.MAX_CONTROL_FRAME_PAYLOAD_LENGTH) { + throw new IllegalArgumentException("Long message: " + expectedLen); + } + int mask = maskingKeySource.nextInt(); + headerBuffer.clear(); + headerWriter.fin(true) + .opcode(Opcode.PONG) + .payloadLen(expectedLen) + .mask(mask) + .write(headerBuffer); + headerBuffer.flip(); + payloadMasker.mask(mask); + started = true; + } + if (!putAvailable(headerBuffer, dst)) { + return false; + } + int count = maskAvailable(src, dst); + actualLen += Math.abs(count); + if (count >= 0 && actualLen != expectedLen) { + throw new IOException("Concurrent message modification"); + } + return count >= 0; + } + + public boolean encodeClose(int statusCode, CharBuffer reason, ByteBuffer dst) + throws IOException + { + if (DEBUG) { + System.out.printf("[Output] encodeClose %s, reason.length=%s, %s%n", + statusCode, reason.length(), dst); + } + if (closed) { + throw new IOException("Output closed"); + } + if (!started) { + if (DEBUG) { + System.out.printf("[Output] reason size %s%n", reason.remaining()); + } + intermediateBuffer.position(0).limit(Frame.MAX_CONTROL_FRAME_PAYLOAD_LENGTH); + intermediateBuffer.putChar((char) statusCode); + CoderResult r = charsetEncoder.reset().encode(reason, intermediateBuffer, true); + if (r.isUnderflow()) { + if (DEBUG) { + System.out.printf("[Output] flushing%n"); + } + r = charsetEncoder.flush(intermediateBuffer); + } + if (DEBUG) { + System.out.printf("[Output] encoding result: %s%n", r); + } + if (r.isError()) { + try { + r.throwException(); + } catch (CharacterCodingException e) { + throw new IllegalArgumentException("Malformed reason", e); + } + } else if (r.isOverflow()) { + // Here the 125 bytes size is ensured by the check for overflow + throw new IllegalArgumentException("Long reason"); + } else if (!r.isUnderflow()) { + throw new InternalError(); // assertion + } + intermediateBuffer.flip(); + headerBuffer.clear(); + int mask = maskingKeySource.nextInt(); + headerWriter.fin(true) + .opcode(Opcode.CLOSE) + .payloadLen(intermediateBuffer.remaining()) + .mask(mask) + .write(headerBuffer); + headerBuffer.flip(); + payloadMasker.mask(mask); + started = true; + closed = true; + if (DEBUG) { + System.out.printf("[Output] intermediateBuffer=%s%n", + intermediateBuffer); + } + } + if (!putAvailable(headerBuffer, dst)) { + return false; + } + return maskAvailable(intermediateBuffer, dst) >= 0; + } +} + + diff -r d818a6a8295a -r 4933a477d628 src/java.net.http/share/classes/jdk/internal/net/http/websocket/MessageQueue.java --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/src/java.net.http/share/classes/jdk/internal/net/http/websocket/MessageQueue.java Wed Mar 07 17:16:28 2018 +0000 @@ -0,0 +1,371 @@ +/* + * Copyright (c) 2018, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. Oracle designates this + * particular file as subject to the "Classpath" exception as provided + * by Oracle in the LICENSE file that accompanied this code. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ + +package jdk.internal.net.http.websocket; + +import jdk.internal.net.http.common.Utils; +import jdk.internal.vm.annotation.Stable; + +import java.io.IOException; +import java.nio.ByteBuffer; +import java.nio.CharBuffer; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.function.BiConsumer; + +/* + * A FIFO message storage facility. + * + * The queue supports at most one consumer and an arbitrary number of producers. + * Methods `peek`, `remove` and `isEmpty` must not be invoked concurrently. + * Methods `addText`, `addBinary`, `addPing`, `addPong` and `addClose` may be + * invoked concurrently. + * + * This queue is of a bounded size. The queue pre-allocates array of the said + * size and fills it with `Message` elements. The resulting structure never + * changes. This allows to avoid re-allocation and garbage collection of + * elements and arrays thereof. For this reason `Message` elements are never + * returned from the `peek` method. Instead their components passed to the + * provided callback. + * + * The queue consists of: + * + * - a ring array of n + 1 `Message` elements + * - indexes H and T denoting the head and the tail elements of the queue + * respectively + * + * Each `Message` element contains a boolean flag. This flag is an auxiliary + * communication between the producers and the consumer. The flag shows + * whether or not the element is ready to be consumed (peeked at, removed). The + * flag is required since updating an element involves many fields and thus is + * not an atomic action. An addition to the queue happens in two steps: + * + * # Step 1 + * + * Producers race with each other to secure an index for the element they add. + * T is atomically advanced [1] only if the advanced value doesn't equal to H + * (a producer doesn't bump into the head of the queue). + * + * # Step 2 + * + * Once T is advanced in the previous step, the producer updates the message + * fields of the element at the previous value of T and then sets the flag of + * this element. + * + * A removal happens in a single step. The consumer gets the element at index H. + * If the flag of this element is set, the consumer clears the fields of the + * element, clears the flag and finally advances H. + * + * ---------------------------------------------------------------------------- + * [1] To advance the index is to change it from i to (i + 1) % (n + 1). + */ +public class MessageQueue { + + private final static boolean DEBUG = false; + + @Stable + private final Message[] elements; + + private final AtomicInteger tail = new AtomicInteger(); + private volatile int head; + + public MessageQueue() { + this(defaultSize()); + } + + /* Exposed for testing */ + protected MessageQueue(int size) { + if (size < 1) { + throw new IllegalArgumentException(); + } + Message[] array = new Message[size + 1]; + for (int i = 0; i < array.length; i++) { + array[i] = new Message(); + } + elements = array; + } + + private static int defaultSize() { + String property = "jdk.httpclient.websocket.outputQueueMaxSize"; + int defaultSize = 128; + String value = Utils.getNetProperty(property); + int size; + if (value == null) { + size = defaultSize; + } else { + try { + size = Integer.parseUnsignedInt(value); + } catch (NumberFormatException ignored) { + size = defaultSize; + } + } + if (DEBUG) { + System.out.printf("[MessageQueue] %s=%s, using size %s%n", + property, value, size); + } + return size; + } + + public void addText(CharBuffer message, + boolean isLast, + T attachment, + BiConsumer action, + CompletableFuture future) + throws IOException + { + add(MessageQueue.Type.TEXT, null, message, isLast, -1, attachment, + action, future); + } + + private void add(Type type, + ByteBuffer binary, + CharBuffer text, + boolean isLast, + int statusCode, + T attachment, + BiConsumer action, + CompletableFuture future) + throws IOException + { + int h, currentTail, newTail; + do { + h = head; + currentTail = tail.get(); + newTail = (currentTail + 1) % elements.length; + if (newTail == h) { + throw new IOException("Queue full"); + } + } while (!tail.compareAndSet(currentTail, newTail)); + Message t = elements[currentTail]; + if (t.ready) { + throw new InternalError(); + } + t.type = type; + t.binary = binary; + t.text = text; + t.isLast = isLast; + t.statusCode = statusCode; + t.attachment = attachment; + t.action = action; + t.future = future; + t.ready = true; + } + + public void addBinary(ByteBuffer message, + boolean isLast, + T attachment, + BiConsumer action, + CompletableFuture future) + throws IOException + { + add(MessageQueue.Type.BINARY, message, null, isLast, -1, attachment, + action, future); + } + + public void addPing(ByteBuffer message, + T attachment, + BiConsumer action, + CompletableFuture future) + throws IOException + { + add(MessageQueue.Type.PING, message, null, false, -1, attachment, + action, future); + } + + public void addPong(ByteBuffer message, + T attachment, + BiConsumer action, + CompletableFuture future) + throws IOException + { + add(MessageQueue.Type.PONG, message, null, false, -1, attachment, + action, future); + } + + public void addClose(int statusCode, + CharBuffer reason, + T attachment, + BiConsumer action, + CompletableFuture future) + throws IOException + { + add(MessageQueue.Type.CLOSE, null, reason, false, statusCode, + attachment, action, future); + } + + @SuppressWarnings("unchecked") + public R peek(QueueCallback callback) + throws E + { + Message h = elements[head]; + if (!h.ready) { + return callback.onEmpty(); + } + Type type = h.type; + switch (type) { + case TEXT: + try { + return (R) callback.onText(h.text, h.isLast, h.attachment, + h.action, h.future); + } catch (Throwable t) { + // Something unpleasant is going on here with the compiler. + // If this seemingly useless catch is omitted, the compiler + // reports an error: + // + // java: unreported exception java.lang.Throwable; + // must be caught or declared to be thrown + // + // My guess is there is a problem with both the type + // inference for the method AND @SuppressWarnings("unchecked") + // being working at the same time. + throw (E) t; + } + case BINARY: + try { + return (R) callback.onBinary(h.binary, h.isLast, h.attachment, + h.action, h.future); + } catch (Throwable t) { + throw (E) t; + } + case PING: + try { + return (R) callback.onPing(h.binary, h.attachment, h.action, + h.future); + } catch (Throwable t) { + throw (E) t; + } + case PONG: + try { + return (R) callback.onPong(h.binary, h.attachment, h.action, + h.future); + } catch (Throwable t) { + throw (E) t; + } + case CLOSE: + try { + return (R) callback.onClose(h.statusCode, h.text, h.attachment, + h.action, h.future); + } catch (Throwable t) { + throw (E) t; + } + default: + throw new InternalError(String.valueOf(type)); + } + } + + public boolean isEmpty() { + return !elements[head].ready; + } + + public void remove() { + int currentHead = head; + Message h = elements[currentHead]; + if (!h.ready) { + throw new InternalError("Queue empty"); + } + h.type = null; + h.binary = null; + h.text = null; + h.attachment = null; + h.action = null; + h.future = null; + h.ready = false; + head = (currentHead + 1) % elements.length; + } + + private enum Type { + + TEXT, + BINARY, + PING, + PONG, + CLOSE + } + + /* + * A callback for consuming a queue element's fields. Can return a result of + * type T or throw an exception of type E. This design allows to avoid + * "returning" results or "throwing" errors by updating some objects from + * the outside of the methods. + */ + public interface QueueCallback { + + R onText(CharBuffer message, + boolean isLast, + T attachment, + BiConsumer action, + CompletableFuture future) throws E; + + R onBinary(ByteBuffer message, + boolean isLast, + T attachment, + BiConsumer action, + CompletableFuture future) throws E; + + R onPing(ByteBuffer message, + T attachment, + BiConsumer action, + CompletableFuture future) throws E; + + R onPong(ByteBuffer message, + T attachment, + BiConsumer action, + CompletableFuture future) throws E; + + R onClose(int statusCode, + CharBuffer reason, + T attachment, + BiConsumer action, + CompletableFuture future) throws E; + + /* The queue is empty*/ + R onEmpty() throws E; + } + + /* + * A union of components of all WebSocket message types; also a node in a + * queue. + * + * A `Message` never leaves the context of the queue, thus the reference to + * it cannot be retained by anyone other than the queue. + */ + private static class Message { + + private volatile boolean ready; + + // -- The source message fields -- + + private Type type; + private ByteBuffer binary; + private CharBuffer text; + private boolean isLast; + private int statusCode; + private Object attachment; + @SuppressWarnings("rawtypes") + private BiConsumer action; + @SuppressWarnings("rawtypes") + private CompletableFuture future; + } +} diff -r d818a6a8295a -r 4933a477d628 src/java.net.http/share/classes/jdk/internal/net/http/websocket/OutgoingMessage.java --- a/src/java.net.http/share/classes/jdk/internal/net/http/websocket/OutgoingMessage.java Wed Mar 07 15:39:25 2018 +0000 +++ /dev/null Thu Jan 01 00:00:00 1970 +0000 @@ -1,296 +0,0 @@ -/* - * Copyright (c) 2015, 2018, Oracle and/or its affiliates. All rights reserved. - * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. - * - * This code is free software; you can redistribute it and/or modify it - * under the terms of the GNU General Public License version 2 only, as - * published by the Free Software Foundation. Oracle designates this - * particular file as subject to the "Classpath" exception as provided - * by Oracle in the LICENSE file that accompanied this code. - * - * This code is distributed in the hope that it will be useful, but WITHOUT - * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or - * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License - * version 2 for more details (a copy is included in the LICENSE file that - * accompanied this code). - * - * You should have received a copy of the GNU General Public License version - * 2 along with this work; if not, write to the Free Software Foundation, - * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. - * - * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA - * or visit www.oracle.com if you need additional information or have any - * questions. - */ - -package jdk.internal.net.http.websocket; - -import jdk.internal.net.http.websocket.Frame.Opcode; - -import java.io.IOException; -import java.nio.ByteBuffer; -import java.nio.CharBuffer; -import java.nio.charset.CharacterCodingException; -import java.nio.charset.CharsetEncoder; -import java.nio.charset.CoderResult; -import java.security.SecureRandom; - -import static java.nio.charset.StandardCharsets.UTF_8; -import static java.util.Objects.requireNonNull; -import static jdk.internal.net.http.common.Utils.EMPTY_BYTEBUFFER; -import static jdk.internal.net.http.websocket.Frame.MAX_HEADER_SIZE_BYTES; -import static jdk.internal.net.http.websocket.Frame.Opcode.BINARY; -import static jdk.internal.net.http.websocket.Frame.Opcode.CLOSE; -import static jdk.internal.net.http.websocket.Frame.Opcode.CONTINUATION; -import static jdk.internal.net.http.websocket.Frame.Opcode.PING; -import static jdk.internal.net.http.websocket.Frame.Opcode.PONG; -import static jdk.internal.net.http.websocket.Frame.Opcode.TEXT; - -/* - * A stateful object that represents a WebSocket message being sent to the - * channel. - * - * Data provided to the constructors is copied. Otherwise we would have to deal - * with mutability, security, masking/unmasking, readonly status, etc. So - * copying greatly simplifies the implementation. - * - * In the case of memory-sensitive environments an alternative implementation - * could use an internal pool of buffers though at the cost of extra complexity - * and possible performance degradation. - */ -abstract class OutgoingMessage { - - // Share per WebSocket? - private static final SecureRandom maskingKeys = new SecureRandom(); - - protected ByteBuffer[] frame; - protected int offset; - - /* - * Performs contextualization. This method is not a part of the constructor - * so it would be possible to defer the work it does until the most - * convenient moment (up to the point where sentTo is invoked). - */ - protected boolean contextualize(Context context) { - // masking and charset decoding should be performed here rather than in - // the constructor (as of today) - if (context.isCloseSent()) { - throw new IllegalStateException("Close sent"); - } - return true; - } - - protected boolean sendTo(RawChannel channel) throws IOException { - while ((offset = nextUnwrittenIndex()) != -1) { - long n = channel.write(frame, offset, frame.length - offset); - if (n == 0) { - return false; - } - } - return true; - } - - private int nextUnwrittenIndex() { - for (int i = offset; i < frame.length; i++) { - if (frame[i].hasRemaining()) { - return i; - } - } - return -1; - } - - static final class Text extends OutgoingMessage { - - private final ByteBuffer payload; - private final boolean isLast; - - Text(CharSequence characters, boolean isLast) { - CharsetEncoder encoder = UTF_8.newEncoder(); // Share per WebSocket? - try { - payload = encoder.encode(CharBuffer.wrap(characters)); - } catch (CharacterCodingException e) { - throw new IllegalArgumentException( - "Malformed UTF-8 text message"); - } - this.isLast = isLast; - } - - @Override - protected boolean contextualize(Context context) { - super.contextualize(context); - if (context.isPreviousBinary() && !context.isPreviousLast()) { - throw new IllegalStateException("Unexpected text message"); - } - frame = getDataMessageBuffers( - TEXT, context.isPreviousLast(), isLast, payload, payload); - context.setPreviousBinary(false); - context.setPreviousText(true); - context.setPreviousLast(isLast); - return true; - } - } - - static final class Binary extends OutgoingMessage { - - private final ByteBuffer payload; - private final boolean isLast; - - Binary(ByteBuffer payload, boolean isLast) { - this.payload = requireNonNull(payload); - this.isLast = isLast; - } - - @Override - protected boolean contextualize(Context context) { - super.contextualize(context); - if (context.isPreviousText() && !context.isPreviousLast()) { - throw new IllegalStateException("Unexpected binary message"); - } - ByteBuffer newBuffer = ByteBuffer.allocate(payload.remaining()); - frame = getDataMessageBuffers( - BINARY, context.isPreviousLast(), isLast, payload, newBuffer); - context.setPreviousText(false); - context.setPreviousBinary(true); - context.setPreviousLast(isLast); - return true; - } - } - - static final class Ping extends OutgoingMessage { - - Ping(ByteBuffer payload) { - frame = getControlMessageBuffers(PING, payload); - } - } - - static final class Pong extends OutgoingMessage { - - Pong(ByteBuffer payload) { - frame = getControlMessageBuffers(PONG, payload); - } - } - - static final class Close extends OutgoingMessage { - - Close() { - frame = getControlMessageBuffers(CLOSE, EMPTY_BYTEBUFFER); - } - - Close(int statusCode, CharSequence reason) { - ByteBuffer payload = ByteBuffer.allocate(125) - .putChar((char) statusCode); - CoderResult result = UTF_8.newEncoder() - .encode(CharBuffer.wrap(reason), - payload, - true); - if (result.isOverflow()) { - throw new IllegalArgumentException("Long reason"); - } else if (result.isError()) { - try { - result.throwException(); - } catch (CharacterCodingException e) { - throw new IllegalArgumentException( - "Malformed UTF-8 reason", e); - } - } - payload.flip(); - frame = getControlMessageBuffers(CLOSE, payload); - } - - @Override - protected boolean contextualize(Context context) { - if (context.isCloseSent()) { - return false; - } else { - context.setCloseSent(); - return true; - } - } - } - - private static ByteBuffer[] getControlMessageBuffers(Opcode opcode, - ByteBuffer payload) { - assert opcode.isControl() : opcode; - int remaining = payload.remaining(); - if (remaining > 125) { - throw new IllegalArgumentException - ("Long message: " + remaining); - } - ByteBuffer frame = ByteBuffer.allocate(MAX_HEADER_SIZE_BYTES + remaining); - int mask = maskingKeys.nextInt(); - new Frame.HeaderWriter() - .fin(true) - .opcode(opcode) - .payloadLen(remaining) - .mask(mask) - .write(frame); - Frame.Masker.transferMasking(payload, frame, mask); - frame.flip(); - return new ByteBuffer[]{frame}; - } - - private static ByteBuffer[] getDataMessageBuffers(Opcode type, - boolean isPreviousLast, - boolean isLast, - ByteBuffer payloadSrc, - ByteBuffer payloadDst) { - assert !type.isControl() && type != CONTINUATION : type; - ByteBuffer header = ByteBuffer.allocate(MAX_HEADER_SIZE_BYTES); - int mask = maskingKeys.nextInt(); - new Frame.HeaderWriter() - .fin(isLast) - .opcode(isPreviousLast ? type : CONTINUATION) - .payloadLen(payloadDst.remaining()) - .mask(mask) - .write(header); - header.flip(); - Frame.Masker.transferMasking(payloadSrc, payloadDst, mask); - payloadDst.flip(); - return new ByteBuffer[]{header, payloadDst}; - } - - /* - * An instance of this class is passed sequentially between messages, so - * every message in a sequence can check the context it is in and update it - * if necessary. - */ - public static class Context { - - boolean previousLast = true; - boolean previousBinary; - boolean previousText; - boolean closeSent; - - private boolean isPreviousText() { - return this.previousText; - } - - private void setPreviousText(boolean value) { - this.previousText = value; - } - - private boolean isPreviousBinary() { - return this.previousBinary; - } - - private void setPreviousBinary(boolean value) { - this.previousBinary = value; - } - - private boolean isPreviousLast() { - return this.previousLast; - } - - private void setPreviousLast(boolean value) { - this.previousLast = value; - } - - private boolean isCloseSent() { - return closeSent; - } - - private void setCloseSent() { - closeSent = true; - } - } -} diff -r d818a6a8295a -r 4933a477d628 src/java.net.http/share/classes/jdk/internal/net/http/websocket/Transport.java --- a/src/java.net.http/share/classes/jdk/internal/net/http/websocket/Transport.java Wed Mar 07 15:39:25 2018 +0000 +++ b/src/java.net.http/share/classes/jdk/internal/net/http/websocket/Transport.java Wed Mar 07 17:16:28 2018 +0000 @@ -28,60 +28,69 @@ import java.io.IOException; import java.nio.ByteBuffer; import java.util.concurrent.CompletableFuture; +import java.util.function.BiConsumer; /* - * Transport needs some way to asynchronously notify the send operation has been - * completed. It can have several different designs each of which has its own - * pros and cons: + * A WebSocket view of the underlying communication channel. This view provides + * an asynchronous exchange of WebSocket messages rather than asynchronous + * exchange of bytes. * - * (1) void sendMessage(..., Callback) - * (2) CompletableFuture sendMessage(...) - * (3) CompletableFuture sendMessage(..., Callback) - * (4) boolean sendMessage(..., Callback) throws IOException - * ... + * Methods sendText, sendBinary, sendPing, sendPong and sendClose initiate a + * corresponding operation and return a CompletableFuture (CF) which will + * complete once the operation has completed (succeeded or failed). * - * If Transport's users use CFs, (1) forces these users to create CFs and pass - * them to the callback. If any additional (dependant) action needs to be - * attached to the returned CF, this means an extra object (CF) must be created - * in (2). (3) and (4) solves both issues, however (4) does not abstract out - * when exactly the operation has been performed. So the handling code needs to - * be repeated twice. And that leads to 2 different code paths (more bugs). - * Unless designed for this, the user should not assume any specific order of - * completion in (3) (e.g. callback first and then the returned CF). + * These methods are designed such that their clients may take an advantage on + * possible implementation optimizations. Namely, these methods: + * + * 1. May return null which is considered the same as a CF completed normally + * 2. Accept an arbitrary attachment to complete a CF with + * 3. Accept an action to take once the operation has completed * - * The only parametrization of Transport used is Transport. The - * type parameter T was introduced solely to avoid circular dependency between - * Transport and WebSocket. After all, instances of T are used solely to - * complete CompletableFutures. Transport doesn't care about the exact type of - * T. - * - * This way the Transport is fully in charge of creating CompletableFutures. - * On the one hand, Transport may use it to cache/reuse CompletableFutures. On - * the other hand, the class that uses Transport, may benefit by not converting - * from CompletableFuture returned from Transport, to CompletableFuture - * needed by the said class. + * All of the above allows not to create unnecessary instances of CF. + * For example, if a message has been sent straight away, there's no need to + * create a CF (given the parties agree on the meaning of null and are prepared + * to handle it). + * If the result of a returned CF is useless to the client, they may specify the + * exact instance (attachment) they want the CF to complete with. Thus, no need + * to create transforming stages (e.g. thenApply(useless -> myResult)). + * If there is the same action that needs to be done each time the CF completes, + * the client may pass it directly to the method instead of creating a dependant + * stage (e.g. whenComplete(action)). */ -public interface Transport { +public interface Transport { - CompletableFuture sendText(CharSequence message, boolean isLast); + CompletableFuture sendText(CharSequence message, + boolean isLast, + T attachment, + BiConsumer action); - CompletableFuture sendBinary(ByteBuffer message, boolean isLast); + CompletableFuture sendBinary(ByteBuffer message, + boolean isLast, + T attachment, + BiConsumer action); - CompletableFuture sendPing(ByteBuffer message); + CompletableFuture sendPing(ByteBuffer message, + T attachment, + BiConsumer action); - CompletableFuture sendPong(ByteBuffer message); + CompletableFuture sendPong(ByteBuffer message, + T attachment, + BiConsumer action); - CompletableFuture sendClose(int statusCode, String reason); + CompletableFuture sendClose(int statusCode, + String reason, + T attachment, + BiConsumer action); void request(long n); /* - * Why is this method needed? Since Receiver operates through callbacks - * this method allows to abstract out what constitutes as a message being - * received (i.e. to decide outside this type when exactly one should - * decrement the demand). + * Why is this method needed? Since receiving of messages operates through + * callbacks this method allows to abstract out what constitutes as a + * message being received (i.e. to decide outside this type when exactly one + * should decrement the demand). */ - void acknowledgeReception(); + void acknowledgeReception(); // TODO: hide void closeOutput() throws IOException; diff -r d818a6a8295a -r 4933a477d628 src/java.net.http/share/classes/jdk/internal/net/http/websocket/TransportFactory.java --- a/src/java.net.http/share/classes/jdk/internal/net/http/websocket/TransportFactory.java Wed Mar 07 15:39:25 2018 +0000 +++ b/src/java.net.http/share/classes/jdk/internal/net/http/websocket/TransportFactory.java Wed Mar 07 17:16:28 2018 +0000 @@ -1,10 +1,32 @@ -package jdk.internal.net.http.websocket; +/* + * Copyright (c) 2017, 2018, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. Oracle designates this + * particular file as subject to the "Classpath" exception as provided + * by Oracle in the LICENSE file that accompanied this code. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ -import java.util.function.Supplier; +package jdk.internal.net.http.websocket; @FunctionalInterface public interface TransportFactory { - Transport createTransport(Supplier sendResultSupplier, - MessageStreamConsumer consumer); + Transport createTransport(MessageStreamConsumer consumer); } diff -r d818a6a8295a -r 4933a477d628 src/java.net.http/share/classes/jdk/internal/net/http/websocket/TransportFactoryImpl.java --- a/src/java.net.http/share/classes/jdk/internal/net/http/websocket/TransportFactoryImpl.java Wed Mar 07 15:39:25 2018 +0000 +++ b/src/java.net.http/share/classes/jdk/internal/net/http/websocket/TransportFactoryImpl.java Wed Mar 07 17:16:28 2018 +0000 @@ -22,10 +22,9 @@ * or visit www.oracle.com if you need additional information or have any * questions. */ + package jdk.internal.net.http.websocket; -import java.util.function.Supplier; - public class TransportFactoryImpl implements TransportFactory { private final RawChannel channel; @@ -35,8 +34,7 @@ } @Override - public Transport createTransport(Supplier sendResultSupplier, - MessageStreamConsumer consumer) { - return new TransportImpl(sendResultSupplier, consumer, channel); + public Transport createTransport(MessageStreamConsumer consumer) { + return new TransportImpl(consumer, channel); } } diff -r d818a6a8295a -r 4933a477d628 src/java.net.http/share/classes/jdk/internal/net/http/websocket/TransportImpl.java --- a/src/java.net.http/share/classes/jdk/internal/net/http/websocket/TransportImpl.java Wed Mar 07 15:39:25 2018 +0000 +++ b/src/java.net.http/share/classes/jdk/internal/net/http/websocket/TransportImpl.java Wed Mar 07 17:16:28 2018 +0000 @@ -27,62 +27,58 @@ import jdk.internal.net.http.common.Demand; import jdk.internal.net.http.common.MinimalFuture; -import jdk.internal.net.http.common.Pair; import jdk.internal.net.http.common.SequentialScheduler; +import jdk.internal.net.http.common.SequentialScheduler.CompleteRestartableTask; +import jdk.internal.net.http.common.Utils; import java.io.IOException; import java.nio.ByteBuffer; +import java.nio.CharBuffer; import java.nio.channels.SelectionKey; -import java.util.Queue; import java.util.concurrent.CompletableFuture; -import java.util.concurrent.ConcurrentLinkedQueue; -import java.util.concurrent.atomic.AtomicBoolean; -import java.util.function.Consumer; -import java.util.function.Supplier; +import java.util.concurrent.atomic.AtomicLong; +import java.util.function.BiConsumer; -import static java.util.Objects.requireNonNull; -import static jdk.internal.net.http.common.MinimalFuture.failedFuture; -import static jdk.internal.net.http.common.Pair.pair; +import static jdk.internal.net.http.websocket.TransportImpl.ChannelState.AVAILABLE; +import static jdk.internal.net.http.websocket.TransportImpl.ChannelState.UNREGISTERED; +import static jdk.internal.net.http.websocket.TransportImpl.ChannelState.WAITING; -public class TransportImpl implements Transport { +public class TransportImpl implements Transport { + + // -- Debugging infrastructure -- + + private final static boolean DEBUG = false; - /* This flag is used solely for assertions */ - private final AtomicBoolean busy = new AtomicBoolean(); - private OutgoingMessage message; - private Consumer completionHandler; - private final RawChannel channel; - private final RawChannel.RawEvent writeEvent = createWriteEvent(); + /* Used for correlating enters to and exists from a method */ + private final static AtomicLong COUNTER = new AtomicLong(); + private final SequentialScheduler sendScheduler = new SequentialScheduler(new SendTask()); - private final Queue>> - queue = new ConcurrentLinkedQueue<>(); - private final OutgoingMessage.Context context = new OutgoingMessage.Context(); - private final Supplier resultSupplier; + private final MessageQueue queue = new MessageQueue(); + private final MessageEncoder encoder = new MessageEncoder(); + /* A reusable buffer for writing, initially with no remaining bytes */ + private final ByteBuffer dst = createWriteBuffer().position(0).limit(0); + /* This array is created once for gathering writes accepted by RawChannel */ + private final ByteBuffer[] dstArray = new ByteBuffer[]{dst}; private final MessageStreamConsumer messageConsumer; - private final FrameConsumer frameConsumer; + private final MessageDecoder decoder; private final Frame.Reader reader = new Frame.Reader(); - private final RawChannel.RawEvent readEvent = createReadEvent(); + private final Demand demand = new Demand(); private final SequentialScheduler receiveScheduler; - + private final RawChannel channel; + private final Object closeLock = new Object(); + private final RawChannel.RawEvent writeEvent = new WriteEvent(); + private final RawChannel.RawEvent readEvent = new ReadEvent(); + private volatile ChannelState writeState = UNREGISTERED; private ByteBuffer data; - private volatile int state; - - private static final int UNREGISTERED = 0; - private static final int AVAILABLE = 1; - private static final int WAITING = 2; - - private final Object lock = new Object(); + private volatile ChannelState readState = UNREGISTERED; private boolean inputClosed; private boolean outputClosed; - - public TransportImpl(Supplier sendResultSupplier, - MessageStreamConsumer consumer, - RawChannel channel) { - this.resultSupplier = sendResultSupplier; + public TransportImpl(MessageStreamConsumer consumer, RawChannel channel) { this.messageConsumer = consumer; this.channel = channel; - this.frameConsumer = new FrameConsumer(this.messageConsumer); + this.decoder = new MessageDecoder(this.messageConsumer); this.data = channel.initialByteBuffer(); // To ensure the initial non-final `data` will be visible // (happens-before) when `readEvent.handle()` invokes `receiveScheduler` @@ -90,190 +86,164 @@ receiveScheduler = new SequentialScheduler(new ReceiveTask()); } - /** - * The supplied handler may be invoked in the calling thread. - * A {@code StackOverflowError} may thus occur if there's a possibility - * that this method is called again by the supplied handler. - */ - public void send(OutgoingMessage message, - Consumer completionHandler) { - requireNonNull(message); - requireNonNull(completionHandler); - if (!busy.compareAndSet(false, true)) { - throw new IllegalStateException(); + private ByteBuffer createWriteBuffer() { + String name = "jdk.httpclient.websocket.writeBufferSize"; + int capacity = Utils.getIntegerNetProperty(name, 16384); + if (DEBUG) { + System.out.printf("[Transport] write buffer capacity %s", capacity); } - send0(message, completionHandler); + // TODO (optimization?): allocateDirect if SSL? + return ByteBuffer.allocate(capacity); } - private RawChannel.RawEvent createWriteEvent() { - return new RawChannel.RawEvent() { - - @Override - public int interestOps() { - return SelectionKey.OP_WRITE; + private boolean write() throws IOException { + if (DEBUG) { + System.out.printf("[Transport] writing to the channel%n"); + } + long count = channel.write(dstArray, 0, dstArray.length); + if (DEBUG) { + System.out.printf("[Transport] %s bytes written%n", count); + } + for (ByteBuffer b : dstArray) { + if (b.hasRemaining()) { + return false; } - - @Override - public void handle() { - // registerEvent(e) happens-before subsequent e.handle(), so - // we're fine reading the stored message and the completionHandler - send0(message, completionHandler); - } - }; + } + return true; } - private void send0(OutgoingMessage message, Consumer handler) { - boolean b = busy.get(); - assert b; // Please don't inline this, as busy.get() has memory - // visibility effects and we don't want the program behaviour - // to depend on whether the assertions are turned on - // or turned off + @Override + public CompletableFuture sendText(CharSequence message, + boolean isLast, + T attachment, + BiConsumer action) { + long id; + if (DEBUG) { + id = COUNTER.incrementAndGet(); + System.out.printf("[Transport] %s: sendText message.length()=%s, last=%s%n", + id, message.length(), isLast); + } + // TODO (optimization?): + // These sendXXX methods might be a good place to decide whether or not + // we can write straight ahead, possibly returning null instead of + // creating a CompletableFuture + + // Even if the text is already CharBuffer, the client will not be happy + // if they discover the position is changing. So, no instanceof + // cheating, wrap always. + CharBuffer text = CharBuffer.wrap(message); + MinimalFuture f = new MinimalFuture<>(); try { - boolean sent = message.sendTo(channel); - if (sent) { - busy.set(false); - handler.accept(null); - } else { - // The message has not been fully sent, the transmitter needs to - // remember the message until it can continue with sending it - this.message = message; - this.completionHandler = handler; - try { - channel.registerEvent(writeEvent); - } catch (IOException e) { - this.message = null; - this.completionHandler = null; - busy.set(false); - handler.accept(e); - } - } + queue.addText(text, isLast, attachment, action, f); + sendScheduler.runOrSchedule(); } catch (IOException e) { - busy.set(false); - handler.accept(e); - } - } - - public CompletableFuture sendText(CharSequence message, - boolean isLast) { - OutgoingMessage.Text m; - try { - m = new OutgoingMessage.Text(message, isLast); - } catch (IllegalArgumentException e) { - return failedFuture(e); + f.completeExceptionally(e); } - return enqueue(m); - } - - public CompletableFuture sendBinary(ByteBuffer message, - boolean isLast) { - return enqueue(new OutgoingMessage.Binary(message, isLast)); - } - - public CompletableFuture sendPing(ByteBuffer message) { - OutgoingMessage.Ping m; - try { - m = new OutgoingMessage.Ping(message); - } catch (IllegalArgumentException e) { - return failedFuture(e); + if (DEBUG) { + System.out.printf("[Transport] %s: sendText returned %s%n", id, f); } - return enqueue(m); + return f; } - public CompletableFuture sendPong(ByteBuffer message) { - OutgoingMessage.Pong m; - try { - m = new OutgoingMessage.Pong(message); - } catch (IllegalArgumentException e) { - return failedFuture(e); + @Override + public CompletableFuture sendBinary(ByteBuffer message, + boolean isLast, + T attachment, + BiConsumer action) { + long id; + if (DEBUG) { + id = COUNTER.incrementAndGet(); + System.out.printf("[Transport] %s: sendBinary message.remaining()=%s, last=%s%n", + id, message.remaining(), isLast); } - return enqueue(m); - } - - public CompletableFuture sendClose(int statusCode, String reason) { - OutgoingMessage.Close m; + MinimalFuture f = new MinimalFuture<>(); try { - m = new OutgoingMessage.Close(statusCode, reason); - } catch (IllegalArgumentException e) { - return failedFuture(e); + queue.addBinary(message, isLast, attachment, action, f); + sendScheduler.runOrSchedule(); + } catch (IOException e) { + f.completeExceptionally(e); } - return enqueue(m); - } - - private CompletableFuture enqueue(OutgoingMessage m) { - CompletableFuture cf = new MinimalFuture<>(); - boolean added = queue.add(pair(m, cf)); - if (!added) { - // The queue is supposed to be unbounded - throw new InternalError(); + if (DEBUG) { + System.out.printf("[Transport] %s: sendBinary returned %s%n", id, f); } - sendScheduler.runOrSchedule(); - return cf; + return f; } - /* - * This is a message sending task. It pulls messages from the queue one by - * one and sends them. It may be run in different threads, but never - * concurrently. - */ - private class SendTask implements SequentialScheduler.RestartableTask { - - @Override - public void run(SequentialScheduler.DeferredCompleter taskCompleter) { - Pair> p = queue.poll(); - if (p == null) { - taskCompleter.complete(); - return; - } - OutgoingMessage message = p.first; - CompletableFuture cf = p.second; - try { - if (!message.contextualize(context)) { // Do not send the message - cf.complete(resultSupplier.get()); - repeat(taskCompleter); - return; - } - Consumer h = e -> { - if (e == null) { - cf.complete(resultSupplier.get()); - } else { - cf.completeExceptionally(e); - } - repeat(taskCompleter); - }; - send(message, h); - } catch (Throwable t) { - cf.completeExceptionally(t); - repeat(taskCompleter); - } + @Override + public CompletableFuture sendPing(ByteBuffer message, + T attachment, + BiConsumer action) { + long id; + if (DEBUG) { + id = COUNTER.incrementAndGet(); + System.out.printf("[Transport] %s: sendPing message.remaining()=%s%n", + id, message.remaining()); } - - private void repeat(SequentialScheduler.DeferredCompleter taskCompleter) { - taskCompleter.complete(); - // More than a single message may have been enqueued while - // the task has been busy with the current message, but - // there is only a single signal recorded + MinimalFuture f = new MinimalFuture<>(); + try { + queue.addPing(message, attachment, action, f); sendScheduler.runOrSchedule(); + } catch (IOException e) { + f.completeExceptionally(e); } + if (DEBUG) { + System.out.printf("[Transport] %s: sendPing returned %s%n", id, f); + } + return f; } - private RawChannel.RawEvent createReadEvent() { - return new RawChannel.RawEvent() { + @Override + public CompletableFuture sendPong(ByteBuffer message, + T attachment, + BiConsumer action) { + long id; + if (DEBUG) { + id = COUNTER.incrementAndGet(); + System.out.printf("[Transport] %s: sendPong message.remaining()=%s%n", + id, message.remaining()); + } + MinimalFuture f = new MinimalFuture<>(); + try { + queue.addPong(message, attachment, action, f); + sendScheduler.runOrSchedule(); + } catch (IOException e) { + f.completeExceptionally(e); + } + if (DEBUG) { + System.out.printf("[Transport] %s: sendPong returned %s%n", id, f); + } + return f; + } - @Override - public int interestOps() { - return SelectionKey.OP_READ; - } - - @Override - public void handle() { - state = AVAILABLE; - receiveScheduler.runOrSchedule(); - } - }; + @Override + public CompletableFuture sendClose(int statusCode, + String reason, + T attachment, + BiConsumer action) { + long id; + if (DEBUG) { + id = COUNTER.incrementAndGet(); + System.out.printf("[Transport] %s: sendClose statusCode=%s, reason.length()=%s%n", + id, statusCode, reason.length()); + } + MinimalFuture f = new MinimalFuture<>(); + try { + queue.addClose(statusCode, CharBuffer.wrap(reason), attachment, action, f); + sendScheduler.runOrSchedule(); + } catch (IOException e) { + f.completeExceptionally(e); + } + if (DEBUG) { + System.out.printf("[Transport] %s: sendClose returned %s%n", id, f); + } + return f; } @Override public void request(long n) { + if (DEBUG) { + System.out.printf("[Transport] request %s%n", n); + } if (demand.increase(n)) { receiveScheduler.runOrSchedule(); } @@ -287,58 +257,20 @@ } } - private class ReceiveTask extends SequentialScheduler.CompleteRestartableTask { - - @Override - public void run() { - while (!receiveScheduler.isStopped()) { - if (data.hasRemaining()) { - if (!demand.isFulfilled()) { - try { - int oldPos = data.position(); - reader.readFrame(data, frameConsumer); - int newPos = data.position(); - assert oldPos != newPos : data; // reader always consumes bytes - } catch (Throwable e) { - receiveScheduler.stop(); - messageConsumer.onError(e); - } - continue; + @Override + public void closeOutput() throws IOException { + if (DEBUG) { + System.out.printf("[Transport] closeOutput%n"); + } + synchronized (closeLock) { + if (!outputClosed) { + outputClosed = true; + try { + channel.shutdownOutput(); + } finally { + if (inputClosed) { + channel.close(); } - break; - } - switch (state) { - case WAITING: - return; - case UNREGISTERED: - try { - state = WAITING; - channel.registerEvent(readEvent); - } catch (Throwable e) { - receiveScheduler.stop(); - messageConsumer.onError(e); - } - return; - case AVAILABLE: - try { - data = channel.read(); - } catch (Throwable e) { - receiveScheduler.stop(); - messageConsumer.onError(e); - return; - } - if (data == null) { // EOF - receiveScheduler.stop(); - messageConsumer.onComplete(); - return; - } else if (!data.hasRemaining()) { - // No data at the moment Pretty much a "goto", - // reusing the existing code path for registration - state = UNREGISTERED; - } - continue; - default: - throw new InternalError(String.valueOf(state)); } } } @@ -350,7 +282,10 @@ */ @Override public void closeInput() throws IOException { - synchronized (lock) { + if (DEBUG) { + System.out.printf("[Transport] closeInput%n"); + } + synchronized (closeLock) { if (!inputClosed) { inputClosed = true; try { @@ -365,19 +300,394 @@ } } - @Override - public void closeOutput() throws IOException { - synchronized (lock) { - if (!outputClosed) { - outputClosed = true; + /* Common states for send and receive tasks */ + enum ChannelState { + UNREGISTERED, + AVAILABLE, + WAITING + } + + @SuppressWarnings({"rawtypes"}) + private class SendTask extends CompleteRestartableTask { + + private final MessageQueue.QueueCallback + encodingCallback = new MessageQueue.QueueCallback<>() { + + @Override + public Boolean onText(CharBuffer message, + boolean isLast, + T attachment, + BiConsumer action, + CompletableFuture future) throws IOException + { + return encoder.encodeText(message, isLast, dst); + } + + @Override + public Boolean onBinary(ByteBuffer message, + boolean isLast, + T attachment, + BiConsumer action, + CompletableFuture future) throws IOException + { + return encoder.encodeBinary(message, isLast, dst); + } + + @Override + public Boolean onPing(ByteBuffer message, + T attachment, + BiConsumer action, + CompletableFuture future) throws IOException + { + return encoder.encodePing(message, dst); + } + + @Override + public Boolean onPong(ByteBuffer message, + T attachment, + BiConsumer action, + CompletableFuture future) throws IOException + { + return encoder.encodePong(message, dst); + } + + @Override + public Boolean onClose(int statusCode, + CharBuffer reason, + T attachment, + BiConsumer action, + CompletableFuture future) throws IOException + { + return encoder.encodeClose(statusCode, reason, dst); + } + + @Override + public Boolean onEmpty() { + return false; + } + }; + + /* Whether the task sees the current head message for first time */ + private boolean firstPass = true; + /* Whether the message has been fully encoded */ + private boolean encoded; + + // -- Current message completion communication fields -- + + private Object attachment; + private BiConsumer action; + private CompletableFuture future; + private final MessageQueue.QueueCallback + /* If there is a message, loads its completion communication fields */ + loadCallback = new MessageQueue.QueueCallback() { + + @Override + public Boolean onText(CharBuffer message, + boolean isLast, + T attachment, + BiConsumer action, + CompletableFuture future) + { + SendTask.this.attachment = attachment; + SendTask.this.action = action; + SendTask.this.future = future; + return true; + } + + @Override + public Boolean onBinary(ByteBuffer message, + boolean isLast, + T attachment, + BiConsumer action, + CompletableFuture future) + { + SendTask.this.attachment = attachment; + SendTask.this.action = action; + SendTask.this.future = future; + return true; + } + + @Override + public Boolean onPing(ByteBuffer message, + T attachment, + BiConsumer action, + CompletableFuture future) + { + SendTask.this.attachment = attachment; + SendTask.this.action = action; + SendTask.this.future = future; + return true; + } + + @Override + public Boolean onPong(ByteBuffer message, + T attachment, + BiConsumer action, + CompletableFuture future) + { + SendTask.this.attachment = attachment; + SendTask.this.action = action; + SendTask.this.future = future; + return true; + } + + @Override + public Boolean onClose(int statusCode, + CharBuffer reason, + T attachment, + BiConsumer action, + CompletableFuture future) + { + SendTask.this.attachment = attachment; + SendTask.this.action = action; + SendTask.this.future = future; + return true; + } + + @Override + public Boolean onEmpty() { + return false; + } + }; + + @Override + public void run() { + // Could have been only called in one of the following cases: + // (a) A message has been added to the queue + // (b) The channel is ready for writing + if (DEBUG) { + System.out.printf("[Transport] begin send task%n"); + } + while (!queue.isEmpty()) { try { - channel.shutdownOutput(); + if (dst.hasRemaining()) { + if (DEBUG) { + System.out.printf("[Transport] %s bytes in buffer%n", + dst.remaining()); + } + // The previous part of the binary representation of the message + // hasn't been fully written + if (!tryCompleteWrite()) { + return; + } + } else if (!encoded) { + if (firstPass) { + firstPass = false; + queue.peek(loadCallback); + if (DEBUG) { + System.out.printf("[Transport] loaded message%n"); + } + } + dst.clear(); + encoded = queue.peek(encodingCallback); + dst.flip(); + if (!tryCompleteWrite()) { + return; + } + } else { + // All done, remove and complete + encoder.reset(); + removeAndComplete(null); + } + } catch (Throwable t) { + if (DEBUG) { + System.out.printf("[Transport] exception %s; cleanup%n", t); + } + // buffer cleanup: if there is an exception, the buffer + // should appear empty for the next write as there is + // nothing to write + dst.position(dst.limit()); + encoder.reset(); + removeAndComplete(t); + } + } + if (DEBUG) { + System.out.printf("[Transport] end send task%n"); + } + } + + private boolean tryCompleteWrite() throws IOException { + if (DEBUG) { + System.out.printf("[Transport] begin writing%n"); + } + boolean finished = false; + loop: + while (true) { + final ChannelState ws = writeState; + if (DEBUG) { + System.out.printf("[Transport] write state: %s%n", ws); + } + switch (ws) { + case WAITING: + break loop; + case UNREGISTERED: + if (DEBUG) { + System.out.printf("[Transport] registering write event%n"); + } + writeState = WAITING; + try { + channel.registerEvent(writeEvent); + } catch (Throwable t) { + writeState = UNREGISTERED; + throw t; + } + if (DEBUG) { + System.out.printf("[Transport] registered write event%n"); + } + break loop; + case AVAILABLE: + boolean written = write(); + if (written) { + if (DEBUG) { + System.out.printf("[Transport] finished writing to the channel%n"); + } + finished = true; + break loop; // All done + } else { + writeState = UNREGISTERED; + continue loop; // Effectively "goto UNREGISTERED" + } + default: + throw new InternalError(String.valueOf(ws)); + } + } + if (DEBUG) { + System.out.printf("[Transport] end writing%n"); + } + return finished; + } + + @SuppressWarnings("unchecked") + private void removeAndComplete(Throwable error) { + if (DEBUG) { + System.out.printf("[Transport] removeAndComplete error=%s%n", error); + } + queue.remove(); + if (error != null) { + try { + action.accept(null, error); } finally { - if (inputClosed) { - channel.close(); + future.completeExceptionally(error); + } + } else { + try { + action.accept(attachment, null); + } finally { + future.complete(attachment); + } + } + encoded = false; + firstPass = true; + attachment = null; + action = null; + future = null; + } + } + + private class ReceiveTask extends CompleteRestartableTask { + + @Override + public void run() { + if (DEBUG) { + System.out.printf("[Transport] begin receive task%n"); + } + loop: + while (!receiveScheduler.isStopped()) { + if (data.hasRemaining()) { + if (DEBUG) { + System.out.printf("[Transport] remaining bytes received %s%n", + data.remaining()); } + if (!demand.isFulfilled()) { + try { + int oldPos = data.position(); + reader.readFrame(data, decoder); + int newPos = data.position(); + // Reader always consumes bytes: + assert oldPos != newPos : data; + } catch (Throwable e) { + receiveScheduler.stop(); + messageConsumer.onError(e); + } + continue; + } + break loop; } + final ChannelState rs = readState; + if (DEBUG) { + System.out.printf("[Transport] receive state: %s%n", rs); + } + switch (rs) { + case WAITING: + break loop; + case UNREGISTERED: + try { + readState = WAITING; + channel.registerEvent(readEvent); + } catch (Throwable e) { + receiveScheduler.stop(); + messageConsumer.onError(e); + } + break loop; + case AVAILABLE: + try { + data = channel.read(); + } catch (Throwable e) { + receiveScheduler.stop(); + messageConsumer.onError(e); + break loop; + } + if (data == null) { // EOF + receiveScheduler.stop(); + messageConsumer.onComplete(); + break loop; + } else if (!data.hasRemaining()) { + // No data at the moment. Pretty much a "goto", + // reusing the existing code path for registration + readState = UNREGISTERED; + } + continue loop; + default: + throw new InternalError(String.valueOf(rs)); + } + } + if (DEBUG) { + System.out.printf("[Transport] end receive task%n"); } } } + + private class WriteEvent implements RawChannel.RawEvent { + + @Override + public int interestOps() { + return SelectionKey.OP_WRITE; + } + + @Override + public void handle() { + if (DEBUG) { + System.out.printf("[Transport] ready to write%n"); + } + writeState = AVAILABLE; + sendScheduler.runOrSchedule(); + } + } + + private class ReadEvent implements RawChannel.RawEvent { + + @Override + public int interestOps() { + return SelectionKey.OP_READ; + } + + @Override + public void handle() { + if (DEBUG) { + System.out.printf("[Transport] ready to read%n"); + } + readState = AVAILABLE; + receiveScheduler.runOrSchedule(); + } + } } diff -r d818a6a8295a -r 4933a477d628 src/java.net.http/share/classes/jdk/internal/net/http/websocket/WebSocketImpl.java --- a/src/java.net.http/share/classes/jdk/internal/net/http/websocket/WebSocketImpl.java Wed Mar 07 15:39:25 2018 +0000 +++ b/src/java.net.http/share/classes/jdk/internal/net/http/websocket/WebSocketImpl.java Wed Mar 07 17:16:28 2018 +0000 @@ -25,7 +25,6 @@ package jdk.internal.net.http.websocket; -import java.net.http.WebSocket; import jdk.internal.net.http.common.Demand; import jdk.internal.net.http.common.Log; import jdk.internal.net.http.common.MinimalFuture; @@ -37,6 +36,7 @@ import java.lang.ref.Reference; import java.net.ProtocolException; import java.net.URI; +import java.net.http.WebSocket; import java.nio.ByteBuffer; import java.util.Objects; import java.util.concurrent.CompletableFuture; @@ -44,6 +44,7 @@ import java.util.concurrent.TimeoutException; import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicReference; +import java.util.function.BiConsumer; import java.util.function.Function; import static java.util.Objects.requireNonNull; @@ -66,6 +67,8 @@ */ public final class WebSocketImpl implements WebSocket { + private final static boolean DEBUG = false; + enum State { OPEN, IDLE, @@ -78,6 +81,7 @@ ERROR; } + private final MinimalFuture DONE = MinimalFuture.completedFuture(this); private volatile boolean inputClosed; private volatile boolean outputClosed; @@ -96,8 +100,9 @@ private final Listener listener; private final AtomicBoolean outstandingSend = new AtomicBoolean(); - private final Transport transport; - private final SequentialScheduler receiveScheduler = new SequentialScheduler(new ReceiveTask()); + private final Transport transport; + private final SequentialScheduler receiveScheduler + = new SequentialScheduler(new ReceiveTask()); private final Demand demand = new Demand(); public static CompletableFuture newInstanceAsync(BuilderImpl b) { @@ -142,10 +147,11 @@ this.subprotocol = requireNonNull(subprotocol); this.listener = requireNonNull(listener); this.transport = transportFactory.createTransport( - () -> WebSocketImpl.this, // What about escape of WebSocketImpl.this? new SignallingMessageConsumer()); } + // FIXME: add to action handling of errors -> signalError() + @Override public CompletableFuture sendText(CharSequence message, boolean isLast) { @@ -153,8 +159,10 @@ if (!outstandingSend.compareAndSet(false, true)) { return failedFuture(new IllegalStateException("Send pending")); } - CompletableFuture cf = transport.sendText(message, isLast); - return cf.whenComplete((r, e) -> outstandingSend.set(false)); + CompletableFuture cf + = transport.sendText(message, isLast, this, + (r, e) -> outstandingSend.set(false)); + return replaceNull(cf); } @Override @@ -164,61 +172,88 @@ if (!outstandingSend.compareAndSet(false, true)) { return failedFuture(new IllegalStateException("Send pending")); } - CompletableFuture cf = transport.sendBinary(message, isLast); - // Optimize? - // if (cf.isDone()) { - // outstandingSend.set(false); - // } else { - // cf.whenComplete((r, e) -> outstandingSend.set(false)); - // } - return cf.whenComplete((r, e) -> outstandingSend.set(false)); + CompletableFuture cf + = transport.sendBinary(message, isLast, this, + (r, e) -> outstandingSend.set(false)); + return replaceNull(cf); + } + + private CompletableFuture replaceNull( + CompletableFuture cf) + { + if (cf == null) { + return DONE; + } else { + return cf; + } } @Override public CompletableFuture sendPing(ByteBuffer message) { - return transport.sendPing(message); + Objects.requireNonNull(message); + CompletableFuture cf + = transport.sendPing(message, this, (r, e) -> { }); + return replaceNull(cf); } @Override public CompletableFuture sendPong(ByteBuffer message) { - return transport.sendPong(message); + Objects.requireNonNull(message); + CompletableFuture cf + = transport.sendPong(message, this, (r, e) -> { }); + return replaceNull(cf); } @Override - public CompletableFuture sendClose(int statusCode, String reason) { + public CompletableFuture sendClose(int statusCode, + String reason) { Objects.requireNonNull(reason); if (!isLegalToSendFromClient(statusCode)) { return failedFuture(new IllegalArgumentException("statusCode")); } - return sendClose0(statusCode, reason); + CompletableFuture cf = sendClose0(statusCode, reason); + return replaceNull(cf); } /* * Sends a Close message, then shuts down the output since no more - * messages are expected to be sent after this. + * messages are expected to be sent at this point. */ - private CompletableFuture sendClose0(int statusCode, String reason ) { + private CompletableFuture sendClose0(int statusCode, + String reason) { outputClosed = true; - return transport.sendClose(statusCode, reason) - .whenComplete((result, error) -> { - try { - transport.closeOutput(); - } catch (IOException e) { - Log.logError(e); - } - Throwable cause = Utils.getCompletionCause(error); - if (cause instanceof TimeoutException) { - try { - transport.closeInput(); - } catch (IOException e) { - Log.logError(e); - } - } - }); + BiConsumer closer = (r, e) -> { + Throwable cause = Utils.getCompletionCause(e); + if (cause instanceof IllegalArgumentException) { + // or pre=check it (isLegalToSendFromClient(statusCode)) + return; + } + try { + transport.closeOutput(); + } catch (IOException ex) { + Log.logError(ex); + } + if (cause instanceof TimeoutException) { // FIXME: it is not the case anymore + if (DEBUG) { + System.out.println("[WebSocket] sendClose0 error: " + e); + } + try { + transport.closeInput(); + } catch (IOException ex) { + Log.logError(ex); + } + } + }; + CompletableFuture cf + = transport.sendClose(statusCode, reason, this, closer); + return cf; } @Override public void request(long n) { + if (DEBUG) { + System.out.printf("[WebSocket] request(%s)%n", n); + } if (demand.increase(n)) { receiveScheduler.runOrSchedule(); } @@ -241,6 +276,9 @@ @Override public void abort() { + if (DEBUG) { + System.out.printf("[WebSocket] abort()%n"); + } inputClosed = true; outputClosed = true; receiveScheduler.stop(); @@ -327,6 +365,9 @@ } private void processError() throws IOException { + if (DEBUG) { + System.out.println("[WebSocket] processError"); + } transport.closeInput(); receiveScheduler.stop(); Throwable err = error.get(); @@ -345,24 +386,33 @@ } private void processClose() throws IOException { + if (DEBUG) { + System.out.println("[WebSocket] processClose"); + } transport.closeInput(); receiveScheduler.stop(); CompletionStage readyToClose; readyToClose = listener.onClose(WebSocketImpl.this, statusCode, reason); if (readyToClose == null) { - readyToClose = MinimalFuture.completedFuture(null); + readyToClose = DONE; } int code; if (statusCode == NO_STATUS_CODE || statusCode == CLOSED_ABNORMALLY) { code = NORMAL_CLOSURE; + if (DEBUG) { + System.out.printf("[WebSocket] using statusCode %s instead of %s%n", + statusCode, code); + } } else { code = statusCode; } readyToClose.whenComplete((r, e) -> { - sendClose0(code, "") + sendClose0(code, "") // FIXME errors from here? .whenComplete((r1, e1) -> { - if (e1 != null) { - Log.logError(e1); + if (DEBUG) { + if (e1 != null) { + e1.printStackTrace(System.out); + } } }); }); @@ -381,14 +431,12 @@ .put(binaryData) .flip(); // Non-exclusive send; - CompletableFuture pongSent = transport.sendPong(copy); - pongSent.whenComplete( - (r, e) -> { - if (e != null) { - signalError(Utils.getCompletionCause(e)); - } - } - ); + BiConsumer reporter = (r, e) -> { + if (e != null) { + signalError(Utils.getCompletionCause(e)); + } + }; + transport.sendPong(copy, WebSocketImpl.this, reporter); listener.onPing(WebSocketImpl.this, slice); } @@ -406,10 +454,16 @@ } private void signalOpen() { + if (DEBUG) { + System.out.printf("[WebSocket] signalOpen%n"); + } receiveScheduler.runOrSchedule(); } private void signalError(Throwable error) { + if (DEBUG) { + System.out.printf("[WebSocket] signalError %s%n", error); + } inputClosed = true; outputClosed = true; if (!this.error.compareAndSet(null, error) || !trySetState(ERROR)) { @@ -420,32 +474,56 @@ } private void close() { + if (DEBUG) { + System.out.println("[WebSocket] close"); + } + Throwable first = null; try { + transport.closeInput(); + } catch (Throwable t1) { + first = t1; + } finally { + Throwable second = null; try { - transport.closeInput(); + transport.closeOutput(); + } catch (Throwable t2) { + second = t2; } finally { - transport.closeOutput(); + Throwable e = null; + if (first != null && second != null) { + first.addSuppressed(second); + e = first; + } else if (first != null) { + e = first; + } else if (second != null) { + e = second; + } + if (DEBUG) { + if (e != null) { + e.printStackTrace(System.out); + } + } } - } catch (Throwable t) { - Log.logError(t); } } - /* - * Signals a Close event (might not correspond to anything happened on the - * channel, i.e. might be synthetic). - */ private void signalClose(int statusCode, String reason) { + // FIXME: make sure no race reason & close are not intermixed inputClosed = true; this.statusCode = statusCode; this.reason = reason; - if (!trySetState(CLOSE)) { - Log.logTrace("Close: {0}, ''{1}''", statusCode, reason); - } else { + boolean managed = trySetState(CLOSE); + if (DEBUG) { + System.out.printf("[WebSocket] signalClose statusCode=%s, reason.length()=%s: %s%n", + statusCode, reason.length(), managed); + } + if (managed) { try { transport.closeInput(); } catch (Throwable t) { - Log.logError(t); + if (DEBUG) { + t.printStackTrace(System.out); + } } } } @@ -501,33 +579,45 @@ } private boolean trySetState(State newState) { + State currentState; + boolean success = false; while (true) { - State currentState = state.get(); + currentState = state.get(); if (currentState == ERROR || currentState == CLOSE) { - return false; + break; } else if (state.compareAndSet(currentState, newState)) { receiveScheduler.runOrSchedule(); - return true; + success = true; + break; } } + if (DEBUG) { + System.out.printf("[WebSocket] set state %s (previous %s) %s%n", + newState, currentState, success); + } + return success; } private boolean tryChangeState(State expectedState, State newState) { State witness = state.compareAndExchange(expectedState, newState); + boolean success = false; if (witness == expectedState) { receiveScheduler.runOrSchedule(); - return true; - } - // This should be the only reason for inability to change the state from - // IDLE to WAITING: the state has changed to terminal - if (witness != ERROR && witness != CLOSE) { + success = true; + } else if (witness != ERROR && witness != CLOSE) { + // This should be the only reason for inability to change the state + // from IDLE to WAITING: the state has changed to terminal throw new InternalError(); } - return false; + if (DEBUG) { + System.out.printf("[WebSocket] change state from %s to %s %s%n", + expectedState, newState, success); + } + return success; } /* Exposed for testing purposes */ - protected final Transport transport() { + protected Transport transport() { return transport; } } diff -r d818a6a8295a -r 4933a477d628 test/jdk/java/net/httpclient/websocket/DummyWebSocketServer.java --- a/test/jdk/java/net/httpclient/websocket/DummyWebSocketServer.java Wed Mar 07 15:39:25 2018 +0000 +++ b/test/jdk/java/net/httpclient/websocket/DummyWebSocketServer.java Wed Mar 07 17:16:28 2018 +0000 @@ -1,5 +1,5 @@ /* - * Copyright (c) 2016, 2017, Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 2016, 2018, Oracle and/or its affiliates. All rights reserved. * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. * * This code is free software; you can redistribute it and/or modify it @@ -42,6 +42,7 @@ import java.util.LinkedList; import java.util.List; import java.util.Map; +import java.util.concurrent.CountDownLatch; import java.util.concurrent.atomic.AtomicBoolean; import java.util.function.Function; import java.util.regex.Pattern; @@ -86,6 +87,8 @@ private final Thread thread; private volatile ServerSocketChannel ssc; private volatile InetSocketAddress address; + private ByteBuffer read = ByteBuffer.allocate(1024); + private final CountDownLatch readReady = new CountDownLatch(1); public DummyWebSocketServer() { this(defaultMapping()); @@ -114,6 +117,7 @@ } finally { err.println("Closed: " + channel); close(channel); + readReady.countDown(); } } } catch (ClosedByInterruptException ignored) { @@ -133,8 +137,26 @@ // or the input is shutdown ByteBuffer b = ByteBuffer.allocate(1024); while (channel.read(b) != -1) { + b.flip(); + if (read.remaining() < b.remaining()) { + int required = read.capacity() - read.remaining() + b.remaining(); + int log2required = 32 - Integer.numberOfLeadingZeros(required - 1); + ByteBuffer newBuffer = ByteBuffer.allocate(1 << log2required); + newBuffer.put(read.flip()); + read = newBuffer; + } + read.put(b); b.clear(); } + ByteBuffer close = ByteBuffer.wrap(new byte[]{(byte) 0x88, 0x00}); + while (close.hasRemaining()) { + channel.write(close); + } + } + + public ByteBuffer read() throws InterruptedException { + readReady.await(); + return read.duplicate().asReadOnlyBuffer().flip(); } public void open() throws IOException { diff -r d818a6a8295a -r 4933a477d628 test/jdk/java/net/httpclient/websocket/Frame.java --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/test/jdk/java/net/httpclient/websocket/Frame.java Wed Mar 07 17:16:28 2018 +0000 @@ -0,0 +1,497 @@ +/* + * Copyright (c) 2016, 2017, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ + +import java.nio.ByteBuffer; + +/* Copied from jdk.internal.net.http.websocket.Frame */ +final class Frame { + + final Opcode opcode; + final ByteBuffer data; + final boolean last; + + public Frame(Opcode opcode, ByteBuffer data, boolean last) { + this.opcode = opcode; + /* copy */ + this.data = ByteBuffer.allocate(data.remaining()).put(data.slice()).flip(); + this.last = last; + } + + static final int MAX_HEADER_SIZE_BYTES = 2 + 8 + 4; + static final int MAX_CONTROL_FRAME_PAYLOAD_SIZE = 125; + + enum Opcode { + + CONTINUATION (0x0), + TEXT (0x1), + BINARY (0x2), + NON_CONTROL_0x3(0x3), + NON_CONTROL_0x4(0x4), + NON_CONTROL_0x5(0x5), + NON_CONTROL_0x6(0x6), + NON_CONTROL_0x7(0x7), + CLOSE (0x8), + PING (0x9), + PONG (0xA), + CONTROL_0xB (0xB), + CONTROL_0xC (0xC), + CONTROL_0xD (0xD), + CONTROL_0xE (0xE), + CONTROL_0xF (0xF); + + private static final Opcode[] opcodes; + + static { + Opcode[] values = values(); + opcodes = new Opcode[values.length]; + for (Opcode c : values) { + opcodes[c.code] = c; + } + } + + private final byte code; + + Opcode(int code) { + this.code = (byte) code; + } + + boolean isControl() { + return (code & 0x8) != 0; + } + + static Opcode ofCode(int code) { + return opcodes[code & 0xF]; + } + } + + /* + * A utility for masking frame payload data. + */ + static final class Masker { + + // Exploiting ByteBuffer's ability to read/write multi-byte integers + private final ByteBuffer acc = ByteBuffer.allocate(8); + private final int[] maskBytes = new int[4]; + private int offset; + private long maskLong; + + /* + * Reads all remaining bytes from the given input buffer, masks them + * with the supplied mask and writes the resulting bytes to the given + * output buffer. + * + * The source and the destination buffers may be the same instance. + */ + static void transferMasking(ByteBuffer src, ByteBuffer dst, int mask) { + if (src.remaining() > dst.remaining()) { + throw new IllegalArgumentException(); + } + new Masker().mask(mask).transferMasking(src, dst); + } + + /* + * Clears this instance's state and sets the mask. + * + * The behaviour is as if the mask was set on a newly created instance. + */ + Masker mask(int value) { + acc.clear().putInt(value).putInt(value).flip(); + for (int i = 0; i < maskBytes.length; i++) { + maskBytes[i] = acc.get(i); + } + offset = 0; + maskLong = acc.getLong(0); + return this; + } + + /* + * Reads as many remaining bytes as possible from the given input + * buffer, masks them with the previously set mask and writes the + * resulting bytes to the given output buffer. + * + * The source and the destination buffers may be the same instance. If + * the mask hasn't been previously set it is assumed to be 0. + */ + Masker transferMasking(ByteBuffer src, ByteBuffer dst) { + begin(src, dst); + loop(src, dst); + end(src, dst); + return this; + } + + /* + * Applies up to 3 remaining from the previous pass bytes of the mask. + */ + private void begin(ByteBuffer src, ByteBuffer dst) { + if (offset == 0) { // No partially applied mask from the previous invocation + return; + } + int i = src.position(), j = dst.position(); + final int srcLim = src.limit(), dstLim = dst.limit(); + for (; offset < 4 && i < srcLim && j < dstLim; i++, j++, offset++) + { + dst.put(j, (byte) (src.get(i) ^ maskBytes[offset])); + } + offset &= 3; // Will become 0 if the mask has been fully applied + src.position(i); + dst.position(j); + } + + /* + * Gallops one long (mask + mask) at a time. + */ + private void loop(ByteBuffer src, ByteBuffer dst) { + int i = src.position(); + int j = dst.position(); + final int srcLongLim = src.limit() - 7, dstLongLim = dst.limit() - 7; + for (; i < srcLongLim && j < dstLongLim; i += 8, j += 8) { + dst.putLong(j, src.getLong(i) ^ maskLong); + } + if (i > src.limit()) { + src.position(i - 8); + } else { + src.position(i); + } + if (j > dst.limit()) { + dst.position(j - 8); + } else { + dst.position(j); + } + } + + /* + * Applies up to 7 remaining from the "galloping" phase bytes of the + * mask. + */ + private void end(ByteBuffer src, ByteBuffer dst) { + assert Math.min(src.remaining(), dst.remaining()) < 8; + final int srcLim = src.limit(), dstLim = dst.limit(); + int i = src.position(), j = dst.position(); + for (; i < srcLim && j < dstLim; + i++, j++, offset = (offset + 1) & 3) // offset cycles through 0..3 + { + dst.put(j, (byte) (src.get(i) ^ maskBytes[offset])); + } + src.position(i); + dst.position(j); + } + } + + /* + * A builder-style writer of frame headers. + * + * The writer does not enforce any protocol-level rules, it simply writes a + * header structure to the given buffer. The order of calls to intermediate + * methods is NOT significant. + */ + static final class HeaderWriter { + + private char firstChar; + private long payloadLen; + private int maskingKey; + private boolean mask; + + HeaderWriter fin(boolean value) { + if (value) { + firstChar |= 0b10000000_00000000; + } else { + firstChar &= ~0b10000000_00000000; + } + return this; + } + + HeaderWriter rsv1(boolean value) { + if (value) { + firstChar |= 0b01000000_00000000; + } else { + firstChar &= ~0b01000000_00000000; + } + return this; + } + + HeaderWriter rsv2(boolean value) { + if (value) { + firstChar |= 0b00100000_00000000; + } else { + firstChar &= ~0b00100000_00000000; + } + return this; + } + + HeaderWriter rsv3(boolean value) { + if (value) { + firstChar |= 0b00010000_00000000; + } else { + firstChar &= ~0b00010000_00000000; + } + return this; + } + + HeaderWriter opcode(Opcode value) { + firstChar = (char) ((firstChar & 0xF0FF) | (value.code << 8)); + return this; + } + + HeaderWriter payloadLen(long value) { + if (value < 0) { + throw new IllegalArgumentException("Negative: " + value); + } + payloadLen = value; + firstChar &= 0b11111111_10000000; // Clear previous payload length leftovers + if (payloadLen < 126) { + firstChar |= payloadLen; + } else if (payloadLen < 65536) { + firstChar |= 126; + } else { + firstChar |= 127; + } + return this; + } + + HeaderWriter mask(int value) { + firstChar |= 0b00000000_10000000; + maskingKey = value; + mask = true; + return this; + } + + HeaderWriter noMask() { + firstChar &= ~0b00000000_10000000; + mask = false; + return this; + } + + /* + * Writes the header to the given buffer. + * + * The buffer must have at least MAX_HEADER_SIZE_BYTES remaining. The + * buffer's position is incremented by the number of bytes written. + */ + void write(ByteBuffer buffer) { + buffer.putChar(firstChar); + if (payloadLen >= 126) { + if (payloadLen < 65536) { + buffer.putChar((char) payloadLen); + } else { + buffer.putLong(payloadLen); + } + } + if (mask) { + buffer.putInt(maskingKey); + } + } + } + + /* + * A consumer of frame parts. + * + * Frame.Reader invokes the consumer's methods in the following order: + * + * fin rsv1 rsv2 rsv3 opcode mask payloadLength maskingKey? payloadData+ endFrame + */ + interface Consumer { + + void fin(boolean value); + + void rsv1(boolean value); + + void rsv2(boolean value); + + void rsv3(boolean value); + + void opcode(Opcode value); + + void mask(boolean value); + + void payloadLen(long value); + + void maskingKey(int value); + + /* + * Called by the Frame.Reader when a part of the (or a complete) payload + * is ready to be consumed. + * + * The sum of numbers of bytes consumed in each invocation of this + * method corresponding to the given frame WILL be equal to + * 'payloadLen', reported to `void payloadLen(long value)` before that. + * + * In particular, if `payloadLen` is 0, then there WILL be a single + * invocation to this method. + * + * No unmasking is done. + */ + void payloadData(ByteBuffer data); + + void endFrame(); + } + + /* + * A Reader of frames. + * + * No protocol-level rules are checked. + */ + static final class Reader { + + private static final int AWAITING_FIRST_BYTE = 1; + private static final int AWAITING_SECOND_BYTE = 2; + private static final int READING_16_LENGTH = 4; + private static final int READING_64_LENGTH = 8; + private static final int READING_MASK = 16; + private static final int READING_PAYLOAD = 32; + + // Exploiting ByteBuffer's ability to read multi-byte integers + private final ByteBuffer accumulator = ByteBuffer.allocate(8); + private int state = AWAITING_FIRST_BYTE; + private boolean mask; + private long remainingPayloadLength; + + /* + * Reads at most one frame from the given buffer invoking the consumer's + * methods corresponding to the frame parts found. + * + * As much of the frame's payload, if any, is read. The buffer's + * position is updated to reflect the number of bytes read. + * + * Throws FailWebSocketException if detects the frame is malformed. + */ + void readFrame(ByteBuffer input, Consumer consumer) { + loop: + while (true) { + byte b; + switch (state) { + case AWAITING_FIRST_BYTE: + if (!input.hasRemaining()) { + break loop; + } + b = input.get(); + consumer.fin( (b & 0b10000000) != 0); + consumer.rsv1((b & 0b01000000) != 0); + consumer.rsv2((b & 0b00100000) != 0); + consumer.rsv3((b & 0b00010000) != 0); + consumer.opcode(Opcode.ofCode(b)); + state = AWAITING_SECOND_BYTE; + continue loop; + case AWAITING_SECOND_BYTE: + if (!input.hasRemaining()) { + break loop; + } + b = input.get(); + consumer.mask(mask = (b & 0b10000000) != 0); + byte p1 = (byte) (b & 0b01111111); + if (p1 < 126) { + assert p1 >= 0 : p1; + consumer.payloadLen(remainingPayloadLength = p1); + state = mask ? READING_MASK : READING_PAYLOAD; + } else if (p1 < 127) { + state = READING_16_LENGTH; + } else { + state = READING_64_LENGTH; + } + continue loop; + case READING_16_LENGTH: + if (!input.hasRemaining()) { + break loop; + } + b = input.get(); + if (accumulator.put(b).position() < 2) { + continue loop; + } + remainingPayloadLength = accumulator.flip().getChar(); + if (remainingPayloadLength < 126) { + throw notMinimalEncoding(remainingPayloadLength); + } + consumer.payloadLen(remainingPayloadLength); + accumulator.clear(); + state = mask ? READING_MASK : READING_PAYLOAD; + continue loop; + case READING_64_LENGTH: + if (!input.hasRemaining()) { + break loop; + } + b = input.get(); + if (accumulator.put(b).position() < 8) { + continue loop; + } + remainingPayloadLength = accumulator.flip().getLong(); + if (remainingPayloadLength < 0) { + throw negativePayload(remainingPayloadLength); + } else if (remainingPayloadLength < 65536) { + throw notMinimalEncoding(remainingPayloadLength); + } + consumer.payloadLen(remainingPayloadLength); + accumulator.clear(); + state = mask ? READING_MASK : READING_PAYLOAD; + continue loop; + case READING_MASK: + if (!input.hasRemaining()) { + break loop; + } + b = input.get(); + if (accumulator.put(b).position() != 4) { + continue loop; + } + consumer.maskingKey(accumulator.flip().getInt()); + accumulator.clear(); + state = READING_PAYLOAD; + continue loop; + case READING_PAYLOAD: + // This state does not require any bytes to be available + // in the input buffer in order to proceed + int deliverable = (int) Math.min(remainingPayloadLength, + input.remaining()); + int oldLimit = input.limit(); + input.limit(input.position() + deliverable); + if (deliverable != 0 || remainingPayloadLength == 0) { + consumer.payloadData(input); + } + int consumed = deliverable - input.remaining(); + if (consumed < 0) { + // Consumer cannot consume more than there was available + throw new InternalError(); + } + input.limit(oldLimit); + remainingPayloadLength -= consumed; + if (remainingPayloadLength == 0) { + consumer.endFrame(); + state = AWAITING_FIRST_BYTE; + } + break loop; + default: + throw new InternalError(String.valueOf(state)); + } + } + } + + private static IllegalArgumentException negativePayload(long payloadLength) + { + return new IllegalArgumentException("Negative payload length: " + + payloadLength); + } + + private static IllegalArgumentException notMinimalEncoding(long payloadLength) + { + return new IllegalArgumentException("Not minimally-encoded payload length:" + + payloadLength); + } + } +} diff -r d818a6a8295a -r 4933a477d628 test/jdk/java/net/httpclient/websocket/MessageQueueTestDriver.java --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/test/jdk/java/net/httpclient/websocket/MessageQueueTestDriver.java Wed Mar 07 17:16:28 2018 +0000 @@ -0,0 +1,32 @@ +/* + * Copyright (c) 2018, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ + +/* + * @test + * @bug 8159053 + * @modules java.net.http/jdk.internal.net.http.websocket:open + * @run testng/othervm + * --add-reads java.net.http=ALL-UNNAMED + * java.net.http/jdk.internal.net.http.websocket.MessageQueueTest + */ +public final class MessageQueueTestDriver { } diff -r d818a6a8295a -r 4933a477d628 test/jdk/java/net/httpclient/websocket/WebSocketImplDriver.java --- a/test/jdk/java/net/httpclient/websocket/WebSocketImplDriver.java Wed Mar 07 15:39:25 2018 +0000 +++ /dev/null Thu Jan 01 00:00:00 1970 +0000 @@ -1,31 +0,0 @@ -/* - * Copyright (c) 2017, 2018, Oracle and/or its affiliates. All rights reserved. - * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. - * - * This code is free software; you can redistribute it and/or modify it - * under the terms of the GNU General Public License version 2 only, as - * published by the Free Software Foundation. - * - * This code is distributed in the hope that it will be useful, but WITHOUT - * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or - * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License - * version 2 for more details (a copy is included in the LICENSE file that - * accompanied this code). - * - * You should have received a copy of the GNU General Public License version - * 2 along with this work; if not, write to the Free Software Foundation, - * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. - * - * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA - * or visit www.oracle.com if you need additional information or have any - * questions. - */ - -/* - * @test - * @modules java.net.http/jdk.internal.net.http.websocket:open - * @run testng/othervm - * --add-reads java.net.http=ALL-UNNAMED - * java.net.http/jdk.internal.net.http.websocket.WebSocketImplTest - */ -public class WebSocketImplDriver { } diff -r d818a6a8295a -r 4933a477d628 test/jdk/java/net/httpclient/websocket/WebSocketTest.java --- a/test/jdk/java/net/httpclient/websocket/WebSocketTest.java Wed Mar 07 15:39:25 2018 +0000 +++ b/test/jdk/java/net/httpclient/websocket/WebSocketTest.java Wed Mar 07 17:16:28 2018 +0000 @@ -24,9 +24,8 @@ /* * @test * @build DummyWebSocketServer - * @run testng/othervm -Djdk.httpclient.HttpClient.log=trace WebSocketTest + * @run testng/othervm WebSocketTest */ - import org.testng.annotations.Test; import java.io.IOException; @@ -36,12 +35,14 @@ import java.nio.channels.SocketChannel; import java.nio.charset.StandardCharsets; import java.util.ArrayList; +import java.util.Arrays; import java.util.List; import java.util.concurrent.CompletableFuture; import java.util.concurrent.CompletionException; import java.util.concurrent.CompletionStage; import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeoutException; +import java.util.stream.Collectors; import static java.net.http.HttpClient.newHttpClient; import static java.net.http.WebSocket.NORMAL_CLOSURE; @@ -58,16 +59,19 @@ private static final Class ISE = IllegalStateException.class; private static final Class IOE = IOException.class; - @Test - public void abort() throws Exception { +// @Test + public void immediateAbort() throws Exception { try (DummyWebSocketServer server = serverWithCannedData(0x81, 0x00, 0x88, 0x00)) { server.open(); CompletableFuture messageReceived = new CompletableFuture<>(); WebSocket ws = newHttpClient() .newWebSocketBuilder() .buildAsync(server.getURI(), new WebSocket.Listener() { + @Override - public void onOpen(WebSocket webSocket) { /* no initial request */ } + public void onOpen(WebSocket webSocket) { + /* no initial request */ + } @Override public CompletionStage onText(WebSocket webSocket, @@ -133,15 +137,17 @@ try { messageReceived.get(10, TimeUnit.SECONDS); fail(); - } catch (TimeoutException expected) { } - // TODO: No send operations MUST succeed -// assertCompletesExceptionally(IOE, ws.sendText("text!", false)); -// assertCompletesExceptionally(IOE, ws.sendText("text!", true)); -// assertCompletesExceptionally(IOE, ws.sendBinary(ByteBuffer.allocate(16), false)); -// assertCompletesExceptionally(IOE, ws.sendBinary(ByteBuffer.allocate(16), true)); -// assertCompletesExceptionally(IOE, ws.sendPing(ByteBuffer.allocate(16))); -// assertCompletesExceptionally(IOE, ws.sendPong(ByteBuffer.allocate(16))); -// assertCompletesExceptionally(IOE, ws.sendClose(NORMAL_CLOSURE, "a reason")); + } catch (TimeoutException expected) { + System.out.println("Finished waiting"); + } + assertCompletesExceptionally(IOE, ws.sendText("text!", false)); + assertCompletesExceptionally(IOE, ws.sendText("text!", true)); + assertCompletesExceptionally(IOE, ws.sendBinary(ByteBuffer.allocate(16), false)); + assertCompletesExceptionally(IOE, ws.sendBinary(ByteBuffer.allocate(16), true)); + assertCompletesExceptionally(IOE, ws.sendPing(ByteBuffer.allocate(16))); + assertCompletesExceptionally(IOE, ws.sendPong(ByteBuffer.allocate(16))); + // Checked last because it changes the state of WebSocket + assertCompletesExceptionally(IOE, ws.sendClose(NORMAL_CLOSURE, "a reason")); } } @@ -154,15 +160,40 @@ @Override protected void serve(SocketChannel channel) throws IOException { ByteBuffer closeMessage = ByteBuffer.wrap(copy); - int wrote = channel.write(closeMessage); - System.out.println("Wrote bytes: " + wrote); + channel.write(closeMessage); super.serve(channel); } }; } + private static void assertCompletesExceptionally(Class clazz, + CompletableFuture stage) { + stage.handle((result, error) -> { + if (error instanceof CompletionException) { + Throwable cause = error.getCause(); + if (cause == null) { + throw new AssertionError("Unexpected null cause: " + error); + } + assertException(clazz, cause); + } else { + assertException(clazz, error); + } + return null; + }).join(); + } + + private static void assertException(Class clazz, + Throwable t) { + if (t == null) { + throw new AssertionError("Expected " + clazz + ", caught nothing"); + } + if (!clazz.isInstance(t)) { + throw new AssertionError("Expected " + clazz + ", caught " + t); + } + } + @Test - public void testNull() throws IOException { + public void sendMethodsThrowNPE() throws IOException { try (DummyWebSocketServer server = new DummyWebSocketServer()) { server.open(); WebSocket ws = newHttpClient() @@ -177,11 +208,25 @@ assertThrows(NPE, () -> ws.sendPing(null)); assertThrows(NPE, () -> ws.sendPong(null)); assertThrows(NPE, () -> ws.sendClose(NORMAL_CLOSURE, null)); + + ws.abort(); + + assertThrows(NPE, () -> ws.sendText(null, false)); + assertThrows(NPE, () -> ws.sendText(null, true)); + assertThrows(NPE, () -> ws.sendBinary(null, false)); + assertThrows(NPE, () -> ws.sendBinary(null, true)); + assertThrows(NPE, () -> ws.sendPing(null)); + assertThrows(NPE, () -> ws.sendPong(null)); + assertThrows(NPE, () -> ws.sendClose(NORMAL_CLOSURE, null)); } } + // TODO: request in onClose/onError + // TODO: throw exception in onClose/onError + // TODO: exception is thrown from request() + @Test - public void testSendClose1() throws IOException { + public void sendCloseCompleted() throws IOException { try (DummyWebSocketServer server = new DummyWebSocketServer()) { server.open(); WebSocket ws = newHttpClient() @@ -197,7 +242,7 @@ } @Test - public void testSendClose2() throws Exception { + public void sendClosePending() throws Exception { try (DummyWebSocketServer server = notReadingServer()) { server.open(); WebSocket ws = newHttpClient() @@ -205,7 +250,7 @@ .buildAsync(server.getURI(), new WebSocket.Listener() { }) .join(); ByteBuffer data = ByteBuffer.allocate(65536); - for (int i = 0; ; i++) { + for (int i = 0; ; i++) { // fill up the send buffer System.out.println("cycle #" + i); try { ws.sendBinary(data, true).get(10, TimeUnit.SECONDS); @@ -215,12 +260,11 @@ } } CompletableFuture cf = ws.sendClose(NORMAL_CLOSURE, ""); + // The output closes even if the Close message has not been sent + assertFalse(cf.isDone()); assertTrue(ws.isOutputClosed()); assertFalse(ws.isInputClosed()); assertEquals(ws.getSubprotocol(), ""); - // The output closes regardless of whether or not the Close message - // has been sent - assertFalse(cf.isDone()); } } @@ -242,6 +286,78 @@ }; } +// @Test + public void abortPendingSendBinary() throws Exception { + try (DummyWebSocketServer server = notReadingServer()) { + server.open(); + WebSocket ws = newHttpClient() + .newWebSocketBuilder() + .buildAsync(server.getURI(), new WebSocket.Listener() { }) + .join(); + ByteBuffer data = ByteBuffer.allocate(65536); + CompletableFuture cf = null; + for (int i = 0; ; i++) { // fill up the send buffer + System.out.println("cycle #" + i); + try { + cf = ws.sendBinary(data, true); + cf.get(10, TimeUnit.SECONDS); + data.clear(); + } catch (TimeoutException e) { + break; + } + } + ws.abort(); + assertTrue(ws.isOutputClosed()); + assertTrue(ws.isInputClosed()); + assertCompletesExceptionally(IOException.class, cf); + } + } + +// @Test + public void abortPendingSendText() throws Exception { + try (DummyWebSocketServer server = notReadingServer()) { + server.open(); + WebSocket ws = newHttpClient() + .newWebSocketBuilder() + .buildAsync(server.getURI(), new WebSocket.Listener() { }) + .join(); + String data = stringWith2NBytes(32768); + CompletableFuture cf = null; + for (int i = 0; ; i++) { // fill up the send buffer + System.out.println("cycle #" + i); + try { + cf = ws.sendText(data, true); + cf.get(10, TimeUnit.SECONDS); + } catch (TimeoutException e) { + break; + } + } + ws.abort(); + assertTrue(ws.isOutputClosed()); + assertTrue(ws.isInputClosed()); + assertCompletesExceptionally(IOException.class, cf); + } + } + + private static String stringWith2NBytes(int n) { + // -- Russian Alphabet (33 characters, 2 bytes per char) -- + char[] abc = { + 0x0410, 0x0411, 0x0412, 0x0413, 0x0414, 0x0415, 0x0401, 0x0416, + 0x0417, 0x0418, 0x0419, 0x041A, 0x041B, 0x041C, 0x041D, 0x041E, + 0x041F, 0x0420, 0x0421, 0x0422, 0x0423, 0x0424, 0x0425, 0x0426, + 0x0427, 0x0428, 0x0429, 0x042A, 0x042B, 0x042C, 0x042D, 0x042E, + 0x042F, + }; + // repeat cyclically + StringBuilder sb = new StringBuilder(n); + for (int i = 0, j = 0; i < n; i++, j = (j + 1) % abc.length) { + sb.append(abc[j]); + } + String s = sb.toString(); + assert s.length() == n && s.getBytes(StandardCharsets.UTF_8).length == 2 * n; + return s; + } + @Test public void testIllegalArgument() throws IOException { try (DummyWebSocketServer server = new DummyWebSocketServer()) { @@ -263,10 +379,10 @@ assertCompletesExceptionally(IAE, ws.sendPong(ByteBuffer.allocate(129))); assertCompletesExceptionally(IAE, ws.sendPong(ByteBuffer.allocate(256))); - assertCompletesExceptionally(IAE, ws.sendText(incompleteString(), true)); - assertCompletesExceptionally(IAE, ws.sendText(incompleteString(), false)); - assertCompletesExceptionally(IAE, ws.sendText(malformedString(), true)); - assertCompletesExceptionally(IAE, ws.sendText(malformedString(), false)); + assertCompletesExceptionally(IOE, ws.sendText(incompleteString(), true)); + assertCompletesExceptionally(IOE, ws.sendText(incompleteString(), false)); + assertCompletesExceptionally(IOE, ws.sendText(malformedString(), true)); + assertCompletesExceptionally(IOE, ws.sendText(malformedString(), false)); assertCompletesExceptionally(IAE, ws.sendClose(NORMAL_CLOSURE, stringWithNBytes(124))); assertCompletesExceptionally(IAE, ws.sendClose(NORMAL_CLOSURE, stringWithNBytes(125))); @@ -316,58 +432,13 @@ } private static String stringWithNBytes(int n) { - StringBuilder sb = new StringBuilder(n); - for (int i = 0; i < n; i++) { - sb.append("A"); - } - return sb.toString(); - } - - private static String stringWith2NBytes(int n) { - // Russian alphabet repeated cyclically - char FIRST = '\u0410'; - char LAST = '\u042F'; - StringBuilder sb = new StringBuilder(n); - char c = FIRST; - for (int i = 0; i < n; i++) { - if (++c > LAST) { - c = FIRST; - } - sb.append(c); - } - String s = sb.toString(); - assert s.length() == n && s.getBytes(StandardCharsets.UTF_8).length == 2 * n; - return s; - } - - private static void assertCompletesExceptionally(Class clazz, - CompletableFuture stage) { - stage.handle((result, error) -> { - if (error instanceof CompletionException) { - Throwable cause = error.getCause(); - if (cause == null) { - throw new AssertionError("Unexpected null cause: " + error); - } - assertException(clazz, cause); - } else { - assertException(clazz, error); - } - return null; - }).join(); - } - - private static void assertException(Class clazz, - Throwable t) { - if (t == null) { - throw new AssertionError("Expected " + clazz + ", caught nothing"); - } - if (!clazz.isInstance(t)) { - throw new AssertionError("Expected " + clazz + ", caught " + t); - } + char[] chars = new char[n]; + Arrays.fill(chars, 'A'); + return new String(chars); } @Test - public void testIllegalStateOutstanding1() throws Exception { + public void outstanding1() throws Exception { try (DummyWebSocketServer server = notReadingServer()) { server.open(); WebSocket ws = newHttpClient() @@ -376,7 +447,7 @@ .join(); ByteBuffer data = ByteBuffer.allocate(65536); - for (int i = 0; ; i++) { + for (int i = 0; ; i++) { // fill up the send buffer System.out.println("cycle #" + i); try { ws.sendBinary(data, true).get(10, TimeUnit.SECONDS); @@ -391,7 +462,7 @@ } @Test - public void testIllegalStateOutstanding2() throws Exception { + public void outstanding2() throws Exception { try (DummyWebSocketServer server = notReadingServer()) { server.open(); WebSocket ws = newHttpClient() @@ -400,7 +471,7 @@ .join(); CharBuffer data = CharBuffer.allocate(65536); - for (int i = 0; ; i++) { + for (int i = 0; ; i++) { // fill up the send buffer System.out.println("cycle #" + i); try { ws.sendText(data, true).get(10, TimeUnit.SECONDS); @@ -415,7 +486,7 @@ } @Test - public void testIllegalStateIntermixed1() throws IOException { + public void interleavingTypes1() throws IOException { try (DummyWebSocketServer server = new DummyWebSocketServer()) { server.open(); WebSocket ws = newHttpClient() @@ -430,7 +501,7 @@ } @Test - public void testIllegalStateIntermixed2() throws IOException { + public void interleavingTypes2() throws IOException { try (DummyWebSocketServer server = new DummyWebSocketServer()) { server.open(); WebSocket ws = newHttpClient() @@ -445,7 +516,7 @@ } @Test - public void testIllegalStateSendClose() throws IOException { + public void sendMethodsThrowIOE1() throws IOException { try (DummyWebSocketServer server = new DummyWebSocketServer()) { server.open(); WebSocket ws = newHttpClient() @@ -453,31 +524,33 @@ .buildAsync(server.getURI(), new WebSocket.Listener() { }) .join(); - ws.sendClose(NORMAL_CLOSURE, "normal close").join(); + ws.sendClose(NORMAL_CLOSURE, "ok").join(); + + assertCompletesExceptionally(IOE, ws.sendClose(WebSocket.NORMAL_CLOSURE, "ok")); - assertCompletesExceptionally(ISE, ws.sendText("", true)); - assertCompletesExceptionally(ISE, ws.sendText("", false)); - assertCompletesExceptionally(ISE, ws.sendText("abc", true)); - assertCompletesExceptionally(ISE, ws.sendText("abc", false)); - assertCompletesExceptionally(ISE, ws.sendBinary(ByteBuffer.allocate(0), true)); - assertCompletesExceptionally(ISE, ws.sendBinary(ByteBuffer.allocate(0), false)); - assertCompletesExceptionally(ISE, ws.sendBinary(ByteBuffer.allocate(1), true)); - assertCompletesExceptionally(ISE, ws.sendBinary(ByteBuffer.allocate(1), false)); + assertCompletesExceptionally(IOE, ws.sendText("", true)); + assertCompletesExceptionally(IOE, ws.sendText("", false)); + assertCompletesExceptionally(IOE, ws.sendText("abc", true)); + assertCompletesExceptionally(IOE, ws.sendText("abc", false)); + assertCompletesExceptionally(IOE, ws.sendBinary(ByteBuffer.allocate(0), true)); + assertCompletesExceptionally(IOE, ws.sendBinary(ByteBuffer.allocate(0), false)); + assertCompletesExceptionally(IOE, ws.sendBinary(ByteBuffer.allocate(1), true)); + assertCompletesExceptionally(IOE, ws.sendBinary(ByteBuffer.allocate(1), false)); - assertCompletesExceptionally(ISE, ws.sendPing(ByteBuffer.allocate(125))); - assertCompletesExceptionally(ISE, ws.sendPing(ByteBuffer.allocate(124))); - assertCompletesExceptionally(ISE, ws.sendPing(ByteBuffer.allocate(1))); - assertCompletesExceptionally(ISE, ws.sendPing(ByteBuffer.allocate(0))); + assertCompletesExceptionally(IOE, ws.sendPing(ByteBuffer.allocate(125))); + assertCompletesExceptionally(IOE, ws.sendPing(ByteBuffer.allocate(124))); + assertCompletesExceptionally(IOE, ws.sendPing(ByteBuffer.allocate(1))); + assertCompletesExceptionally(IOE, ws.sendPing(ByteBuffer.allocate(0))); - assertCompletesExceptionally(ISE, ws.sendPong(ByteBuffer.allocate(125))); - assertCompletesExceptionally(ISE, ws.sendPong(ByteBuffer.allocate(124))); - assertCompletesExceptionally(ISE, ws.sendPong(ByteBuffer.allocate(1))); - assertCompletesExceptionally(ISE, ws.sendPong(ByteBuffer.allocate(0))); + assertCompletesExceptionally(IOE, ws.sendPong(ByteBuffer.allocate(125))); + assertCompletesExceptionally(IOE, ws.sendPong(ByteBuffer.allocate(124))); + assertCompletesExceptionally(IOE, ws.sendPong(ByteBuffer.allocate(1))); + assertCompletesExceptionally(IOE, ws.sendPong(ByteBuffer.allocate(0))); } } @Test - public void testIllegalStateOnClose() throws Exception { + public void sendMethodsThrowIOE2() throws Exception { try (DummyWebSocketServer server = serverWithCannedData(0x88, 0x00)) { server.open(); CompletableFuture onCloseCalled = new CompletableFuture<>(); @@ -490,7 +563,7 @@ public CompletionStage onClose(WebSocket webSocket, int statusCode, String reason) { - System.out.println("onClose(" + statusCode + ")"); + System.out.printf("onClose(%s, '%s')%n", statusCode, reason); onCloseCalled.complete(null); return canClose; } @@ -498,38 +571,143 @@ @Override public void onError(WebSocket webSocket, Throwable error) { System.out.println("onError(" + error + ")"); - error.printStackTrace(); + onCloseCalled.completeExceptionally(error); } }) .join(); onCloseCalled.join(); // Wait for onClose to be called + canClose.complete(null); // Signal to the WebSocket it can close the output TimeUnit.SECONDS.sleep(5); // Give canClose some time to reach the WebSocket - canClose.complete(null); // Signal to the WebSocket it can close the output + + assertCompletesExceptionally(IOE, ws.sendClose(WebSocket.NORMAL_CLOSURE, "ok")); - assertCompletesExceptionally(ISE, ws.sendText("", true)); - assertCompletesExceptionally(ISE, ws.sendText("", false)); - assertCompletesExceptionally(ISE, ws.sendText("abc", true)); - assertCompletesExceptionally(ISE, ws.sendText("abc", false)); - assertCompletesExceptionally(ISE, ws.sendBinary(ByteBuffer.allocate(0), true)); - assertCompletesExceptionally(ISE, ws.sendBinary(ByteBuffer.allocate(0), false)); - assertCompletesExceptionally(ISE, ws.sendBinary(ByteBuffer.allocate(1), true)); - assertCompletesExceptionally(ISE, ws.sendBinary(ByteBuffer.allocate(1), false)); + assertCompletesExceptionally(IOE, ws.sendText("", true)); + assertCompletesExceptionally(IOE, ws.sendText("", false)); + assertCompletesExceptionally(IOE, ws.sendText("abc", true)); + assertCompletesExceptionally(IOE, ws.sendText("abc", false)); + assertCompletesExceptionally(IOE, ws.sendBinary(ByteBuffer.allocate(0), true)); + assertCompletesExceptionally(IOE, ws.sendBinary(ByteBuffer.allocate(0), false)); + assertCompletesExceptionally(IOE, ws.sendBinary(ByteBuffer.allocate(1), true)); + assertCompletesExceptionally(IOE, ws.sendBinary(ByteBuffer.allocate(1), false)); - assertCompletesExceptionally(ISE, ws.sendPing(ByteBuffer.allocate(125))); - assertCompletesExceptionally(ISE, ws.sendPing(ByteBuffer.allocate(124))); - assertCompletesExceptionally(ISE, ws.sendPing(ByteBuffer.allocate(1))); - assertCompletesExceptionally(ISE, ws.sendPing(ByteBuffer.allocate(0))); + assertCompletesExceptionally(IOE, ws.sendPing(ByteBuffer.allocate(125))); + assertCompletesExceptionally(IOE, ws.sendPing(ByteBuffer.allocate(124))); + assertCompletesExceptionally(IOE, ws.sendPing(ByteBuffer.allocate(1))); + assertCompletesExceptionally(IOE, ws.sendPing(ByteBuffer.allocate(0))); - assertCompletesExceptionally(ISE, ws.sendPong(ByteBuffer.allocate(125))); - assertCompletesExceptionally(ISE, ws.sendPong(ByteBuffer.allocate(124))); - assertCompletesExceptionally(ISE, ws.sendPong(ByteBuffer.allocate(1))); - assertCompletesExceptionally(ISE, ws.sendPong(ByteBuffer.allocate(0))); + assertCompletesExceptionally(IOE, ws.sendPong(ByteBuffer.allocate(125))); + assertCompletesExceptionally(IOE, ws.sendPong(ByteBuffer.allocate(124))); + assertCompletesExceptionally(IOE, ws.sendPong(ByteBuffer.allocate(1))); + assertCompletesExceptionally(IOE, ws.sendPong(ByteBuffer.allocate(0))); } } @Test - public void simpleAggregatingMessages() throws IOException { + public void simpleAggregatingBinaryMessages() throws IOException { + List expected = List.of("alpha", "beta", "gamma", "delta") + .stream() + .map(s -> s.getBytes(StandardCharsets.US_ASCII)) + .collect(Collectors.toList()); + int[] binary = new int[]{ + 0x82, 0x05, 0x61, 0x6c, 0x70, 0x68, 0x61, // [alpha] + 0x02, 0x02, 0x62, 0x65, // [be + 0x80, 0x02, 0x74, 0x61, // ta] + 0x02, 0x01, 0x67, // [g + 0x00, 0x01, 0x61, // a + 0x00, 0x00, // + 0x00, 0x00, // + 0x00, 0x01, 0x6d, // m + 0x00, 0x01, 0x6d, // m + 0x80, 0x01, 0x61, // a] + 0x8a, 0x00, // + 0x02, 0x04, 0x64, 0x65, 0x6c, 0x74, // [delt + 0x00, 0x01, 0x61, // a + 0x80, 0x00, // ] + 0x88, 0x00 // + }; + CompletableFuture> actual = new CompletableFuture<>(); + + try (DummyWebSocketServer server = serverWithCannedData(binary)) { + server.open(); + + WebSocket.Listener listener = new WebSocket.Listener() { + + List collectedBytes = new ArrayList<>(); + ByteBuffer binary; + + @Override + public CompletionStage onBinary(WebSocket webSocket, + ByteBuffer message, + WebSocket.MessagePart part) { + System.out.printf("onBinary(%s, %s)%n", message, part); + webSocket.request(1); + byte[] bytes = null; + switch (part) { + case FIRST: + binary = ByteBuffer.allocate(message.remaining() * 2); + case PART: + append(message); + return null; + case LAST: + append(message); + binary.flip(); + bytes = new byte[binary.remaining()]; + binary.get(bytes); + binary.clear(); + break; + case WHOLE: + bytes = new byte[message.remaining()]; + message.get(bytes); + break; + } + processWholeBinary(bytes); + return null; + } + + private void append(ByteBuffer message) { + if (binary.remaining() < message.remaining()) { + assert message.remaining() > 0; + int cap = (binary.capacity() + message.remaining()) * 2; + ByteBuffer b = ByteBuffer.allocate(cap); + b.put(binary.flip()); + binary = b; + } + binary.put(message); + } + + private void processWholeBinary(byte[] bytes) { + String stringBytes = new String(bytes, StandardCharsets.UTF_8); + System.out.println("processWholeBinary: " + stringBytes); + collectedBytes.add(bytes); + } + + @Override + public CompletionStage onClose(WebSocket webSocket, + int statusCode, + String reason) { + actual.complete(collectedBytes); + return null; + } + + @Override + public void onError(WebSocket webSocket, Throwable error) { + actual.completeExceptionally(error); + } + }; + + newHttpClient().newWebSocketBuilder() + .buildAsync(server.getURI(), listener) + .join(); + + List a = actual.join(); + System.out.println("joined"); + assertEquals(a, expected); + } + } + + @Test + public void simpleAggregatingTextMessages() throws IOException { List expected = List.of("alpha", "beta", "gamma", "delta"); @@ -557,24 +735,25 @@ WebSocket.Listener listener = new WebSocket.Listener() { - List collected = new ArrayList<>(); - StringBuilder text = new StringBuilder(); + List collectedStrings = new ArrayList<>(); + StringBuilder text; @Override public CompletionStage onText(WebSocket webSocket, CharSequence message, WebSocket.MessagePart part) { + System.out.printf("onText(%s, %s)%n", message, part); webSocket.request(1); String str = null; switch (part) { case FIRST: + text = new StringBuilder(message.length() * 2); case PART: text.append(message); return null; case LAST: text.append(message); str = text.toString(); - text.setLength(0); break; case WHOLE: str = message.toString(); @@ -586,30 +765,38 @@ private void processWholeText(String string) { System.out.println(string); - // -- your code here -- - collected.add(string); + collectedStrings.add(string); } @Override public CompletionStage onClose(WebSocket webSocket, int statusCode, String reason) { - actual.complete(collected); + actual.complete(collectedStrings); return null; } + + @Override + public void onError(WebSocket webSocket, Throwable error) { + actual.completeExceptionally(error); + } }; newHttpClient().newWebSocketBuilder() - .buildAsync(server.getURI(), listener) - .join(); + .buildAsync(server.getURI(), listener) + .join(); List a = actual.join(); assertEquals(a, expected); } } + /* + * Exercises the scenario where requests for more messages are made prior to + * completing the returned CompletionStage instances. + */ @Test - public void aggregatingMessages() throws IOException { + public void aggregatingTextMessages() throws IOException { List expected = List.of("alpha", "beta", "gamma", "delta"); @@ -638,7 +825,12 @@ WebSocket.Listener listener = new WebSocket.Listener() { - List parts = new ArrayList<>(); + List parts; + /* + * A CompletableFuture which will complete once the current + * message has been fully assembled (LAST/WHOLE). Until then + * the listener returns this instance for every call. + */ CompletableFuture currentCf; List collected = new ArrayList<>(); @@ -653,6 +845,7 @@ processWholeMessage(List.of(message), cf); return cf; case FIRST: + parts = new ArrayList<>(); parts.add(message); currentCf = new CompletableFuture<>(); currentCf.thenRun(() -> webSocket.request(1)); @@ -664,12 +857,11 @@ break; case LAST: parts.add(message); - List copy = List.copyOf(parts); - parts.clear(); - CompletableFuture cf1 = currentCf; + CompletableFuture copyCf = this.currentCf; + processWholeMessage(parts, copyCf); currentCf = null; - processWholeMessage(copy, cf1); - return cf1; + parts = null; + return copyCf; } return currentCf; } @@ -682,6 +874,11 @@ return null; } + @Override + public void onError(WebSocket webSocket, Throwable error) { + actual.completeExceptionally(error); + } + public void processWholeMessage(List data, CompletableFuture cf) { StringBuilder b = new StringBuilder(); @@ -692,10 +889,10 @@ collected.add(s); } }; - WebSocket ws = newHttpClient() - .newWebSocketBuilder() - .buildAsync(server.getURI(), listener) - .join(); + + newHttpClient().newWebSocketBuilder() + .buildAsync(server.getURI(), listener) + .join(); List a = actual.join(); assertEquals(a, expected); diff -r d818a6a8295a -r 4933a477d628 test/jdk/java/net/httpclient/websocket/WebSocketTextTest.java --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/test/jdk/java/net/httpclient/websocket/WebSocketTextTest.java Wed Mar 07 17:16:28 2018 +0000 @@ -0,0 +1,318 @@ +/* + * Copyright (c) 2018, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ + +import org.testng.annotations.DataProvider; +import org.testng.annotations.Test; + +import java.io.IOException; +import java.net.http.WebSocket; +import java.nio.ByteBuffer; +import java.nio.charset.StandardCharsets; +import java.util.ArrayList; +import java.util.List; +import java.util.Random; + +import static java.net.http.HttpClient.newHttpClient; +import static org.testng.Assert.assertEquals; +import static org.testng.Assert.assertTrue; + +/* + * @test + * @bug 8159053 + * + * + * @run testng/othervm + * -Djdk.httpclient.websocket.writeBufferSize=1024 + * -Djdk.httpclient.websocket.intermediateBufferSize=2048 WebSocketTextTest + */ +public class WebSocketTextTest { + + private final static Random random; + static { + long seed = System.currentTimeMillis(); + System.out.println("seed=" + seed); + random = new Random(seed); + } + +// * @run testng/othervm +// * -Djdk.httpclient.websocket.writeBufferSize=16 +// * -Djdk.httpclient.sendBufferSize=32 WebSocketTextTest + + + + // FIXME ensure subsequent (sendText/Binary, false) only CONTINUATIONs + + @Test(dataProvider = "binary") + public void binary(ByteBuffer expected) throws IOException, InterruptedException { + try (DummyWebSocketServer server = new DummyWebSocketServer()) { + server.open(); + WebSocket ws = newHttpClient() + .newWebSocketBuilder() + .buildAsync(server.getURI(), new WebSocket.Listener() { }) + .join(); + ws.sendBinary(expected.duplicate(), true).join(); + ws.abort(); + ByteBuffer data = server.read(); + List frames = readFrames(data); + assertEquals(frames.size(), 1); + Frame f = frames.get(0); + assertTrue(f.last); + assertEquals(f.opcode, Frame.Opcode.BINARY); + assertEquals(f.data, expected); + } + } + + private static List readFrames(ByteBuffer src) { + List frames = new ArrayList<>(); + Frame.Consumer consumer = new Frame.Consumer() { + + ByteBuffer data; + Frame.Opcode opcode; + Frame.Masker masker = new Frame.Masker(); + boolean last; + + @Override + public void fin(boolean value) { + last = value; + } + + @Override + public void rsv1(boolean value) { + if (value) { + throw new AssertionError(); + } + } + + @Override + public void rsv2(boolean value) { + if (value) { + throw new AssertionError(); + } + } + + @Override + public void rsv3(boolean value) { + if (value) { + throw new AssertionError(); + } + } + + @Override + public void opcode(Frame.Opcode value) { + opcode = value; + } + + @Override + public void mask(boolean value) { + if (!value) { // Frames from the client MUST be masked + throw new AssertionError(); + } + } + + @Override + public void payloadLen(long value) { + data = ByteBuffer.allocate((int) value); + } + + @Override + public void maskingKey(int value) { + masker.mask(value); + } + + @Override + public void payloadData(ByteBuffer data) { + masker.transferMasking(data, this.data); + } + + @Override + public void endFrame() { + frames.add(new Frame(opcode, this.data.flip(), last)); + } + }; + + Frame.Reader r = new Frame.Reader(); + while (src.hasRemaining()) { + r.readFrame(src, consumer); + } + return frames; + } + + @Test(dataProvider = "pingPong") + public void ping(ByteBuffer expected) throws Exception { + try (DummyWebSocketServer server = new DummyWebSocketServer()) { + server.open(); + WebSocket ws = newHttpClient() + .newWebSocketBuilder() + .buildAsync(server.getURI(), new WebSocket.Listener() { }) + .join(); + ws.sendPing(expected.duplicate()).join(); + ws.abort(); + ByteBuffer data = server.read(); + List frames = readFrames(data); + assertEquals(frames.size(), 1); + Frame f = frames.get(0); + assertEquals(f.opcode, Frame.Opcode.PING); + ByteBuffer actual = ByteBuffer.allocate(expected.remaining()); + actual.put(f.data); + actual.flip(); + assertEquals(actual, expected); + } + } + + @Test(dataProvider = "pingPong") + public void pong(ByteBuffer expected) throws Exception { + try (DummyWebSocketServer server = new DummyWebSocketServer()) { + server.open(); + WebSocket ws = newHttpClient() + .newWebSocketBuilder() + .buildAsync(server.getURI(), new WebSocket.Listener() { }) + .join(); + ws.sendPong(expected.duplicate()).join(); + ws.abort(); + ByteBuffer data = server.read(); + List frames = readFrames(data); + assertEquals(frames.size(), 1); + Frame f = frames.get(0); + assertEquals(f.opcode, Frame.Opcode.PONG); + ByteBuffer actual = ByteBuffer.allocate(expected.remaining()); + actual.put(f.data); + actual.flip(); + assertEquals(actual, expected); + } + } + + @Test(dataProvider = "close") + public void close(int statusCode, String reason) throws Exception { + try (DummyWebSocketServer server = new DummyWebSocketServer()) { + server.open(); + WebSocket ws = newHttpClient() + .newWebSocketBuilder() + .buildAsync(server.getURI(), new WebSocket.Listener() { }) + .join(); + ws.sendClose(statusCode, reason).join(); + ws.abort(); + ByteBuffer data = server.read(); + List frames = readFrames(data); + assertEquals(frames.size(), 1); + Frame f = frames.get(0); + assertEquals(f.opcode, Frame.Opcode.CLOSE); + ByteBuffer actual = ByteBuffer.allocate(Frame.MAX_CONTROL_FRAME_PAYLOAD_SIZE); + actual.put(f.data); + actual.flip(); + assertEquals(actual.getChar(), statusCode); + assertEquals(StandardCharsets.UTF_8.decode(actual).toString(), reason); + } + } + + @Test(dataProvider = "text") + public void text(String expected) throws Exception { + try (DummyWebSocketServer server = new DummyWebSocketServer()) { + server.open(); + WebSocket ws = newHttpClient() + .newWebSocketBuilder() + .buildAsync(server.getURI(), new WebSocket.Listener() { }) + .join(); + ws.sendText(expected, true).join(); + ws.abort(); + ByteBuffer data = server.read(); + List frames = readFrames(data); + + int maxBytes = (int) StandardCharsets.UTF_8.newEncoder().maxBytesPerChar() * expected.length(); + ByteBuffer actual = ByteBuffer.allocate(maxBytes); + frames.stream().forEachOrdered(f -> actual.put(f.data)); + actual.flip(); + assertEquals(StandardCharsets.UTF_8.decode(actual).toString(), expected); + } + } + + @DataProvider(name = "pingPong") + public Object[][] pingPongSizes() { + return new Object[][]{ + {bytes( 0)}, + {bytes( 1)}, + {bytes( 63)}, + {bytes(125)}, + }; + } + + @DataProvider(name = "close") + public Object[][] closeArguments() { + return new Object[][]{ + {WebSocket.NORMAL_CLOSURE, utf8String( 0)}, + {WebSocket.NORMAL_CLOSURE, utf8String( 1)}, + // 123 / 3 = max reason bytes / max bytes per char + {WebSocket.NORMAL_CLOSURE, utf8String(41)}, + }; + } + + private static String utf8String(int n) { + char[] abc = { + // -- English Alphabet (26 characters, 1 byte per char) -- + 0x0041, 0x0042, 0x0043, 0x0044, 0x0045, 0x0046, 0x0047, 0x0048, + 0x0049, 0x004A, 0x004B, 0x004C, 0x004D, 0x004E, 0x004F, 0x0050, + 0x0051, 0x0052, 0x0053, 0x0054, 0x0055, 0x0056, 0x0057, 0x0058, + 0x0059, 0x005A, + // -- Russian Alphabet (33 characters, 2 bytes per char) -- + 0x0410, 0x0411, 0x0412, 0x0413, 0x0414, 0x0415, 0x0401, 0x0416, + 0x0417, 0x0418, 0x0419, 0x041A, 0x041B, 0x041C, 0x041D, 0x041E, + 0x041F, 0x0420, 0x0421, 0x0422, 0x0423, 0x0424, 0x0425, 0x0426, + 0x0427, 0x0428, 0x0429, 0x042A, 0x042B, 0x042C, 0x042D, 0x042E, + 0x042F, + // -- Hiragana base characters (46 characters, 3 bytes per char) -- + 0x3042, 0x3044, 0x3046, 0x3048, 0x304A, 0x304B, 0x304D, 0x304F, + 0x3051, 0x3053, 0x3055, 0x3057, 0x3059, 0x305B, 0x305D, 0x305F, + 0x3061, 0x3064, 0x3066, 0x3068, 0x306A, 0x306B, 0x306C, 0x306D, + 0x306E, 0x306F, 0x3072, 0x3075, 0x3078, 0x307B, 0x307E, 0x307F, + 0x3080, 0x3081, 0x3082, 0x3084, 0x3086, 0x3088, 0x3089, 0x308A, + 0x308B, 0x308C, 0x308D, 0x308F, 0x3092, 0x3093, + }; + + assert new String(abc).getBytes(StandardCharsets.UTF_8).length > abc.length; + + StringBuilder str = new StringBuilder(n); + random.ints(0, abc.length).limit(n).forEach(i -> str.append(abc[i])); + return str.toString(); + } + + @DataProvider(name = "text") + public Object[][] texts() { + return new Object[][]{ + {utf8String( 0)}, + {utf8String(1024)}, + }; + } + + @DataProvider(name = "binary") + public Object[][] binary() { + return new Object[][]{ + {bytes( 0)}, + {bytes(1024)}, + }; + } + + private static ByteBuffer bytes(int n) { + byte[] array = new byte[n]; + random.nextBytes(array); + return ByteBuffer.wrap(array); + } +} diff -r d818a6a8295a -r 4933a477d628 test/jdk/java/net/httpclient/websocket/java.net.http/jdk/internal/net/http/websocket/MessageQueueTest.java --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/test/jdk/java/net/httpclient/websocket/java.net.http/jdk/internal/net/http/websocket/MessageQueueTest.java Wed Mar 07 17:16:28 2018 +0000 @@ -0,0 +1,449 @@ +/* + * Copyright (c) 2018, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ + +package jdk.internal.net.http.websocket; + +import org.testng.annotations.DataProvider; +import org.testng.annotations.Test; + +import java.io.IOException; +import java.nio.ByteBuffer; +import java.nio.CharBuffer; +import java.security.SecureRandom; +import java.util.ArrayList; +import java.util.LinkedList; +import java.util.List; +import java.util.Objects; +import java.util.Queue; +import java.util.Random; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.CyclicBarrier; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.Future; +import java.util.function.BiConsumer; + +import static org.testng.Assert.assertEquals; +import static org.testng.Assert.assertFalse; +import static org.testng.Assert.assertThrows; +import static org.testng.Assert.assertTrue; + +public class MessageQueueTest { + + private static final Random r = new SecureRandom(); + + @DataProvider(name = "illegalCapacities") + public static Object[][] illegalCapacities() { + return new Object[][]{ + new Object[]{Integer.MIN_VALUE}, + new Object[]{-2}, + new Object[]{-1}, + new Object[]{ 0}, + }; + } + + @Test(dataProvider = "illegalCapacities") + public void illegalCapacity(int n) { + assertThrows(IllegalArgumentException.class, () -> new MessageQueue(n)); + } + + @Test(dataProvider = "capacities") + public void emptiness(int n) { + assertTrue(new MessageQueue(n).isEmpty()); + } + + @Test(dataProvider = "capacities") + public void fullness(int n) throws IOException { + MessageQueue q = new MessageQueue(n); + Adder adder = new Adder(); + Queue referenceQueue = new LinkedList<>(); + for (int i = 0; i < n; i++) { + Message m = createRandomMessage(); + referenceQueue.add(m); + adder.add(q, m); + } + for (int i = 0; i < n + 1; i++) { + Message m = createRandomMessage(); + assertThrows(IOException.class, () -> adder.add(q, m)); + } + for (int i = 0; i < n; i++) { + Message expected = referenceQueue.remove(); + Message actual = new Remover().removeFrom(q); + assertEquals(actual, expected); + } + } + + private Message createRandomMessage() { + Message.Type[] values = Message.Type.values(); + Message.Type type = values[r.nextInt(values.length)]; + ByteBuffer binary = null; + CharBuffer text = null; + boolean isLast = false; + int statusCode = -1; + switch (type) { + case TEXT: + text = CharBuffer.allocate(r.nextInt(17)); + isLast = r.nextBoolean(); + break; + case BINARY: + binary = ByteBuffer.allocate(r.nextInt(19)); + isLast = r.nextBoolean(); + break; + case PING: + binary = ByteBuffer.allocate(r.nextInt(19)); + break; + case PONG: + binary = ByteBuffer.allocate(r.nextInt(19)); + break; + case CLOSE: + text = CharBuffer.allocate(r.nextInt(17)); + statusCode = r.nextInt(); + break; + default: + throw new AssertionError(); + } + BiConsumer action = new BiConsumer<>() { + @Override + public void accept(Integer o, Throwable throwable) { } + }; + CompletableFuture future = new CompletableFuture<>(); + return new Message(type, binary, text, isLast, statusCode, r.nextInt(), + action, future); + } + + @Test(dataProvider = "capacities") + public void caterpillarWalk(int n) throws IOException { +// System.out.println("n: " + n); + for (int p = 1; p <= n; p++) { // pace +// System.out.println(" pace: " + p); + MessageQueue q = new MessageQueue(n); + Queue referenceQueue = new LinkedList<>(); + Adder adder = new Adder(); + for (int k = 0; k < (n / p) + 1; k++) { +// System.out.println(" cycle: " + k); + for (int i = 0; i < p; i++) { + Message m = createRandomMessage(); + referenceQueue.add(m); + adder.add(q, m); + } + Remover remover = new Remover(); + for (int i = 0; i < p; i++) { + Message expected = referenceQueue.remove(); + Message actual = remover.removeFrom(q); + assertEquals(actual, expected); + } + assertTrue(q.isEmpty()); + } + } + } + + /* Exercises only concurrent additions */ + @Test + public void halfConcurrency() throws ExecutionException, InterruptedException { + int n = Runtime.getRuntime().availableProcessors() + 2; + ExecutorService executorService = Executors.newFixedThreadPool(n); + CyclicBarrier start = new CyclicBarrier(n); + Adder adder = new Adder(); + List> futures = new ArrayList<>(n); + try { + for (int k = 0; k < 1024; k++) { + MessageQueue q = new MessageQueue(n); + for (int i = 0; i < n; i++) { + Message m = createRandomMessage(); + Future f = executorService.submit(() -> { + start.await(); + adder.add(q, m); + return null; + }); + futures.add(f); + } + for (Future f : futures) { + f.get(); // Just to check for exceptions + } + futures.clear(); + // Make sure the queue is full + assertThrows(IOException.class, + () -> adder.add(q, createRandomMessage())); + } + } finally { + executorService.shutdownNow(); + } + } + + // TODO: same message; different messages; a mix thereof + + @Test + public void concurrency() throws ExecutionException, InterruptedException { + int nProducers = Runtime.getRuntime().availableProcessors() + 2; + int nThreads = nProducers + 1; + ExecutorService executorService = Executors.newFixedThreadPool(nThreads); + CyclicBarrier start = new CyclicBarrier(nThreads); + MessageQueue q = new MessageQueue(nProducers); + Adder adder = new Adder(); + Remover remover = new Remover(); + List expectedList = new ArrayList<>(nProducers); + List actualList = new ArrayList<>(nProducers); + List> futures = new ArrayList<>(nProducers); + try { + for (int k = 0; k < 1024; k++) { + for (int i = 0; i < nProducers; i++) { + Message m = createRandomMessage(); + expectedList.add(m); + Future f = executorService.submit(() -> { + start.await(); + adder.add(q, m); + return null; + }); + futures.add(f); + } + Future consumer = executorService.submit(() -> { + int i = 0; + start.await(); + while (i < nProducers) { + Message m = remover.removeFrom(q); + if (m != null) { + actualList.add(m); + i++; + } + } + return null; + }); + for (Future f : futures) { + f.get(); // Just to check for exceptions + } + consumer.get(); // Waiting for consumer to collect all the messages + assertEquals(actualList.size(), expectedList.size()); + for (Message m : expectedList) { + assertTrue(actualList.remove(m)); + } + assertTrue(actualList.isEmpty()); + assertTrue(q.isEmpty()); + expectedList.clear(); + futures.clear(); + } + } finally { + executorService.shutdownNow(); + } + } + + @Test(dataProvider = "capacities") + public void testSingleThreaded(int n) throws IOException { + Queue referenceQueue = new LinkedList<>(); + MessageQueue q = new MessageQueue(n); + Adder adder = new Adder(); + for (int i = 0; i < n; i++) { + Message m = createRandomMessage(); + referenceQueue.add(m); + adder.add(q, m); + } + for (int i = 0; i < n; i++) { + Message expected = referenceQueue.remove(); + Message actual = new Remover().removeFrom(q); + assertEquals(actual, expected); + } + assertTrue(q.isEmpty()); + } + + @DataProvider(name = "capacities") + public Object[][] capacities() { + return new Object[][]{ + new Object[]{ 1}, + new Object[]{ 2}, + new Object[]{ 3}, + new Object[]{ 4}, + new Object[]{ 5}, + new Object[]{ 6}, + new Object[]{ 7}, + new Object[]{ 8}, + new Object[]{ 9}, + new Object[]{128}, + new Object[]{256}, + }; + } + + // -- auxiliary test infrastructure -- + + static class Adder { + + @SuppressWarnings("unchecked") + void add(MessageQueue q, Message m) throws IOException { + switch (m.type) { + case TEXT: + q.addText(m.text, m.isLast, m.attachment, m.action, m.future); + break; + case BINARY: + q.addBinary(m.binary, m.isLast, m.attachment, m.action, m.future); + break; + case PING: + q.addPing(m.binary, m.attachment, m.action, m.future); + break; + case PONG: + q.addPong(m.binary, m.attachment, m.action, m.future); + break; + case CLOSE: + q.addClose(m.statusCode, m.text, m.attachment, m.action, m.future); + break; + default: + throw new InternalError(); + } + } + } + + static class Remover { + + Message removeFrom(MessageQueue q) { + Message m = q.peek(new MessageQueue.QueueCallback<>() { + + boolean called; + + @Override + public Message onText(CharBuffer message, + boolean isLast, + T attachment, + BiConsumer action, + CompletableFuture future) { + assertFalse(called); + called = true; + return new Message(Message.Type.TEXT, null, message, isLast, + -1, attachment, action, future); + } + + @Override + public Message onBinary(ByteBuffer message, + boolean isLast, + T attachment, + BiConsumer action, + CompletableFuture future) { + assertFalse(called); + called = true; + return new Message(Message.Type.BINARY, message, null, isLast, + -1, attachment, action, future); + } + + @Override + public Message onPing(ByteBuffer message, + T attachment, + BiConsumer action, + CompletableFuture future) { + assertFalse(called); + called = true; + return new Message(Message.Type.PING, message, null, false, + -1, attachment, action, future); + } + + @Override + public Message onPong(ByteBuffer message, + T attachment, + BiConsumer action, + CompletableFuture future) { + assertFalse(called); + called = true; + return new Message(Message.Type.PONG, message, null, false, + -1, attachment, action, future); + } + + @Override + public Message onClose(int statusCode, + CharBuffer reason, + T attachment, + BiConsumer action, + CompletableFuture future) { + assertFalse(called); + called = true; + return new Message(Message.Type.CLOSE, null, reason, false, + statusCode, attachment, action, future); + } + + @Override + public Message onEmpty() throws RuntimeException { + return null; + } + }); + if (m != null) { + q.remove(); + } + return m; + } + } + + static class Message { + + private final Type type; + private final ByteBuffer binary; + private final CharBuffer text; + private final boolean isLast; + private final int statusCode; + private final Object attachment; + @SuppressWarnings("rawtypes") + private final BiConsumer action; + @SuppressWarnings("rawtypes") + private final CompletableFuture future; + + Message(Type type, + ByteBuffer binary, + CharBuffer text, + boolean isLast, + int statusCode, + T attachment, + BiConsumer action, + CompletableFuture future) { + this.type = type; + this.binary = binary; + this.text = text; + this.isLast = isLast; + this.statusCode = statusCode; + this.attachment = attachment; + this.action = action; + this.future = future; + } + + @Override + public int hashCode() { + return Objects.hash(type, binary, text, isLast, statusCode, attachment, action, future); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + Message message = (Message) o; + return isLast == message.isLast && + statusCode == message.statusCode && + type == message.type && + Objects.equals(binary, message.binary) && + Objects.equals(text, message.text) && + Objects.equals(attachment, message.attachment) && + Objects.equals(action, message.action) && + Objects.equals(future, message.future); + } + + enum Type { + TEXT, + BINARY, + PING, + PONG, + CLOSE + } + } +} diff -r d818a6a8295a -r 4933a477d628 test/jdk/java/net/httpclient/websocket/java.net.http/jdk/internal/net/http/websocket/WebSocketImplTest.java --- a/test/jdk/java/net/httpclient/websocket/java.net.http/jdk/internal/net/http/websocket/WebSocketImplTest.java Wed Mar 07 15:39:25 2018 +0000 +++ /dev/null Thu Jan 01 00:00:00 1970 +0000 @@ -1,453 +0,0 @@ -/* - * Copyright (c) 2017, Oracle and/or its affiliates. All rights reserved. - * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. - * - * This code is free software; you can redistribute it and/or modify it - * under the terms of the GNU General Public License version 2 only, as - * published by the Free Software Foundation. - * - * This code is distributed in the hope that it will be useful, but WITHOUT - * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or - * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License - * version 2 for more details (a copy is included in the LICENSE file that - * accompanied this code). - * - * You should have received a copy of the GNU General Public License version - * 2 along with this work; if not, write to the Free Software Foundation, - * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. - * - * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA - * or visit www.oracle.com if you need additional information or have any - * questions. - */ - -package jdk.internal.net.http.websocket; - -import java.net.http.WebSocket; -import org.testng.annotations.Test; - -import java.net.URI; -import java.nio.ByteBuffer; -import java.util.Collection; -import java.util.List; -import java.util.Random; -import java.util.concurrent.CompletableFuture; -import java.util.concurrent.TimeUnit; -import java.util.concurrent.atomic.AtomicInteger; -import java.util.function.Consumer; -import java.util.function.Supplier; - -import static java.net.http.WebSocket.MessagePart.FIRST; -import static java.net.http.WebSocket.MessagePart.LAST; -import static java.net.http.WebSocket.MessagePart.PART; -import static java.net.http.WebSocket.MessagePart.WHOLE; -import static java.net.http.WebSocket.NORMAL_CLOSURE; -import static jdk.internal.net.http.websocket.MockListener.Invocation.onClose; -import static jdk.internal.net.http.websocket.MockListener.Invocation.onError; -import static jdk.internal.net.http.websocket.MockListener.Invocation.onOpen; -import static jdk.internal.net.http.websocket.MockListener.Invocation.onPing; -import static jdk.internal.net.http.websocket.MockListener.Invocation.onPong; -import static jdk.internal.net.http.websocket.MockListener.Invocation.onText; -import static jdk.internal.net.http.websocket.MockTransport.onClose; -import static jdk.internal.net.http.websocket.MockTransport.onPing; -import static jdk.internal.net.http.websocket.MockTransport.onPong; -import static jdk.internal.net.http.websocket.MockTransport.onText; -import static jdk.internal.net.http.websocket.TestSupport.assertCompletesExceptionally; -import static org.testng.Assert.assertEquals; - -/* - * Formatting in this file may seem strange: - * - * ( - * ( ...) - * ... - * ) - * ... - * - * However there is a rationale behind it. Sometimes the level of argument - * nesting is high, which makes it hard to manage parentheses. - */ -public class WebSocketImplTest { - - // TODO: request in onClose/onError - // TODO: throw exception in onClose/onError - // TODO: exception is thrown from request() - // TODO: repeated sendClose complete normally - // TODO: default Close message is sent if IAE is thrown from sendClose - - @Test - public void testNonPositiveRequest() throws Exception { - MockListener listener = new MockListener(Long.MAX_VALUE) { - @Override - protected void onOpen0(WebSocket webSocket) { - webSocket.request(0); - } - }; - WebSocket ws = newInstance(listener, List.of(now(onText("1", WHOLE)))); - listener.onCloseOrOnErrorCalled().get(10, TimeUnit.SECONDS); - List invocations = listener.invocations(); - assertEquals( - invocations, - List.of( - onOpen(ws), - onError(ws, IllegalArgumentException.class) - ) - ); - } - - @Test - public void testText1() throws Exception { - MockListener listener = new MockListener(Long.MAX_VALUE); - WebSocket ws = newInstance( - listener, - List.of( - now(onText("1", FIRST)), - now(onText("2", PART)), - now(onText("3", LAST)), - now(onClose(NORMAL_CLOSURE, "no reason")) - ) - ); - listener.onCloseOrOnErrorCalled().get(10, TimeUnit.SECONDS); - List invocations = listener.invocations(); - assertEquals( - invocations, - List.of( - onOpen(ws), - onText(ws, "1", FIRST), - onText(ws, "2", PART), - onText(ws, "3", LAST), - onClose(ws, NORMAL_CLOSURE, "no reason") - ) - ); - } - - @Test - public void testText2() throws Exception { - MockListener listener = new MockListener(Long.MAX_VALUE); - WebSocket ws = newInstance( - listener, - List.of( - now(onText("1", FIRST)), - seconds(1, onText("2", PART)), - now(onText("3", LAST)), - seconds(1, onClose(NORMAL_CLOSURE, "no reason")) - ) - ); - listener.onCloseOrOnErrorCalled().get(10, TimeUnit.SECONDS); - List invocations = listener.invocations(); - assertEquals( - invocations, - List.of( - onOpen(ws), - onText(ws, "1", FIRST), - onText(ws, "2", PART), - onText(ws, "3", LAST), - onClose(ws, NORMAL_CLOSURE, "no reason") - ) - ); - } - - @Test - public void testTextIntermixedWithPongs() throws Exception { - MockListener listener = new MockListener(Long.MAX_VALUE); - WebSocket ws = newInstance( - listener, - List.of( - now(onText("1", FIRST)), - now(onText("2", PART)), - now(onPong(ByteBuffer.allocate(16))), - seconds(1, onPong(ByteBuffer.allocate(32))), - now(onText("3", LAST)), - now(onPong(ByteBuffer.allocate(64))), - now(onClose(NORMAL_CLOSURE, "no reason")) - ) - ); - listener.onCloseOrOnErrorCalled().get(10, TimeUnit.SECONDS); - List invocations = listener.invocations(); - assertEquals( - invocations, - List.of( - onOpen(ws), - onText(ws, "1", FIRST), - onText(ws, "2", PART), - onPong(ws, ByteBuffer.allocate(16)), - onPong(ws, ByteBuffer.allocate(32)), - onText(ws, "3", LAST), - onPong(ws, ByteBuffer.allocate(64)), - onClose(ws, NORMAL_CLOSURE, "no reason") - ) - ); - } - - @Test - public void testTextIntermixedWithPings() throws Exception { - MockListener listener = new MockListener(Long.MAX_VALUE); - WebSocket ws = newInstance( - listener, - List.of( - now(onText("1", FIRST)), - now(onText("2", PART)), - now(onPing(ByteBuffer.allocate(16))), - seconds(1, onPing(ByteBuffer.allocate(32))), - now(onText("3", LAST)), - now(onPing(ByteBuffer.allocate(64))), - now(onClose(NORMAL_CLOSURE, "no reason")) - ) - ); - listener.onCloseOrOnErrorCalled().get(10, TimeUnit.SECONDS); - List invocations = listener.invocations(); - assertEquals( - invocations, - List.of( - onOpen(ws), - onText(ws, "1", FIRST), - onText(ws, "2", PART), - onPing(ws, ByteBuffer.allocate(16)), - onPing(ws, ByteBuffer.allocate(32)), - onText(ws, "3", LAST), - onPing(ws, ByteBuffer.allocate(64)), - onClose(ws, NORMAL_CLOSURE, "no reason")) - ); - } - - // Tease out "java.lang.IllegalStateException: Send pending" due to possible - // race between sending a message and replenishing the permit - @Test - public void testManyTextMessages() { - WebSocketImpl ws = newInstance( - new MockListener(1), - new TransportFactory() { - @Override - public Transport createTransport(Supplier sendResultSupplier, - MessageStreamConsumer consumer) { - - final Random r = new Random(); - - return new MockTransport<>(sendResultSupplier, consumer) { - @Override - protected CompletableFuture defaultSend() { - return millis(r.nextInt(100), result()); - } - }; - } - }); - int NUM_MESSAGES = 512; - CompletableFuture current = CompletableFuture.completedFuture(ws); - for (int i = 0; i < NUM_MESSAGES; i++) { - current = current.thenCompose(w -> w.sendText(" ", true)); - } - current.join(); - MockTransport transport = (MockTransport) ws.transport(); - assertEquals(transport.invocations().size(), NUM_MESSAGES); - } - - @Test - public void testManyBinaryMessages() { - WebSocketImpl ws = newInstance( - new MockListener(1), - new TransportFactory() { - @Override - public Transport createTransport(Supplier sendResultSupplier, - MessageStreamConsumer consumer) { - - final Random r = new Random(); - - return new MockTransport<>(sendResultSupplier, consumer) { - @Override - protected CompletableFuture defaultSend() { - return millis(r.nextInt(150), result()); - } - }; - } - }); - CompletableFuture start = new CompletableFuture<>(); - - int NUM_MESSAGES = 512; - CompletableFuture current = start; - for (int i = 0; i < NUM_MESSAGES; i++) { - current = current.thenComposeAsync(w -> w.sendBinary(ByteBuffer.allocate(1), true)); - } - - start.completeAsync(() -> ws); - current.join(); - - MockTransport transport = (MockTransport) ws.transport(); - assertEquals(transport.invocations().size(), NUM_MESSAGES); - } - - - @Test - public void sendTextImmediately() { - WebSocketImpl ws = newInstance( - new MockListener(1), - new TransportFactory() { - @Override - public Transport createTransport(Supplier sendResultSupplier, - MessageStreamConsumer consumer) { - return new MockTransport<>(sendResultSupplier, consumer); - } - }); - CompletableFuture.completedFuture(ws) - .thenCompose(w -> w.sendText("1", true)) - .thenCompose(w -> w.sendText("2", true)) - .thenCompose(w -> w.sendText("3", true)) - .join(); - MockTransport transport = (MockTransport) ws.transport(); - assertEquals(transport.invocations().size(), 3); - } - - @Test - public void sendTextWithDelay() { - MockListener listener = new MockListener(1); - WebSocketImpl ws = newInstance( - listener, - new TransportFactory() { - @Override - public Transport createTransport(Supplier sendResultSupplier, - MessageStreamConsumer consumer) { - return new MockTransport<>(sendResultSupplier, consumer) { - @Override - protected CompletableFuture defaultSend() { - return seconds(1, result()); - } - }; - } - }); - CompletableFuture.completedFuture(ws) - .thenCompose(w -> w.sendText("1", true)) - .thenCompose(w -> w.sendText("2", true)) - .thenCompose(w -> w.sendText("3", true)) - .join(); - assertEquals(listener.invocations(), List.of(onOpen(ws))); - MockTransport transport = (MockTransport) ws.transport(); - assertEquals(transport.invocations().size(), 3); - } - - @Test - public void sendTextMixedDelay() { - MockListener listener = new MockListener(1); - WebSocketImpl ws = newInstance( - listener, - new TransportFactory() { - - final Random r = new Random(); - - @Override - public Transport createTransport(Supplier sendResultSupplier, - MessageStreamConsumer consumer) { - return new MockTransport<>(sendResultSupplier, consumer) { - @Override - protected CompletableFuture defaultSend() { - return r.nextBoolean() - ? seconds(1, result()) - : now(result()); - } - }; - } - }); - CompletableFuture.completedFuture(ws) - .thenCompose(w -> w.sendText("1", true)) - .thenCompose(w -> w.sendText("2", true)) - .thenCompose(w -> w.sendText("3", true)) - .thenCompose(w -> w.sendText("4", true)) - .thenCompose(w -> w.sendText("5", true)) - .thenCompose(w -> w.sendText("6", true)) - .thenCompose(w -> w.sendText("7", true)) - .thenCompose(w -> w.sendText("8", true)) - .thenCompose(w -> w.sendText("9", true)) - .join(); - assertEquals(listener.invocations(), List.of(onOpen(ws))); - MockTransport transport = (MockTransport) ws.transport(); - assertEquals(transport.invocations().size(), 9); - } - - @Test(enabled = false) // temporarily disabled - public void sendControlMessagesConcurrently() { - MockListener listener = new MockListener(1); - - CompletableFuture first = new CompletableFuture<>(); // barrier - - WebSocketImpl ws = newInstance( - listener, - new TransportFactory() { - - final AtomicInteger i = new AtomicInteger(); - - @Override - public Transport createTransport(Supplier sendResultSupplier, - MessageStreamConsumer consumer) { - return new MockTransport<>(sendResultSupplier, consumer) { - @Override - protected CompletableFuture defaultSend() { - if (i.incrementAndGet() == 1) { - return first.thenApply(o -> result()); - } else { - return now(result()); - } - } - }; - } - }); - - CompletableFuture cf1 = ws.sendPing(ByteBuffer.allocate(0)); - CompletableFuture cf2 = ws.sendPong(ByteBuffer.allocate(0)); - CompletableFuture cf3 = ws.sendClose(NORMAL_CLOSURE, ""); - CompletableFuture cf4 = ws.sendClose(NORMAL_CLOSURE, ""); - CompletableFuture cf5 = ws.sendPing(ByteBuffer.allocate(0)); - CompletableFuture cf6 = ws.sendPong(ByteBuffer.allocate(0)); - - first.complete(null); - // Don't care about exceptional completion, only that all of them have - // completed - CompletableFuture.allOf(cf1, cf2, cf3, cf4, cf5, cf6) - .handle((v, e) -> null).join(); - - cf3.join(); /* Check that sendClose has completed normally */ - cf4.join(); /* Check that repeated sendClose has completed normally */ - assertCompletesExceptionally(IllegalStateException.class, cf5); - assertCompletesExceptionally(IllegalStateException.class, cf6); - - assertEquals(listener.invocations(), List.of(onOpen(ws))); - MockTransport transport = (MockTransport) ws.transport(); - assertEquals(transport.invocations().size(), 3); // 6 minus 3 that were not accepted - } - - private static CompletableFuture seconds(long val, T result) { - return new CompletableFuture() - .completeOnTimeout(result, val, TimeUnit.SECONDS); - } - - private static CompletableFuture millis(long val, T result) { - return new CompletableFuture() - .completeOnTimeout(result, val, TimeUnit.MILLISECONDS); - } - - private static CompletableFuture now(T result) { - return CompletableFuture.completedFuture(result); - } - - private static WebSocketImpl newInstance( - WebSocket.Listener listener, - Collection>> input) { - TransportFactory factory = new TransportFactory() { - @Override - public Transport createTransport(Supplier sendResultSupplier, - MessageStreamConsumer consumer) { - return new MockTransport(sendResultSupplier, consumer) { - @Override - protected Collection>> receive() { - return input; - } - }; - } - }; - return newInstance(listener, factory); - } - - private static WebSocketImpl newInstance(WebSocket.Listener listener, - TransportFactory factory) { - URI uri = URI.create("ws://localhost"); - String subprotocol = ""; - return WebSocketImpl.newInstance(uri, subprotocol, listener, factory); - } -}