4918870: Examine session cache implementation (sun.misc.Cache)
authorxuelei
Fri, 20 Feb 2009 12:50:02 +0800
changeset 2060 75e464ce81af
parent 2058 577525f89bd4
child 2061 5a32855b67d4
4918870: Examine session cache implementation (sun.misc.Cache) Summary: replace sun.misc.Cache with sun.security.util.Cache Reviewed-by: weijun
jdk/src/share/classes/sun/security/ssl/SSLSessionContextImpl.java
jdk/src/share/classes/sun/security/util/Cache.java
--- a/jdk/src/share/classes/sun/security/ssl/SSLSessionContextImpl.java	Mon Feb 16 17:19:05 2009 +0000
+++ b/jdk/src/share/classes/sun/security/ssl/SSLSessionContextImpl.java	Fri Feb 20 12:50:02 2009 +0800
@@ -1,5 +1,5 @@
 /*
- * Copyright 1999-2007 Sun Microsystems, Inc.  All Rights Reserved.
+ * Copyright 1999-2009 Sun Microsystems, Inc.  All Rights Reserved.
  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
  *
  * This code is free software; you can redistribute it and/or modify it
@@ -41,88 +41,112 @@
 import javax.net.ssl.SSLPeerUnverifiedException;
 import javax.net.ssl.SSLSession;
 
-import sun.misc.Cache;
+import sun.security.util.Cache;
 
 
-final class SSLSessionContextImpl implements SSLSessionContext
-{
-    private Cache       sessionCache = new Cache();
-    private Cache       sessionHostPortCache = new Cache();
-    private int         cacheLimit;
-    private long        timeoutMillis;
+final class SSLSessionContextImpl implements SSLSessionContext {
+    private Cache sessionCache;         // session cache, session id as key
+    private Cache sessionHostPortCache; // session cache, "host:port" as key
+    private int cacheLimit;             // the max cache size
+    private int timeout;                // timeout in seconds
+
     private static final Debug debug = Debug.getInstance("ssl");
 
-    // file private
-    SSLSessionContextImpl()
-    {
-        cacheLimit = getCacheLimit();
-        timeoutMillis = 86400000; // default, 24 hours
+    // package private
+    SSLSessionContextImpl() {
+        cacheLimit = getDefaultCacheLimit();    // default cache size
+        timeout = 86400;                        // default, 24 hours
+
+        // use soft reference
+        sessionCache = Cache.newSoftMemoryCache(cacheLimit, timeout);
+        sessionHostPortCache = Cache.newSoftMemoryCache(cacheLimit, timeout);
     }
 
     /**
-     * Returns the SSL session object associated with the
-     * specific session ID passed.
+     * Returns the <code>SSLSession</code> bound to the specified session id.
      */
