jdk/src/java.httpclient/share/classes/java/net/http/WS.java
author prappo
Wed, 08 Jun 2016 15:19:58 +0100
changeset 38864 bf2b41533aed
parent 38856 cc3a0d1e96e0
child 39133 b5641ce64cf7
permissions -rw-r--r--
8156693: Improve usability of CompletableFuture use in WebSocket API Reviewed-by: rriggs

/*
 * Copyright (c) 2015, 2016, Oracle and/or its affiliates. All rights reserved.
 * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
 *
 * This code is free software; you can redistribute it and/or modify it
 * under the terms of the GNU General Public License version 2 only, as
 * published by the Free Software Foundation.  Oracle designates this
 * particular file as subject to the "Classpath" exception as provided
 * by Oracle in the LICENSE file that accompanied this code.
 *
 * This code is distributed in the hope that it will be useful, but WITHOUT
 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
 * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
 * version 2 for more details (a copy is included in the LICENSE file that
 * accompanied this code).
 *
 * You should have received a copy of the GNU General Public License version
 * 2 along with this work; if not, write to the Free Software Foundation,
 * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
 *
 * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
 * or visit www.oracle.com if you need additional information or have any
 * questions.
 */
package java.net.http;

import java.io.IOException;
import java.net.ProtocolException;
import java.net.http.WSOpeningHandshake.Result;
import java.nio.ByteBuffer;
import java.util.Optional;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CompletionStage;
import java.util.concurrent.Executor;
import java.util.function.Consumer;
import java.util.function.Supplier;
import java.util.stream.Stream;

import static java.lang.System.Logger.Level.ERROR;
import static java.lang.System.Logger.Level.WARNING;
import static java.net.http.WSUtils.logger;
import static java.util.Objects.requireNonNull;

/*
 * A WebSocket client.
 *
 * Consists of two independent parts; a transmitter responsible for sending
 * messages, and a receiver which notifies the listener of incoming messages.
 */
final class WS implements WebSocket {

    private final String subprotocol;
    private final RawChannel channel;
    private final WSTransmitter transmitter;
    private final WSReceiver receiver;
    private final Listener listener;
    private final Object stateLock = new Object();
    private volatile State state = State.CONNECTED;
    private final CompletableFuture<Void> whenClosed = new CompletableFuture<>();

    static CompletableFuture<WebSocket> newInstanceAsync(WSBuilder b) {
        CompletableFuture<Result> result = new WSOpeningHandshake(b).performAsync();
        Listener listener = b.getListener();
        Executor executor = b.getClient().executorService();
        return result.thenApply(r -> {
            WS ws = new WS(listener, r.subprotocol, r.channel, executor);
            ws.start();
            return ws;
        });
    }

    private WS(Listener listener, String subprotocol, RawChannel channel,
               Executor executor) {
        this.listener = wrapListener(listener);
        this.channel = channel;
        this.subprotocol = subprotocol;
        Consumer<Throwable> errorHandler = error -> {
            if (error == null) {
                throw new InternalError();
            }
            // If the channel is closed, we need to update the state, to denote
            // there's no point in trying to continue using WebSocket
            if (!channel.isOpen()) {
                synchronized (stateLock) {
                    tryChangeState(State.ERROR);
                }
            }
        };
        transmitter = new WSTransmitter(this, executor, channel, errorHandler);
        receiver = new WSReceiver(this.listener, this, executor, channel);
    }

    private void start() {
        receiver.start();
    }

    @Override
    public CompletableFuture<WebSocket> sendText(CharSequence message, boolean isLast) {
        requireNonNull(message, "message");
        synchronized (stateLock) {
            checkState();
            return transmitter.sendText(message, isLast);
        }
    }

    @Override
    public CompletableFuture<WebSocket> sendText(Stream<? extends CharSequence> message) {
        requireNonNull(message, "message");
        synchronized (stateLock) {
            checkState();
            return transmitter.sendText(message);
        }
    }

    @Override
    public CompletableFuture<WebSocket> sendBinary(ByteBuffer message, boolean isLast) {
        requireNonNull(message, "message");
        synchronized (stateLock) {
            checkState();
            return transmitter.sendBinary(message, isLast);
        }
    }

    @Override
    public CompletableFuture<WebSocket> sendPing(ByteBuffer message) {
        requireNonNull(message, "message");
        synchronized (stateLock) {
            checkState();
            return transmitter.sendPing(message);
        }
    }

    @Override
    public CompletableFuture<WebSocket> sendPong(ByteBuffer message) {
        requireNonNull(message, "message");
        synchronized (stateLock) {
            checkState();
            return transmitter.sendPong(message);
        }
    }

    @Override
    public CompletableFuture<WebSocket> sendClose(CloseCode code, CharSequence reason) {
        requireNonNull(code, "code");
        requireNonNull(reason, "reason");
        synchronized (stateLock) {
            return doSendClose(() -> transmitter.sendClose(code, reason));
        }
    }

    @Override
    public CompletableFuture<WebSocket> sendClose() {
        synchronized (stateLock) {
            return doSendClose(() -> transmitter.sendClose());
        }
    }

    private CompletableFuture<WebSocket> doSendClose(Supplier<CompletableFuture<WebSocket>> s) {
        checkState();
        boolean closeChannel = false;
        synchronized (stateLock) {
            if (state == State.CLOSED_REMOTELY) {
                closeChannel = tryChangeState(State.CLOSED);
            } else {
                tryChangeState(State.CLOSED_LOCALLY);
            }
        }
        CompletableFuture<WebSocket> sent = s.get();
        if (closeChannel) {
            sent.whenComplete((v, t) -> {
                try {
                    channel.close();
                } catch (IOException e) {
                    logger.log(ERROR, "Error transitioning to state " + State.CLOSED, e);
                }
            });
        }
        return sent;
    }

