# HG changeset patch # User xuelei # Date 1583943248 -10800 # Wed Mar 11 19:14:08 2020 +0300 # Node ID 9bee6abd576956b863ff297ddfc2c7337f2b4e6b # Parent a364726474e637c7df7a347554eddfeaab256c5d 8239798: SSLSocket closes socket both socket endpoints on a SocketTimeoutException Reviewed-by: xuelei Contributed-by: alexey@azul.com verghese@amazon.com diff --git a/src/share/classes/sun/security/ssl/SSLSocketImpl.java b/src/share/classes/sun/security/ssl/SSLSocketImpl.java --- a/src/share/classes/sun/security/ssl/SSLSocketImpl.java +++ b/src/share/classes/sun/security/ssl/SSLSocketImpl.java @@ -434,6 +434,8 @@ if (!conContext.isNegotiated) { readHandshakeRecord(); } + } catch (InterruptedIOException iioe) { + handleException(iioe); } catch (IOException ioe) { throw conContext.fatal(Alert.HANDSHAKE_FAILURE, "Couldn't kickstart handshaking", ioe); @@ -1295,12 +1297,11 @@ } } catch (SSLException ssle) { throw ssle; + } catch (InterruptedIOException iioe) { + // don't change exception in case of timeouts or interrupts + throw iioe; } catch (IOException ioe) { - if (!(ioe instanceof SSLException)) { - throw new SSLException("readHandshakeRecord", ioe); - } else { - throw ioe; - } + throw new SSLException("readHandshakeRecord", ioe); } } @@ -1361,6 +1362,9 @@ } } catch (SSLException ssle) { throw ssle; + } catch (InterruptedIOException iioe) { + // don't change exception in case of timeouts or interrupts + throw iioe; } catch (IOException ioe) { if (!(ioe instanceof SSLException)) { throw new SSLException("readApplicationRecord", ioe); diff --git a/src/share/classes/sun/security/ssl/SSLSocketInputRecord.java b/src/share/classes/sun/security/ssl/SSLSocketInputRecord.java --- a/src/share/classes/sun/security/ssl/SSLSocketInputRecord.java +++ b/src/share/classes/sun/security/ssl/SSLSocketInputRecord.java @@ -1,5 +1,6 @@ /* * Copyright (c) 1996, 2020, Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 2020, Azul Systems, Inc. All rights reserved. * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. * * This code is free software; you can redistribute it and/or modify it @@ -26,6 +27,7 @@ package sun.security.ssl; import java.io.EOFException; +import java.io.InterruptedIOException; import java.io.IOException; import java.io.InputStream; import java.io.OutputStream; @@ -47,37 +49,31 @@ final class SSLSocketInputRecord extends InputRecord implements SSLRecord { private InputStream is = null; private OutputStream os = null; - private final byte[] temporary = new byte[1024]; + private final byte[] header = new byte[headerSize]; + private int headerOff = 0; + // Cache for incomplete record body. + private ByteBuffer recordBody = ByteBuffer.allocate(1024); private boolean formatVerified = false; // SSLv2 ruled out? // Cache for incomplete handshake messages. private ByteBuffer handshakeBuffer = null; - private boolean hasHeader = false; // Had read the record header - SSLSocketInputRecord(HandshakeHash handshakeHash) { super(handshakeHash, SSLReadCipher.nullTlsReadCipher()); } @Override int bytesInCompletePacket() throws IOException { - if (!hasHeader) { - // read exactly one record - try { - int really = read(is, temporary, 0, headerSize); - if (really < 0) { - // EOF: peer shut down incorrectly - return -1; - } - } catch (EOFException eofe) { - // The caller will handle EOF. - return -1; - } - hasHeader = true; + // read header + try { + readHeader(); + } catch (EOFException eofe) { + // The caller will handle EOF. + return -1; } - byte byteZero = temporary[0]; + byte byteZero = header[0]; int len = 0; /* @@ -93,9 +89,9 @@ * Last sanity check that it's not a wild record */ if (!ProtocolVersion.isNegotiable( - temporary[1], temporary[2], false)) { + header[1], header[2], false)) { throw new SSLException("Unrecognized record version " + - ProtocolVersion.nameOf(temporary[1], temporary[2]) + + ProtocolVersion.nameOf(header[1], header[2]) + " , plaintext connection?"); } @@ -109,8 +105,8 @@ /* * One of the SSLv3/TLS message types. */ - len = ((temporary[3] & 0xFF) << 8) + - (temporary[4] & 0xFF) + headerSize; + len = ((header[3] & 0xFF) << 8) + + (header[4] & 0xFF) + headerSize; } else { /* * Must be SSLv2 or something unknown. @@ -121,11 +117,11 @@ */ boolean isShort = ((byteZero & 0x80) != 0); - if (isShort && ((temporary[2] == 1) || (temporary[2] == 4))) { + if (isShort && ((header[2] == 1) || (header[2] == 4))) { if (!ProtocolVersion.isNegotiable( - temporary[3], temporary[4], false)) { + header[3], header[4], false)) { throw new SSLException("Unrecognized record version " + - ProtocolVersion.nameOf(temporary[3], temporary[4]) + + ProtocolVersion.nameOf(header[3], header[4]) + " , plaintext connection?"); } @@ -138,9 +134,9 @@ // // int mask = (isShort ? 0x7F : 0x3F); // len = ((byteZero & mask) << 8) + - // (temporary[1] & 0xFF) + (isShort ? 2 : 3); + // (header[1] & 0xFF) + (isShort ? 2 : 3); // - len = ((byteZero & 0x7F) << 8) + (temporary[1] & 0xFF) + 2; + len = ((byteZero & 0x7F) << 8) + (header[1] & 0xFF) + 2; } else { // Gobblygook! throw new SSLException( @@ -160,34 +156,41 @@ return null; } - if (!hasHeader) { - // read exactly one record - int really = read(is, temporary, 0, headerSize); - if (really < 0) { - throw new EOFException("SSL peer shut down incorrectly"); - } - hasHeader = true; - } + // read header + readHeader(); + + Plaintext[] plaintext = null; + boolean cleanInBuffer = true; + try { + if (!formatVerified) { + formatVerified = true; - Plaintext plaintext = null; - if (!formatVerified) { - formatVerified = true; + /* + * The first record must either be a handshake record or an + * alert message. If it's not, it is either invalid or an + * SSLv2 message. + */ + if ((header[0] != ContentType.HANDSHAKE.id) && + (header[0] != ContentType.ALERT.id)) { + plaintext = handleUnknownRecord(); + } + } - /* - * The first record must either be a handshake record or an - * alert message. If it's not, it is either invalid or an - * SSLv2 message. - */ - if ((temporary[0] != ContentType.HANDSHAKE.id) && - (temporary[0] != ContentType.ALERT.id)) { - hasHeader = false; - return handleUnknownRecord(temporary); + // The record header should has consumed. + if (plaintext == null) { + plaintext = decodeInputRecord(); + } + } catch(InterruptedIOException e) { + // do not clean header and recordBody in case of Socket Timeout + cleanInBuffer = false; + throw e; + } finally { + if (cleanInBuffer) { + headerOff = 0; + recordBody.clear(); } } - - // The record header should has consumed. - hasHeader = false; - return decodeInputRecord(temporary); + return plaintext; } @Override @@ -200,9 +203,7 @@ this.os = outputStream; } - // Note that destination may be null - private Plaintext[] decodeInputRecord( - byte[] header) throws IOException, BadPaddingException { + private Plaintext[] decodeInputRecord() throws IOException, BadPaddingException { byte contentType = header[0]; // pos: 0 byte majorVersion = header[1]; // pos: 1 byte minorVersion = header[2]; // pos: 2 @@ -227,30 +228,27 @@ } // - // Read a complete record. + // Read a complete record and store in the recordBody + // recordBody is used to cache incoming record and restore in case of + // read operation timedout // - ByteBuffer destination = ByteBuffer.allocate(headerSize + contentLen); - int dstPos = destination.position(); - destination.put(temporary, 0, headerSize); - while (contentLen > 0) { - int howmuch = Math.min(temporary.length, contentLen); - int really = read(is, temporary, 0, howmuch); - if (really < 0) { - throw new EOFException("SSL peer shut down incorrectly"); + if (recordBody.position() == 0) { + if (recordBody.capacity() < contentLen) { + recordBody = ByteBuffer.allocate(contentLen); } - - destination.put(temporary, 0, howmuch); - contentLen -= howmuch; + recordBody.limit(contentLen); + } else { + contentLen = recordBody.remaining(); } - destination.flip(); - destination.position(dstPos + headerSize); + readFully(contentLen); + recordBody.flip(); if (SSLLogger.isOn && SSLLogger.isOn("record")) { SSLLogger.fine( "READ: " + ProtocolVersion.nameOf(majorVersion, minorVersion) + " " + ContentType.nameOf(contentType) + ", length = " + - destination.remaining()); + recordBody.remaining()); } // @@ -259,7 +257,7 @@ ByteBuffer fragment; try { Plaintext plaintext = - readCipher.decrypt(contentType, destination, null); + readCipher.decrypt(contentType, recordBody, null); fragment = plaintext.fragment; contentType = plaintext.contentType; } catch (BadPaddingException bpe) { @@ -361,8 +359,7 @@ }; } - private Plaintext[] handleUnknownRecord( - byte[] header) throws IOException, BadPaddingException { + private Plaintext[] handleUnknownRecord() throws IOException, BadPaddingException { byte firstByte = header[0]; byte thirdByte = header[2]; @@ -404,32 +401,29 @@ } int msgLen = ((header[0] & 0x7F) << 8) | (header[1] & 0xFF); - - ByteBuffer destination = ByteBuffer.allocate(headerSize + msgLen); - destination.put(temporary, 0, headerSize); + if (recordBody.position() == 0) { + if (recordBody.capacity() < (headerSize + msgLen)) { + recordBody = ByteBuffer.allocate(headerSize + msgLen); + } + recordBody.limit(headerSize + msgLen); + recordBody.put(header, 0, headerSize); + } else { + msgLen = recordBody.remaining(); + } msgLen -= 3; // had read 3 bytes of content as header - while (msgLen > 0) { - int howmuch = Math.min(temporary.length, msgLen); - int really = read(is, temporary, 0, howmuch); - if (really < 0) { - throw new EOFException("SSL peer shut down incorrectly"); - } - - destination.put(temporary, 0, howmuch); - msgLen -= howmuch; - } - destination.flip(); + readFully(msgLen); + recordBody.flip(); /* * If we can map this into a V3 ClientHello, read and * hash the rest of the V2 handshake, turn it into a * V3 ClientHello message, and pass it up. */ - destination.position(2); // exclude the header - handshakeHash.receive(destination); - destination.position(0); + recordBody.position(2); // exclude the header + handshakeHash.receive(recordBody); + recordBody.position(0); - ByteBuffer converted = convertToClientHello(destination); + ByteBuffer converted = convertToClientHello(recordBody); if (SSLLogger.isOn && SSLLogger.isOn("packet")) { SSLLogger.fine( @@ -449,28 +443,42 @@ } } - // Read the exact bytes of data, otherwise, return -1. - private static int read(InputStream is, - byte[] buffer, int offset, int len) throws IOException { - int n = 0; - while (n < len) { - int readLen = is.read(buffer, offset + n, len - n); - if (readLen < 0) { - if (SSLLogger.isOn && SSLLogger.isOn("packet")) { - SSLLogger.fine("Raw read: EOF"); - } - return -1; + // Read the exact bytes of data, otherwise, throw IOException. + private int readFully(int len) throws IOException { + int end = len + recordBody.position(); + int off = recordBody.position(); + try { + while (off < end) { + off += read(is, recordBody.array(), off, end - off); } + } finally { + recordBody.position(off); + } + return len; + } + // Read SSE record header, otherwise, throw IOException. + private int readHeader() throws IOException { + while (headerOff < headerSize) { + headerOff += read(is, header, headerOff, headerSize - headerOff); + } + return headerSize; + } + + private static int read(InputStream is, byte[] buf, int off, int len) throws IOException { + int readLen = is.read(buf, off, len); + if (readLen < 0) { if (SSLLogger.isOn && SSLLogger.isOn("packet")) { - ByteBuffer bb = ByteBuffer.wrap(buffer, offset + n, readLen); - SSLLogger.fine("Raw read", bb); + SSLLogger.fine("Raw read: EOF"); } - - n += readLen; + throw new EOFException("SSL peer shut down incorrectly"); } - return n; + if (SSLLogger.isOn && SSLLogger.isOn("packet")) { + ByteBuffer bb = ByteBuffer.wrap(buf, off, readLen); + SSLLogger.fine("Raw read", bb); + } + return readLen; } // Try to use up the input stream without impact the performance too much. diff --git a/src/share/classes/sun/security/ssl/SSLTransport.java b/src/share/classes/sun/security/ssl/SSLTransport.java --- a/src/share/classes/sun/security/ssl/SSLTransport.java +++ b/src/share/classes/sun/security/ssl/SSLTransport.java @@ -27,6 +27,7 @@ import java.io.EOFException; import java.io.IOException; +import java.io.InterruptedIOException; import java.nio.ByteBuffer; import javax.crypto.AEADBadTagException; import javax.crypto.BadPaddingException; @@ -134,6 +135,9 @@ } catch (EOFException eofe) { // rethrow EOFException, the call will handle it if neede. throw eofe; + } catch (InterruptedIOException iioe) { + // don't close the Socket in case of timeouts or interrupts. + throw iioe; } catch (IOException ioe) { throw context.fatal(Alert.UNEXPECTED_MESSAGE, ioe); } diff --git a/test/sun/security/ssl/SSLSocketImpl/ClientTimeout.java b/test/sun/security/ssl/SSLSocketImpl/ClientTimeout.java --- a/test/sun/security/ssl/SSLSocketImpl/ClientTimeout.java +++ b/test/sun/security/ssl/SSLSocketImpl/ClientTimeout.java @@ -26,8 +26,7 @@ /* * @test - * @bug 4836493 - * @ignore need further evaluation + * @bug 4836493 8239798 * @summary Socket timeouts for SSLSockets causes data corruption. * @run main/othervm ClientTimeout */ diff --git a/test/sun/security/ssl/SSLSocketImpl/SSLExceptionForIOIssue.java b/test/sun/security/ssl/SSLSocketImpl/SSLExceptionForIOIssue.java --- a/test/sun/security/ssl/SSLSocketImpl/SSLExceptionForIOIssue.java +++ b/test/sun/security/ssl/SSLSocketImpl/SSLExceptionForIOIssue.java @@ -36,7 +36,7 @@ import javax.net.ssl.*; import java.io.*; -import java.net.InetAddress; +import java.net.*; public class SSLExceptionForIOIssue implements SSLContextTemplate { @@ -139,7 +139,7 @@ } catch (SSLProtocolException | SSLHandshakeException sslhe) { clientException = sslhe; System.err.println("unexpected client exception: " + sslhe); - } catch (SSLException ssle) { + } catch (SSLException | SocketTimeoutException ssle) { // the expected exception, ignore it System.err.println("expected client exception: " + ssle); } catch (Exception e) {