23 * questions. |
23 * questions. |
24 */ |
24 */ |
25 |
25 |
26 package sun.security.ssl; |
26 package sun.security.ssl; |
27 |
27 |
28 import java.io.*; |
28 import java.io.IOException; |
29 import java.nio.*; |
29 import java.nio.ByteBuffer; |
30 |
30 import java.security.GeneralSecurityException; |
|
31 import java.util.ArrayList; |
31 import javax.crypto.BadPaddingException; |
32 import javax.crypto.BadPaddingException; |
32 |
33 import javax.net.ssl.SSLException; |
33 import javax.net.ssl.*; |
34 import javax.net.ssl.SSLHandshakeException; |
34 |
35 import javax.net.ssl.SSLProtocolException; |
35 import sun.security.util.HexDumpEncoder; |
36 import sun.security.ssl.SSLCipher.SSLReadCipher; |
36 |
37 import sun.security.ssl.KeyUpdate.KeyUpdateMessage; |
|
38 import sun.security.ssl.KeyUpdate.KeyUpdateRequest; |
37 |
39 |
38 /** |
40 /** |
39 * {@code InputRecord} implementation for {@code SSLEngine}. |
41 * {@code InputRecord} implementation for {@code SSLEngine}. |
40 */ |
42 */ |
41 final class SSLEngineInputRecord extends InputRecord implements SSLRecord { |
43 final class SSLEngineInputRecord extends InputRecord implements SSLRecord { |
42 // used by handshake hash computation for handshake fragment |
|
43 private byte prevType = -1; |
|
44 private int hsMsgOff = 0; |
|
45 private int hsMsgLen = 0; |
|
46 |
|
47 private boolean formatVerified = false; // SSLv2 ruled out? |
44 private boolean formatVerified = false; // SSLv2 ruled out? |
48 |
45 |
49 SSLEngineInputRecord() { |
46 // Cache for incomplete handshake messages. |
50 this.readAuthenticator = MAC.TLS_NULL; |
47 private ByteBuffer handshakeBuffer = null; |
|
48 |
|
49 SSLEngineInputRecord(HandshakeHash handshakeHash) { |
|
50 super(handshakeHash, SSLReadCipher.nullTlsReadCipher()); |
51 } |
51 } |
52 |
52 |
53 @Override |
53 @Override |
54 int estimateFragmentSize(int packetSize) { |
54 int estimateFragmentSize(int packetSize) { |
55 int macLen = 0; |
|
56 if (readAuthenticator instanceof MAC) { |
|
57 macLen = ((MAC)readAuthenticator).MAClen(); |
|
58 } |
|
59 |
|
60 if (packetSize > 0) { |
55 if (packetSize > 0) { |
61 return readCipher.estimateFragmentSize( |
56 return readCipher.estimateFragmentSize(packetSize, headerSize); |
62 packetSize, macLen, headerSize); |
|
63 } else { |
57 } else { |
64 return Record.maxDataSize; |
58 return Record.maxDataSize; |
65 } |
59 } |
66 } |
60 } |
67 |
61 |
68 @Override |
62 @Override |
69 int bytesInCompletePacket(ByteBuffer packet) throws SSLException { |
63 int bytesInCompletePacket( |
|
64 ByteBuffer[] srcs, int srcsOffset, int srcsLength) throws IOException { |
|
65 |
|
66 return bytesInCompletePacket(srcs[srcsOffset]); |
|
67 } |
|
68 |
|
69 private int bytesInCompletePacket(ByteBuffer packet) throws SSLException { |
70 /* |
70 /* |
71 * SSLv2 length field is in bytes 0/1 |
71 * SSLv2 length field is in bytes 0/1 |
72 * SSLv3/TLS length field is in bytes 3/4 |
72 * SSLv3/TLS length field is in bytes 3/4 |
73 */ |
73 */ |
74 if (packet.remaining() < 5) { |
74 if (packet.remaining() < 5) { |
81 int len = 0; |
81 int len = 0; |
82 |
82 |
83 /* |
83 /* |
84 * If we have already verified previous packets, we can |
84 * If we have already verified previous packets, we can |
85 * ignore the verifications steps, and jump right to the |
85 * ignore the verifications steps, and jump right to the |
86 * determination. Otherwise, try one last hueristic to |
86 * determination. Otherwise, try one last heuristic to |
87 * see if it's SSL/TLS. |
87 * see if it's SSL/TLS. |
88 */ |
88 */ |
89 if (formatVerified || |
89 if (formatVerified || |
90 (byteZero == ct_handshake) || (byteZero == ct_alert)) { |
90 (byteZero == ContentType.HANDSHAKE.id) || |
|
91 (byteZero == ContentType.ALERT.id)) { |
91 /* |
92 /* |
92 * Last sanity check that it's not a wild record |
93 * Last sanity check that it's not a wild record |
93 */ |
94 */ |
94 ProtocolVersion recordVersion = ProtocolVersion.valueOf( |
95 byte majorVersion = packet.get(pos + 1); |
95 packet.get(pos + 1), packet.get(pos + 2)); |
96 byte minorVersion = packet.get(pos + 2); |
96 |
97 if (!ProtocolVersion.isNegotiable( |
97 // check the record version |
98 majorVersion, minorVersion, false, false)) { |
98 checkRecordVersion(recordVersion, false); |
99 throw new SSLException("Unrecognized record version " + |
|
100 ProtocolVersion.nameOf(majorVersion, minorVersion) + |
|
101 " , plaintext connection?"); |
|
102 } |
99 |
103 |
100 /* |
104 /* |
101 * Reasonably sure this is a V3, disable further checks. |
105 * Reasonably sure this is a V3, disable further checks. |
102 * We can't do the same in the v2 check below, because |
106 * We can't do the same in the v2 check below, because |
103 * read still needs to parse/handle the v2 clientHello. |
107 * read still needs to parse/handle the v2 clientHello. |
145 |
152 |
146 return len; |
153 return len; |
147 } |
154 } |
148 |
155 |
149 @Override |
156 @Override |
150 void checkRecordVersion(ProtocolVersion recordVersion, |
157 Plaintext[] decode(ByteBuffer[] srcs, int srcsOffset, |
151 boolean allowSSL20Hello) throws SSLException { |
158 int srcsLength) throws IOException, BadPaddingException { |
152 |
159 if (srcs == null || srcs.length == 0 || srcsLength == 0) { |
153 if (recordVersion.maybeDTLSProtocol()) { |
160 return new Plaintext[0]; |
154 throw new SSLException( |
161 } else if (srcsLength == 1) { |
155 "Unrecognized record version " + recordVersion + |
162 return decode(srcs[srcsOffset]); |
156 " , DTLS packet?"); |
163 } else { |
157 } |
164 ByteBuffer packet = extract(srcs, |
158 |
165 srcsOffset, srcsLength, SSLRecord.headerSize); |
159 // Check if the record version is too old. |
166 |
160 if ((recordVersion.v < ProtocolVersion.MIN.v)) { |
167 return decode(packet); |
161 // if it's not SSLv2, we're out of here. |
168 } |
162 if (!allowSSL20Hello || |
169 } |
163 (recordVersion.v != ProtocolVersion.SSL20Hello.v)) { |
170 |
164 throw new SSLException( |
171 private Plaintext[] decode(ByteBuffer packet) |
165 "Unsupported record version " + recordVersion); |
|
166 } |
|
167 } |
|
168 } |
|
169 |
|
170 @Override |
|
171 Plaintext decode(ByteBuffer packet) |
|
172 throws IOException, BadPaddingException { |
172 throws IOException, BadPaddingException { |
173 |
173 |
174 if (isClosed) { |
174 if (isClosed) { |
175 return null; |
175 return null; |
176 } |
176 } |
177 |
177 |
178 if (debug != null && Debug.isOn("packet")) { |
178 if (SSLLogger.isOn && SSLLogger.isOn("packet")) { |
179 Debug.printHex( |
179 SSLLogger.fine("Raw read", packet); |
180 "[Raw read]: length = " + packet.remaining(), packet); |
|
181 } |
180 } |
182 |
181 |
183 // The caller should have validated the record. |
182 // The caller should have validated the record. |
184 if (!formatVerified) { |
183 if (!formatVerified) { |
185 formatVerified = true; |
184 formatVerified = true; |
189 * alert message. If it's not, it is either invalid or an |
188 * alert message. If it's not, it is either invalid or an |
190 * SSLv2 message. |
189 * SSLv2 message. |
191 */ |
190 */ |
192 int pos = packet.position(); |
191 int pos = packet.position(); |
193 byte byteZero = packet.get(pos); |
192 byte byteZero = packet.get(pos); |
194 if (byteZero != ct_handshake && byteZero != ct_alert) { |
193 if (byteZero != ContentType.HANDSHAKE.id && |
|
194 byteZero != ContentType.ALERT.id) { |
195 return handleUnknownRecord(packet); |
195 return handleUnknownRecord(packet); |
196 } |
196 } |
197 } |
197 } |
198 |
198 |
199 return decodeInputRecord(packet); |
199 return decodeInputRecord(packet); |
200 } |
200 } |
201 |
201 |
202 private Plaintext decodeInputRecord(ByteBuffer packet) |
202 private Plaintext[] decodeInputRecord(ByteBuffer packet) |
203 throws IOException, BadPaddingException { |
203 throws IOException, BadPaddingException { |
204 |
|
205 // |
204 // |
206 // The packet should be a complete record, or more. |
205 // The packet should be a complete record, or more. |
207 // |
206 // |
208 |
|
209 int srcPos = packet.position(); |
207 int srcPos = packet.position(); |
210 int srcLim = packet.limit(); |
208 int srcLim = packet.limit(); |
211 |
209 |
212 byte contentType = packet.get(); // pos: 0 |
210 byte contentType = packet.get(); // pos: 0 |
213 byte majorVersion = packet.get(); // pos: 1 |
211 byte majorVersion = packet.get(); // pos: 1 |
214 byte minorVersion = packet.get(); // pos: 2 |
212 byte minorVersion = packet.get(); // pos: 2 |
215 int contentLen = ((packet.get() & 0xFF) << 8) + |
213 int contentLen = Record.getInt16(packet); // pos: 3, 4 |
216 (packet.get() & 0xFF); // pos: 3, 4 |
214 |
217 |
215 if (SSLLogger.isOn && SSLLogger.isOn("record")) { |
218 if (debug != null && Debug.isOn("record")) { |
216 SSLLogger.fine( |
219 System.out.println(Thread.currentThread().getName() + |
217 "READ: " + |
220 ", READ: " + |
218 ProtocolVersion.nameOf(majorVersion, minorVersion) + |
221 ProtocolVersion.valueOf(majorVersion, minorVersion) + |
219 " " + ContentType.nameOf(contentType) + ", length = " + |
222 " " + Record.contentName(contentType) + ", length = " + |
|
223 contentLen); |
220 contentLen); |
224 } |
221 } |
225 |
222 |
226 // |
223 // |
227 // Check for upper bound. |
224 // Check for upper bound. |
231 throw new SSLProtocolException( |
228 throw new SSLProtocolException( |
232 "Bad input record size, TLSCiphertext.length = " + contentLen); |
229 "Bad input record size, TLSCiphertext.length = " + contentLen); |
233 } |
230 } |
234 |
231 |
235 // |
232 // |
236 // check for handshake fragment |
|
237 // |
|
238 if ((contentType != ct_handshake) && (hsMsgOff != hsMsgLen)) { |
|
239 throw new SSLProtocolException( |
|
240 "Expected to get a handshake fragment"); |
|
241 } |
|
242 |
|
243 // |
|
244 // Decrypt the fragment |
233 // Decrypt the fragment |
245 // |
234 // |
246 int recLim = srcPos + SSLRecord.headerSize + contentLen; |
235 int recLim = srcPos + SSLRecord.headerSize + contentLen; |
247 packet.limit(recLim); |
236 packet.limit(recLim); |
248 packet.position(srcPos + SSLRecord.headerSize); |
237 packet.position(srcPos + SSLRecord.headerSize); |
249 |
238 |
250 ByteBuffer plaintext; |
239 ByteBuffer fragment; |
251 try { |
240 try { |
252 plaintext = |
241 Plaintext plaintext = |
253 decrypt(readAuthenticator, readCipher, contentType, packet); |
242 readCipher.decrypt(contentType, packet, null); |
|
243 fragment = plaintext.fragment; |
|
244 contentType = plaintext.contentType; |
|
245 } catch (BadPaddingException bpe) { |
|
246 throw bpe; |
|
247 } catch (GeneralSecurityException gse) { |
|
248 throw (SSLProtocolException)(new SSLProtocolException( |
|
249 "Unexpected exception")).initCause(gse); |
254 } finally { |
250 } finally { |
255 // comsume a complete record |
251 // consume a complete record |
256 packet.limit(srcLim); |
252 packet.limit(srcLim); |
257 packet.position(recLim); |
253 packet.position(recLim); |
258 } |
254 } |
259 |
255 |
260 // |
256 // |
261 // handshake hashing |
257 // check for handshake fragment |
262 // |
258 // |
263 if (contentType == ct_handshake) { |
259 if (contentType != ContentType.HANDSHAKE.id && |
264 int pltPos = plaintext.position(); |
260 handshakeBuffer != null && handshakeBuffer.hasRemaining()) { |
265 int pltLim = plaintext.limit(); |
261 throw new SSLProtocolException( |
266 int frgPos = pltPos; |
262 "Expecting a handshake fragment, but received " + |
267 for (int remains = plaintext.remaining(); remains > 0;) { |
263 ContentType.nameOf(contentType)); |
268 int howmuch; |
264 } |
269 byte handshakeType; |
265 |
270 if (hsMsgOff < hsMsgLen) { |
266 // |
271 // a fragment of the handshake message |
267 // parse handshake messages |
272 howmuch = Math.min((hsMsgLen - hsMsgOff), remains); |
268 // |
273 handshakeType = prevType; |
269 if (contentType == ContentType.HANDSHAKE.id) { |
274 |
270 ByteBuffer handshakeFrag = fragment; |
275 hsMsgOff += howmuch; |
271 if ((handshakeBuffer != null) && |
276 if (hsMsgOff == hsMsgLen) { |
272 (handshakeBuffer.remaining() != 0)) { |
277 // Now is a complete handshake message. |
273 ByteBuffer bb = ByteBuffer.wrap(new byte[ |
278 hsMsgOff = 0; |
274 handshakeBuffer.remaining() + fragment.remaining()]); |
279 hsMsgLen = 0; |
275 bb.put(handshakeBuffer); |
|
276 bb.put(fragment); |
|
277 handshakeFrag = bb.rewind(); |
|
278 handshakeBuffer = null; |
|
279 } |
|
280 |
|
281 ArrayList<Plaintext> plaintexts = new ArrayList<>(5); |
|
282 while (handshakeFrag.hasRemaining()) { |
|
283 int remaining = handshakeFrag.remaining(); |
|
284 if (remaining < handshakeHeaderSize) { |
|
285 handshakeBuffer = ByteBuffer.wrap(new byte[remaining]); |
|
286 handshakeBuffer.put(handshakeFrag); |
|
287 handshakeBuffer.rewind(); |
|
288 break; |
|
289 } |
|
290 |
|
291 handshakeFrag.mark(); |
|
292 // skip the first byte: handshake type |
|
293 byte handshakeType = handshakeFrag.get(); |
|
294 int handshakeBodyLen = Record.getInt24(handshakeFrag); |
|
295 handshakeFrag.reset(); |
|
296 int handshakeMessageLen = |
|
297 handshakeHeaderSize + handshakeBodyLen; |
|
298 if (remaining < handshakeMessageLen) { |
|
299 handshakeBuffer = ByteBuffer.wrap(new byte[remaining]); |
|
300 handshakeBuffer.put(handshakeFrag); |
|
301 handshakeBuffer.rewind(); |
|
302 break; |
|
303 } if (remaining == handshakeMessageLen) { |
|
304 if (handshakeHash.isHashable(handshakeType)) { |
|
305 handshakeHash.receive(handshakeFrag); |
280 } |
306 } |
281 } else { // hsMsgOff == hsMsgLen, a new handshake message |
307 |
282 handshakeType = plaintext.get(); |
308 plaintexts.add( |
283 int handshakeLen = ((plaintext.get() & 0xFF) << 16) | |
309 new Plaintext(contentType, |
284 ((plaintext.get() & 0xFF) << 8) | |
310 majorVersion, minorVersion, -1, -1L, handshakeFrag) |
285 (plaintext.get() & 0xFF); |
311 ); |
286 plaintext.position(frgPos); |
312 break; |
287 if (remains < (handshakeLen + 4)) { // 4: handshake header |
313 } else { |
288 // This handshake message is fragmented. |
314 int fragPos = handshakeFrag.position(); |
289 prevType = handshakeType; |
315 int fragLim = handshakeFrag.limit(); |
290 hsMsgOff = remains - 4; // 4: handshake header |
316 int nextPos = fragPos + handshakeMessageLen; |
291 hsMsgLen = handshakeLen; |
317 handshakeFrag.limit(nextPos); |
|
318 |
|
319 if (handshakeHash.isHashable(handshakeType)) { |
|
320 handshakeHash.receive(handshakeFrag); |
292 } |
321 } |
293 |
322 |
294 howmuch = Math.min(handshakeLen + 4, remains); |
323 plaintexts.add( |
|
324 new Plaintext(contentType, majorVersion, minorVersion, |
|
325 -1, -1L, handshakeFrag.slice()) |
|
326 ); |
|
327 |
|
328 handshakeFrag.position(nextPos); |
|
329 handshakeFrag.limit(fragLim); |
295 } |
330 } |
296 |
331 } |
297 plaintext.limit(frgPos + howmuch); |
332 |
298 |
333 return plaintexts.toArray(new Plaintext[0]); |
299 if (handshakeType == HandshakeMessage.ht_hello_request) { |
334 } |
300 // omitted from handshake hash computation |
335 |
301 } else if ((handshakeType != HandshakeMessage.ht_finished) && |
336 // KeyLimit check during application data. |
302 (handshakeType != HandshakeMessage.ht_certificate_verify)) { |
337 // atKeyLimit() inactive when limits not checked, tc set when limits |
303 |
338 // are active. |
304 if (handshakeHash == null) { |
339 |
305 // used for cache only |
340 if (readCipher.atKeyLimit()) { |
306 handshakeHash = new HandshakeHash(false); |
341 if (SSLLogger.isOn && SSLLogger.isOn("ssl")) { |
307 } |
342 SSLLogger.fine("KeyUpdate: triggered, read side."); |
308 handshakeHash.update(plaintext); |
343 } |
309 } else { |
344 |
310 // Reserve until this handshake message has been processed. |
345 PostHandshakeContext p = new PostHandshakeContext(tc); |
311 if (handshakeHash == null) { |
346 KeyUpdate.handshakeProducer.produce(p, |
312 // used for cache only |
347 new KeyUpdateMessage(p, KeyUpdateRequest.REQUESTED)); |
313 handshakeHash = new HandshakeHash(false); |
348 } |
314 } |
349 |
315 handshakeHash.reserve(plaintext); |
350 return new Plaintext[] { |
316 } |
351 new Plaintext(contentType, |
317 |
352 majorVersion, minorVersion, -1, -1L, fragment) |
318 plaintext.position(frgPos + howmuch); |
353 }; |
319 plaintext.limit(pltLim); |
354 } |
320 |
355 |
321 frgPos += howmuch; |
356 private Plaintext[] handleUnknownRecord(ByteBuffer packet) |
322 remains -= howmuch; |
|
323 } |
|
324 |
|
325 plaintext.position(pltPos); |
|
326 } |
|
327 |
|
328 return new Plaintext(contentType, |
|
329 majorVersion, minorVersion, -1, -1L, plaintext); |
|
330 // recordEpoch, recordSeq, plaintext); |
|
331 } |
|
332 |
|
333 private Plaintext handleUnknownRecord(ByteBuffer packet) |
|
334 throws IOException, BadPaddingException { |
357 throws IOException, BadPaddingException { |
335 |
|
336 // |
358 // |
337 // The packet should be a complete record. |
359 // The packet should be a complete record. |
338 // |
360 // |
339 int srcPos = packet.position(); |
361 int srcPos = packet.position(); |
340 int srcLim = packet.limit(); |
362 int srcLim = packet.limit(); |
378 * If we can map this into a V3 ClientHello, read and |
400 * If we can map this into a V3 ClientHello, read and |
379 * hash the rest of the V2 handshake, turn it into a |
401 * hash the rest of the V2 handshake, turn it into a |
380 * V3 ClientHello message, and pass it up. |
402 * V3 ClientHello message, and pass it up. |
381 */ |
403 */ |
382 packet.position(srcPos + 2); // exclude the header |
404 packet.position(srcPos + 2); // exclude the header |
383 |
405 handshakeHash.receive(packet); |
384 if (handshakeHash == null) { |
|
385 // used for cache only |
|
386 handshakeHash = new HandshakeHash(false); |
|
387 } |
|
388 handshakeHash.update(packet); |
|
389 packet.position(srcPos); |
406 packet.position(srcPos); |
390 |
407 |
391 ByteBuffer converted = convertToClientHello(packet); |
408 ByteBuffer converted = convertToClientHello(packet); |
392 |
409 |
393 if (debug != null && Debug.isOn("packet")) { |
410 if (SSLLogger.isOn && SSLLogger.isOn("packet")) { |
394 Debug.printHex( |
411 SSLLogger.fine( |
395 "[Converted] ClientHello", converted); |
412 "[Converted] ClientHello", converted); |
396 } |
413 } |
397 |
414 |
398 return new Plaintext(ct_handshake, |
415 return new Plaintext[] { |
399 majorVersion, minorVersion, -1, -1L, converted); |
416 new Plaintext(ContentType.HANDSHAKE.id, |
|
417 majorVersion, minorVersion, -1, -1L, converted) |
|
418 }; |
400 } else { |
419 } else { |
401 if (((firstByte & 0x80) != 0) && (thirdByte == 4)) { |
420 if (((firstByte & 0x80) != 0) && (thirdByte == 4)) { |
402 throw new SSLException("SSL V2.0 servers are not supported."); |
421 throw new SSLException("SSL V2.0 servers are not supported."); |
403 } |
422 } |
404 |
423 |
405 throw new SSLException("Unsupported or unrecognized SSL message"); |
424 throw new SSLException("Unsupported or unrecognized SSL message"); |
406 } |
425 } |
407 } |
426 } |
408 |
|
409 } |
427 } |