--- old/src/java.base/share/classes/sun/security/ssl/SSLSocketInputRecord.java 2018-05-11 15:06:05.777490200 -0700 +++ new/src/java.base/share/classes/sun/security/ssl/SSLSocketInputRecord.java 2018-05-11 15:06:05.281723400 -0700 @@ -1,5 +1,5 @@ /* - * Copyright (c) 1996, 2014, Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 1996, 2018, Oracle and/or its affiliates. All rights reserved. * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. * * This code is free software; you can redistribute it and/or modify it @@ -27,13 +27,11 @@ import java.io.*; import java.nio.*; - +import java.security.GeneralSecurityException; +import java.util.ArrayList; import javax.crypto.BadPaddingException; - import javax.net.ssl.*; - -import sun.security.util.HexDumpEncoder; - +import sun.security.ssl.SSLCipher.SSLReadCipher; /** * {@code InputRecord} implementation for {@code SSLSocket}. @@ -41,8 +39,9 @@ * @author David Brownell */ final class SSLSocketInputRecord extends InputRecord implements SSLRecord { - private OutputStream deliverStream = null; - private byte[] temporary = new byte[1024]; + private InputStream is = null; + private OutputStream os = null; + private final byte[] temporary = new byte[1024]; // used by handshake hash computation for handshake fragment private byte prevType = -1; @@ -51,20 +50,24 @@ private boolean formatVerified = false; // SSLv2 ruled out? + // Cache for incomplete handshake messages. + private ByteBuffer handshakeBuffer = null; + private boolean hasHeader = false; // Had read the record header - SSLSocketInputRecord() { - this.readAuthenticator = MAC.TLS_NULL; + SSLSocketInputRecord(HandshakeHash handshakeHash) { + super(handshakeHash, SSLReadCipher.nullTlsReadCipher()); } @Override - int bytesInCompletePacket(InputStream is) throws IOException { + int bytesInCompletePacket() throws IOException { if (!hasHeader) { // read exactly one record int really = read(is, temporary, 0, headerSize); if (really < 0) { - throw new EOFException("SSL peer shut down incorrectly"); + // EOF: peer shut down incorrectly + return -1; } hasHeader = true; } @@ -79,15 +82,17 @@ * see if it's SSL/TLS. */ if (formatVerified || - (byteZero == ct_handshake) || (byteZero == ct_alert)) { + (byteZero == ContentType.HANDSHAKE.id) || + (byteZero == ContentType.ALERT.id)) { /* * Last sanity check that it's not a wild record */ - ProtocolVersion recordVersion = - ProtocolVersion.valueOf(temporary[1], temporary[2]); - - // check the record version - checkRecordVersion(recordVersion, false); + if (!ProtocolVersion.isNegotiable( + temporary[1], temporary[2], false, false)) { + throw new SSLException("Unrecognized record version " + + ProtocolVersion.nameOf(temporary[1], temporary[2]) + + " , plaintext connection?"); + } /* * Reasonably sure this is a V3, disable further checks. @@ -112,11 +117,12 @@ boolean isShort = ((byteZero & 0x80) != 0); if (isShort && ((temporary[2] == 1) || (temporary[2] == 4))) { - ProtocolVersion recordVersion = - ProtocolVersion.valueOf(temporary[3], temporary[4]); - - // check the record version - checkRecordVersion(recordVersion, true); + if (!ProtocolVersion.isNegotiable( + temporary[3], temporary[4], false, false)) { + throw new SSLException("Unrecognized record version " + + ProtocolVersion.nameOf(temporary[3], temporary[4]) + + " , plaintext connection?"); + } /* * Client or Server Hello @@ -140,10 +146,10 @@ return len; } - // destination.position() is zero. + // Note that the input arguments are not used actually. @Override - Plaintext decode(InputStream is, ByteBuffer destination) - throws IOException, BadPaddingException { + Plaintext[] decode(ByteBuffer[] srcs, int srcsOffset, + int srcsLength) throws IOException, BadPaddingException { if (isClosed) { return null; @@ -167,36 +173,44 @@ * alert message. If it's not, it is either invalid or an * SSLv2 message. */ - if ((temporary[0] != ct_handshake) && - (temporary[0] != ct_alert)) { - - plaintext = handleUnknownRecord(is, temporary, destination); + if ((temporary[0] != ContentType.HANDSHAKE.id) && + (temporary[0] != ContentType.ALERT.id)) { + hasHeader = false; + return handleUnknownRecord(temporary); } } - if (plaintext == null) { - plaintext = decodeInputRecord(is, temporary, destination); - } - // The record header should has comsumed. hasHeader = false; + return decodeInputRecord(temporary); + } - return plaintext; + @Override + void setReceiverStream(InputStream inputStream) { + this.is = inputStream; } @Override void setDeliverStream(OutputStream outputStream) { - this.deliverStream = outputStream; + this.os = outputStream; } // Note that destination may be null - private Plaintext decodeInputRecord(InputStream is, byte[] header, - ByteBuffer destination) throws IOException, BadPaddingException { - - byte contentType = header[0]; - byte majorVersion = header[1]; - byte minorVersion = header[2]; - int contentLen = ((header[3] & 0xFF) << 8) + (header[4] & 0xFF); + private Plaintext[] decodeInputRecord( + byte[] header) throws IOException, BadPaddingException { + byte contentType = header[0]; // pos: 0 + byte majorVersion = header[1]; // pos: 1 + byte minorVersion = header[2]; // pos: 2 + int contentLen = ((header[3] & 0xFF) << 8) + + (header[4] & 0xFF); // pos: 3, 4 + + if (SSLLogger.isOn && SSLLogger.isOn("record")) { + SSLLogger.fine( + "READ: " + + ProtocolVersion.nameOf(majorVersion, minorVersion) + + " " + ContentType.nameOf(contentType) + ", length = " + + contentLen); + } // // Check for upper bound. @@ -210,10 +224,7 @@ // // Read a complete record. // - if (destination == null) { - destination = ByteBuffer.allocate(headerSize + contentLen); - } // Otherwise, the destination buffer should have enough room. - + ByteBuffer destination = ByteBuffer.allocate(headerSize + contentLen); int dstPos = destination.position(); destination.put(temporary, 0, headerSize); while (contentLen > 0) { @@ -229,100 +240,113 @@ destination.flip(); destination.position(dstPos + headerSize); - if (debug != null && Debug.isOn("record")) { - System.out.println(Thread.currentThread().getName() + - ", READ: " + - ProtocolVersion.valueOf(majorVersion, minorVersion) + - " " + Record.contentName(contentType) + ", length = " + + if (SSLLogger.isOn && SSLLogger.isOn("record")) { + SSLLogger.fine( + "READ: " + + ProtocolVersion.nameOf(majorVersion, minorVersion) + + " " + ContentType.nameOf(contentType) + ", length = " + destination.remaining()); } // // Decrypt the fragment // - ByteBuffer plaintext = - decrypt(readAuthenticator, readCipher, contentType, destination); - - if ((contentType != ct_handshake) && (hsMsgOff != hsMsgLen)) { + ByteBuffer fragment; + try { + Plaintext plaintext = + readCipher.decrypt(contentType, destination, null); + fragment = plaintext.fragment; + contentType = plaintext.contentType; + } catch (BadPaddingException bpe) { + throw bpe; + } catch (GeneralSecurityException gse) { + throw (SSLProtocolException)(new SSLProtocolException( + "Unexpected exception")).initCause(gse); + } + if (contentType != ContentType.HANDSHAKE.id && hsMsgOff != hsMsgLen) { throw new SSLProtocolException( "Expected to get a handshake fragment"); } // - // handshake hashing + // parse handshake messages // - if (contentType == ct_handshake) { - int pltPos = plaintext.position(); - int pltLim = plaintext.limit(); - int frgPos = pltPos; - for (int remains = plaintext.remaining(); remains > 0;) { - int howmuch; - byte handshakeType; - if (hsMsgOff < hsMsgLen) { - // a fragment of the handshake message - howmuch = Math.min((hsMsgLen - hsMsgOff), remains); - handshakeType = prevType; - - hsMsgOff += howmuch; - if (hsMsgOff == hsMsgLen) { - // Now is a complete handshake message. - hsMsgOff = 0; - hsMsgLen = 0; - } - } else { // hsMsgOff == hsMsgLen, a new handshake message - handshakeType = plaintext.get(); - int handshakeLen = ((plaintext.get() & 0xFF) << 16) | - ((plaintext.get() & 0xFF) << 8) | - (plaintext.get() & 0xFF); - plaintext.position(frgPos); - if (remains < (handshakeLen + 1)) { // 1: handshake type - // This handshake message is fragmented. - prevType = handshakeType; - hsMsgOff = remains - 4; // 4: handshake header - hsMsgLen = handshakeLen; - } - - howmuch = Math.min(handshakeLen + 4, remains); + if (contentType == ContentType.HANDSHAKE.id) { + ByteBuffer handshakeFrag = fragment; + if ((handshakeBuffer != null) && + (handshakeBuffer.remaining() != 0)) { + ByteBuffer bb = ByteBuffer.wrap(new byte[ + handshakeBuffer.remaining() + fragment.remaining()]); + bb.put(handshakeBuffer); + bb.put(fragment); + handshakeFrag = bb.rewind(); + handshakeBuffer = null; + } + + ArrayList plaintexts = new ArrayList<>(5); + while (handshakeFrag.hasRemaining()) { + int remaining = handshakeFrag.remaining(); + if (remaining < handshakeHeaderSize) { + handshakeBuffer = ByteBuffer.wrap(new byte[remaining]); + handshakeBuffer.put(handshakeFrag); + handshakeBuffer.rewind(); + break; } - plaintext.limit(frgPos + howmuch); - - if (handshakeType == HandshakeMessage.ht_hello_request) { - // omitted from handshake hash computation - } else if ((handshakeType != HandshakeMessage.ht_finished) && - (handshakeType != HandshakeMessage.ht_certificate_verify)) { - - if (handshakeHash == null) { - // used for cache only - handshakeHash = new HandshakeHash(false); + handshakeFrag.mark(); + // skip the first byte: handshake type + byte handshakeType = handshakeFrag.get(); + int handshakeBodyLen = Record.getInt24(handshakeFrag); + handshakeFrag.reset(); + int handshakeMessageLen = + handshakeHeaderSize + handshakeBodyLen; + if (remaining < handshakeMessageLen) { + handshakeBuffer = ByteBuffer.wrap(new byte[remaining]); + handshakeBuffer.put(handshakeFrag); + handshakeBuffer.rewind(); + break; + } if (remaining == handshakeMessageLen) { + if (handshakeHash.isHashable(handshakeType)) { + handshakeHash.receive(handshakeFrag); } - handshakeHash.update(plaintext); + + plaintexts.add( + new Plaintext(contentType, + majorVersion, minorVersion, -1, -1L, handshakeFrag) + ); + break; } else { - // Reserve until this handshake message has been processed. - if (handshakeHash == null) { - // used for cache only - handshakeHash = new HandshakeHash(false); + int fragPos = handshakeFrag.position(); + int fragLim = handshakeFrag.limit(); + int nextPos = fragPos + handshakeMessageLen; + handshakeFrag.limit(nextPos); + + if (handshakeHash.isHashable(handshakeType)) { + handshakeHash.receive(handshakeFrag); } - handshakeHash.reserve(plaintext); - } - plaintext.position(frgPos + howmuch); - plaintext.limit(pltLim); + plaintexts.add( + new Plaintext(contentType, majorVersion, minorVersion, + -1, -1L, handshakeFrag.slice()) + ); - frgPos += howmuch; - remains -= howmuch; + handshakeFrag.position(nextPos); + handshakeFrag.limit(fragLim); + } } - plaintext.position(pltPos); + + return plaintexts.toArray(new Plaintext[0]); } - return new Plaintext(contentType, - majorVersion, minorVersion, -1, -1L, plaintext); - // recordEpoch, recordSeq, plaintext); + return new Plaintext[] { + new Plaintext(contentType, + majorVersion, minorVersion, -1, -1L, fragment) + // recordEpoch, recordSeq, plaintext); + }; } - private Plaintext handleUnknownRecord(InputStream is, byte[] header, - ByteBuffer destination) throws IOException, BadPaddingException { - + private Plaintext[] handleUnknownRecord( + byte[] header) throws IOException, BadPaddingException { byte firstByte = header[0]; byte thirdByte = header[2]; @@ -347,19 +371,16 @@ * error message, one that's treated as fatal by * clients (Otherwise we'll hang.) */ - deliverStream.write(SSLRecord.v2NoCipher); // SSLv2Hello + os.write(SSLRecord.v2NoCipher); // SSLv2Hello - if (debug != null) { - if (Debug.isOn("record")) { - System.out.println(Thread.currentThread().getName() + + if (SSLLogger.isOn) { + if (SSLLogger.isOn("record")) { + SSLLogger.fine( "Requested to negotiate unsupported SSLv2!"); } - if (Debug.isOn("packet")) { - Debug.printHex( - "[Raw write]: length = " + - SSLRecord.v2NoCipher.length, - SSLRecord.v2NoCipher); + if (SSLLogger.isOn("packet")) { + SSLLogger.fine("Raw write", SSLRecord.v2NoCipher); } } @@ -368,9 +389,7 @@ int msgLen = ((header[0] & 0x7F) << 8) | (header[1] & 0xFF); - if (destination == null) { - destination = ByteBuffer.allocate(headerSize + msgLen); - } + ByteBuffer destination = ByteBuffer.allocate(headerSize + msgLen); destination.put(temporary, 0, headerSize); msgLen -= 3; // had read 3 bytes of content as header while (msgLen > 0) { @@ -391,23 +410,20 @@ * V3 ClientHello message, and pass it up. */ destination.position(2); // exclude the header - - if (handshakeHash == null) { - // used for cache only - handshakeHash = new HandshakeHash(false); - } - handshakeHash.update(destination); + handshakeHash.receive(destination); destination.position(0); ByteBuffer converted = convertToClientHello(destination); - if (debug != null && Debug.isOn("packet")) { - Debug.printHex( + if (SSLLogger.isOn && SSLLogger.isOn("packet")) { + SSLLogger.fine( "[Converted] ClientHello", converted); } - return new Plaintext(ct_handshake, - majorVersion, minorVersion, -1, -1L, converted); + return new Plaintext[] { + new Plaintext(ContentType.HANDSHAKE.id, + majorVersion, minorVersion, -1, -1L, converted) + }; } else { if (((firstByte & 0x80) != 0) && (thirdByte == 4)) { throw new SSLException("SSL V2.0 servers are not supported."); @@ -424,13 +440,15 @@ while (n < len) { int readLen = is.read(buffer, offset + n, len - n); if (readLen < 0) { + if (SSLLogger.isOn && SSLLogger.isOn("packet")) { + SSLLogger.fine("Raw read: EOF"); + } return -1; } - if (debug != null && Debug.isOn("packet")) { - Debug.printHex( - "[Raw read]: length = " + readLen, - buffer, offset + n, readLen); + if (SSLLogger.isOn && SSLLogger.isOn("packet")) { + ByteBuffer bb = ByteBuffer.wrap(buffer, offset + n, readLen); + SSLLogger.fine("Raw read", bb); } n += readLen;