-    public SSLSession   getSession(byte[] id)
-    {
-        SSLSession sess = (SSLSession) sessionCache.get(
-                                new SessionId(id));
-        return checkTimeValidity(sess);
+    public SSLSession getSession(byte[] sessionId) {
+        if (sessionId == null) {
+            throw new NullPointerException("session id cannot be null");
+        }
+
+        SSLSessionImpl sess =
+                (SSLSessionImpl)sessionCache.get(new SessionId(sessionId));
+        if (!isTimedout(sess)) {
+            return sess;
+        }
+
+        return null;
     }
 
     /**
      * Returns an enumeration of the active SSL sessions.
      */
     public Enumeration<byte[]> getIds() {
-        Vector<byte[]> v = new Vector<byte[]>(sessionCache.size());
-        SessionId sessId;
+        SessionCacheVisitor scVisitor = new SessionCacheVisitor();
+        sessionCache.accept(scVisitor);
 
-        for (Enumeration e = sessionCache.keys(); e.hasMoreElements(); ) {
-            sessId = (SessionId) e.nextElement();
-            if (!isTimedout((SSLSession)sessionCache.get(sessId)))
-                v.addElement(sessId.getId());
-        }
-        return v.elements();
+        return scVisitor.getSessionIds();
     }
 
+    /**
+     * Sets the timeout limit for cached <code>SSLSession</code> objects
+     *
+     * Note that after reset the timeout, the cached session before
+     * should be timed within the shorter one of the old timeout and the
+     * new timeout.
+     */
     public void setSessionTimeout(int seconds)
                  throws IllegalArgumentException {
-        if (seconds < 0)
+        if (seconds < 0) {
             throw new IllegalArgumentException();
-        timeoutMillis = seconds * 1000L;
+        }
+
+        if (timeout != seconds) {
+            sessionCache.setTimeout(seconds);
+            sessionHostPortCache.setTimeout(seconds);
+            timeout = seconds;
+        }
     }
 
+    /**
+     * Gets the timeout limit for cached <code>SSLSession</code> objects
+     */
     public int getSessionTimeout() {
-        return (int) (timeoutMillis / 1000);
+        return timeout;
     }
 
+    /**
+     * Sets the size of the cache used for storing
+     * <code>SSLSession</code> objects.
+     */
     public void setSessionCacheSize(int size)
                  throws IllegalArgumentException {
         if (size < 0)
             throw new IllegalArgumentException();
-        cacheLimit = size;
 
-        /**
-         * If cache size limit is reduced, when the cache is full to its
-         * previous limit, trim the cache before its contents
-         * are used.
-         */
-        if ((cacheLimit != 0) && (sessionCache.size() > cacheLimit))
-            adjustCacheSizeTo(cacheLimit);
+        if (cacheLimit != size) {
+            sessionCache.setCapacity(size);
+            sessionHostPortCache.setCapacity(size);
+            cacheLimit = size;
+        }
     }
 
+    /**
+     * Gets the size of the cache used for storing
+     * <code>SSLSession</code> objects.
+     */
     public int getSessionCacheSize() {
         return cacheLimit;
     }
 
+
+    // package-private method, used ONLY by ServerHandshaker
     SSLSessionImpl get(byte[] id) {
-        return (SSLSessionImpl) getSession(id);
+        return (SSLSessionImpl)getSession(id);
     }
 
-    /**
-     * Returns the SSL session object associated with the
-     * specific host name and port number passed.
-     */
+    // package-private method, used ONLY by ClientHandshaker
     SSLSessionImpl get(String hostname, int port) {
         /*
          * If no session caching info is available, we won't
@@ -131,96 +155,51 @@
         if (hostname == null && port == -1) {
             return null;
         }
-        SSLSession sess =  (SSLSessionImpl) sessionHostPortCache
-                                .get(getKey(hostname, port));
-        return (SSLSessionImpl) checkTimeValidity(sess);
+
+        SSLSessionImpl sess =
+            (SSLSessionImpl)sessionHostPortCache.get(getKey(hostname, port));
+        if (!isTimedout(sess)) {
+            return sess;
+        }
+
+        return null;
     }
 
     private String getKey(String hostname, int port) {
-        return (hostname + ":" + String.valueOf(port))
-                        .toLowerCase();
+        return (hostname + ":" + String.valueOf(port)).toLowerCase();
     }
 
+    // cache a SSLSession
+    //
+    // In SunJSSE implementation, a session is created while getting a
+    // client hello or a server hello message, and cached while the
+    // handshaking finished.
+    // Here we time the session from the time it cached instead of the
+    // time it created, which is a little longer than the expected. So
+    // please do check isTimedout() while getting entry from the cache.
     void put(SSLSessionImpl s) {
-        // make space for the new session to be added
-        if ((cacheLimit != 0) && (sessionCache.size() >= cacheLimit))
-            adjustCacheSizeTo(cacheLimit - 1);
-
-        /*
-         * Can always add the session id.
-         */
         sessionCache.put(s.getSessionId(), s);
 
-        /*
-         * If no hostname/port info is available, don't add this one.
-         */
+        // If no hostname/port info is available, don't add this one.
         if ((s.getPeerHost() != null) && (s.getPeerPort() != -1)) {
             sessionHostPortCache.put(
                 getKey(s.getPeerHost(), s.getPeerPort()), s);
         }
+
         s.setContext(this);
     }
 
-    private void adjustCacheSizeTo(int targetSize) {
-
-        int cacheSize = sessionCache.size();
-
-        if (targetSize < 0)
-           return;
-
-        while (cacheSize > targetSize) {
-            SSLSessionImpl lru = null;
-            SSLSessionImpl s = null;
-            Enumeration e;
-
-            if (debug != null && Debug.isOn("sessioncache")) {
-                System.out.println("exceeded cache limit of " + cacheLimit);
-            }
-
-            /*
-             * Count the number of elements in the cache. The size() method
-             * does not reflect the cache entries that are no longer available,
-             * i.e entries that are garbage collected (the cache entries are
-             * held using soft references and are garbage collected when not
-             * in use).
-             */
-            int count;
-            for (count = 0, e = sessionCache.elements();
-                         e.hasMoreElements(); count++) {
-                try {
-                    s = (SSLSessionImpl)e.nextElement();
-                } catch (NoSuchElementException nsee) {
-                    break;
-                }
-                if (isTimedout(s)) {
-                    lru = s;
-                    break;
-                } else if ((lru == null) || (s.getLastAccessedTime()
-                         < lru.getLastAccessedTime())) {
-                    lru = s;
-                }
-            }
-            if ((lru != null) && (count > targetSize)) {
-                if (debug != null && Debug.isOn("sessioncache")) {
-                    System.out.println("uncaching " + lru);
-                }
-                lru.invalidate();
-                count--; // element removed from the cache
-            }
-            cacheSize = count;
+    // package-private method, remove a cached SSLSession
+    void remove(SessionId key) {
+        SSLSessionImpl s = (SSLSessionImpl)sessionCache.get(key);
+        if (s != null) {
+            sessionCache.remove(key);
+            sessionHostPortCache.remove(
+                        getKey(s.getPeerHost(), s.getPeerPort()));
         }
     }
 
-    // file private
-    void remove(SessionId key)
-    {
-        SSLSessionImpl s = (SSLSessionImpl) sessionCache.get(key);
-        sessionCache.remove(key);
-        sessionHostPortCache.remove(getKey(s.getPeerHost(),
-                                         s.getPeerPort()));
-    }
-
-    private int getCacheLimit() {
+    private int getDefaultCacheLimit() {
         int cacheLimit = 0;
         try {
         String s = java.security.AccessController.doPrivileged(
@@ -237,21 +216,40 @@
         return (cacheLimit > 0) ? cacheLimit : 0;
     }
 
-    SSLSession checkTimeValidity(SSLSession sess) {
-        if (isTimedout(sess)) {
+    boolean isTimedout(SSLSession sess) {
+        if (timeout == 0) {
+            return false;
+        }
+
+        if ((sess != null) && ((sess.getCreationTime() + timeout * 1000L)
+                                        <= (System.currentTimeMillis()))) {
             sess.invalidate();
-            return null;
-        } else
-            return sess;
+            return true;
+        }
+
+        return false;
     }
 
-    boolean isTimedout(SSLSession sess) {
-        if (timeoutMillis == 0)
-            return false;
-        if ((sess != null) &&
-            ((sess.getCreationTime() + timeoutMillis)
-                <= (System.currentTimeMillis())))
-            return true;
-        return false;
+    final class SessionCacheVisitor
+            implements sun.security.util.Cache.CacheVisitor {
+        Vector<byte[]> ids = null;
+
+        // public void visit(java.util.Map<Object, Object> map) {}
+        public void visit(java.util.Map<Object, Object> map) {
+            ids = new Vector<byte[]>(map.size());
+
+            for (Object key : map.keySet()) {
+                SSLSessionImpl value = (SSLSessionImpl)map.get(key);
+                if (!isTimedout(value)) {
+                    ids.addElement(((SessionId)key).getId());
+                }
+            }
+        }
+
+        public Enumeration<byte[]> getSessionIds() {
+            return  ids != null ? ids.elements() :
+                                  new Vector<byte[]>().elements();
+        }
     }
+
 }
--- a/jdk/src/share/classes/sun/security/util/Cache.java	Mon Feb 16 17:19:05 2009 +0000
+++ b/jdk/src/share/classes/sun/security/util/Cache.java	Fri Feb 20 12:50:02 2009 +0800
@@ -1,5 +1,5 @@
 /*
- * Copyright 2002-2006 Sun Microsystems, Inc.  All Rights Reserved.
+ * Copyright 2002-2009 Sun Microsystems, Inc.  All Rights Reserved.
  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
  *
  * This code is free software; you can redistribute it and/or modify it
@@ -101,6 +101,21 @@
     public abstract void remove(Object key);
 
     /**
+     * Set the maximum size.
+     */
+    public abstract void setCapacity(int size);
+
+    /**
+     * Set the timeout(in seconds).
+     */
+    public abstract void setTimeout(int timeout);
+
+    /**
+     * accept a visitor
+     */
+    public abstract void accept(CacheVisitor visitor);
+
+    /**
      * Return a new memory cache with the specified maximum size, unlimited
      * lifetime for entries, with the values held by SoftReferences.
      */
@@ -178,6 +193,10 @@
         }
     }
 
+    public interface CacheVisitor {
+        public void visit(Map<Object, Object> map);
+    }
+
 }
 
 class NullCache extends Cache {
@@ -208,6 +227,18 @@
         // empty
     }
 
+    public void setCapacity(int size) {
+        // empty
+    }
+
+    public void setTimeout(int timeout) {
+        // empty
+    }
+
+    public void accept(CacheVisitor visitor) {
+        // empty
+    }
+
 }
 
 class MemoryCache extends Cache {
@@ -218,8 +249,8 @@
     private final static boolean DEBUG = false;
 
     private final Map<Object, CacheEntry> cacheMap;
-    private final int maxSize;
-    private final int lifetime;
+    private int maxSize;
+    private long lifetime;
     private final ReferenceQueue queue;
 
     public MemoryCache(boolean soft, int maxSize) {
@@ -328,7 +359,7 @@
             oldEntry.invalidate();
             return;
         }
-        if (cacheMap.size() > maxSize) {
+        if (maxSize > 0 && cacheMap.size() > maxSize) {
             expungeExpiredEntries();
             if (cacheMap.size() > maxSize) { // still too large?
                 Iterator<CacheEntry> t = cacheMap.values().iterator();
@@ -368,6 +399,55 @@
         }
     }
 
+    public synchronized void setCapacity(int size) {
+        expungeExpiredEntries();
+        if (size > 0 && cacheMap.size() > size) {
+            Iterator<CacheEntry> t = cacheMap.values().iterator();
+            for (int i = cacheMap.size() - size; i > 0; i--) {
+                CacheEntry lruEntry = t.next();
+                if (DEBUG) {
+                    System.out.println("** capacity reset removal "
+                        + lruEntry.getKey() + " | " + lruEntry.getValue());
+                }
+                t.remove();
+                lruEntry.invalidate();
+            }
+        }
+
+        maxSize = size > 0 ? size : 0;
+
+        if (DEBUG) {
+            System.out.println("** capacity reset to " + size);
+        }
+    }
+
+    public synchronized void setTimeout(int timeout) {
+        emptyQueue();
+        lifetime = timeout > 0 ? timeout * 1000L : 0L;
+
+        if (DEBUG) {
+            System.out.println("** lifetime reset to " + timeout);
+        }
+    }
+
+    // it is a heavyweight method.
+    public synchronized void accept(CacheVisitor visitor) {
+        expungeExpiredEntries();
+        Map<Object, Object> cached = getCachedEntries();
+
+        visitor.visit(cached);
+    }
+
+    private Map<Object, Object> getCachedEntries() {
+        Map<Object,Object> kvmap = new HashMap<Object,Object>(cacheMap.size());
+
+        for (CacheEntry entry : cacheMap.values()) {
+            kvmap.put(entry.getKey(), entry.getValue());
+        }
+
+        return kvmap;
+    }
+
     protected CacheEntry newEntry(Object key, Object value,
             long expirationTime, ReferenceQueue queue) {
         if (queue != null) {