src/java.base/share/classes/sun/security/ssl/DTLSOutputRecord.java
changeset 50768 68fa3d4026ea
parent 47216 71c04702a3d5
child 51407 910f7b56592f
--- a/src/java.base/share/classes/sun/security/ssl/DTLSOutputRecord.java	Mon Jun 25 21:22:16 2018 +0300
+++ b/src/java.base/share/classes/sun/security/ssl/DTLSOutputRecord.java	Mon Jun 25 13:41:39 2018 -0700
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 1996, 2016, Oracle and/or its affiliates. All rights reserved.
+ * Copyright (c) 1996, 2018, Oracle and/or its affiliates. All rights reserved.
  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
  *
  * This code is free software; you can redistribute it and/or modify it
@@ -28,13 +28,8 @@
 import java.io.*;
 import java.nio.*;
 import java.util.*;
-
-import javax.crypto.BadPaddingException;
-
 import javax.net.ssl.*;
-
-import sun.security.util.HexDumpEncoder;
-import static sun.security.ssl.Ciphertext.RecordType;
+import sun.security.ssl.SSLCipher.SSLWriteCipher;
 
 /**
  * DTLS {@code OutputRecord} implementation for {@code SSLEngine}.
@@ -47,54 +42,62 @@
 
     int                 prevWriteEpoch;
     Authenticator       prevWriteAuthenticator;
-    CipherBox           prevWriteCipher;
+    SSLWriteCipher      prevWriteCipher;
 
-    private LinkedList<RecordMemo> alertMemos = new LinkedList<>();
+    private final LinkedList<RecordMemo> alertMemos = new LinkedList<>();
 
-    DTLSOutputRecord() {
-        this.writeAuthenticator = new MAC(true);
+    DTLSOutputRecord(HandshakeHash handshakeHash) {
+        super(handshakeHash, SSLWriteCipher.nullDTlsWriteCipher());
 
         this.writeEpoch = 0;
         this.prevWriteEpoch = 0;
-        this.prevWriteCipher = CipherBox.NULL;
-        this.prevWriteAuthenticator = new MAC(true);
+        this.prevWriteCipher = SSLWriteCipher.nullDTlsWriteCipher();
 
         this.packetSize = DTLSRecord.maxRecordSize;
-        this.protocolVersion = ProtocolVersion.DEFAULT_DTLS;
+        this.protocolVersion = ProtocolVersion.NONE;
+    }
+
+    @Override
+    void initHandshaker() {
+        // clean up
+        fragmenter = null;
     }
 
     @Override
-    void changeWriteCiphers(Authenticator writeAuthenticator,
-            CipherBox writeCipher) throws IOException {
+    void finishHandshake() {
+        // Nothing to do here currently.
+    }
 
-        encodeChangeCipherSpec();
+    @Override
+    void changeWriteCiphers(SSLWriteCipher writeCipher,
+            boolean useChangeCipherSpec) throws IOException {
+        if (useChangeCipherSpec) {
+            encodeChangeCipherSpec();
+        }
 
         prevWriteCipher.dispose();
 
-        this.prevWriteAuthenticator = this.writeAuthenticator;
         this.prevWriteCipher = this.writeCipher;
         this.prevWriteEpoch = this.writeEpoch;
 
-        this.writeAuthenticator = writeAuthenticator;
         this.writeCipher = writeCipher;
         this.writeEpoch++;
 
         this.isFirstAppOutputRecord = true;
 
         // set the epoch number
-        this.writeAuthenticator.setEpochNumber(this.writeEpoch);
+        this.writeCipher.authenticator.setEpochNumber(this.writeEpoch);
     }
 
     @Override
     void encodeAlert(byte level, byte description) throws IOException {
         RecordMemo memo = new RecordMemo();
 
-        memo.contentType = Record.ct_alert;
+        memo.contentType = ContentType.ALERT.id;
         memo.majorVersion = protocolVersion.major;
         memo.minorVersion = protocolVersion.minor;
         memo.encodeEpoch = writeEpoch;
         memo.encodeCipher = writeCipher;
-        memo.encodeAuthenticator = writeAuthenticator;
 
         memo.fragment = new byte[2];
         memo.fragment[0] = level;
@@ -114,7 +117,6 @@
     @Override
     void encodeHandshake(byte[] source,
             int offset, int length) throws IOException {
-
         if (firstMessage) {
             firstMessage = false;
         }
@@ -127,30 +129,53 @@
     }
 
     @Override
-    Ciphertext encode(ByteBuffer[] sources, int offset, int length,
+    Ciphertext encode(
+        ByteBuffer[] srcs, int srcsOffset, int srcsLength,
+        ByteBuffer[] dsts, int dstsOffset, int dstsLength) throws IOException {
+        return encode(srcs, srcsOffset, srcsLength, dsts[0]);
+    }
+
+    private Ciphertext encode(ByteBuffer[] sources, int offset, int length,
             ByteBuffer destination) throws IOException {
 
-        if (writeAuthenticator.seqNumOverflow()) {
-            if (debug != null && Debug.isOn("ssl")) {
-                System.out.println(Thread.currentThread().getName() +
-                    ", sequence number extremely close to overflow " +
+        if (writeCipher.authenticator.seqNumOverflow()) {
+            if (SSLLogger.isOn && SSLLogger.isOn("ssl")) {
+                SSLLogger.fine(
+                    "sequence number extremely close to overflow " +
                     "(2^64-1 packets). Closing connection.");
             }
 
             throw new SSLHandshakeException("sequence number overflow");
         }
 
-        // not apply to handshake message
-        int macLen = 0;
-        if (writeAuthenticator instanceof MAC) {
-            macLen = ((MAC)writeAuthenticator).MAClen();
+        // Don't process the incoming record until all of the buffered records
+        // get handled.  May need retransmission if no sources specified.
+        if (!isEmpty() || sources == null || sources.length == 0) {
+            Ciphertext ct = acquireCiphertext(destination);
+            if (ct != null) {
+                return ct;
+            }
         }
 
+        if (sources == null || sources.length == 0) {
+            return null;
+        }
+
+        int srcsRemains = 0;
+        for (int i = offset; i < offset + length; i++) {
+            srcsRemains += sources[i].remaining();
+        }
+
+        if (srcsRemains == 0) {
+            return null;
+        }
+
+        // not apply to handshake message
         int fragLen;
         if (packetSize > 0) {
             fragLen = Math.min(maxRecordSize, packetSize);
             fragLen = writeCipher.calculateFragmentSize(
-                    fragLen, macLen, headerSize);
+                    fragLen, headerSize);
 
             fragLen = Math.min(fragLen, Record.maxDataSize);
         } else {
@@ -183,44 +208,38 @@
         destination.limit(destination.position());
         destination.position(dstContent);
 
-        if ((debug != null) && Debug.isOn("record")) {
-            System.out.println(Thread.currentThread().getName() +
-                    ", WRITE: " + protocolVersion + " " +
-                    Record.contentName(Record.ct_application_data) +
+        if (SSLLogger.isOn && SSLLogger.isOn("record")) {
+            SSLLogger.fine(
+                    "WRITE: " + protocolVersion + " " +
+                    ContentType.APPLICATION_DATA.name +
                     ", length = " + destination.remaining());
         }
 
         // Encrypt the fragment and wrap up a record.
-        long recordSN = encrypt(writeAuthenticator, writeCipher,
-                Record.ct_application_data, destination,
+        long recordSN = encrypt(writeCipher,
+                ContentType.APPLICATION_DATA.id, destination,
                 dstPos, dstLim, headerSize,
-                protocolVersion, true);
+                protocolVersion);
 
-        if ((debug != null) && Debug.isOn("packet")) {
+        if (SSLLogger.isOn && SSLLogger.isOn("packet")) {
             ByteBuffer temporary = destination.duplicate();
             temporary.limit(temporary.position());
             temporary.position(dstPos);
-            Debug.printHex(
-                    "[Raw write]: length = " + temporary.remaining(),
-                    temporary);
+            SSLLogger.fine("Raw write", temporary);
         }
 
         // remain the limit unchanged
         destination.limit(dstLim);
 
-        return new Ciphertext(RecordType.RECORD_APPLICATION_DATA, recordSN);
+        return new Ciphertext(ContentType.APPLICATION_DATA.id,
+                SSLHandshake.NOT_APPLICABLE.id, recordSN);
     }
 
-    @Override
-    Ciphertext acquireCiphertext(ByteBuffer destination) throws IOException {
+    private Ciphertext acquireCiphertext(
+            ByteBuffer destination) throws IOException {
         if (alertMemos != null && !alertMemos.isEmpty()) {
             RecordMemo memo = alertMemos.pop();
 
-            int macLen = 0;
-            if (memo.encodeAuthenticator instanceof MAC) {
-                macLen = ((MAC)memo.encodeAuthenticator).MAClen();
-            }
-
             int dstPos = destination.position();
             int dstLim = destination.limit();
             int dstContent = dstPos + headerSize +
@@ -232,32 +251,32 @@
             destination.limit(destination.position());
             destination.position(dstContent);
 
-            if ((debug != null) && Debug.isOn("record")) {
-                System.out.println(Thread.currentThread().getName() +
-                        ", WRITE: " + protocolVersion + " " +
-                        Record.contentName(Record.ct_alert) +
+            if (SSLLogger.isOn && SSLLogger.isOn("record")) {
+                SSLLogger.fine(
+                        "WRITE: " + protocolVersion + " " +
+                        ContentType.ALERT.name +
                         ", length = " + destination.remaining());
             }
 
             // Encrypt the fragment and wrap up a record.
-            long recordSN = encrypt(memo.encodeAuthenticator, memo.encodeCipher,
-                    Record.ct_alert, destination, dstPos, dstLim, headerSize,
+            long recordSN = encrypt(memo.encodeCipher,
+                    ContentType.ALERT.id,
+                    destination, dstPos, dstLim, headerSize,
                     ProtocolVersion.valueOf(memo.majorVersion,
-                            memo.minorVersion), true);
+                            memo.minorVersion));
 
-            if ((debug != null) && Debug.isOn("packet")) {
+            if (SSLLogger.isOn && SSLLogger.isOn("packet")) {
                 ByteBuffer temporary = destination.duplicate();
                 temporary.limit(temporary.position());
                 temporary.position(dstPos);
-                Debug.printHex(
-                        "[Raw write]: length = " + temporary.remaining(),
-                        temporary);
+                SSLLogger.fine("Raw write", temporary);
             }
 
             // remain the limit unchanged
             destination.limit(dstLim);
 
-            return new Ciphertext(RecordType.RECORD_ALERT, recordSN);
+            return new Ciphertext(ContentType.ALERT.id,
+                    SSLHandshake.NOT_APPLICABLE.id, recordSN);
         }
 
         if (fragmenter != null) {
@@ -274,12 +293,6 @@
     }
 
     @Override
-    void initHandshaker() {
-        // clean up
-        fragmenter = null;
-    }
-
-    @Override
     void launchRetransmission() {
         // Note: Please don't retransmit if there are handshake messages
         // or alerts waiting in the queue.
@@ -295,8 +308,7 @@
         byte            majorVersion;
         byte            minorVersion;
         int             encodeEpoch;
-        CipherBox       encodeCipher;
-        Authenticator   encodeAuthenticator;
+        SSLWriteCipher  encodeCipher;
 
         byte[]          fragment;
     }
@@ -308,7 +320,8 @@
     }
 
     private final class DTLSFragmenter {
-        private LinkedList<RecordMemo> handshakeMemos = new LinkedList<>();
+        private final LinkedList<RecordMemo> handshakeMemos =
+                new LinkedList<>();
         private int acquireIndex = 0;
         private int messageSequence = 0;
         private boolean flightIsReady = false;
@@ -336,12 +349,11 @@
 
             RecordMemo memo = new RecordMemo();
 
-            memo.contentType = Record.ct_change_cipher_spec;
+            memo.contentType = ContentType.CHANGE_CIPHER_SPEC.id;
             memo.majorVersion = protocolVersion.major;
             memo.minorVersion = protocolVersion.minor;
             memo.encodeEpoch = writeEpoch;
             memo.encodeCipher = writeCipher;
-            memo.encodeAuthenticator = writeAuthenticator;
 
             memo.fragment = new byte[1];
             memo.fragment[0] = 1;
@@ -361,12 +373,11 @@
 
             HandshakeMemo memo = new HandshakeMemo();
 
-            memo.contentType = Record.ct_handshake;
+            memo.contentType = ContentType.HANDSHAKE.id;
             memo.majorVersion = protocolVersion.major;
             memo.minorVersion = protocolVersion.minor;
             memo.encodeEpoch = writeEpoch;
             memo.encodeCipher = writeCipher;
-            memo.encodeAuthenticator = writeAuthenticator;
 
             memo.handshakeType = buf[offset];
             memo.messageSequence = messageSequence++;
@@ -379,12 +390,12 @@
             handshakeHashing(memo, memo.fragment);
             handshakeMemos.add(memo);
 
-            if ((memo.handshakeType == HandshakeMessage.ht_client_hello) ||
-                (memo.handshakeType == HandshakeMessage.ht_hello_request) ||
+            if ((memo.handshakeType == SSLHandshake.CLIENT_HELLO.id) ||
+                (memo.handshakeType == SSLHandshake.HELLO_REQUEST.id) ||
                 (memo.handshakeType ==
-                        HandshakeMessage.ht_hello_verify_request) ||
-                (memo.handshakeType == HandshakeMessage.ht_server_hello_done) ||
-                (memo.handshakeType == HandshakeMessage.ht_finished)) {
+                        SSLHandshake.HELLO_VERIFY_REQUEST.id) ||
+                (memo.handshakeType == SSLHandshake.SERVER_HELLO_DONE.id) ||
+                (memo.handshakeType == SSLHandshake.FINISHED.id)) {
 
                 flightIsReady = true;
             }
@@ -401,22 +412,17 @@
 
             RecordMemo memo = handshakeMemos.get(acquireIndex);
             HandshakeMemo hsMemo = null;
-            if (memo.contentType == Record.ct_handshake) {
+            if (memo.contentType == ContentType.HANDSHAKE.id) {
                 hsMemo = (HandshakeMemo)memo;
             }
 
-            int macLen = 0;
-            if (memo.encodeAuthenticator instanceof MAC) {
-                macLen = ((MAC)memo.encodeAuthenticator).MAClen();
-            }
-
             // ChangeCipherSpec message is pretty small.  Don't worry about
             // the fragmentation of ChangeCipherSpec record.
             int fragLen;
             if (packetSize > 0) {
                 fragLen = Math.min(maxRecordSize, packetSize);
                 fragLen = memo.encodeCipher.calculateFragmentSize(
-                        fragLen, macLen, 25);   // 25: header size
+                        fragLen, 25);   // 25: header size
                                                 //   13: DTLS record
                                                 //   12: DTLS handshake message
                 fragLen = Math.min(fragLen, Record.maxDataSize);
@@ -459,27 +465,26 @@
             dstBuf.limit(dstBuf.position());
             dstBuf.position(dstContent);
 
-            if ((debug != null) && Debug.isOn("record")) {
-                System.out.println(Thread.currentThread().getName() +
-                        ", WRITE: " + protocolVersion + " " +
-                        Record.contentName(memo.contentType) +
+            if (SSLLogger.isOn && SSLLogger.isOn("record")) {
+                SSLLogger.fine(
+                        "WRITE: " + protocolVersion + " " +
+                        ContentType.nameOf(memo.contentType) +
                         ", length = " + dstBuf.remaining());
             }
 
             // Encrypt the fragment and wrap up a record.
-            long recordSN = encrypt(memo.encodeAuthenticator, memo.encodeCipher,
+            long recordSN = encrypt(memo.encodeCipher,
                     memo.contentType, dstBuf,
                     dstPos, dstLim, headerSize,
                     ProtocolVersion.valueOf(memo.majorVersion,
-                            memo.minorVersion), true);
+                            memo.minorVersion));
 
-            if ((debug != null) && Debug.isOn("packet")) {
+            if (SSLLogger.isOn && SSLLogger.isOn("packet")) {
                 ByteBuffer temporary = dstBuf.duplicate();
                 temporary.limit(temporary.position());
                 temporary.position(dstPos);
-                Debug.printHex(
-                        "[Raw write]: length = " + temporary.remaining(),
-                        temporary);
+                SSLLogger.fine(
+                        "Raw write (" + temporary.remaining() + ")", temporary);
             }
 
             // remain the limit unchanged
@@ -492,39 +497,23 @@
                     acquireIndex++;
                 }
 
-                return new Ciphertext(RecordType.valueOf(
-                        hsMemo.contentType, hsMemo.handshakeType), recordSN);
+                return new Ciphertext(hsMemo.contentType,
+                        hsMemo.handshakeType, recordSN);
             } else {
                 acquireIndex++;
-                return new Ciphertext(
-                        RecordType.RECORD_CHANGE_CIPHER_SPEC, recordSN);
+                return new Ciphertext(ContentType.CHANGE_CIPHER_SPEC.id,
+                        SSLHandshake.NOT_APPLICABLE.id, recordSN);
             }
         }
 
         private void handshakeHashing(HandshakeMemo hsFrag, byte[] hsBody) {
 
             byte hsType = hsFrag.handshakeType;
-            if ((hsType == HandshakeMessage.ht_hello_request) ||
-                (hsType == HandshakeMessage.ht_hello_verify_request)) {
-
+            if (!handshakeHash.isHashable(hsType)) {
                 // omitted from handshake hash computation
                 return;
             }
 
-            if ((hsFrag.messageSequence == 0) &&
-                (hsType == HandshakeMessage.ht_client_hello)) {
-
-                // omit initial ClientHello message
-                //
-                //  2: ClientHello.client_version
-                // 32: ClientHello.random
-                int sidLen = hsBody[34];
-
-                if (sidLen == 0) {      // empty session_id, initial handshake
-                    return;
-                }
-            }
-
             // calculate the DTLS header
             byte[] temporary = new byte[12];    // 12: handshake header size
 
@@ -550,17 +539,8 @@
             temporary[10] = temporary[2];
             temporary[11] = temporary[3];
 
-            if ((hsType != HandshakeMessage.ht_finished) &&
-                (hsType != HandshakeMessage.ht_certificate_verify)) {
-
-                handshakeHash.update(temporary, 0, 12);
-                handshakeHash.update(hsBody, 0, hsBody.length);
-            } else {
-                // Reserve until this handshake message has been processed.
-                handshakeHash.reserve(temporary, 0, 12);
-                handshakeHash.reserve(hsBody, 0, hsBody.length);
-            }
-
+            handshakeHash.deliver(temporary, 0, 12);
+            handshakeHash.deliver(hsBody, 0, hsBody.length);
         }
 
         boolean isEmpty() {