src/java.base/share/classes/sun/security/ssl/SSLEngineInputRecord.java
branchJDK-8145252-TLS13-branch
changeset 56542 56aaa6cb3693
parent 47216 71c04702a3d5
child 56694 aa54a1f8e426
equal deleted inserted replaced
56541:92cbbfc996f3 56542:56aaa6cb3693
     1 /*
     1 /*
     2  * Copyright (c) 1996, 2014, Oracle and/or its affiliates. All rights reserved.
     2  * Copyright (c) 1996, 2018, Oracle and/or its affiliates. All rights reserved.
     3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
     3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
     4  *
     4  *
     5  * This code is free software; you can redistribute it and/or modify it
     5  * This code is free software; you can redistribute it and/or modify it
     6  * under the terms of the GNU General Public License version 2 only, as
     6  * under the terms of the GNU General Public License version 2 only, as
     7  * published by the Free Software Foundation.  Oracle designates this
     7  * published by the Free Software Foundation.  Oracle designates this
    25 
    25 
    26 package sun.security.ssl;
    26 package sun.security.ssl;
    27 
    27 
    28 import java.io.*;
    28 import java.io.*;
    29 import java.nio.*;
    29 import java.nio.*;
    30 
    30 import java.security.GeneralSecurityException;
       
    31 import java.util.ArrayList;
    31 import javax.crypto.BadPaddingException;
    32 import javax.crypto.BadPaddingException;
    32 
       
    33 import javax.net.ssl.*;
    33 import javax.net.ssl.*;
    34 
    34 import sun.security.ssl.SSLCipher.SSLReadCipher;
    35 import sun.security.util.HexDumpEncoder;
       
    36 
       
    37 
    35 
    38 /**
    36 /**
    39  * {@code InputRecord} implementation for {@code SSLEngine}.
    37  * {@code InputRecord} implementation for {@code SSLEngine}.
    40  */
    38  */
    41 final class SSLEngineInputRecord extends InputRecord implements SSLRecord {
    39 final class SSLEngineInputRecord extends InputRecord implements SSLRecord {
    44     private int hsMsgOff = 0;
    42     private int hsMsgOff = 0;
    45     private int hsMsgLen = 0;
    43     private int hsMsgLen = 0;
    46 
    44 
    47     private boolean formatVerified = false;     // SSLv2 ruled out?
    45     private boolean formatVerified = false;     // SSLv2 ruled out?
    48 
    46 
    49     SSLEngineInputRecord() {
    47     // Cache for incomplete handshake messages.
    50         this.readAuthenticator = MAC.TLS_NULL;
    48     private ByteBuffer handshakeBuffer = null;
       
    49 
       
    50     SSLEngineInputRecord(HandshakeHash handshakeHash) {
       
    51         super(handshakeHash, SSLReadCipher.nullTlsReadCipher());
    51     }
    52     }
    52 
    53 
    53     @Override
    54     @Override
    54     int estimateFragmentSize(int packetSize) {
    55     int estimateFragmentSize(int packetSize) {
    55         int macLen = 0;
       
    56         if (readAuthenticator instanceof MAC) {
       
    57             macLen = ((MAC)readAuthenticator).MAClen();
       
    58         }
       
    59 
       
    60         if (packetSize > 0) {
    56         if (packetSize > 0) {
    61             return readCipher.estimateFragmentSize(
    57             return readCipher.estimateFragmentSize(packetSize, headerSize);
    62                     packetSize, macLen, headerSize);
       
    63         } else {
    58         } else {
    64             return Record.maxDataSize;
    59             return Record.maxDataSize;
    65         }
    60         }
    66     }
    61     }
    67 
    62 
    68     @Override
    63     @Override
    69     int bytesInCompletePacket(ByteBuffer packet) throws SSLException {
    64     int bytesInCompletePacket(
       
    65         ByteBuffer[] srcs, int srcsOffset, int srcsLength) throws IOException {
       
    66 
       
    67         return bytesInCompletePacket(srcs[srcsOffset]);
       
    68     }
       
    69 
       
    70     private int bytesInCompletePacket(ByteBuffer packet) throws SSLException {
    70         /*
    71         /*
    71          * SSLv2 length field is in bytes 0/1
    72          * SSLv2 length field is in bytes 0/1
    72          * SSLv3/TLS length field is in bytes 3/4
    73          * SSLv3/TLS length field is in bytes 3/4
    73          */
    74          */
    74         if (packet.remaining() < 5) {
    75         if (packet.remaining() < 5) {
    85          * ignore the verifications steps, and jump right to the
    86          * ignore the verifications steps, and jump right to the
    86          * determination.  Otherwise, try one last hueristic to
    87          * determination.  Otherwise, try one last hueristic to
    87          * see if it's SSL/TLS.
    88          * see if it's SSL/TLS.
    88          */
    89          */
    89         if (formatVerified ||
    90         if (formatVerified ||
    90                 (byteZero == ct_handshake) || (byteZero == ct_alert)) {
    91                 (byteZero == ContentType.HANDSHAKE.id) ||
       
    92                 (byteZero == ContentType.ALERT.id)) {
    91             /*
    93             /*
    92              * Last sanity check that it's not a wild record
    94              * Last sanity check that it's not a wild record
    93              */
    95              */
    94             ProtocolVersion recordVersion = ProtocolVersion.valueOf(
    96             byte majorVersion = packet.get(pos + 1);
    95                                     packet.get(pos + 1), packet.get(pos + 2));
    97             byte minorVersion = packet.get(pos + 2);
    96 
    98             if (!ProtocolVersion.isNegotiable(
    97             // check the record version
    99                     majorVersion, minorVersion, false, false)) {
    98             checkRecordVersion(recordVersion, false);
   100                 throw new SSLException("Unrecognized record version " +
       
   101                         ProtocolVersion.nameOf(majorVersion, minorVersion) +
       
   102                         " , plaintext connection?");
       
   103             }
    99 
   104 
   100             /*
   105             /*
   101              * Reasonably sure this is a V3, disable further checks.
   106              * Reasonably sure this is a V3, disable further checks.
   102              * We can't do the same in the v2 check below, because
   107              * We can't do the same in the v2 check below, because
   103              * read still needs to parse/handle the v2 clientHello.
   108              * read still needs to parse/handle the v2 clientHello.
   121             boolean isShort = ((byteZero & 0x80) != 0);
   126             boolean isShort = ((byteZero & 0x80) != 0);
   122 
   127 
   123             if (isShort &&
   128             if (isShort &&
   124                     ((packet.get(pos + 2) == 1) || packet.get(pos + 2) == 4)) {
   129                     ((packet.get(pos + 2) == 1) || packet.get(pos + 2) == 4)) {
   125 
   130 
   126                 ProtocolVersion recordVersion = ProtocolVersion.valueOf(
   131                 byte majorVersion = packet.get(pos + 3);
   127                                     packet.get(pos + 3), packet.get(pos + 4));
   132                 byte minorVersion = packet.get(pos + 4);
   128 
   133                 if (!ProtocolVersion.isNegotiable(
   129                 // check the record version
   134                         majorVersion, minorVersion, false, false)) {
   130                 checkRecordVersion(recordVersion, true);
   135                     throw new SSLException("Unrecognized record version " +
       
   136                             ProtocolVersion.nameOf(majorVersion, minorVersion) +
       
   137                             " , plaintext connection?");
       
   138                 }
   131 
   139 
   132                 /*
   140                 /*
   133                  * Client or Server Hello
   141                  * Client or Server Hello
   134                  */
   142                  */
   135                 int mask = (isShort ? 0x7F : 0x3F);
   143                 int mask = (isShort ? 0x7F : 0x3F);
   145 
   153 
   146         return len;
   154         return len;
   147     }
   155     }
   148 
   156 
   149     @Override
   157     @Override
   150     void checkRecordVersion(ProtocolVersion recordVersion,
   158     Plaintext[] decode(ByteBuffer[] srcs, int srcsOffset,
   151             boolean allowSSL20Hello) throws SSLException {
   159             int srcsLength) throws IOException, BadPaddingException {
   152 
   160         if (srcs == null || srcs.length == 0 || srcsLength == 0) {
   153         if (recordVersion.maybeDTLSProtocol()) {
   161             return new Plaintext[0];
   154             throw new SSLException(
   162         } else if (srcsLength == 1) {
   155                     "Unrecognized record version " + recordVersion +
   163             return decode(srcs[srcsOffset]);
   156                     " , DTLS packet?");
   164         } else {
   157         }
   165             ByteBuffer packet = extract(srcs,
   158 
   166                     srcsOffset, srcsLength, SSLRecord.headerSize);
   159         // Check if the record version is too old.
   167 
   160         if ((recordVersion.v < ProtocolVersion.MIN.v)) {
   168             return decode(packet);
   161             // if it's not SSLv2, we're out of here.
   169         }
   162             if (!allowSSL20Hello ||
   170     }
   163                     (recordVersion.v != ProtocolVersion.SSL20Hello.v)) {
   171 
   164                 throw new SSLException(
   172     private Plaintext[] decode(ByteBuffer packet)
   165                     "Unsupported record version " + recordVersion);
       
   166             }
       
   167         }
       
   168     }
       
   169 
       
   170     @Override
       
   171     Plaintext decode(ByteBuffer packet)
       
   172             throws IOException, BadPaddingException {
   173             throws IOException, BadPaddingException {
   173 
   174 
   174         if (isClosed) {
   175         if (isClosed) {
   175             return null;
   176             return null;
   176         }
   177         }
   177 
   178 
   178         if (debug != null && Debug.isOn("packet")) {
   179         if (SSLLogger.isOn && SSLLogger.isOn("packet")) {
   179              Debug.printHex(
   180             SSLLogger.fine("Raw read", packet);
   180                     "[Raw read]: length = " + packet.remaining(), packet);
       
   181         }
   181         }
   182 
   182 
   183         // The caller should have validated the record.
   183         // The caller should have validated the record.
   184         if (!formatVerified) {
   184         if (!formatVerified) {
   185             formatVerified = true;
   185             formatVerified = true;
   189              * alert message. If it's not, it is either invalid or an
   189              * alert message. If it's not, it is either invalid or an
   190              * SSLv2 message.
   190              * SSLv2 message.
   191              */
   191              */
   192             int pos = packet.position();
   192             int pos = packet.position();
   193             byte byteZero = packet.get(pos);
   193             byte byteZero = packet.get(pos);
   194             if (byteZero != ct_handshake && byteZero != ct_alert) {
   194             if (byteZero != ContentType.HANDSHAKE.id &&
       
   195                     byteZero != ContentType.ALERT.id) {
   195                 return handleUnknownRecord(packet);
   196                 return handleUnknownRecord(packet);
   196             }
   197             }
   197         }
   198         }
   198 
   199 
   199         return decodeInputRecord(packet);
   200         return decodeInputRecord(packet);
   200     }
   201     }
   201 
   202 
   202     private Plaintext decodeInputRecord(ByteBuffer packet)
   203     private Plaintext[] decodeInputRecord(ByteBuffer packet)
   203             throws IOException, BadPaddingException {
   204             throws IOException, BadPaddingException {
   204 
       
   205         //
   205         //
   206         // The packet should be a complete record, or more.
   206         // The packet should be a complete record, or more.
   207         //
   207         //
   208 
       
   209         int srcPos = packet.position();
   208         int srcPos = packet.position();
   210         int srcLim = packet.limit();
   209         int srcLim = packet.limit();
   211 
   210 
   212         byte contentType = packet.get();                   // pos: 0
   211         byte contentType = packet.get();                   // pos: 0
   213         byte majorVersion = packet.get();                  // pos: 1
   212         byte majorVersion = packet.get();                  // pos: 1
   214         byte minorVersion = packet.get();                  // pos: 2
   213         byte minorVersion = packet.get();                  // pos: 2
   215         int contentLen = ((packet.get() & 0xFF) << 8) +
   214         int contentLen = Record.getInt16(packet);          // pos: 3, 4
   216                           (packet.get() & 0xFF);           // pos: 3, 4
   215 
   217 
   216         if (SSLLogger.isOn && SSLLogger.isOn("record")) {
   218         if (debug != null && Debug.isOn("record")) {
   217             SSLLogger.fine(
   219              System.out.println(Thread.currentThread().getName() +
   218                     "READ: " +
   220                     ", READ: " +
   219                     ProtocolVersion.nameOf(majorVersion, minorVersion) +
   221                     ProtocolVersion.valueOf(majorVersion, minorVersion) +
   220                     " " + ContentType.nameOf(contentType) + ", length = " +
   222                     " " + Record.contentName(contentType) + ", length = " +
       
   223                     contentLen);
   221                     contentLen);
   224         }
   222         }
   225 
   223 
   226         //
   224         //
   227         // Check for upper bound.
   225         // Check for upper bound.
   233         }
   231         }
   234 
   232 
   235         //
   233         //
   236         // check for handshake fragment
   234         // check for handshake fragment
   237         //
   235         //
   238         if ((contentType != ct_handshake) && (hsMsgOff != hsMsgLen)) {
   236         if (contentType != ContentType.HANDSHAKE.id && hsMsgOff != hsMsgLen) {
   239             throw new SSLProtocolException(
   237             throw new SSLProtocolException(
   240                     "Expected to get a handshake fragment");
   238                     "Expected to get a handshake fragment");
   241         }
   239         }
   242 
   240 
   243         //
   241         //
   245         //
   243         //
   246         int recLim = srcPos + SSLRecord.headerSize + contentLen;
   244         int recLim = srcPos + SSLRecord.headerSize + contentLen;
   247         packet.limit(recLim);
   245         packet.limit(recLim);
   248         packet.position(srcPos + SSLRecord.headerSize);
   246         packet.position(srcPos + SSLRecord.headerSize);
   249 
   247 
   250         ByteBuffer plaintext;
   248         ByteBuffer fragment;
   251         try {
   249         try {
   252             plaintext =
   250             Plaintext plaintext =
   253                 decrypt(readAuthenticator, readCipher, contentType, packet);
   251                     readCipher.decrypt(contentType, packet, null);
       
   252             fragment = plaintext.fragment;
       
   253             contentType = plaintext.contentType;
       
   254         } catch (BadPaddingException bpe) {
       
   255             throw bpe;
       
   256         } catch (GeneralSecurityException gse) {
       
   257             throw (SSLProtocolException)(new SSLProtocolException(
       
   258                     "Unexpected exception")).initCause(gse);
   254         } finally {
   259         } finally {
   255             // comsume a complete record
   260             // comsume a complete record
   256             packet.limit(srcLim);
   261             packet.limit(srcLim);
   257             packet.position(recLim);
   262             packet.position(recLim);
   258         }
   263         }
   259 
   264 
   260         //
   265         //
   261         // handshake hashing
   266         // parse handshake messages
   262         //
   267         //
   263         if (contentType == ct_handshake) {
   268         if (contentType == ContentType.HANDSHAKE.id) {
   264             int pltPos = plaintext.position();
   269             ByteBuffer handshakeFrag = fragment;
   265             int pltLim = plaintext.limit();
   270             if ((handshakeBuffer != null) &&
   266             int frgPos = pltPos;
   271                     (handshakeBuffer.remaining() != 0)) {
   267             for (int remains = plaintext.remaining(); remains > 0;) {
   272                 ByteBuffer bb = ByteBuffer.wrap(new byte[
   268                 int howmuch;
   273                         handshakeBuffer.remaining() + fragment.remaining()]);
   269                 byte handshakeType;
   274                 bb.put(handshakeBuffer);
   270                 if (hsMsgOff < hsMsgLen) {
   275                 bb.put(fragment);
   271                     // a fragment of the handshake message
   276                 handshakeFrag = bb.rewind();
   272                     howmuch = Math.min((hsMsgLen - hsMsgOff), remains);
   277                 handshakeBuffer = null;
   273                     handshakeType = prevType;
   278             }
   274 
   279 
   275                     hsMsgOff += howmuch;
   280             ArrayList<Plaintext> plaintexts = new ArrayList<>(5);
   276                     if (hsMsgOff == hsMsgLen) {
   281             while (handshakeFrag.hasRemaining()) {
   277                         // Now is a complete handshake message.
   282                 int remaining = handshakeFrag.remaining();
   278                         hsMsgOff = 0;
   283                 if (remaining < handshakeHeaderSize) {
   279                         hsMsgLen = 0;
   284                     handshakeBuffer = ByteBuffer.wrap(new byte[remaining]);
       
   285                     handshakeBuffer.put(handshakeFrag);
       
   286                     handshakeBuffer.rewind();
       
   287                     break;
       
   288                 }
       
   289 
       
   290                 handshakeFrag.mark();
       
   291                 // skip the first byte: handshake type
       
   292                 byte handshakeType = handshakeFrag.get();
       
   293                 int handshakeBodyLen = Record.getInt24(handshakeFrag);
       
   294                 handshakeFrag.reset();
       
   295                 int handshakeMessageLen =
       
   296                         handshakeHeaderSize + handshakeBodyLen;
       
   297                 if (remaining < handshakeMessageLen) {
       
   298                     handshakeBuffer = ByteBuffer.wrap(new byte[remaining]);
       
   299                     handshakeBuffer.put(handshakeFrag);
       
   300                     handshakeBuffer.rewind();
       
   301                     break;
       
   302                 } if (remaining == handshakeMessageLen) {
       
   303                     if (handshakeHash.isHashable(handshakeType)) {
       
   304                         handshakeHash.receive(handshakeFrag);
   280                     }
   305                     }
   281                 } else {    // hsMsgOff == hsMsgLen, a new handshake message
   306 
   282                     handshakeType = plaintext.get();
   307                     plaintexts.add(
   283                     int handshakeLen = ((plaintext.get() & 0xFF) << 16) |
   308                         new Plaintext(contentType,
   284                                        ((plaintext.get() & 0xFF) << 8) |
   309                             majorVersion, minorVersion, -1, -1L, handshakeFrag)
   285                                         (plaintext.get() & 0xFF);
   310                     );
   286                     plaintext.position(frgPos);
   311                     break;
   287                     if (remains < (handshakeLen + 4)) { // 4: handshake header
   312                 } else {
   288                         // This handshake message is fragmented.
   313                     int fragPos = handshakeFrag.position();
   289                         prevType = handshakeType;
   314                     int fragLim = handshakeFrag.limit();
   290                         hsMsgOff = remains - 4;         // 4: handshake header
   315                     int nextPos = fragPos + handshakeMessageLen;
   291                         hsMsgLen = handshakeLen;
   316                     handshakeFrag.limit(nextPos);
       
   317 
       
   318                     if (handshakeHash.isHashable(handshakeType)) {
       
   319                         handshakeHash.receive(handshakeFrag);
   292                     }
   320                     }
   293 
   321 
   294                     howmuch = Math.min(handshakeLen + 4, remains);
   322                     plaintexts.add(
       
   323                         new Plaintext(contentType, majorVersion, minorVersion,
       
   324                             -1, -1L, handshakeFrag.slice())
       
   325                     );
       
   326 
       
   327                     handshakeFrag.position(nextPos);
       
   328                     handshakeFrag.limit(fragLim);
   295                 }
   329                 }
   296 
   330             }
   297                 plaintext.limit(frgPos + howmuch);
   331 
   298 
   332             return plaintexts.toArray(new Plaintext[0]);
   299                 if (handshakeType == HandshakeMessage.ht_hello_request) {
   333         }
   300                     // omitted from handshake hash computation
   334 
   301                 } else if ((handshakeType != HandshakeMessage.ht_finished) &&
   335         return new Plaintext[] {
   302                     (handshakeType != HandshakeMessage.ht_certificate_verify)) {
   336             new Plaintext(contentType,
   303 
   337                 majorVersion, minorVersion, -1, -1L, fragment)
   304                     if (handshakeHash == null) {
   338         };
   305                         // used for cache only
   339     }
   306                         handshakeHash = new HandshakeHash(false);
   340 
   307                     }
   341     private Plaintext[] handleUnknownRecord(ByteBuffer packet)
   308                     handshakeHash.update(plaintext);
       
   309                 } else {
       
   310                     // Reserve until this handshake message has been processed.
       
   311                     if (handshakeHash == null) {
       
   312                         // used for cache only
       
   313                         handshakeHash = new HandshakeHash(false);
       
   314                     }
       
   315                     handshakeHash.reserve(plaintext);
       
   316                 }
       
   317 
       
   318                 plaintext.position(frgPos + howmuch);
       
   319                 plaintext.limit(pltLim);
       
   320 
       
   321                 frgPos += howmuch;
       
   322                 remains -= howmuch;
       
   323             }
       
   324 
       
   325             plaintext.position(pltPos);
       
   326         }
       
   327 
       
   328         return new Plaintext(contentType,
       
   329                 majorVersion, minorVersion, -1, -1L, plaintext);
       
   330                 // recordEpoch, recordSeq, plaintext);
       
   331     }
       
   332 
       
   333     private Plaintext handleUnknownRecord(ByteBuffer packet)
       
   334             throws IOException, BadPaddingException {
   342             throws IOException, BadPaddingException {
   335 
       
   336         //
   343         //
   337         // The packet should be a complete record.
   344         // The packet should be a complete record.
   338         //
   345         //
   339         int srcPos = packet.position();
   346         int srcPos = packet.position();
   340         int srcLim = packet.limit();
   347         int srcLim = packet.limit();
   361                  * Looks like a V2 client hello, but not one saying
   368                  * Looks like a V2 client hello, but not one saying
   362                  * "let's talk SSLv3".  So we need to send an SSLv2
   369                  * "let's talk SSLv3".  So we need to send an SSLv2
   363                  * error message, one that's treated as fatal by
   370                  * error message, one that's treated as fatal by
   364                  * clients (Otherwise we'll hang.)
   371                  * clients (Otherwise we'll hang.)
   365                  */
   372                  */
   366                 if (debug != null && Debug.isOn("record")) {
   373                 if (SSLLogger.isOn && SSLLogger.isOn("record")) {
   367                      System.out.println(Thread.currentThread().getName() +
   374                    SSLLogger.fine(
   368                             "Requested to negotiate unsupported SSLv2!");
   375                             "Requested to negotiate unsupported SSLv2!");
   369                 }
   376                 }
   370 
   377 
   371                 // hack code, the exception is caught in SSLEngineImpl
   378                 // hack code, the exception is caught in SSLEngineImpl
   372                 // so that SSLv2 error message can be delivered properly.
   379                 // so that SSLv2 error message can be delivered properly.
   378              * If we can map this into a V3 ClientHello, read and
   385              * If we can map this into a V3 ClientHello, read and
   379              * hash the rest of the V2 handshake, turn it into a
   386              * hash the rest of the V2 handshake, turn it into a
   380              * V3 ClientHello message, and pass it up.
   387              * V3 ClientHello message, and pass it up.
   381              */
   388              */
   382             packet.position(srcPos + 2);        // exclude the header
   389             packet.position(srcPos + 2);        // exclude the header
   383 
   390             handshakeHash.receive(packet);
   384             if (handshakeHash == null) {
       
   385                 // used for cache only
       
   386                 handshakeHash = new HandshakeHash(false);
       
   387             }
       
   388             handshakeHash.update(packet);
       
   389             packet.position(srcPos);
   391             packet.position(srcPos);
   390 
   392 
   391             ByteBuffer converted = convertToClientHello(packet);
   393             ByteBuffer converted = convertToClientHello(packet);
   392 
   394 
   393             if (debug != null && Debug.isOn("packet")) {
   395             if (SSLLogger.isOn && SSLLogger.isOn("packet")) {
   394                  Debug.printHex(
   396                 SSLLogger.fine(
   395                         "[Converted] ClientHello", converted);
   397                         "[Converted] ClientHello", converted);
   396             }
   398             }
   397 
   399 
   398             return new Plaintext(ct_handshake,
   400             return new Plaintext[] {
   399                 majorVersion, minorVersion, -1, -1L, converted);
   401                     new Plaintext(ContentType.HANDSHAKE.id,
       
   402                     majorVersion, minorVersion, -1, -1L, converted)
       
   403                 };
   400         } else {
   404         } else {
   401             if (((firstByte & 0x80) != 0) && (thirdByte == 4)) {
   405             if (((firstByte & 0x80) != 0) && (thirdByte == 4)) {
   402                 throw new SSLException("SSL V2.0 servers are not supported.");
   406                 throw new SSLException("SSL V2.0 servers are not supported.");
   403             }
   407             }
   404 
   408 
   405             throw new SSLException("Unsupported or unrecognized SSL message");
   409             throw new SSLException("Unsupported or unrecognized SSL message");
   406         }
   410         }
   407     }
   411     }
   408 
       
   409 }
   412 }