    @Override
    public long request(long n) {
        if (n < 0L) {
            throw new IllegalArgumentException("The number must not be negative: " + n);
        }
        return receiver.request(n);
    }

    @Override
    public String getSubprotocol() {
        return subprotocol;
    }

    @Override
    public boolean isClosed() {
        return state.isTerminal();
    }

    @Override
    public void abort() throws IOException {
        synchronized (stateLock) {
            tryChangeState(State.ABORTED);
        }
        channel.close();
    }

    @Override
    public String toString() {
        return super.toString() + "[" + state + "]";
    }

    private void checkState() {
        if (state.isTerminal() || state == State.CLOSED_LOCALLY) {
            throw new IllegalStateException("WebSocket is closed [" + state + "]");
        }
    }

    /*
     * Wraps the user's listener passed to the constructor into own listener to
     * intercept transitions to terminal states (onClose and onError) and to act
     * upon exceptions and values from the user's listener.
     */
    private Listener wrapListener(Listener listener) {
        return new Listener() {

            // Listener's method MUST be invoked in a happen-before order
            private final Object visibilityLock = new Object();

            @Override
            public void onOpen(WebSocket webSocket) {
                synchronized (visibilityLock) {
                    listener.onOpen(webSocket);
                }
            }

            @Override
            public CompletionStage<?> onText(WebSocket webSocket, CharSequence message,
                                             MessagePart part) {
                synchronized (visibilityLock) {
                    return listener.onText(webSocket, message, part);
                }
            }

            @Override
            public CompletionStage<?> onBinary(WebSocket webSocket, ByteBuffer message,
                                               MessagePart part) {
                synchronized (visibilityLock) {
                    return listener.onBinary(webSocket, message, part);
                }
            }

            @Override
            public CompletionStage<?> onPing(WebSocket webSocket, ByteBuffer message) {
                synchronized (visibilityLock) {
                    return listener.onPing(webSocket, message);
                }
            }

            @Override
            public CompletionStage<?> onPong(WebSocket webSocket, ByteBuffer message) {
                synchronized (visibilityLock) {
                    return listener.onPong(webSocket, message);
                }
            }

            @Override
            public void onClose(WebSocket webSocket, Optional<CloseCode> code, String reason) {
                synchronized (stateLock) {
                    if (state == State.CLOSED_REMOTELY || state.isTerminal()) {
                        throw new InternalError("Unexpected onClose in state " + state);
                    } else if (state == State.CLOSED_LOCALLY) {
                        try {
                            channel.close();
                        } catch (IOException e) {
                            logger.log(ERROR, "Error transitioning to state " + State.CLOSED, e);
                        }
                        tryChangeState(State.CLOSED);
                    } else if (state == State.CONNECTED) {
                        tryChangeState(State.CLOSED_REMOTELY);
                    }
                }
                synchronized (visibilityLock) {
                    listener.onClose(webSocket, code, reason);
                }
            }

            @Override
            public void onError(WebSocket webSocket, Throwable error) {
                // An error doesn't necessarily mean the connection must be
                // closed automatically
                if (!channel.isOpen()) {
                    synchronized (stateLock) {
                        tryChangeState(State.ERROR);
                    }
                } else if (error instanceof ProtocolException
                        && error.getCause() instanceof WSProtocolException) {
                    WSProtocolException cause = (WSProtocolException) error.getCause();
                    logger.log(WARNING, "Failing connection {0}, reason: ''{1}''",
                            webSocket, cause.getMessage());
                    CloseCode cc = cause.getCloseCode();
                    transmitter.sendClose(cc, "").whenComplete((v, t) -> {
                        synchronized (stateLock) {
                            tryChangeState(State.ERROR);
                        }
                        try {
                            channel.close();
                        } catch (IOException e) {
                            logger.log(ERROR, e);
                        }
                    });
                }
                synchronized (visibilityLock) {
                    listener.onError(webSocket, error);
                }
            }
        };
    }

    private boolean tryChangeState(State newState) {
        assert Thread.holdsLock(stateLock);
        if (state.isTerminal()) {
            return false;
        }
        state = newState;
        if (newState.isTerminal()) {
            whenClosed.complete(null);
        }
        return true;
    }

    CompletionStage<Void> whenClosed() {
        return whenClosed;
    }

    /*
     * WebSocket connection internal state.
     */
    private enum State {

        /*
         * Initial WebSocket state. The WebSocket is connected (i.e. remains in
         * this state) unless proven otherwise. For example, by reading or
         * writing operations on the channel.
         */
        CONNECTED,

        /*
         * A Close message has been received by the client. No more messages
         * will be received.
         */
        CLOSED_REMOTELY,

        /*
         * A Close message has been sent by the client. No more messages can be
         * sent.
         */
        CLOSED_LOCALLY,

        /*
         * Close messages has been both sent and received (closing handshake)
         * and TCP connection closed. Closed _cleanly_ in terms of RFC 6455.
         */
        CLOSED,

        /*
         * The connection has been aborted by the client. Closed not _cleanly_
         * in terms of RFC 6455.
         */
        ABORTED,

        /*
         * The connection has been terminated due to a protocol or I/O error.
         * Might happen during sending or receiving.
         */
        ERROR;

        /*
         * Returns `true` if this state is terminal. If WebSocket has transited
         * to such a state, if remains in it forever.
         */
        boolean isTerminal() {
            return this == CLOSED || this == ABORTED || this == ERROR;
        }
    }
}