< prev index next >
src/java.base/share/classes/sun/security/ssl/SSLSocketInputRecord.java
Print this page
@@ -1,7 +1,7 @@
/*
- * Copyright (c) 1996, 2014, Oracle and/or its affiliates. All rights reserved.
+ * Copyright (c) 1996, 2018, Oracle and/or its affiliates. All rights reserved.
* DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
*
* This code is free software; you can redistribute it and/or modify it
* under the terms of the GNU General Public License version 2 only, as
* published by the Free Software Foundation. Oracle designates this
@@ -25,48 +25,51 @@
package sun.security.ssl;
import java.io.*;
import java.nio.*;
-
+import java.security.GeneralSecurityException;
+import java.util.ArrayList;
import javax.crypto.BadPaddingException;
-
import javax.net.ssl.*;
-
-import sun.security.util.HexDumpEncoder;
-
+import sun.security.ssl.SSLCipher.SSLReadCipher;
/**
* {@code InputRecord} implementation for {@code SSLSocket}.
*
* @author David Brownell
*/
final class SSLSocketInputRecord extends InputRecord implements SSLRecord {
- private OutputStream deliverStream = null;
- private byte[] temporary = new byte[1024];
+ private InputStream is = null;
+ private OutputStream os = null;
+ private final byte[] temporary = new byte[1024];
// used by handshake hash computation for handshake fragment
private byte prevType = -1;
private int hsMsgOff = 0;
private int hsMsgLen = 0;
private boolean formatVerified = false; // SSLv2 ruled out?
+ // Cache for incomplete handshake messages.
+ private ByteBuffer handshakeBuffer = null;
+
private boolean hasHeader = false; // Had read the record header
- SSLSocketInputRecord() {
- this.readAuthenticator = MAC.TLS_NULL;
+ SSLSocketInputRecord(HandshakeHash handshakeHash) {
+ super(handshakeHash, SSLReadCipher.nullTlsReadCipher());
}
@Override
- int bytesInCompletePacket(InputStream is) throws IOException {
+ int bytesInCompletePacket() throws IOException {
if (!hasHeader) {
// read exactly one record
int really = read(is, temporary, 0, headerSize);
if (really < 0) {
- throw new EOFException("SSL peer shut down incorrectly");
+ // EOF: peer shut down incorrectly
+ return -1;
}
hasHeader = true;
}
byte byteZero = temporary[0];
@@ -77,19 +80,21 @@
* ignore the verifications steps, and jump right to the
* determination. Otherwise, try one last hueristic to
* see if it's SSL/TLS.
*/
if (formatVerified ||
- (byteZero == ct_handshake) || (byteZero == ct_alert)) {
+ (byteZero == ContentType.HANDSHAKE.id) ||
+ (byteZero == ContentType.ALERT.id)) {
/*
* Last sanity check that it's not a wild record
*/
- ProtocolVersion recordVersion =
- ProtocolVersion.valueOf(temporary[1], temporary[2]);
-
- // check the record version
- checkRecordVersion(recordVersion, false);
+ if (!ProtocolVersion.isNegotiable(
+ temporary[1], temporary[2], false, false)) {
+ throw new SSLException("Unrecognized record version " +
+ ProtocolVersion.nameOf(temporary[1], temporary[2]) +
+ " , plaintext connection?");
+ }
/*
* Reasonably sure this is a V3, disable further checks.
* We can't do the same in the v2 check below, because
* read still needs to parse/handle the v2 clientHello.
@@ -110,15 +115,16 @@
* Internals can warn about unsupported SSLv2
*/
boolean isShort = ((byteZero & 0x80) != 0);
if (isShort && ((temporary[2] == 1) || (temporary[2] == 4))) {
- ProtocolVersion recordVersion =
- ProtocolVersion.valueOf(temporary[3], temporary[4]);
-
- // check the record version
- checkRecordVersion(recordVersion, true);
+ if (!ProtocolVersion.isNegotiable(
+ temporary[3], temporary[4], false, false)) {
+ throw new SSLException("Unrecognized record version " +
+ ProtocolVersion.nameOf(temporary[3], temporary[4]) +
+ " , plaintext connection?");
+ }
/*
* Client or Server Hello
*/
//
@@ -138,14 +144,14 @@
}
return len;
}
- // destination.position() is zero.
+ // Note that the input arguments are not used actually.
@Override
- Plaintext decode(InputStream is, ByteBuffer destination)
- throws IOException, BadPaddingException {
+ Plaintext[] decode(ByteBuffer[] srcs, int srcsOffset,
+ int srcsLength) throws IOException, BadPaddingException {
if (isClosed) {
return null;
}
@@ -165,40 +171,48 @@
/*
* The first record must either be a handshake record or an
* alert message. If it's not, it is either invalid or an
* SSLv2 message.
*/
- if ((temporary[0] != ct_handshake) &&
- (temporary[0] != ct_alert)) {
-
- plaintext = handleUnknownRecord(is, temporary, destination);
- }
+ if ((temporary[0] != ContentType.HANDSHAKE.id) &&
+ (temporary[0] != ContentType.ALERT.id)) {
+ hasHeader = false;
+ return handleUnknownRecord(temporary);
}
-
- if (plaintext == null) {
- plaintext = decodeInputRecord(is, temporary, destination);
}
// The record header should has comsumed.
hasHeader = false;
+ return decodeInputRecord(temporary);
+ }
- return plaintext;
+ @Override
+ void setReceiverStream(InputStream inputStream) {
+ this.is = inputStream;
}
@Override
void setDeliverStream(OutputStream outputStream) {
- this.deliverStream = outputStream;
+ this.os = outputStream;
}
// Note that destination may be null
- private Plaintext decodeInputRecord(InputStream is, byte[] header,
- ByteBuffer destination) throws IOException, BadPaddingException {
-
- byte contentType = header[0];
- byte majorVersion = header[1];
- byte minorVersion = header[2];
- int contentLen = ((header[3] & 0xFF) << 8) + (header[4] & 0xFF);
+ private Plaintext[] decodeInputRecord(
+ byte[] header) throws IOException, BadPaddingException {
+ byte contentType = header[0]; // pos: 0
+ byte majorVersion = header[1]; // pos: 1
+ byte minorVersion = header[2]; // pos: 2
+ int contentLen = ((header[3] & 0xFF) << 8) +
+ (header[4] & 0xFF); // pos: 3, 4
+
+ if (SSLLogger.isOn && SSLLogger.isOn("record")) {
+ SSLLogger.fine(
+ "READ: " +
+ ProtocolVersion.nameOf(majorVersion, minorVersion) +
+ " " + ContentType.nameOf(contentType) + ", length = " +
+ contentLen);
+ }
//
// Check for upper bound.
//
// Note: May check packetSize limit in the future.
@@ -208,14 +222,11 @@
}
//
// Read a complete record.
//
- if (destination == null) {
- destination = ByteBuffer.allocate(headerSize + contentLen);
- } // Otherwise, the destination buffer should have enough room.
-
+ ByteBuffer destination = ByteBuffer.allocate(headerSize + contentLen);
int dstPos = destination.position();
destination.put(temporary, 0, headerSize);
while (contentLen > 0) {
int howmuch = Math.min(temporary.length, contentLen);
int really = read(is, temporary, 0, howmuch);
@@ -227,104 +238,117 @@
contentLen -= howmuch;
}
destination.flip();
destination.position(dstPos + headerSize);
- if (debug != null && Debug.isOn("record")) {
- System.out.println(Thread.currentThread().getName() +
- ", READ: " +
- ProtocolVersion.valueOf(majorVersion, minorVersion) +
- " " + Record.contentName(contentType) + ", length = " +
+ if (SSLLogger.isOn && SSLLogger.isOn("record")) {
+ SSLLogger.fine(
+ "READ: " +
+ ProtocolVersion.nameOf(majorVersion, minorVersion) +
+ " " + ContentType.nameOf(contentType) + ", length = " +
destination.remaining());
}
//
// Decrypt the fragment
//
- ByteBuffer plaintext =
- decrypt(readAuthenticator, readCipher, contentType, destination);
-
- if ((contentType != ct_handshake) && (hsMsgOff != hsMsgLen)) {
+ ByteBuffer fragment;
+ try {
+ Plaintext plaintext =
+ readCipher.decrypt(contentType, destination, null);
+ fragment = plaintext.fragment;
+ contentType = plaintext.contentType;
+ } catch (BadPaddingException bpe) {
+ throw bpe;
+ } catch (GeneralSecurityException gse) {
+ throw (SSLProtocolException)(new SSLProtocolException(
+ "Unexpected exception")).initCause(gse);
+ }
+ if (contentType != ContentType.HANDSHAKE.id && hsMsgOff != hsMsgLen) {
throw new SSLProtocolException(
"Expected to get a handshake fragment");
}
//
- // handshake hashing
+ // parse handshake messages
//
- if (contentType == ct_handshake) {
- int pltPos = plaintext.position();
- int pltLim = plaintext.limit();
- int frgPos = pltPos;
- for (int remains = plaintext.remaining(); remains > 0;) {
- int howmuch;
- byte handshakeType;
- if (hsMsgOff < hsMsgLen) {
- // a fragment of the handshake message
- howmuch = Math.min((hsMsgLen - hsMsgOff), remains);
- handshakeType = prevType;
-
- hsMsgOff += howmuch;
- if (hsMsgOff == hsMsgLen) {
- // Now is a complete handshake message.
- hsMsgOff = 0;
- hsMsgLen = 0;
- }
- } else { // hsMsgOff == hsMsgLen, a new handshake message
- handshakeType = plaintext.get();
- int handshakeLen = ((plaintext.get() & 0xFF) << 16) |
- ((plaintext.get() & 0xFF) << 8) |
- (plaintext.get() & 0xFF);
- plaintext.position(frgPos);
- if (remains < (handshakeLen + 1)) { // 1: handshake type
- // This handshake message is fragmented.
- prevType = handshakeType;
- hsMsgOff = remains - 4; // 4: handshake header
- hsMsgLen = handshakeLen;
- }
-
- howmuch = Math.min(handshakeLen + 4, remains);
- }
-
- plaintext.limit(frgPos + howmuch);
-
- if (handshakeType == HandshakeMessage.ht_hello_request) {
- // omitted from handshake hash computation
- } else if ((handshakeType != HandshakeMessage.ht_finished) &&
- (handshakeType != HandshakeMessage.ht_certificate_verify)) {
-
- if (handshakeHash == null) {
- // used for cache only
- handshakeHash = new HandshakeHash(false);
- }
- handshakeHash.update(plaintext);
+ if (contentType == ContentType.HANDSHAKE.id) {
+ ByteBuffer handshakeFrag = fragment;
+ if ((handshakeBuffer != null) &&
+ (handshakeBuffer.remaining() != 0)) {
+ ByteBuffer bb = ByteBuffer.wrap(new byte[
+ handshakeBuffer.remaining() + fragment.remaining()]);
+ bb.put(handshakeBuffer);
+ bb.put(fragment);
+ handshakeFrag = bb.rewind();
+ handshakeBuffer = null;
+ }
+
+ ArrayList<Plaintext> plaintexts = new ArrayList<>(5);
+ while (handshakeFrag.hasRemaining()) {
+ int remaining = handshakeFrag.remaining();
+ if (remaining < handshakeHeaderSize) {
+ handshakeBuffer = ByteBuffer.wrap(new byte[remaining]);
+ handshakeBuffer.put(handshakeFrag);
+ handshakeBuffer.rewind();
+ break;
+ }
+
+ handshakeFrag.mark();
+ // skip the first byte: handshake type
+ byte handshakeType = handshakeFrag.get();
+ int handshakeBodyLen = Record.getInt24(handshakeFrag);
+ handshakeFrag.reset();
+ int handshakeMessageLen =
+ handshakeHeaderSize + handshakeBodyLen;
+ if (remaining < handshakeMessageLen) {
+ handshakeBuffer = ByteBuffer.wrap(new byte[remaining]);
+ handshakeBuffer.put(handshakeFrag);
+ handshakeBuffer.rewind();
+ break;
+ } if (remaining == handshakeMessageLen) {
+ if (handshakeHash.isHashable(handshakeType)) {
+ handshakeHash.receive(handshakeFrag);
+ }
+
+ plaintexts.add(
+ new Plaintext(contentType,
+ majorVersion, minorVersion, -1, -1L, handshakeFrag)
+ );
+ break;
} else {
- // Reserve until this handshake message has been processed.
- if (handshakeHash == null) {
- // used for cache only
- handshakeHash = new HandshakeHash(false);
- }
- handshakeHash.reserve(plaintext);
+ int fragPos = handshakeFrag.position();
+ int fragLim = handshakeFrag.limit();
+ int nextPos = fragPos + handshakeMessageLen;
+ handshakeFrag.limit(nextPos);
+
+ if (handshakeHash.isHashable(handshakeType)) {
+ handshakeHash.receive(handshakeFrag);
}
- plaintext.position(frgPos + howmuch);
- plaintext.limit(pltLim);
+ plaintexts.add(
+ new Plaintext(contentType, majorVersion, minorVersion,
+ -1, -1L, handshakeFrag.slice())
+ );
- frgPos += howmuch;
- remains -= howmuch;
+ handshakeFrag.position(nextPos);
+ handshakeFrag.limit(fragLim);
}
- plaintext.position(pltPos);
}
- return new Plaintext(contentType,
- majorVersion, minorVersion, -1, -1L, plaintext);
- // recordEpoch, recordSeq, plaintext);
+ return plaintexts.toArray(new Plaintext[0]);
}
- private Plaintext handleUnknownRecord(InputStream is, byte[] header,
- ByteBuffer destination) throws IOException, BadPaddingException {
+ return new Plaintext[] {
+ new Plaintext(contentType,
+ majorVersion, minorVersion, -1, -1L, fragment)
+ // recordEpoch, recordSeq, plaintext);
+ };
+ }
+ private Plaintext[] handleUnknownRecord(
+ byte[] header) throws IOException, BadPaddingException {
byte firstByte = header[0];
byte thirdByte = header[2];
// Does it look like a Version 2 client hello (V2ClientHello)?
if (((firstByte & 0x80) != 0) && (thirdByte == 1)) {
@@ -345,34 +369,29 @@
* Looks like a V2 client hello, but not one saying
* "let's talk SSLv3". So we need to send an SSLv2
* error message, one that's treated as fatal by
* clients (Otherwise we'll hang.)
*/
- deliverStream.write(SSLRecord.v2NoCipher); // SSLv2Hello
+ os.write(SSLRecord.v2NoCipher); // SSLv2Hello
- if (debug != null) {
- if (Debug.isOn("record")) {
- System.out.println(Thread.currentThread().getName() +
+ if (SSLLogger.isOn) {
+ if (SSLLogger.isOn("record")) {
+ SSLLogger.fine(
"Requested to negotiate unsupported SSLv2!");
}
- if (Debug.isOn("packet")) {
- Debug.printHex(
- "[Raw write]: length = " +
- SSLRecord.v2NoCipher.length,
- SSLRecord.v2NoCipher);
+ if (SSLLogger.isOn("packet")) {
+ SSLLogger.fine("Raw write", SSLRecord.v2NoCipher);
}
}
throw new SSLException("Unsupported SSL v2.0 ClientHello");
}
int msgLen = ((header[0] & 0x7F) << 8) | (header[1] & 0xFF);
- if (destination == null) {
- destination = ByteBuffer.allocate(headerSize + msgLen);
- }
+ ByteBuffer destination = ByteBuffer.allocate(headerSize + msgLen);
destination.put(temporary, 0, headerSize);
msgLen -= 3; // had read 3 bytes of content as header
while (msgLen > 0) {
int howmuch = Math.min(temporary.length, msgLen);
int really = read(is, temporary, 0, howmuch);
@@ -389,27 +408,24 @@
* If we can map this into a V3 ClientHello, read and
* hash the rest of the V2 handshake, turn it into a
* V3 ClientHello message, and pass it up.
*/
destination.position(2); // exclude the header
-
- if (handshakeHash == null) {
- // used for cache only
- handshakeHash = new HandshakeHash(false);
- }
- handshakeHash.update(destination);
+ handshakeHash.receive(destination);
destination.position(0);
ByteBuffer converted = convertToClientHello(destination);
- if (debug != null && Debug.isOn("packet")) {
- Debug.printHex(
+ if (SSLLogger.isOn && SSLLogger.isOn("packet")) {
+ SSLLogger.fine(
"[Converted] ClientHello", converted);
}
- return new Plaintext(ct_handshake,
- majorVersion, minorVersion, -1, -1L, converted);
+ return new Plaintext[] {
+ new Plaintext(ContentType.HANDSHAKE.id,
+ majorVersion, minorVersion, -1, -1L, converted)
+ };
} else {
if (((firstByte & 0x80) != 0) && (thirdByte == 4)) {
throw new SSLException("SSL V2.0 servers are not supported.");
}
@@ -422,17 +438,19 @@
byte[] buffer, int offset, int len) throws IOException {
int n = 0;
while (n < len) {
int readLen = is.read(buffer, offset + n, len - n);
if (readLen < 0) {
+ if (SSLLogger.isOn && SSLLogger.isOn("packet")) {
+ SSLLogger.fine("Raw read: EOF");
+ }
return -1;
}
- if (debug != null && Debug.isOn("packet")) {
- Debug.printHex(
- "[Raw read]: length = " + readLen,
- buffer, offset + n, readLen);
+ if (SSLLogger.isOn && SSLLogger.isOn("packet")) {
+ ByteBuffer bb = ByteBuffer.wrap(buffer, offset + n, readLen);
+ SSLLogger.fine("Raw read", bb);
}
n += readLen;
}
< prev index next >