http-client-branch: fix race condition between WebSocket and PlainHttpConnection::detachChannel http-client-branch
authordfuchs
Tue, 13 Mar 2018 20:17:12 +0000
branchhttp-client-branch
changeset 56299 903ff8ec239d
parent 56298 81d4669da207
child 56300 13a2ec671e62
http-client-branch: fix race condition between WebSocket and PlainHttpConnection::detachChannel
src/java.net.http/share/classes/jdk/internal/net/http/HttpClientImpl.java
src/java.net.http/share/classes/jdk/internal/net/http/PlainHttpConnection.java
src/java.net.http/share/classes/jdk/internal/net/http/SocketTube.java
--- a/src/java.net.http/share/classes/jdk/internal/net/http/HttpClientImpl.java	Tue Mar 13 17:37:30 2018 +0000
+++ b/src/java.net.http/share/classes/jdk/internal/net/http/HttpClientImpl.java	Tue Mar 13 20:17:12 2018 +0000
@@ -46,6 +46,7 @@
 import java.time.Instant;
 import java.time.temporal.ChronoUnit;
 import java.util.ArrayList;
+import java.util.Arrays;
 import java.util.HashSet;
 import java.util.Iterator;
 import java.util.LinkedList;
@@ -373,6 +374,10 @@
         selmgr.cancel(s);
     }
 
+    void detachChannel(SocketChannel s, AsyncEvent... events) {
+        selmgr.detach(s, events);
+    }
+
     /**
      * Allows an AsyncEvent to modify its interestOps.
      * @param event The modified event.
@@ -509,6 +514,7 @@
         private final Selector selector;
         private volatile boolean closed;
         private final List<AsyncEvent> registrations;
+        private final List<AsyncTriggerEvent> deregistrations;
         private final System.Logger debug;
         private final System.Logger debugtimeout;
         HttpClientImpl owner;
@@ -521,9 +527,41 @@
             debugtimeout = ref.debugtimeout;
             pool = ref.connectionPool();
             registrations = new ArrayList<>();
+            deregistrations = new ArrayList<>();
             selector = Selector.open();
         }
 
+        void detach(SelectableChannel channel, AsyncEvent... events) {
+            if (Thread.currentThread() == this) {
+                SelectionKey key = channel.keyFor(selector);
+                if (key != null) {
+                    boolean removed = false;
+                    SelectorAttachment sa = (SelectorAttachment) key.attachment();
+                    if (sa != null) {
+                        for (AsyncEvent e : events) {
+                            if (sa.pending.remove(e)) removed = true;
+                        }
+                        // The key could already have been cancelled, in which
+                        // case the events would already have been removed.
+                        if (removed) {
+                            // We found at least one of the events, so we
+                            // should now cancel the key.
+                            sa.resetInterestOps(0);
+                            key.cancel();
+                        }
+                    }
+                }
+                registrations.removeAll(Arrays.asList(events));
+            } else {
+                synchronized (this) {
+                    deregistrations.add(new AsyncTriggerEvent(
+                            (x) -> debug.log(Level.DEBUG,
+                                    "Unexpected exception raised while detaching channel", x),
+                            () -> detach(channel, events)));
+                }
+            }
+        }
+
         void eventUpdated(AsyncEvent e) throws ClosedChannelException {
             if (Thread.currentThread() == this) {
                 SelectionKey key = e.channel().keyFor(selector);
@@ -585,6 +623,10 @@
                         assert errorList.isEmpty();
                         assert readyList.isEmpty();
                         assert resetList.isEmpty();
+                        for (AsyncTriggerEvent event : deregistrations) {
+                            event.handle();
+                        }
+                        deregistrations.clear();
                         for (AsyncEvent event : registrations) {
                             if (event instanceof AsyncTriggerEvent) {
                                 readyList.add(event);
@@ -829,6 +871,10 @@
             }
         }
 
+        boolean deregister(AsyncEvent e) {
+            return pending.remove(e);
+        }
+
         /**
          * Returns a Stream<AsyncEvents> containing only events that are
          * registered with the given {@code interestOps}.
--- a/src/java.net.http/share/classes/jdk/internal/net/http/PlainHttpConnection.java	Tue Mar 13 17:37:30 2018 +0000
+++ b/src/java.net.http/share/classes/jdk/internal/net/http/PlainHttpConnection.java	Tue Mar 13 20:17:12 2018 +0000
@@ -315,7 +315,7 @@
     // It should be removed when RawChannelImpl moves to using asynchronous APIs.
     @Override
     DetachedConnectionChannel detachChannel() {
-        client().cancelRegistration(channel());
+        tube.detach();
         return new PlainDetachedChannel(this);
     }
 
--- a/src/java.net.http/share/classes/jdk/internal/net/http/SocketTube.java	Tue Mar 13 17:37:30 2018 +0000
+++ b/src/java.net.http/share/classes/jdk/internal/net/http/SocketTube.java	Tue Mar 13 20:17:12 2018 +0000
@@ -32,6 +32,7 @@
 import java.util.List;
 import java.util.Objects;
 import java.util.concurrent.Flow;
+import java.util.concurrent.atomic.AtomicBoolean;
 import java.util.concurrent.atomic.AtomicLong;
 import java.util.concurrent.atomic.AtomicReference;
 import java.nio.channels.SelectableChannel;
@@ -64,6 +65,7 @@
     private final Supplier<ByteBuffer> buffersSource;
     private final Object lock = new Object();
     private final AtomicReference<Throwable> errorRef = new AtomicReference<>();
+    private final AtomicBoolean detached = new AtomicBoolean();
     private final InternalReadPublisher readPublisher;
     private final InternalWriteSubscriber writeSubscriber;
     private final long id = IDS.incrementAndGet();
@@ -162,6 +164,20 @@
                 new IOException("connection closed locally"));
     }
 
+    void detach() {
+        if (detached.compareAndSet(false, true)) {
+            readPublisher.subscriptionImpl.readScheduler.stop();
+            SocketFlowEvent[] events = {
+                    readPublisher.subscriptionImpl.readEvent,
+                    writeSubscriber.writeEvent
+            };
+            for (SocketFlowEvent event : events) {
+                event.pause();
+            }
+            client.detachChannel(channel, events);
+        }
+    }
+
     /**
      * A restartable task used to process tasks in sequence.
      */
