src/java.net.http/share/classes/java/net/http/internal/websocket/OutgoingMessage.java
branchhttp-client-branch
changeset 56089 42208b2f224e
parent 56088 38fac6d0521d
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/src/java.net.http/share/classes/java/net/http/internal/websocket/OutgoingMessage.java	Wed Feb 07 14:17:24 2018 +0000
@@ -0,0 +1,296 @@
+/*
+ * Copyright (c) 2015, 2018, Oracle and/or its affiliates. All rights reserved.
+ * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
+ *
+ * This code is free software; you can redistribute it and/or modify it
+ * under the terms of the GNU General Public License version 2 only, as
+ * published by the Free Software Foundation.  Oracle designates this
+ * particular file as subject to the "Classpath" exception as provided
+ * by Oracle in the LICENSE file that accompanied this code.
+ *
+ * This code is distributed in the hope that it will be useful, but WITHOUT
+ * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
+ * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
+ * version 2 for more details (a copy is included in the LICENSE file that
+ * accompanied this code).
+ *
+ * You should have received a copy of the GNU General Public License version
+ * 2 along with this work; if not, write to the Free Software Foundation,
+ * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
+ *
+ * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
+ * or visit www.oracle.com if you need additional information or have any
+ * questions.
+ */
+
+package java.net.http.internal.websocket;
+
+import java.net.http.internal.websocket.Frame.Opcode;
+
+import java.io.IOException;
+import java.nio.ByteBuffer;
+import java.nio.CharBuffer;
+import java.nio.charset.CharacterCodingException;
+import java.nio.charset.CharsetEncoder;
+import java.nio.charset.CoderResult;
+import java.security.SecureRandom;
+
+import static java.nio.charset.StandardCharsets.UTF_8;
+import static java.util.Objects.requireNonNull;
+import static java.net.http.internal.common.Utils.EMPTY_BYTEBUFFER;
+import static java.net.http.internal.websocket.Frame.MAX_HEADER_SIZE_BYTES;
+import static java.net.http.internal.websocket.Frame.Opcode.BINARY;
+import static java.net.http.internal.websocket.Frame.Opcode.CLOSE;
+import static java.net.http.internal.websocket.Frame.Opcode.CONTINUATION;
+import static java.net.http.internal.websocket.Frame.Opcode.PING;
+import static java.net.http.internal.websocket.Frame.Opcode.PONG;
+import static java.net.http.internal.websocket.Frame.Opcode.TEXT;
+
+/*
+ * A stateful object that represents a WebSocket message being sent to the
+ * channel.
+ *
+ * Data provided to the constructors is copied. Otherwise we would have to deal
+ * with mutability, security, masking/unmasking, readonly status, etc. So
+ * copying greatly simplifies the implementation.
+ *
+ * In the case of memory-sensitive environments an alternative implementation
+ * could use an internal pool of buffers though at the cost of extra complexity
+ * and possible performance degradation.
+ */
+abstract class OutgoingMessage {
+
+    // Share per WebSocket?
+    private static final SecureRandom maskingKeys = new SecureRandom();
+
+    protected ByteBuffer[] frame;
+    protected int offset;
+
+    /*
+     * Performs contextualization. This method is not a part of the constructor
+     * so it would be possible to defer the work it does until the most
+     * convenient moment (up to the point where sentTo is invoked).
+     */
+    protected boolean contextualize(Context context) {
+        // masking and charset decoding should be performed here rather than in
+        // the constructor (as of today)
+        if (context.isCloseSent()) {
+            throw new IllegalStateException("Close sent");
+        }
+        return true;
+    }
+
+    protected boolean sendTo(RawChannel channel) throws IOException {
+        while ((offset = nextUnwrittenIndex()) != -1) {
+            long n = channel.write(frame, offset, frame.length - offset);
+            if (n == 0) {
+                return false;
+            }
+        }
+        return true;
+    }
+
+    private int nextUnwrittenIndex() {
+        for (int i = offset; i < frame.length; i++) {
+            if (frame[i].hasRemaining()) {
+                return i;
+            }
+        }
+        return -1;
+    }
+
+    static final class Text extends OutgoingMessage {
+
+        private final ByteBuffer payload;
+        private final boolean isLast;
+
+        Text(CharSequence characters, boolean isLast) {
+            CharsetEncoder encoder = UTF_8.newEncoder(); // Share per WebSocket?
+            try {
+                payload = encoder.encode(CharBuffer.wrap(characters));
+            } catch (CharacterCodingException e) {
+                throw new IllegalArgumentException(
+                        "Malformed UTF-8 text message");
+            }
+            this.isLast = isLast;
+        }
+
+        @Override
+        protected boolean contextualize(Context context) {
+            super.contextualize(context);
+            if (context.isPreviousBinary() && !context.isPreviousLast()) {
+                throw new IllegalStateException("Unexpected text message");
+            }
+            frame = getDataMessageBuffers(
+                    TEXT, context.isPreviousLast(), isLast, payload, payload);
+            context.setPreviousBinary(false);
+            context.setPreviousText(true);
+            context.setPreviousLast(isLast);
+            return true;
+        }
+    }
+
+    static final class Binary extends OutgoingMessage {
+
+        private final ByteBuffer payload;
+        private final boolean isLast;
+
+        Binary(ByteBuffer payload, boolean isLast) {
+            this.payload = requireNonNull(payload);
+            this.isLast = isLast;
+        }
+
+        @Override
+        protected boolean contextualize(Context context) {
+            super.contextualize(context);
+            if (context.isPreviousText() && !context.isPreviousLast()) {
+                throw new IllegalStateException("Unexpected binary message");
+            }
+            ByteBuffer newBuffer = ByteBuffer.allocate(payload.remaining());
+            frame = getDataMessageBuffers(
+                    BINARY, context.isPreviousLast(), isLast, payload, newBuffer);
+            context.setPreviousText(false);
+            context.setPreviousBinary(true);
+            context.setPreviousLast(isLast);
+            return true;
+        }
+    }
+
+    static final class Ping extends OutgoingMessage {
+
+        Ping(ByteBuffer payload) {
+            frame = getControlMessageBuffers(PING, payload);
+        }
+    }
+
+    static final class Pong extends OutgoingMessage {
+
+        Pong(ByteBuffer payload) {
+            frame = getControlMessageBuffers(PONG, payload);
+        }
+    }
+
+    static final class Close extends OutgoingMessage {
+
+        Close() {
+            frame = getControlMessageBuffers(CLOSE, EMPTY_BYTEBUFFER);
+        }
+
+        Close(int statusCode, CharSequence reason) {
+            ByteBuffer payload = ByteBuffer.allocate(125)
+                    .putChar((char) statusCode);
+            CoderResult result = UTF_8.newEncoder()
+                    .encode(CharBuffer.wrap(reason),
+                            payload,
+                            true);
+            if (result.isOverflow()) {
+                throw new IllegalArgumentException("Long reason");
+            } else if (result.isError()) {
+                try {
+                    result.throwException();
+                } catch (CharacterCodingException e) {
+                    throw new IllegalArgumentException(
+                            "Malformed UTF-8 reason", e);
+                }
+            }
+            payload.flip();
+            frame = getControlMessageBuffers(CLOSE, payload);
+        }
+
+        @Override
+        protected boolean contextualize(Context context) {
+            if (context.isCloseSent()) {
+                return false;
+            } else {
+                context.setCloseSent();
+                return true;
+            }
+        }
+    }
+
+    private static ByteBuffer[] getControlMessageBuffers(Opcode opcode,
+                                                         ByteBuffer payload) {
+        assert opcode.isControl() : opcode;
+        int remaining = payload.remaining();
+        if (remaining > 125) {
+            throw new IllegalArgumentException
+                    ("Long message: " + remaining);
+        }
+        ByteBuffer frame = ByteBuffer.allocate(MAX_HEADER_SIZE_BYTES + remaining);
+        int mask = maskingKeys.nextInt();
+        new Frame.HeaderWriter()
+                .fin(true)
+                .opcode(opcode)
+                .payloadLen(remaining)
+                .mask(mask)
+                .write(frame);
+        Frame.Masker.transferMasking(payload, frame, mask);
+        frame.flip();
+        return new ByteBuffer[]{frame};
+    }
+
+    private static ByteBuffer[] getDataMessageBuffers(Opcode type,
+                                                      boolean isPreviousLast,
+                                                      boolean isLast,
+                                                      ByteBuffer payloadSrc,
+                                                      ByteBuffer payloadDst) {
+        assert !type.isControl() && type != CONTINUATION : type;
+        ByteBuffer header = ByteBuffer.allocate(MAX_HEADER_SIZE_BYTES);
+        int mask = maskingKeys.nextInt();
+        new Frame.HeaderWriter()
+                .fin(isLast)
+                .opcode(isPreviousLast ? type : CONTINUATION)
+                .payloadLen(payloadDst.remaining())
+                .mask(mask)
+                .write(header);
+        header.flip();
+        Frame.Masker.transferMasking(payloadSrc, payloadDst, mask);
+        payloadDst.flip();
+        return new ByteBuffer[]{header, payloadDst};
+    }
+
+    /*
+     * An instance of this class is passed sequentially between messages, so
+     * every message in a sequence can check the context it is in and update it
+     * if necessary.
+     */
+    public static class Context {
+
+        boolean previousLast = true;
+        boolean previousBinary;
+        boolean previousText;
+        boolean closeSent;
+
+        private boolean isPreviousText() {
+            return this.previousText;
+        }
+
+        private void setPreviousText(boolean value) {
+            this.previousText = value;
+        }
+
+        private boolean isPreviousBinary() {
+            return this.previousBinary;
+        }
+
+        private void setPreviousBinary(boolean value) {
+            this.previousBinary = value;
+        }
+
+        private boolean isPreviousLast() {
+            return this.previousLast;
+        }
+
+        private void setPreviousLast(boolean value) {
+            this.previousLast = value;
+        }
+
+        private boolean isCloseSent() {
+            return closeSent;
+        }
+
+        private void setCloseSent() {
+            closeSent = true;
+        }
+    }
+}