< prev index next >

src/share/classes/sun/security/ssl/SSLSocketInputRecord.java

Print this page
rev 14406 : 8239798: SSLSocket closes socket both socket endpoints on a SocketTimeoutException
Reviewed-by: xuelei
Contributed-by: alexey@azul.com verghese@amazon.com

@@ -1,7 +1,8 @@
 /*
  * Copyright (c) 1996, 2020, Oracle and/or its affiliates. All rights reserved.
+ * Copyright (c) 2020, Azul Systems, Inc. All rights reserved.
  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
  *
  * This code is free software; you can redistribute it and/or modify it
  * under the terms of the GNU General Public License version 2 only, as
  * published by the Free Software Foundation.  Oracle designates this

@@ -24,10 +25,11 @@
  */
 
 package sun.security.ssl;
 
 import java.io.EOFException;
+import java.io.InterruptedIOException;
 import java.io.IOException;
 import java.io.InputStream;
 import java.io.OutputStream;
 import java.nio.ByteBuffer;
 import java.security.GeneralSecurityException;

@@ -45,41 +47,35 @@
  * @author David Brownell
  */
 final class SSLSocketInputRecord extends InputRecord implements SSLRecord {
     private InputStream is = null;
     private OutputStream os = null;
-    private final byte[] temporary = new byte[1024];
+    private final byte[] header = new byte[headerSize];
+    private int headerOff = 0;
+    // Cache for incomplete record body.
+    private ByteBuffer recordBody = ByteBuffer.allocate(1024);
 
     private boolean formatVerified = false;     // SSLv2 ruled out?
 
     // Cache for incomplete handshake messages.
     private ByteBuffer handshakeBuffer = null;
 
-    private boolean hasHeader = false;          // Had read the record header
-
     SSLSocketInputRecord(HandshakeHash handshakeHash) {
         super(handshakeHash, SSLReadCipher.nullTlsReadCipher());
     }
 
     @Override
     int bytesInCompletePacket() throws IOException {
-        if (!hasHeader) {
-            // read exactly one record
+        // read header
             try {
-                int really = read(is, temporary, 0, headerSize);
-                if (really < 0) {
-                    // EOF: peer shut down incorrectly
-                    return -1;
-                }
+            readHeader();
             } catch (EOFException eofe) {
                 // The caller will handle EOF.
                 return -1;
             }
-            hasHeader = true;
-        }
 
-        byte byteZero = temporary[0];
+        byte byteZero = header[0];
         int len = 0;
 
         /*
          * If we have already verified previous packets, we can
          * ignore the verifications steps, and jump right to the

@@ -91,13 +87,13 @@
                 (byteZero == ContentType.ALERT.id)) {
             /*
              * Last sanity check that it's not a wild record
              */
             if (!ProtocolVersion.isNegotiable(
-                    temporary[1], temporary[2], false)) {
+                    header[1], header[2], false)) {
                 throw new SSLException("Unrecognized record version " +
-                        ProtocolVersion.nameOf(temporary[1], temporary[2]) +
+                        ProtocolVersion.nameOf(header[1], header[2]) +
                         " , plaintext connection?");
             }
 
             /*
              * Reasonably sure this is a V3, disable further checks.

@@ -107,27 +103,27 @@
             formatVerified = true;
 
             /*
              * One of the SSLv3/TLS message types.
              */
-            len = ((temporary[3] & 0xFF) << 8) +
-                   (temporary[4] & 0xFF) + headerSize;
+            len = ((header[3] & 0xFF) << 8) +
+                    (header[4] & 0xFF) + headerSize;
         } else {
             /*
              * Must be SSLv2 or something unknown.
              * Check if it's short (2 bytes) or
              * long (3) header.
              *
              * Internals can warn about unsupported SSLv2
              */
             boolean isShort = ((byteZero & 0x80) != 0);
 
-            if (isShort && ((temporary[2] == 1) || (temporary[2] == 4))) {
+            if (isShort && ((header[2] == 1) || (header[2] == 4))) {
                 if (!ProtocolVersion.isNegotiable(
-                        temporary[3], temporary[4], false)) {
+                        header[3], header[4], false)) {
                     throw new SSLException("Unrecognized record version " +
-                            ProtocolVersion.nameOf(temporary[3], temporary[4]) +
+                            ProtocolVersion.nameOf(header[3], header[4]) +
                             " , plaintext connection?");
                 }
 
                 /*
                  * Client or Server Hello

@@ -136,13 +132,13 @@
                 // Short header is using here.  We reverse the code here
                 // in case it is used in the future.
                 //
                 // int mask = (isShort ? 0x7F : 0x3F);
                 // len = ((byteZero & mask) << 8) +
-                //        (temporary[1] & 0xFF) + (isShort ? 2 : 3);
+                //        (header[1] & 0xFF) + (isShort ? 2 : 3);
                 //
-                len = ((byteZero & 0x7F) << 8) + (temporary[1] & 0xFF) + 2;
+                len = ((byteZero & 0x7F) << 8) + (header[1] & 0xFF) + 2;
             } else {
                 // Gobblygook!
                 throw new SSLException(
                         "Unrecognized SSL message, plaintext connection?");
             }

@@ -158,38 +154,45 @@
 
         if (isClosed) {
             return null;
         }
 
-        if (!hasHeader) {
-            // read exactly one record
-            int really = read(is, temporary, 0, headerSize);
-            if (really < 0) {
-                throw new EOFException("SSL peer shut down incorrectly");
-            }
-            hasHeader = true;
-        }
+        // read header
+        readHeader();
 
-        Plaintext plaintext = null;
+        Plaintext[] plaintext = null;
+        boolean cleanInBuffer = true;
+        try {
         if (!formatVerified) {
             formatVerified = true;
 
             /*
              * The first record must either be a handshake record or an
              * alert message. If it's not, it is either invalid or an
              * SSLv2 message.
              */
-            if ((temporary[0] != ContentType.HANDSHAKE.id) &&
-                (temporary[0] != ContentType.ALERT.id)) {
-                hasHeader = false;
-                return handleUnknownRecord(temporary);
+                if ((header[0] != ContentType.HANDSHAKE.id) &&
+                        (header[0] != ContentType.ALERT.id)) {
+                    plaintext = handleUnknownRecord();
             }
         }
 
         // The record header should has consumed.
-        hasHeader = false;
-        return decodeInputRecord(temporary);
+            if (plaintext == null) {
+                plaintext = decodeInputRecord();
+            }
+        } catch(InterruptedIOException e) {
+            // do not clean header and recordBody in case of Socket Timeout
+            cleanInBuffer = false;
+            throw e;
+        } finally {
+            if (cleanInBuffer) {
+                headerOff = 0;
+                recordBody.clear();
+            }
+        }
+        return plaintext;
     }
 
     @Override
     void setReceiverStream(InputStream inputStream) {
         this.is = inputStream;

@@ -198,13 +201,11 @@
     @Override
     void setDeliverStream(OutputStream outputStream) {
         this.os = outputStream;
     }
 
-    // Note that destination may be null
-    private Plaintext[] decodeInputRecord(
-            byte[] header) throws IOException, BadPaddingException {
+    private Plaintext[] decodeInputRecord() throws IOException, BadPaddingException {
         byte contentType = header[0];                   // pos: 0
         byte majorVersion = header[1];                  // pos: 1
         byte minorVersion = header[2];                  // pos: 2
         int contentLen = ((header[3] & 0xFF) << 8) +
                            (header[4] & 0xFF);          // pos: 3, 4

@@ -225,43 +226,40 @@
             throw new SSLProtocolException(
                 "Bad input record size, TLSCiphertext.length = " + contentLen);
         }
 
         //
-        // Read a complete record.
-        //
-        ByteBuffer destination = ByteBuffer.allocate(headerSize + contentLen);
-        int dstPos = destination.position();
-        destination.put(temporary, 0, headerSize);
-        while (contentLen > 0) {
-            int howmuch = Math.min(temporary.length, contentLen);
-            int really = read(is, temporary, 0, howmuch);
-            if (really < 0) {
-                throw new EOFException("SSL peer shut down incorrectly");
+        // Read a complete record and store in the recordBody
+        // recordBody is used to cache incoming record and restore in case of
+        // read operation timedout
+        //
+        if (recordBody.position() == 0) {
+            if (recordBody.capacity() < contentLen) {
+                recordBody = ByteBuffer.allocate(contentLen);
             }
-
-            destination.put(temporary, 0, howmuch);
-            contentLen -= howmuch;
+            recordBody.limit(contentLen);
+        } else {
+            contentLen = recordBody.remaining();
         }
-        destination.flip();
-        destination.position(dstPos + headerSize);
+        readFully(contentLen);
+        recordBody.flip();
 
         if (SSLLogger.isOn && SSLLogger.isOn("record")) {
             SSLLogger.fine(
                     "READ: " +
                     ProtocolVersion.nameOf(majorVersion, minorVersion) +
                     " " + ContentType.nameOf(contentType) + ", length = " +
-                    destination.remaining());
+                    recordBody.remaining());
         }
 
         //
         // Decrypt the fragment
         //
         ByteBuffer fragment;
         try {
             Plaintext plaintext =
-                    readCipher.decrypt(contentType, destination, null);
+                    readCipher.decrypt(contentType, recordBody, null);
             fragment = plaintext.fragment;
             contentType = plaintext.contentType;
         } catch (BadPaddingException bpe) {
             throw bpe;
         } catch (GeneralSecurityException gse) {

@@ -359,12 +357,11 @@
                 new Plaintext(contentType,
                     majorVersion, minorVersion, -1, -1L, fragment)
             };
     }
 
-    private Plaintext[] handleUnknownRecord(
-            byte[] header) throws IOException, BadPaddingException {
+    private Plaintext[] handleUnknownRecord() throws IOException, BadPaddingException {
         byte firstByte = header[0];
         byte thirdByte = header[2];
 
         // Does it look like a Version 2 client hello (V2ClientHello)?
         if (((firstByte & 0x80) != 0) && (thirdByte == 1)) {

@@ -402,36 +399,33 @@
 
                 throw new SSLException("Unsupported SSL v2.0 ClientHello");
             }
 
             int msgLen = ((header[0] & 0x7F) << 8) | (header[1] & 0xFF);
-
-            ByteBuffer destination = ByteBuffer.allocate(headerSize + msgLen);
-            destination.put(temporary, 0, headerSize);
-            msgLen -= 3;            // had read 3 bytes of content as header
-            while (msgLen > 0) {
-                int howmuch = Math.min(temporary.length, msgLen);
-                int really = read(is, temporary, 0, howmuch);
-                if (really < 0) {
-                    throw new EOFException("SSL peer shut down incorrectly");
+            if (recordBody.position() == 0) {
+                if (recordBody.capacity() < (headerSize + msgLen)) {
+                    recordBody = ByteBuffer.allocate(headerSize + msgLen);
                 }
-
-                destination.put(temporary, 0, howmuch);
-                msgLen -= howmuch;
+                recordBody.limit(headerSize + msgLen);
+                recordBody.put(header, 0, headerSize);
+            } else {
+                msgLen = recordBody.remaining();
             }
-            destination.flip();
+            msgLen -= 3;            // had read 3 bytes of content as header
+            readFully(msgLen);
+            recordBody.flip();
 
             /*
              * If we can map this into a V3 ClientHello, read and
              * hash the rest of the V2 handshake, turn it into a
              * V3 ClientHello message, and pass it up.
              */
-            destination.position(2);     // exclude the header
-            handshakeHash.receive(destination);
-            destination.position(0);
+            recordBody.position(2);     // exclude the header
+            handshakeHash.receive(recordBody);
+            recordBody.position(0);
 
-            ByteBuffer converted = convertToClientHello(destination);
+            ByteBuffer converted = convertToClientHello(recordBody);
 
             if (SSLLogger.isOn && SSLLogger.isOn("packet")) {
                 SSLLogger.fine(
                         "[Converted] ClientHello", converted);
             }

@@ -447,32 +441,46 @@
 
             throw new SSLException("Unsupported or unrecognized SSL message");
         }
     }
 
-    // Read the exact bytes of data, otherwise, return -1.
-    private static int read(InputStream is,
-            byte[] buffer, int offset, int len) throws IOException {
-        int n = 0;
-        while (n < len) {
-            int readLen = is.read(buffer, offset + n, len - n);
+    // Read the exact bytes of data, otherwise, throw IOException.
+    private int readFully(int len) throws IOException {
+        int end = len + recordBody.position();
+        int off = recordBody.position();
+        try {
+            while (off < end) {
+                off += read(is, recordBody.array(), off, end - off);
+            }
+        } finally {
+            recordBody.position(off);
+        }
+        return len;
+    }
+
+    // Read SSE record header, otherwise, throw IOException.
+    private int readHeader() throws IOException {
+        while (headerOff < headerSize) {
+            headerOff += read(is, header, headerOff, headerSize - headerOff);
+        }
+        return headerSize;
+    }
+
+    private static int read(InputStream is, byte[] buf, int off, int len)  throws IOException {
+        int readLen = is.read(buf, off, len);
             if (readLen < 0) {
                 if (SSLLogger.isOn && SSLLogger.isOn("packet")) {
                     SSLLogger.fine("Raw read: EOF");
                 }
-                return -1;
+            throw new EOFException("SSL peer shut down incorrectly");
             }
 
             if (SSLLogger.isOn && SSLLogger.isOn("packet")) {
-                ByteBuffer bb = ByteBuffer.wrap(buffer, offset + n, readLen);
+            ByteBuffer bb = ByteBuffer.wrap(buf, off, readLen);
                 SSLLogger.fine("Raw read", bb);
             }
-
-            n += readLen;
-        }
-
-        return n;
+        return readLen;
     }
 
     // Try to use up the input stream without impact the performance too much.
     void deplete(boolean tryToRead) throws IOException {
         int remaining = is.available();
< prev index next >