@@ -436,6 +452,7 @@
         }
 
         void signalError(Throwable error) {
+            if (detached.get()) return;
             debug.log(Level.DEBUG, () -> "write error: " + error);
             completed = true;
             readPublisher.signalError(error);
@@ -528,6 +545,7 @@
         }
 
         void signalError(Throwable error) {
+            if (detached.get()) return;
             debug.log(Level.DEBUG, () -> "error signalled " + error);
             if (!errorRef.compareAndSet(null, error)) {
                 return;
@@ -695,6 +713,7 @@
             }
 
             final void signalError(Throwable error) {
+                if (detached.get()) return;
                 if (!errorRef.compareAndSet(null, error)) {
                     return;
                 }
@@ -703,6 +722,7 @@
             }
 
             final void signalReadable() {
+                if (detached.get()) return;
                 readScheduler.runOrSchedule();
             }
 
@@ -717,6 +737,7 @@
                 try {
                     while(!readScheduler.isStopped()) {
                         if (completed) return;
+                        if (detached.get()) return;
 
                         // make sure we have a subscriber
                         if (handlePending()) {
@@ -855,6 +876,7 @@
             }
             @Override
             protected final void signalEvent() {
+                if (detached.get()) return;
                 try {
                     client.eventUpdated(this);
                     sub.signalReadable();
@@ -865,6 +887,7 @@
 
             @Override
             protected final void signalError(Throwable error) {
+                if (detached.get()) return;
                 sub.signalError(error);
             }