8202788: Explicitly reclaim cached thread-local direct buffers at thread exit
authorplevart
Fri, 22 Jun 2018 17:56:55 +0200
changeset 50719 106dc156ce6b
parent 50715 46492a773912
child 50720 c55b1386f119
8202788: Explicitly reclaim cached thread-local direct buffers at thread exit Summary: Add internal TerminatingThreadLocal and use it to free cached thread-local direct buffers and nio-fs native buffers Reviewed-by: tonyp, alanb
src/java.base/share/classes/java/lang/Thread.java
src/java.base/share/classes/java/lang/ThreadLocal.java
src/java.base/share/classes/jdk/internal/misc/TerminatingThreadLocal.java
src/java.base/share/classes/sun/nio/ch/Util.java
src/java.base/share/classes/sun/nio/fs/NativeBuffers.java
test/jdk/java/nio/channels/FileChannel/TempDirectBuffersReclamation.java
test/jdk/jdk/internal/misc/TerminatingThreadLocal/TestTerminatingThreadLocal.java
test/jdk/sun/nio/ch/TestMaxCachedBufferSize.java
--- a/src/java.base/share/classes/java/lang/Thread.java	Fri Jun 22 21:42:00 2018 +0800
+++ b/src/java.base/share/classes/java/lang/Thread.java	Fri Jun 22 17:56:55 2018 +0200
@@ -36,6 +36,8 @@
 import java.util.concurrent.ConcurrentHashMap;
 import java.util.concurrent.ConcurrentMap;
 import java.util.concurrent.locks.LockSupport;
+
+import jdk.internal.misc.TerminatingThreadLocal;
 import sun.nio.ch.Interruptible;
 import jdk.internal.reflect.CallerSensitive;
 import jdk.internal.reflect.Reflection;
