src/java.base/share/classes/sun/security/ssl/DTLSOutputRecord.java
changeset 51407 910f7b56592f
parent 50768 68fa3d4026ea
child 54443 dfba4e321ab3
--- a/src/java.base/share/classes/sun/security/ssl/DTLSOutputRecord.java	Tue Aug 14 19:52:34 2018 -0400
+++ b/src/java.base/share/classes/sun/security/ssl/DTLSOutputRecord.java	Tue Aug 14 18:16:47 2018 -0700
@@ -44,7 +44,7 @@
     Authenticator       prevWriteAuthenticator;
     SSLWriteCipher      prevWriteCipher;
 
-    private final LinkedList<RecordMemo> alertMemos = new LinkedList<>();
+    private volatile boolean isCloseWaiting = false;
 
     DTLSOutputRecord(HandshakeHash handshakeHash) {
         super(handshakeHash, SSLWriteCipher.nullDTlsWriteCipher());
@@ -58,6 +58,21 @@
     }
 
     @Override
+    public synchronized void close() throws IOException {
+        if (!isClosed) {
+            if (fragmenter != null && fragmenter.hasAlert()) {
+                isCloseWaiting = true;
+            } else {
+                super.close();
+            }
+        }
+    }
+
+    boolean isClosed() {
+        return isClosed || isCloseWaiting;
+    }
+
+    @Override
     void initHandshaker() {
         // clean up
         fragmenter = null;
@@ -71,6 +86,14 @@
     @Override
     void changeWriteCiphers(SSLWriteCipher writeCipher,
             boolean useChangeCipherSpec) throws IOException {
+        if (isClosed()) {
+            if (SSLLogger.isOn && SSLLogger.isOn("ssl")) {
+                SSLLogger.warning("outbound has closed, ignore outbound " +
+                    "change_cipher_spec message");
+            }
+            return;
+        }
+
         if (useChangeCipherSpec) {
             encodeChangeCipherSpec();
         }
@@ -91,23 +114,31 @@
 
     @Override
     void encodeAlert(byte level, byte description) throws IOException {
-        RecordMemo memo = new RecordMemo();
+        if (isClosed()) {
+            if (SSLLogger.isOn && SSLLogger.isOn("ssl")) {
+                SSLLogger.warning("outbound has closed, ignore outbound " +
+                    "alert message: " + Alert.nameOf(description));
+            }
+            return;
+        }
 
-        memo.contentType = ContentType.ALERT.id;
-        memo.majorVersion = protocolVersion.major;
-        memo.minorVersion = protocolVersion.minor;
-        memo.encodeEpoch = writeEpoch;
-        memo.encodeCipher = writeCipher;
+        if (fragmenter == null) {
+           fragmenter = new DTLSFragmenter();
+        }
 
-        memo.fragment = new byte[2];
-        memo.fragment[0] = level;
-        memo.fragment[1] = description;
-
-        alertMemos.add(memo);
+        fragmenter.queueUpAlert(level, description);
     }
 
     @Override
     void encodeChangeCipherSpec() throws IOException {
+        if (isClosed()) {
+            if (SSLLogger.isOn && SSLLogger.isOn("ssl")) {
+                SSLLogger.warning("outbound has closed, ignore outbound " +
+                    "change_cipher_spec message");
+            }
+            return;
+        }
+
         if (fragmenter == null) {
            fragmenter = new DTLSFragmenter();
         }
@@ -117,6 +148,15 @@
     @Override
     void encodeHandshake(byte[] source,
             int offset, int length) throws IOException {
+        if (isClosed()) {
+            if (SSLLogger.isOn && SSLLogger.isOn("ssl")) {
+                SSLLogger.warning("outbound has closed, ignore outbound " +
+                        "handshake message",
+                        ByteBuffer.wrap(source, offset, length));
+            }
+            return;
+        }
+
         if (firstMessage) {
             firstMessage = false;
         }
@@ -132,6 +172,23 @@
     Ciphertext encode(
         ByteBuffer[] srcs, int srcsOffset, int srcsLength,
         ByteBuffer[] dsts, int dstsOffset, int dstsLength) throws IOException {
+
+        if (isClosed) {
+            if (SSLLogger.isOn && SSLLogger.isOn("ssl")) {
+                SSLLogger.warning("outbound has closed, ignore outbound " +
+                    "application data or cached messages");
+            }
+
+            return null;
+        } else if (isCloseWaiting) {
+            if (SSLLogger.isOn && SSLLogger.isOn("ssl")) {
+                SSLLogger.warning("outbound has closed, ignore outbound " +
+                    "application data");
+            }
+
+            srcs = null;    // use no application data.
+        }
+
         return encode(srcs, srcsOffset, srcsLength, dsts[0]);
     }
 
@@ -237,48 +294,6 @@
 
     private Ciphertext acquireCiphertext(
             ByteBuffer destination) throws IOException {
-        if (alertMemos != null && !alertMemos.isEmpty()) {
-            RecordMemo memo = alertMemos.pop();
-
-            int dstPos = destination.position();
-            int dstLim = destination.limit();
-            int dstContent = dstPos + headerSize +
-                                writeCipher.getExplicitNonceSize();
-            destination.position(dstContent);
-
-            destination.put(memo.fragment);
-
-            destination.limit(destination.position());
-            destination.position(dstContent);
-
-            if (SSLLogger.isOn && SSLLogger.isOn("record")) {
-                SSLLogger.fine(
-                        "WRITE: " + protocolVersion + " " +
-                        ContentType.ALERT.name +
-                        ", length = " + destination.remaining());
-            }
-
-            // Encrypt the fragment and wrap up a record.
-            long recordSN = encrypt(memo.encodeCipher,
-                    ContentType.ALERT.id,
-                    destination, dstPos, dstLim, headerSize,
-                    ProtocolVersion.valueOf(memo.majorVersion,
-                            memo.minorVersion));
-
-            if (SSLLogger.isOn && SSLLogger.isOn("packet")) {
-                ByteBuffer temporary = destination.duplicate();
-                temporary.limit(temporary.position());
-                temporary.position(dstPos);
-                SSLLogger.fine("Raw write", temporary);
-            }
-
-            // remain the limit unchanged
-            destination.limit(dstLim);
-
-            return new Ciphertext(ContentType.ALERT.id,
-                    SSLHandshake.NOT_APPLICABLE.id, recordSN);
-        }
-
         if (fragmenter != null) {
             return fragmenter.acquireCiphertext(destination);
         }
@@ -288,16 +303,14 @@
 
     @Override
     boolean isEmpty() {
-        return ((fragmenter == null) || fragmenter.isEmpty()) &&
-               ((alertMemos == null) || alertMemos.isEmpty());
+        return (fragmenter == null) || fragmenter.isEmpty();
     }
 
     @Override
     void launchRetransmission() {
         // Note: Please don't retransmit if there are handshake messages
         // or alerts waiting in the queue.
-        if (((alertMemos == null) || alertMemos.isEmpty()) &&
-                (fragmenter != null) && fragmenter.isRetransmittable()) {
+        if ((fragmenter != null) && fragmenter.isRetransmittable()) {
             fragmenter.setRetransmission();
         }
     }
@@ -338,29 +351,6 @@
         // size is bigger than 256 bytes.
         private int retransmits = 2;            // attemps of retransmits
 
-        void queueUpChangeCipherSpec() {
-
-            // Cleanup if a new flight starts.
-            if (flightIsReady) {
-                handshakeMemos.clear();
-                acquireIndex = 0;
-                flightIsReady = false;
-            }
-
-            RecordMemo memo = new RecordMemo();
-
-            memo.contentType = ContentType.CHANGE_CIPHER_SPEC.id;
-            memo.majorVersion = protocolVersion.major;
-            memo.minorVersion = protocolVersion.minor;
-            memo.encodeEpoch = writeEpoch;
-            memo.encodeCipher = writeCipher;
-
-            memo.fragment = new byte[1];
-            memo.fragment[0] = 1;
-
-            handshakeMemos.add(memo);
-        }
-
         void queueUpHandshake(byte[] buf,
                 int offset, int length) throws IOException {
 
@@ -401,6 +391,45 @@
             }
         }
 
+        void queueUpChangeCipherSpec() {
+
+            // Cleanup if a new flight starts.
+            if (flightIsReady) {
+                handshakeMemos.clear();
+                acquireIndex = 0;
+                flightIsReady = false;
+            }
+
+            RecordMemo memo = new RecordMemo();
+
+            memo.contentType = ContentType.CHANGE_CIPHER_SPEC.id;
+            memo.majorVersion = protocolVersion.major;
+            memo.minorVersion = protocolVersion.minor;
+            memo.encodeEpoch = writeEpoch;
+            memo.encodeCipher = writeCipher;
+
+            memo.fragment = new byte[1];
+            memo.fragment[0] = 1;
+
+            handshakeMemos.add(memo);
+        }
+
+        void queueUpAlert(byte level, byte description) throws IOException {
+            RecordMemo memo = new RecordMemo();
+
+            memo.contentType = ContentType.ALERT.id;
+            memo.majorVersion = protocolVersion.major;
+            memo.minorVersion = protocolVersion.minor;
+            memo.encodeEpoch = writeEpoch;
+            memo.encodeCipher = writeCipher;
+
+            memo.fragment = new byte[2];
+            memo.fragment[0] = level;
+            memo.fragment[1] = description;
+
+            handshakeMemos.add(memo);
+        }
+
         Ciphertext acquireCiphertext(ByteBuffer dstBuf) throws IOException {
             if (isEmpty()) {
                 if (isRetransmittable()) {
@@ -500,8 +529,13 @@
                 return new Ciphertext(hsMemo.contentType,
                         hsMemo.handshakeType, recordSN);
             } else {
+                if (isCloseWaiting &&
+                        memo.contentType == ContentType.ALERT.id) {
+                    close();
+                }
+
                 acquireIndex++;
-                return new Ciphertext(ContentType.CHANGE_CIPHER_SPEC.id,
+                return new Ciphertext(memo.contentType,
                         SSLHandshake.NOT_APPLICABLE.id, recordSN);
             }
         }
@@ -552,6 +586,16 @@
             return false;
         }
 
+        boolean hasAlert() {
+            for (RecordMemo memo : handshakeMemos) {
+                if (memo.contentType == ContentType.ALERT.id) {
+                    return true;
+                }
+            }
+
+            return false;
+        }
+
         boolean isRetransmittable() {
             return (flightIsReady && !handshakeMemos.isEmpty() &&
                                 (acquireIndex >= handshakeMemos.size()));