< prev index next >
src/share/classes/sun/security/ssl/SSLSocketInputRecord.java
Print this page
rev 14406 : 8239798: SSLSocket closes socket both socket endpoints on a SocketTimeoutException
Reviewed-by: xuelei
Contributed-by: alexey@azul.com verghese@amazon.com
@@ -1,7 +1,8 @@
/*
* Copyright (c) 1996, 2020, Oracle and/or its affiliates. All rights reserved.
+ * Copyright (c) 2020, Azul Systems, Inc. All rights reserved.
* DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
*
* This code is free software; you can redistribute it and/or modify it
* under the terms of the GNU General Public License version 2 only, as
* published by the Free Software Foundation. Oracle designates this
@@ -24,10 +25,11 @@
*/
package sun.security.ssl;
import java.io.EOFException;
+import java.io.InterruptedIOException;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.nio.ByteBuffer;
import java.security.GeneralSecurityException;
@@ -45,41 +47,35 @@
* @author David Brownell
*/
final class SSLSocketInputRecord extends InputRecord implements SSLRecord {
private InputStream is = null;
private OutputStream os = null;
- private final byte[] temporary = new byte[1024];
+ private final byte[] header = new byte[headerSize];
+ private int headerOff = 0;
+ // Cache for incomplete record body.
+ private ByteBuffer recordBody = ByteBuffer.allocate(1024);
private boolean formatVerified = false; // SSLv2 ruled out?
// Cache for incomplete handshake messages.
private ByteBuffer handshakeBuffer = null;
- private boolean hasHeader = false; // Had read the record header
-
SSLSocketInputRecord(HandshakeHash handshakeHash) {
super(handshakeHash, SSLReadCipher.nullTlsReadCipher());
}
@Override
int bytesInCompletePacket() throws IOException {
- if (!hasHeader) {
- // read exactly one record
+ // read header
try {
- int really = read(is, temporary, 0, headerSize);
- if (really < 0) {
- // EOF: peer shut down incorrectly
- return -1;
- }
+ readHeader();
} catch (EOFException eofe) {
// The caller will handle EOF.
return -1;
}
- hasHeader = true;
- }
- byte byteZero = temporary[0];
+ byte byteZero = header[0];
int len = 0;
/*
* If we have already verified previous packets, we can
* ignore the verifications steps, and jump right to the
@@ -91,13 +87,13 @@
(byteZero == ContentType.ALERT.id)) {
/*
* Last sanity check that it's not a wild record
*/
if (!ProtocolVersion.isNegotiable(
- temporary[1], temporary[2], false)) {
+ header[1], header[2], false)) {
throw new SSLException("Unrecognized record version " +
- ProtocolVersion.nameOf(temporary[1], temporary[2]) +
+ ProtocolVersion.nameOf(header[1], header[2]) +
" , plaintext connection?");
}
/*
* Reasonably sure this is a V3, disable further checks.
@@ -107,27 +103,27 @@
formatVerified = true;
/*
* One of the SSLv3/TLS message types.
*/
- len = ((temporary[3] & 0xFF) << 8) +
- (temporary[4] & 0xFF) + headerSize;
+ len = ((header[3] & 0xFF) << 8) +
+ (header[4] & 0xFF) + headerSize;
} else {
/*
* Must be SSLv2 or something unknown.
* Check if it's short (2 bytes) or
* long (3) header.
*
* Internals can warn about unsupported SSLv2
*/
boolean isShort = ((byteZero & 0x80) != 0);
- if (isShort && ((temporary[2] == 1) || (temporary[2] == 4))) {
+ if (isShort && ((header[2] == 1) || (header[2] == 4))) {
if (!ProtocolVersion.isNegotiable(
- temporary[3], temporary[4], false)) {
+ header[3], header[4], false)) {
throw new SSLException("Unrecognized record version " +
- ProtocolVersion.nameOf(temporary[3], temporary[4]) +
+ ProtocolVersion.nameOf(header[3], header[4]) +
" , plaintext connection?");
}
/*
* Client or Server Hello
@@ -136,13 +132,13 @@
// Short header is using here. We reverse the code here
// in case it is used in the future.
//
// int mask = (isShort ? 0x7F : 0x3F);
// len = ((byteZero & mask) << 8) +
- // (temporary[1] & 0xFF) + (isShort ? 2 : 3);
+ // (header[1] & 0xFF) + (isShort ? 2 : 3);
//
- len = ((byteZero & 0x7F) << 8) + (temporary[1] & 0xFF) + 2;
+ len = ((byteZero & 0x7F) << 8) + (header[1] & 0xFF) + 2;
} else {
// Gobblygook!
throw new SSLException(
"Unrecognized SSL message, plaintext connection?");
}
@@ -158,38 +154,45 @@
if (isClosed) {
return null;
}
- if (!hasHeader) {
- // read exactly one record
- int really = read(is, temporary, 0, headerSize);
- if (really < 0) {
- throw new EOFException("SSL peer shut down incorrectly");
- }
- hasHeader = true;
- }
+ // read header
+ readHeader();
- Plaintext plaintext = null;
+ Plaintext[] plaintext = null;
+ boolean cleanInBuffer = true;
+ try {
if (!formatVerified) {
formatVerified = true;
/*
* The first record must either be a handshake record or an
* alert message. If it's not, it is either invalid or an
* SSLv2 message.
*/
- if ((temporary[0] != ContentType.HANDSHAKE.id) &&
- (temporary[0] != ContentType.ALERT.id)) {
- hasHeader = false;
- return handleUnknownRecord(temporary);
+ if ((header[0] != ContentType.HANDSHAKE.id) &&
+ (header[0] != ContentType.ALERT.id)) {
+ plaintext = handleUnknownRecord();
}
}
// The record header should has consumed.
- hasHeader = false;
- return decodeInputRecord(temporary);
+ if (plaintext == null) {
+ plaintext = decodeInputRecord();
+ }
+ } catch(InterruptedIOException e) {
+ // do not clean header and recordBody in case of Socket Timeout
+ cleanInBuffer = false;
+ throw e;
+ } finally {
+ if (cleanInBuffer) {
+ headerOff = 0;
+ recordBody.clear();
+ }
+ }
+ return plaintext;
}
@Override
void setReceiverStream(InputStream inputStream) {
this.is = inputStream;
@@ -198,13 +201,11 @@
@Override
void setDeliverStream(OutputStream outputStream) {
this.os = outputStream;
}
- // Note that destination may be null
- private Plaintext[] decodeInputRecord(
- byte[] header) throws IOException, BadPaddingException {
+ private Plaintext[] decodeInputRecord() throws IOException, BadPaddingException {
byte contentType = header[0]; // pos: 0
byte majorVersion = header[1]; // pos: 1
byte minorVersion = header[2]; // pos: 2
int contentLen = ((header[3] & 0xFF) << 8) +
(header[4] & 0xFF); // pos: 3, 4
@@ -225,43 +226,40 @@
throw new SSLProtocolException(
"Bad input record size, TLSCiphertext.length = " + contentLen);
}
//
- // Read a complete record.
- //
- ByteBuffer destination = ByteBuffer.allocate(headerSize + contentLen);
- int dstPos = destination.position();
- destination.put(temporary, 0, headerSize);
- while (contentLen > 0) {
- int howmuch = Math.min(temporary.length, contentLen);
- int really = read(is, temporary, 0, howmuch);
- if (really < 0) {
- throw new EOFException("SSL peer shut down incorrectly");
+ // Read a complete record and store in the recordBody
+ // recordBody is used to cache incoming record and restore in case of
+ // read operation timedout
+ //
+ if (recordBody.position() == 0) {
+ if (recordBody.capacity() < contentLen) {
+ recordBody = ByteBuffer.allocate(contentLen);
}
-
- destination.put(temporary, 0, howmuch);
- contentLen -= howmuch;
+ recordBody.limit(contentLen);
+ } else {
+ contentLen = recordBody.remaining();
}
- destination.flip();
- destination.position(dstPos + headerSize);
+ readFully(contentLen);
+ recordBody.flip();
if (SSLLogger.isOn && SSLLogger.isOn("record")) {
SSLLogger.fine(
"READ: " +
ProtocolVersion.nameOf(majorVersion, minorVersion) +
" " + ContentType.nameOf(contentType) + ", length = " +
- destination.remaining());
+ recordBody.remaining());
}
//
// Decrypt the fragment
//
ByteBuffer fragment;
try {
Plaintext plaintext =
- readCipher.decrypt(contentType, destination, null);
+ readCipher.decrypt(contentType, recordBody, null);
fragment = plaintext.fragment;
contentType = plaintext.contentType;
} catch (BadPaddingException bpe) {
throw bpe;
} catch (GeneralSecurityException gse) {
@@ -359,12 +357,11 @@
new Plaintext(contentType,
majorVersion, minorVersion, -1, -1L, fragment)
};
}
- private Plaintext[] handleUnknownRecord(
- byte[] header) throws IOException, BadPaddingException {
+ private Plaintext[] handleUnknownRecord() throws IOException, BadPaddingException {
byte firstByte = header[0];
byte thirdByte = header[2];
// Does it look like a Version 2 client hello (V2ClientHello)?
if (((firstByte & 0x80) != 0) && (thirdByte == 1)) {
@@ -402,36 +399,33 @@
throw new SSLException("Unsupported SSL v2.0 ClientHello");
}
int msgLen = ((header[0] & 0x7F) << 8) | (header[1] & 0xFF);
-
- ByteBuffer destination = ByteBuffer.allocate(headerSize + msgLen);
- destination.put(temporary, 0, headerSize);
- msgLen -= 3; // had read 3 bytes of content as header
- while (msgLen > 0) {
- int howmuch = Math.min(temporary.length, msgLen);
- int really = read(is, temporary, 0, howmuch);
- if (really < 0) {
- throw new EOFException("SSL peer shut down incorrectly");
+ if (recordBody.position() == 0) {
+ if (recordBody.capacity() < (headerSize + msgLen)) {
+ recordBody = ByteBuffer.allocate(headerSize + msgLen);
}
-
- destination.put(temporary, 0, howmuch);
- msgLen -= howmuch;
+ recordBody.limit(headerSize + msgLen);
+ recordBody.put(header, 0, headerSize);
+ } else {
+ msgLen = recordBody.remaining();
}
- destination.flip();
+ msgLen -= 3; // had read 3 bytes of content as header
+ readFully(msgLen);
+ recordBody.flip();
/*
* If we can map this into a V3 ClientHello, read and
* hash the rest of the V2 handshake, turn it into a
* V3 ClientHello message, and pass it up.
*/
- destination.position(2); // exclude the header
- handshakeHash.receive(destination);
- destination.position(0);
+ recordBody.position(2); // exclude the header
+ handshakeHash.receive(recordBody);
+ recordBody.position(0);
- ByteBuffer converted = convertToClientHello(destination);
+ ByteBuffer converted = convertToClientHello(recordBody);
if (SSLLogger.isOn && SSLLogger.isOn("packet")) {
SSLLogger.fine(
"[Converted] ClientHello", converted);
}
@@ -447,32 +441,46 @@
throw new SSLException("Unsupported or unrecognized SSL message");
}
}
- // Read the exact bytes of data, otherwise, return -1.
- private static int read(InputStream is,
- byte[] buffer, int offset, int len) throws IOException {
- int n = 0;
- while (n < len) {
- int readLen = is.read(buffer, offset + n, len - n);
+ // Read the exact bytes of data, otherwise, throw IOException.
+ private int readFully(int len) throws IOException {
+ int end = len + recordBody.position();
+ int off = recordBody.position();
+ try {
+ while (off < end) {
+ off += read(is, recordBody.array(), off, end - off);
+ }
+ } finally {
+ recordBody.position(off);
+ }
+ return len;
+ }
+
+ // Read SSE record header, otherwise, throw IOException.
+ private int readHeader() throws IOException {
+ while (headerOff < headerSize) {
+ headerOff += read(is, header, headerOff, headerSize - headerOff);
+ }
+ return headerSize;
+ }
+
+ private static int read(InputStream is, byte[] buf, int off, int len) throws IOException {
+ int readLen = is.read(buf, off, len);
if (readLen < 0) {
if (SSLLogger.isOn && SSLLogger.isOn("packet")) {
SSLLogger.fine("Raw read: EOF");
}
- return -1;
+ throw new EOFException("SSL peer shut down incorrectly");
}
if (SSLLogger.isOn && SSLLogger.isOn("packet")) {
- ByteBuffer bb = ByteBuffer.wrap(buffer, offset + n, readLen);
+ ByteBuffer bb = ByteBuffer.wrap(buf, off, readLen);
SSLLogger.fine("Raw read", bb);
}
-
- n += readLen;
- }
-
- return n;
+ return readLen;
}
// Try to use up the input stream without impact the performance too much.
void deplete(boolean tryToRead) throws IOException {
int remaining = is.available();
< prev index next >