@@ -838,6 +840,9 @@
      * a chance to clean up before it actually exits.
      */
     private void exit() {
+        if (TerminatingThreadLocal.REGISTRY.isPresent()) {
+            TerminatingThreadLocal.threadTerminated();
+        }
         if (group != null) {
             group.threadTerminated(this);
             group = null;
--- a/src/java.base/share/classes/java/lang/ThreadLocal.java	Fri Jun 22 21:42:00 2018 +0800
+++ b/src/java.base/share/classes/java/lang/ThreadLocal.java	Fri Jun 22 17:56:55 2018 +0200
@@ -24,6 +24,8 @@
  */
 
 package java.lang;
+import jdk.internal.misc.TerminatingThreadLocal;
+
 import java.lang.ref.*;
 import java.util.Objects;
 import java.util.concurrent.atomic.AtomicInteger;
@@ -171,6 +173,19 @@
     }
 
     /**
+     * Returns {@code true} if there is a value in the current thread's copy of
+     * this thread-local variable, even if that values is {@code null}.
+     *
+     * @return {@code true} if current thread has associated value in this
+     *         thread-local variable; {@code false} if not
+     */
+    boolean isPresent() {
+        Thread t = Thread.currentThread();
+        ThreadLocalMap map = getMap(t);
+        return map != null && map.getEntry(this) != null;
+    }
+
+    /**
      * Variant of set() to establish initialValue. Used instead
      * of set() in case user has overridden the set() method.
      *
@@ -180,10 +195,14 @@
         T value = initialValue();
         Thread t = Thread.currentThread();
         ThreadLocalMap map = getMap(t);
-        if (map != null)
+        if (map != null) {
             map.set(this, value);
-        else
+        } else {
             createMap(t, value);
+        }
+        if (this instanceof TerminatingThreadLocal) {
+            TerminatingThreadLocal.register((TerminatingThreadLocal<?>) this);
+        }
         return value;
     }
 
@@ -199,10 +218,11 @@
     public void set(T value) {
         Thread t = Thread.currentThread();
         ThreadLocalMap map = getMap(t);
-        if (map != null)
+        if (map != null) {
             map.set(this, value);
-        else
+        } else {
             createMap(t, value);
+        }
     }
 
     /**
@@ -218,8 +238,9 @@
      */
      public void remove() {
          ThreadLocalMap m = getMap(Thread.currentThread());
-         if (m != null)
+         if (m != null) {
              m.remove(this);
+         }
      }
 
     /**
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/src/java.base/share/classes/jdk/internal/misc/TerminatingThreadLocal.java	Fri Jun 22 17:56:55 2018 +0200
@@ -0,0 +1,105 @@
+/*
+ * Copyright (c) 2018, Oracle and/or its affiliates. All rights reserved.
+ * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
+ *
+ * This code is free software; you can redistribute it and/or modify it
+ * under the terms of the GNU General Public License version 2 only, as
+ * published by the Free Software Foundation.  Oracle designates this
+ * particular file as subject to the "Classpath" exception as provided
+ * by Oracle in the LICENSE file that accompanied this code.
+ *
+ * This code is distributed in the hope that it will be useful, but WITHOUT
+ * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
+ * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
+ * version 2 for more details (a copy is included in the LICENSE file that
+ * accompanied this code).
+ *
+ * You should have received a copy of the GNU General Public License version
+ * 2 along with this work; if not, write to the Free Software Foundation,
+ * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
+ *
+ * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
+ * or visit www.oracle.com if you need additional information or have any
+ * questions.
+ */
+package jdk.internal.misc;
+
+import java.util.Collection;
+import java.util.Collections;
+import java.util.IdentityHashMap;
+
+/**
+ * A thread-local variable that is notified when a thread terminates and
+ * it has been initialized in the terminating thread (even if it was
+ * initialized with a null value).
+ */
+public class TerminatingThreadLocal<T> extends ThreadLocal<T> {
+
+    @Override
+    public void set(T value) {
+        super.set(value);
+        register(this);
+    }
+
+    @Override
+    public void remove() {
+        super.remove();
+        unregister(this);
+    }
+
+    /**
+     * Invoked by a thread when terminating and this thread-local has an associated
+     * value for the terminating thread (even if that value is null), so that any
+     * native resources maintained by the value can be released.
+     *
+     * @param value current thread's value of this thread-local variable
+     *              (may be null but only if null value was explicitly initialized)
+     */
+    protected void threadTerminated(T value) {
+    }
+
+    // following methods and field are implementation details and should only be
+    // called from the corresponding code int Thread/ThreadLocal class.
+
+    /**
+     * Invokes the TerminatingThreadLocal's {@link #threadTerminated()} method
+     * on all instances registered in current thread.
+     */
+    public static void threadTerminated() {
+        for (TerminatingThreadLocal<?> ttl : REGISTRY.get()) {
+            ttl._threadTerminated();
+        }
+    }
+
+    private void _threadTerminated() { threadTerminated(get()); }
+
+    /**
+     * Register given TerminatingThreadLocal
+     *
+     * @param tl the ThreadLocal to register
+     */
+    public static void register(TerminatingThreadLocal<?> tl) {
+        REGISTRY.get().add(tl);
+    }
+
+    /**
+     * Unregister given TerminatingThreadLocal
+     *
+     * @param tl the ThreadLocal to unregister
+     */
+    private static void unregister(TerminatingThreadLocal<?> tl) {
+        REGISTRY.get().remove(tl);
+    }
+
+    /**
+     * a per-thread registry of TerminatingThreadLocal(s) that have been registered
+     * but later not unregistered in a particular thread.
+     */
+    public static final ThreadLocal<Collection<TerminatingThreadLocal<?>>> REGISTRY =
+        new ThreadLocal<>() {
+            @Override
+            protected Collection<TerminatingThreadLocal<?>> initialValue() {
+                return Collections.newSetFromMap(new IdentityHashMap<>(4));
+            }
+        };
+}
--- a/src/java.base/share/classes/sun/nio/ch/Util.java	Fri Jun 22 21:42:00 2018 +0800
+++ b/src/java.base/share/classes/sun/nio/ch/Util.java	Fri Jun 22 17:56:55 2018 +0200
@@ -26,6 +26,7 @@
 package sun.nio.ch;
 
 import java.io.FileDescriptor;
+import java.io.IOException;
 import java.lang.reflect.Constructor;
 import java.lang.reflect.InvocationTargetException;
 import java.nio.ByteBuffer;
@@ -35,9 +36,10 @@
 import java.util.Collection;
 import java.util.Iterator;
 import java.util.Set;
+
+import jdk.internal.misc.TerminatingThreadLocal;
 import jdk.internal.misc.Unsafe;
 import sun.security.action.GetPropertyAction;
-import java.io.IOException;
 
 public class Util {
 
@@ -50,13 +52,18 @@
     private static final long MAX_CACHED_BUFFER_SIZE = getMaxCachedBufferSize();
 
     // Per-thread cache of temporary direct buffers
-    private static ThreadLocal<BufferCache> bufferCache =
-        new ThreadLocal<BufferCache>()
-    {
+    private static ThreadLocal<BufferCache> bufferCache = new TerminatingThreadLocal<>() {
         @Override
         protected BufferCache initialValue() {
             return new BufferCache();
         }
+        @Override
+        protected void threadTerminated(BufferCache cache) { // will never be null
+            while (!cache.isEmpty()) {
+                ByteBuffer bb = cache.removeFirst();
+                free(bb);
+            }
+        }
     };
 
     /**
--- a/src/java.base/share/classes/sun/nio/fs/NativeBuffers.java	Fri Jun 22 21:42:00 2018 +0800
+++ b/src/java.base/share/classes/sun/nio/fs/NativeBuffers.java	Fri Jun 22 17:56:55 2018 +0200
@@ -25,6 +25,7 @@
 
 package sun.nio.fs;
 
+import jdk.internal.misc.TerminatingThreadLocal;
 import jdk.internal.misc.Unsafe;
 
 /**
@@ -37,8 +38,21 @@
     private static final Unsafe unsafe = Unsafe.getUnsafe();
 
     private static final int TEMP_BUF_POOL_SIZE = 3;
-    private static ThreadLocal<NativeBuffer[]> threadLocal =
-        new ThreadLocal<NativeBuffer[]>();
+    private static ThreadLocal<NativeBuffer[]> threadLocal = new TerminatingThreadLocal<>() {
+        @Override
+        protected void threadTerminated(NativeBuffer[] buffers) {
+            // threadLocal may be initialized but with initialValue of null
+            if (buffers != null) {
+                for (int i = 0; i < TEMP_BUF_POOL_SIZE; i++) {
+                    NativeBuffer buffer = buffers[i];
+                    if (buffer != null) {
+                        buffer.free();
+                        buffers[i] = null;
+                    }
+                }
+            }
+        }
+    };
 
     /**
      * Allocates a native buffer, of at least the given size, from the heap.
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/test/jdk/java/nio/channels/FileChannel/TempDirectBuffersReclamation.java	Fri Jun 22 17:56:55 2018 +0200
@@ -0,0 +1,91 @@
+/*
+ * Copyright (c) 2018, Oracle and/or its affiliates. All rights reserved.
+ * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
+ *
+ * This code is free software; you can redistribute it and/or modify it
+ * under the terms of the GNU General Public License version 2 only, as
+ * published by the Free Software Foundation.
+ *
+ * This code is distributed in the hope that it will be useful, but WITHOUT
+ * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
+ * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
+ * version 2 for more details (a copy is included in the LICENSE file that
+ * accompanied this code).
+ *
+ * You should have received a copy of the GNU General Public License version
+ * 2 along with this work; if not, write to the Free Software Foundation,
+ * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
+ *
+ * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
+ * or visit www.oracle.com if you need additional information or have any
+ * questions.
+ */
+
+import java.io.IOException;
+import java.io.UncheckedIOException;
+import java.lang.management.BufferPoolMXBean;
+import java.lang.management.ManagementFactory;
+import java.nio.ByteBuffer;
+import java.nio.channels.FileChannel;
+import java.nio.charset.StandardCharsets;
+import java.nio.file.Files;
+import java.nio.file.Path;
+
+import static java.nio.file.StandardOpenOption.CREATE;
+import static java.nio.file.StandardOpenOption.TRUNCATE_EXISTING;
+import static java.nio.file.StandardOpenOption.WRITE;
+
+/*
+ * @test
+ * @bug 8202788
+ * @summary Test reclamation of thread-local temporary direct byte buffers at thread exit
+ * @modules java.management
+ * @run main/othervm TempDirectBuffersReclamation
+ */
+public class TempDirectBuffersReclamation {
+
+    public static void main(String[] args) throws IOException {
+
+        BufferPoolMXBean dbPool = ManagementFactory
+            .getPlatformMXBeans(BufferPoolMXBean.class)
+            .stream()
+            .filter(bp -> bp.getName().equals("direct"))
+            .findFirst()
+            .orElseThrow(() -> new RuntimeException("Can't obtain direct BufferPoolMXBean"));
+
+        long count0 = dbPool.getCount();
+        long memoryUsed0 = dbPool.getMemoryUsed();
+
+        Thread thread = new Thread(TempDirectBuffersReclamation::doFileChannelWrite);
+        thread.start();
+        try {
+            thread.join();
+        } catch (InterruptedException e) {
+            throw new RuntimeException(e);
+        }
+
+        long count1 = dbPool.getCount();
+        long memoryUsed1 = dbPool.getMemoryUsed();
+
+        if (count0 != count1 || memoryUsed0 != memoryUsed1) {
+            throw new AssertionError(
+                "Direct BufferPool not same before thread activity and after thread exit.\n" +
+                "Before: # of buffers: " + count0 + ", memory used: " + memoryUsed0 + "\n" +
+                " After: # of buffers: " + count1 + ", memory used: " + memoryUsed1 + "\n"
+            );
+        }
+    }
+
+    static void doFileChannelWrite() {
+        try {
+            Path file = Files.createTempFile("test", ".tmp");
+            try (FileChannel fc = FileChannel.open(file, CREATE, WRITE, TRUNCATE_EXISTING)) {
+                fc.write(ByteBuffer.wrap("HELLO".getBytes(StandardCharsets.UTF_8)));
+            } finally {
+                Files.delete(file);
+            }
+        } catch (IOException e) {
+            throw new UncheckedIOException(e);
+        }
+    }
+}
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/test/jdk/jdk/internal/misc/TerminatingThreadLocal/TestTerminatingThreadLocal.java	Fri Jun 22 17:56:55 2018 +0200
@@ -0,0 +1,90 @@
+/*
+ * Copyright (c) 2018, Oracle and/or its affiliates. All rights reserved.
+ * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
+ *
+ * This code is free software; you can redistribute it and/or modify it
+ * under the terms of the GNU General Public License version 2 only, as
+ * published by the Free Software Foundation.
+ *
+ * This code is distributed in the hope that it will be useful, but WITHOUT
+ * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
+ * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
+ * version 2 for more details (a copy is included in the LICENSE file that
+ * accompanied this code).
+ *
+ * You should have received a copy of the GNU General Public License version
+ * 2 along with this work; if not, write to the Free Software Foundation,
+ * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
+ *
+ * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
+ * or visit www.oracle.com if you need additional information or have any
+ * questions.
+ */
+
+import jdk.internal.misc.TerminatingThreadLocal;
+
+import java.util.Arrays;
+import java.util.List;
+import java.util.concurrent.CopyOnWriteArrayList;
+import java.util.function.Consumer;
+
+/*
+ * @test
+ * @bug 8202788
+ * @summary TerminatingThreadLocal unit test
+ * @modules java.base/jdk.internal.misc
+ * @run main TestTerminatingThreadLocal
+ */
+public class TestTerminatingThreadLocal {
+
+    public static void main(String[] args) {
+        ttlTestSet(42,   112);
+        ttlTestSet(null, 112);
+        ttlTestSet(42,  null);
+    }
+
+    static <T> void ttlTestSet(T v0, T v1) {
+        ttlTest(v0, ttl -> {                                         }    );
+        ttlTest(v0, ttl -> { ttl.get();                              }, v0);
+        ttlTest(v0, ttl -> { ttl.get();   ttl.remove();              }    );
+        ttlTest(v0, ttl -> { ttl.get();   ttl.set(v1);               }, v1);
+        ttlTest(v0, ttl -> { ttl.set(v1);                            }, v1);
+        ttlTest(v0, ttl -> { ttl.set(v1); ttl.remove();              }    );
+        ttlTest(v0, ttl -> { ttl.set(v1); ttl.remove(); ttl.get();   }, v0);
+        ttlTest(v0, ttl -> { ttl.get();   ttl.remove(); ttl.set(v1); }, v1);
+    }
+
+    @SafeVarargs
+    static <T> void ttlTest(T initialValue,
+                            Consumer<? super TerminatingThreadLocal<T>> ttlOps,
+                            T... expectedTerminatedValues)
+    {
+        List<T> terminatedValues = new CopyOnWriteArrayList<>();
+
+        TerminatingThreadLocal<T> ttl = new TerminatingThreadLocal<>() {
+            @Override
+            protected void threadTerminated(T value) {
+                terminatedValues.add(value);
+            }
+
+            @Override
+            protected T initialValue() {
+                return initialValue;
+            }
+        };
+
+        Thread thread = new Thread(() -> ttlOps.accept(ttl));
+        thread.start();
+        try {
+            thread.join();
+        } catch (InterruptedException e) {
+            throw new RuntimeException(e);
+        }
+
+        if (!terminatedValues.equals(Arrays.asList(expectedTerminatedValues))) {
+            throw new AssertionError("Expected terminated values: " +
+                                     Arrays.toString(expectedTerminatedValues) +
+                                     " but got: " + terminatedValues);
+        }
+    }
+}
--- a/test/jdk/sun/nio/ch/TestMaxCachedBufferSize.java	Fri Jun 22 21:42:00 2018 +0800
+++ b/test/jdk/sun/nio/ch/TestMaxCachedBufferSize.java	Fri Jun 22 17:56:55 2018 +0200
@@ -39,6 +39,7 @@
 
 import java.util.List;
 import java.util.Random;
+import java.util.concurrent.CountDownLatch;
 
 /*
  * @test
@@ -93,6 +94,7 @@
     // by setting the jdk.nio.maxCachedBufferSize property.
     private static class Worker implements Runnable {
         private final int id;
+        private final CountDownLatch finishLatch, exitLatch;
         private final Random random = new Random();
         private long smallBufferCount = 0;
         private long largeBufferCount = 0;
@@ -152,6 +154,13 @@
                 }
             } catch (IOException e) {
                 throw new Error("I/O error", e);
+            } finally {
+                finishLatch.countDown();
+                try {
+                    exitLatch.await();
+                } catch (InterruptedException e) {
+                    // ignore
+                }
             }
         }
 
@@ -160,8 +169,10 @@
             loop();
         }
 
-        public Worker(int id) {
+        public Worker(int id, CountDownLatch finishLatch, CountDownLatch exitLatch) {
             this.id = id;
+            this.finishLatch = finishLatch;
+            this.exitLatch = exitLatch;
         }
     }
 
@@ -171,10 +182,6 @@
         System.out.printf("Direct %d / %dK\n",
                           directCount, directTotalCapacity / 1024);
 
-        // Note that directCount could be < expectedCount. This can
-        // happen if a GC occurs after one of the worker threads exits
-        // since its thread-local DirectByteBuffer could be cleaned up
-        // before we reach here.
         if (directCount > expectedCount) {
             throw new Error(String.format(
                 "inconsistent direct buffer total count, expected = %d, found = %d",
@@ -208,46 +215,57 @@
                           threadNum, iters, maxBufferSize);
         System.out.println();
 
+        final CountDownLatch finishLatch = new CountDownLatch(threadNum);
+        final CountDownLatch exitLatch = new CountDownLatch(1);
         final Thread[] threads = new Thread[threadNum];
         for (int i = 0; i < threadNum; i += 1) {
-            threads[i] = new Thread(new Worker(i));
+            threads[i] = new Thread(new Worker(i, finishLatch, exitLatch));
             threads[i].start();
         }
 
         try {
-            for (int i = 0; i < threadNum; i += 1) {
-                threads[i].join();
+            try {
+                finishLatch.await();
+            } catch (InterruptedException e) {
+                throw new Error("finishLatch.await() interrupted!", e);
             }
-        } catch (InterruptedException e) {
-            throw new Error("join() interrupted!", e);
-        }
 
-        // There is an assumption here that, at this point, only the
-        // cached DirectByteBuffers should be active. Given we
-        // haven't used any other DirectByteBuffers in this test, this
-        // should hold.
-        //
-        // Also note that we can only do the sanity checking at the
-        // end and not during the run given that, at any time, there
-        // could be buffers currently in use by some of the workers
-        // that will not be cached.
+            // There is an assumption here that, at this point, only the
+            // cached DirectByteBuffers should be active. Given we
+            // haven't used any other DirectByteBuffers in this test, this
+            // should hold.
+            //
+            // Also note that we can only do the sanity checking at the
+            // end and not during the run given that, at any time, there
+            // could be buffers currently in use by some of the workers
+            // that will not be cached.
 
-        System.out.println();
-        if (maxBufferSize < SMALL_BUFFER_MAX_SIZE) {
-            // The max buffer size is smaller than all buffers that
-            // were allocated. No buffers should have been cached.
-            checkDirectBuffers(0, 0);
-        } else if (maxBufferSize < LARGE_BUFFER_MIN_SIZE) {
-            // The max buffer size is larger than all small buffers
-            // but smaller than all large buffers that were
-            // allocated. Only small buffers could have been cached.
-            checkDirectBuffers(threadNum,
-                               (long) threadNum * (long) SMALL_BUFFER_MAX_SIZE);
-        } else {
-            // The max buffer size is larger than all buffers that
-            // were allocated. All buffers could have been cached.
-            checkDirectBuffers(threadNum,
-                               (long) threadNum * (long) LARGE_BUFFER_MAX_SIZE);
+            System.out.println();
+            if (maxBufferSize < SMALL_BUFFER_MAX_SIZE) {
+                // The max buffer size is smaller than all buffers that
+                // were allocated. No buffers should have been cached.
+                checkDirectBuffers(0, 0);
+            } else if (maxBufferSize < LARGE_BUFFER_MIN_SIZE) {
+                // The max buffer size is larger than all small buffers
+                // but smaller than all large buffers that were
+                // allocated. Only small buffers could have been cached.
+                checkDirectBuffers(threadNum,
+                                   (long) threadNum * (long) SMALL_BUFFER_MAX_SIZE);
+            } else {
+                // The max buffer size is larger than all buffers that
+                // were allocated. All buffers could have been cached.
+                checkDirectBuffers(threadNum,
+                                   (long) threadNum * (long) LARGE_BUFFER_MAX_SIZE);
+            }
+        } finally {
+            exitLatch.countDown();
+            try {
+                for (int i = 0; i < threadNum; i += 1) {
+                    threads[i].join();
+                }
+            } catch (InterruptedException e) {
+                // ignore
+            }
         }
     }
 }