diff -r 92cbbfc996f3 -r 56aaa6cb3693 src/java.base/share/classes/sun/security/ssl/SSLEngineInputRecord.java --- a/src/java.base/share/classes/sun/security/ssl/SSLEngineInputRecord.java Fri May 11 14:55:56 2018 -0700 +++ b/src/java.base/share/classes/sun/security/ssl/SSLEngineInputRecord.java Fri May 11 15:53:12 2018 -0700 @@ -1,5 +1,5 @@ /* - * Copyright (c) 1996, 2014, Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 1996, 2018, Oracle and/or its affiliates. All rights reserved. * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. * * This code is free software; you can redistribute it and/or modify it @@ -27,13 +27,11 @@ import java.io.*; import java.nio.*; - +import java.security.GeneralSecurityException; +import java.util.ArrayList; import javax.crypto.BadPaddingException; - import javax.net.ssl.*; - -import sun.security.util.HexDumpEncoder; - +import sun.security.ssl.SSLCipher.SSLReadCipher; /** * {@code InputRecord} implementation for {@code SSLEngine}. @@ -46,27 +44,30 @@ private boolean formatVerified = false; // SSLv2 ruled out? - SSLEngineInputRecord() { - this.readAuthenticator = MAC.TLS_NULL; + // Cache for incomplete handshake messages. + private ByteBuffer handshakeBuffer = null; + + SSLEngineInputRecord(HandshakeHash handshakeHash) { + super(handshakeHash, SSLReadCipher.nullTlsReadCipher()); } @Override int estimateFragmentSize(int packetSize) { - int macLen = 0; - if (readAuthenticator instanceof MAC) { - macLen = ((MAC)readAuthenticator).MAClen(); - } - if (packetSize > 0) { - return readCipher.estimateFragmentSize( - packetSize, macLen, headerSize); + return readCipher.estimateFragmentSize(packetSize, headerSize); } else { return Record.maxDataSize; } } @Override - int bytesInCompletePacket(ByteBuffer packet) throws SSLException { + int bytesInCompletePacket( + ByteBuffer[] srcs, int srcsOffset, int srcsLength) throws IOException { + + return bytesInCompletePacket(srcs[srcsOffset]); + } + + private int bytesInCompletePacket(ByteBuffer packet) throws SSLException { /* * SSLv2 length field is in bytes 0/1 * SSLv3/TLS length field is in bytes 3/4 @@ -87,15 +88,19 @@ * see if it's SSL/TLS. */ if (formatVerified || - (byteZero == ct_handshake) || (byteZero == ct_alert)) { + (byteZero == ContentType.HANDSHAKE.id) || + (byteZero == ContentType.ALERT.id)) { /* * Last sanity check that it's not a wild record */ - ProtocolVersion recordVersion = ProtocolVersion.valueOf( - packet.get(pos + 1), packet.get(pos + 2)); - - // check the record version - checkRecordVersion(recordVersion, false); + byte majorVersion = packet.get(pos + 1); + byte minorVersion = packet.get(pos + 2); + if (!ProtocolVersion.isNegotiable( + majorVersion, minorVersion, false, false)) { + throw new SSLException("Unrecognized record version " + + ProtocolVersion.nameOf(majorVersion, minorVersion) + + " , plaintext connection?"); + } /* * Reasonably sure this is a V3, disable further checks. @@ -123,11 +128,14 @@ if (isShort && ((packet.get(pos + 2) == 1) || packet.get(pos + 2) == 4)) { - ProtocolVersion recordVersion = ProtocolVersion.valueOf( - packet.get(pos + 3), packet.get(pos + 4)); - - // check the record version - checkRecordVersion(recordVersion, true); + byte majorVersion = packet.get(pos + 3); + byte minorVersion = packet.get(pos + 4); + if (!ProtocolVersion.isNegotiable( + majorVersion, minorVersion, false, false)) { + throw new SSLException("Unrecognized record version " + + ProtocolVersion.nameOf(majorVersion, minorVersion) + + " , plaintext connection?"); + } /* * Client or Server Hello @@ -147,37 +155,29 @@ } @Override - void checkRecordVersion(ProtocolVersion recordVersion, - boolean allowSSL20Hello) throws SSLException { - - if (recordVersion.maybeDTLSProtocol()) { - throw new SSLException( - "Unrecognized record version " + recordVersion + - " , DTLS packet?"); - } + Plaintext[] decode(ByteBuffer[] srcs, int srcsOffset, + int srcsLength) throws IOException, BadPaddingException { + if (srcs == null || srcs.length == 0 || srcsLength == 0) { + return new Plaintext[0]; + } else if (srcsLength == 1) { + return decode(srcs[srcsOffset]); + } else { + ByteBuffer packet = extract(srcs, + srcsOffset, srcsLength, SSLRecord.headerSize); - // Check if the record version is too old. - if ((recordVersion.v < ProtocolVersion.MIN.v)) { - // if it's not SSLv2, we're out of here. - if (!allowSSL20Hello || - (recordVersion.v != ProtocolVersion.SSL20Hello.v)) { - throw new SSLException( - "Unsupported record version " + recordVersion); - } + return decode(packet); } } - @Override - Plaintext decode(ByteBuffer packet) + private Plaintext[] decode(ByteBuffer packet) throws IOException, BadPaddingException { if (isClosed) { return null; } - if (debug != null && Debug.isOn("packet")) { - Debug.printHex( - "[Raw read]: length = " + packet.remaining(), packet); + if (SSLLogger.isOn && SSLLogger.isOn("packet")) { + SSLLogger.fine("Raw read", packet); } // The caller should have validated the record. @@ -191,7 +191,8 @@ */ int pos = packet.position(); byte byteZero = packet.get(pos); - if (byteZero != ct_handshake && byteZero != ct_alert) { + if (byteZero != ContentType.HANDSHAKE.id && + byteZero != ContentType.ALERT.id) { return handleUnknownRecord(packet); } } @@ -199,27 +200,24 @@ return decodeInputRecord(packet); } - private Plaintext decodeInputRecord(ByteBuffer packet) + private Plaintext[] decodeInputRecord(ByteBuffer packet) throws IOException, BadPaddingException { - // // The packet should be a complete record, or more. // - int srcPos = packet.position(); int srcLim = packet.limit(); byte contentType = packet.get(); // pos: 0 byte majorVersion = packet.get(); // pos: 1 byte minorVersion = packet.get(); // pos: 2 - int contentLen = ((packet.get() & 0xFF) << 8) + - (packet.get() & 0xFF); // pos: 3, 4 + int contentLen = Record.getInt16(packet); // pos: 3, 4 - if (debug != null && Debug.isOn("record")) { - System.out.println(Thread.currentThread().getName() + - ", READ: " + - ProtocolVersion.valueOf(majorVersion, minorVersion) + - " " + Record.contentName(contentType) + ", length = " + + if (SSLLogger.isOn && SSLLogger.isOn("record")) { + SSLLogger.fine( + "READ: " + + ProtocolVersion.nameOf(majorVersion, minorVersion) + + " " + ContentType.nameOf(contentType) + ", length = " + contentLen); } @@ -235,7 +233,7 @@ // // check for handshake fragment // - if ((contentType != ct_handshake) && (hsMsgOff != hsMsgLen)) { + if (contentType != ContentType.HANDSHAKE.id && hsMsgOff != hsMsgLen) { throw new SSLProtocolException( "Expected to get a handshake fragment"); } @@ -247,10 +245,17 @@ packet.limit(recLim); packet.position(srcPos + SSLRecord.headerSize); - ByteBuffer plaintext; + ByteBuffer fragment; try { - plaintext = - decrypt(readAuthenticator, readCipher, contentType, packet); + Plaintext plaintext = + readCipher.decrypt(contentType, packet, null); + fragment = plaintext.fragment; + contentType = plaintext.contentType; + } catch (BadPaddingException bpe) { + throw bpe; + } catch (GeneralSecurityException gse) { + throw (SSLProtocolException)(new SSLProtocolException( + "Unexpected exception")).initCause(gse); } finally { // comsume a complete record packet.limit(srcLim); @@ -258,81 +263,83 @@ } // - // handshake hashing + // parse handshake messages // - if (contentType == ct_handshake) { - int pltPos = plaintext.position(); - int pltLim = plaintext.limit(); - int frgPos = pltPos; - for (int remains = plaintext.remaining(); remains > 0;) { - int howmuch; - byte handshakeType; - if (hsMsgOff < hsMsgLen) { - // a fragment of the handshake message - howmuch = Math.min((hsMsgLen - hsMsgOff), remains); - handshakeType = prevType; + if (contentType == ContentType.HANDSHAKE.id) { + ByteBuffer handshakeFrag = fragment; + if ((handshakeBuffer != null) && + (handshakeBuffer.remaining() != 0)) { + ByteBuffer bb = ByteBuffer.wrap(new byte[ + handshakeBuffer.remaining() + fragment.remaining()]); + bb.put(handshakeBuffer); + bb.put(fragment); + handshakeFrag = bb.rewind(); + handshakeBuffer = null; + } - hsMsgOff += howmuch; - if (hsMsgOff == hsMsgLen) { - // Now is a complete handshake message. - hsMsgOff = 0; - hsMsgLen = 0; - } - } else { // hsMsgOff == hsMsgLen, a new handshake message - handshakeType = plaintext.get(); - int handshakeLen = ((plaintext.get() & 0xFF) << 16) | - ((plaintext.get() & 0xFF) << 8) | - (plaintext.get() & 0xFF); - plaintext.position(frgPos); - if (remains < (handshakeLen + 4)) { // 4: handshake header - // This handshake message is fragmented. - prevType = handshakeType; - hsMsgOff = remains - 4; // 4: handshake header - hsMsgLen = handshakeLen; - } - - howmuch = Math.min(handshakeLen + 4, remains); + ArrayList plaintexts = new ArrayList<>(5); + while (handshakeFrag.hasRemaining()) { + int remaining = handshakeFrag.remaining(); + if (remaining < handshakeHeaderSize) { + handshakeBuffer = ByteBuffer.wrap(new byte[remaining]); + handshakeBuffer.put(handshakeFrag); + handshakeBuffer.rewind(); + break; } - plaintext.limit(frgPos + howmuch); + handshakeFrag.mark(); + // skip the first byte: handshake type + byte handshakeType = handshakeFrag.get(); + int handshakeBodyLen = Record.getInt24(handshakeFrag); + handshakeFrag.reset(); + int handshakeMessageLen = + handshakeHeaderSize + handshakeBodyLen; + if (remaining < handshakeMessageLen) { + handshakeBuffer = ByteBuffer.wrap(new byte[remaining]); + handshakeBuffer.put(handshakeFrag); + handshakeBuffer.rewind(); + break; + } if (remaining == handshakeMessageLen) { + if (handshakeHash.isHashable(handshakeType)) { + handshakeHash.receive(handshakeFrag); + } - if (handshakeType == HandshakeMessage.ht_hello_request) { - // omitted from handshake hash computation - } else if ((handshakeType != HandshakeMessage.ht_finished) && - (handshakeType != HandshakeMessage.ht_certificate_verify)) { - - if (handshakeHash == null) { - // used for cache only - handshakeHash = new HandshakeHash(false); - } - handshakeHash.update(plaintext); + plaintexts.add( + new Plaintext(contentType, + majorVersion, minorVersion, -1, -1L, handshakeFrag) + ); + break; } else { - // Reserve until this handshake message has been processed. - if (handshakeHash == null) { - // used for cache only - handshakeHash = new HandshakeHash(false); + int fragPos = handshakeFrag.position(); + int fragLim = handshakeFrag.limit(); + int nextPos = fragPos + handshakeMessageLen; + handshakeFrag.limit(nextPos); + + if (handshakeHash.isHashable(handshakeType)) { + handshakeHash.receive(handshakeFrag); } - handshakeHash.reserve(plaintext); + + plaintexts.add( + new Plaintext(contentType, majorVersion, minorVersion, + -1, -1L, handshakeFrag.slice()) + ); + + handshakeFrag.position(nextPos); + handshakeFrag.limit(fragLim); } - - plaintext.position(frgPos + howmuch); - plaintext.limit(pltLim); - - frgPos += howmuch; - remains -= howmuch; } - plaintext.position(pltPos); + return plaintexts.toArray(new Plaintext[0]); } - return new Plaintext(contentType, - majorVersion, minorVersion, -1, -1L, plaintext); - // recordEpoch, recordSeq, plaintext); + return new Plaintext[] { + new Plaintext(contentType, + majorVersion, minorVersion, -1, -1L, fragment) + }; } - private Plaintext handleUnknownRecord(ByteBuffer packet) + private Plaintext[] handleUnknownRecord(ByteBuffer packet) throws IOException, BadPaddingException { - // // The packet should be a complete record. // @@ -363,8 +370,8 @@ * error message, one that's treated as fatal by * clients (Otherwise we'll hang.) */ - if (debug != null && Debug.isOn("record")) { - System.out.println(Thread.currentThread().getName() + + if (SSLLogger.isOn && SSLLogger.isOn("record")) { + SSLLogger.fine( "Requested to negotiate unsupported SSLv2!"); } @@ -380,23 +387,20 @@ * V3 ClientHello message, and pass it up. */ packet.position(srcPos + 2); // exclude the header - - if (handshakeHash == null) { - // used for cache only - handshakeHash = new HandshakeHash(false); - } - handshakeHash.update(packet); + handshakeHash.receive(packet); packet.position(srcPos); ByteBuffer converted = convertToClientHello(packet); - if (debug != null && Debug.isOn("packet")) { - Debug.printHex( + if (SSLLogger.isOn && SSLLogger.isOn("packet")) { + SSLLogger.fine( "[Converted] ClientHello", converted); } - return new Plaintext(ct_handshake, - majorVersion, minorVersion, -1, -1L, converted); + return new Plaintext[] { + new Plaintext(ContentType.HANDSHAKE.id, + majorVersion, minorVersion, -1, -1L, converted) + }; } else { if (((firstByte & 0x80) != 0) && (thirdByte == 4)) { throw new SSLException("SSL V2.0 servers are not supported."); @@ -405,5 +409,4 @@ throw new SSLException("Unsupported or unrecognized SSL message"); } } - }