< prev index next >

src/java.base/share/classes/sun/security/ssl/SSLSocketInputRecord.java

Print this page

        

@@ -1,7 +1,7 @@
 /*
- * Copyright (c) 1996, 2014, Oracle and/or its affiliates. All rights reserved.
+ * Copyright (c) 1996, 2018, Oracle and/or its affiliates. All rights reserved.
  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
  *
  * This code is free software; you can redistribute it and/or modify it
  * under the terms of the GNU General Public License version 2 only, as
  * published by the Free Software Foundation.  Oracle designates this

@@ -25,48 +25,51 @@
 
 package sun.security.ssl;
 
 import java.io.*;
 import java.nio.*;
-
+import java.security.GeneralSecurityException;
+import java.util.ArrayList;
 import javax.crypto.BadPaddingException;
-
 import javax.net.ssl.*;
-
-import sun.security.util.HexDumpEncoder;
-
+import sun.security.ssl.SSLCipher.SSLReadCipher;
 
 /**
  * {@code InputRecord} implementation for {@code SSLSocket}.
  *
  * @author David Brownell
  */
 final class SSLSocketInputRecord extends InputRecord implements SSLRecord {
-    private OutputStream deliverStream = null;
-    private byte[] temporary = new byte[1024];
+    private InputStream is = null;
+    private OutputStream os = null;
+    private final byte[] temporary = new byte[1024];
 
     // used by handshake hash computation for handshake fragment
     private byte prevType = -1;
     private int hsMsgOff = 0;
     private int hsMsgLen = 0;
 
     private boolean formatVerified = false;     // SSLv2 ruled out?
 
+    // Cache for incomplete handshake messages.
+    private ByteBuffer handshakeBuffer = null;
+
     private boolean hasHeader = false;          // Had read the record header
 
-    SSLSocketInputRecord() {
-        this.readAuthenticator = MAC.TLS_NULL;
+    SSLSocketInputRecord(HandshakeHash handshakeHash) {
+        super(handshakeHash, SSLReadCipher.nullTlsReadCipher());
     }
 
     @Override
-    int bytesInCompletePacket(InputStream is) throws IOException {
+    int bytesInCompletePacket() throws IOException {
 
         if (!hasHeader) {
             // read exactly one record
             int really = read(is, temporary, 0, headerSize);
             if (really < 0) {
-                throw new EOFException("SSL peer shut down incorrectly");
+                // EOF: peer shut down incorrectly
+                return -1;
             }
             hasHeader = true;
         }
 
         byte byteZero = temporary[0];

@@ -77,19 +80,21 @@
          * ignore the verifications steps, and jump right to the
          * determination.  Otherwise, try one last hueristic to
          * see if it's SSL/TLS.
          */
         if (formatVerified ||
-                (byteZero == ct_handshake) || (byteZero == ct_alert)) {
+                (byteZero == ContentType.HANDSHAKE.id) ||
+                (byteZero == ContentType.ALERT.id)) {
             /*
              * Last sanity check that it's not a wild record
              */
-            ProtocolVersion recordVersion =
-                    ProtocolVersion.valueOf(temporary[1], temporary[2]);
-
-            // check the record version
-            checkRecordVersion(recordVersion, false);
+            if (!ProtocolVersion.isNegotiable(
+                    temporary[1], temporary[2], false, false)) {
+                throw new SSLException("Unrecognized record version " +
+                        ProtocolVersion.nameOf(temporary[1], temporary[2]) +
+                        " , plaintext connection?");
+            }
 
             /*
              * Reasonably sure this is a V3, disable further checks.
              * We can't do the same in the v2 check below, because
              * read still needs to parse/handle the v2 clientHello.

@@ -110,15 +115,16 @@
              * Internals can warn about unsupported SSLv2
              */
             boolean isShort = ((byteZero & 0x80) != 0);
 
             if (isShort && ((temporary[2] == 1) || (temporary[2] == 4))) {
-                ProtocolVersion recordVersion =
-                        ProtocolVersion.valueOf(temporary[3], temporary[4]);
-
-                // check the record version
-                checkRecordVersion(recordVersion, true);
+                if (!ProtocolVersion.isNegotiable(
+                        temporary[3], temporary[4], false, false)) {
+                    throw new SSLException("Unrecognized record version " +
+                            ProtocolVersion.nameOf(temporary[3], temporary[4]) +
+                            " , plaintext connection?");
+                }
 
                 /*
                  * Client or Server Hello
                  */
                 //

@@ -138,14 +144,14 @@
         }
 
         return len;
     }
 
-    // destination.position() is zero.
+    // Note that the input arguments are not used actually.
     @Override
-    Plaintext decode(InputStream is, ByteBuffer destination)
-            throws IOException, BadPaddingException {
+    Plaintext[] decode(ByteBuffer[] srcs, int srcsOffset,
+            int srcsLength) throws IOException, BadPaddingException {
 
         if (isClosed) {
             return null;
         }
 

@@ -165,40 +171,48 @@
             /*
              * The first record must either be a handshake record or an
              * alert message. If it's not, it is either invalid or an
              * SSLv2 message.
              */
-            if ((temporary[0] != ct_handshake) &&
-                (temporary[0] != ct_alert)) {
-
-                plaintext = handleUnknownRecord(is, temporary, destination);
-            }
+            if ((temporary[0] != ContentType.HANDSHAKE.id) &&
+                (temporary[0] != ContentType.ALERT.id)) {
+                hasHeader = false;
+                return handleUnknownRecord(temporary);
         }
-
-        if (plaintext == null) {
-            plaintext = decodeInputRecord(is, temporary, destination);
         }
 
         // The record header should has comsumed.
         hasHeader = false;
+        return decodeInputRecord(temporary);
+    }
 
-        return plaintext;
+    @Override
+    void setReceiverStream(InputStream inputStream) {
+        this.is = inputStream;
     }
 
     @Override
     void setDeliverStream(OutputStream outputStream) {
-        this.deliverStream = outputStream;
+        this.os = outputStream;
     }
 
     // Note that destination may be null
-    private Plaintext decodeInputRecord(InputStream is, byte[] header,
-            ByteBuffer destination) throws IOException, BadPaddingException {
-
-        byte contentType = header[0];
-        byte majorVersion = header[1];
-        byte minorVersion = header[2];
-        int contentLen = ((header[3] & 0xFF) << 8) + (header[4] & 0xFF);
+    private Plaintext[] decodeInputRecord(
+            byte[] header) throws IOException, BadPaddingException {
+        byte contentType = header[0];                   // pos: 0
+        byte majorVersion = header[1];                  // pos: 1
+        byte minorVersion = header[2];                  // pos: 2
+        int contentLen = ((header[3] & 0xFF) << 8) +
+                           (header[4] & 0xFF);          // pos: 3, 4
+
+        if (SSLLogger.isOn && SSLLogger.isOn("record")) {
+            SSLLogger.fine(
+                    "READ: " +
+                    ProtocolVersion.nameOf(majorVersion, minorVersion) +
+                    " " + ContentType.nameOf(contentType) + ", length = " +
+                    contentLen);
+        }
 
         //
         // Check for upper bound.
         //
         // Note: May check packetSize limit in the future.

@@ -208,14 +222,11 @@
         }
 
         //
         // Read a complete record.
         //
-        if (destination == null) {
-            destination = ByteBuffer.allocate(headerSize + contentLen);
-        }  // Otherwise, the destination buffer should have enough room.
-
+        ByteBuffer destination = ByteBuffer.allocate(headerSize + contentLen);
         int dstPos = destination.position();
         destination.put(temporary, 0, headerSize);
         while (contentLen > 0) {
             int howmuch = Math.min(temporary.length, contentLen);
             int really = read(is, temporary, 0, howmuch);

@@ -227,104 +238,117 @@
             contentLen -= howmuch;
         }
         destination.flip();
         destination.position(dstPos + headerSize);
 
-        if (debug != null && Debug.isOn("record")) {
-             System.out.println(Thread.currentThread().getName() +
-                    ", READ: " +
-                    ProtocolVersion.valueOf(majorVersion, minorVersion) +
-                    " " + Record.contentName(contentType) + ", length = " +
+        if (SSLLogger.isOn && SSLLogger.isOn("record")) {
+            SSLLogger.fine(
+                    "READ: " +
+                    ProtocolVersion.nameOf(majorVersion, minorVersion) +
+                    " " + ContentType.nameOf(contentType) + ", length = " +
                     destination.remaining());
         }
 
         //
         // Decrypt the fragment
         //
-        ByteBuffer plaintext =
-            decrypt(readAuthenticator, readCipher, contentType, destination);
-
-        if ((contentType != ct_handshake) && (hsMsgOff != hsMsgLen)) {
+        ByteBuffer fragment;
+        try {
+            Plaintext plaintext =
+                    readCipher.decrypt(contentType, destination, null);
+            fragment = plaintext.fragment;
+            contentType = plaintext.contentType;
+        } catch (BadPaddingException bpe) {
+            throw bpe;
+        } catch (GeneralSecurityException gse) {
+            throw (SSLProtocolException)(new SSLProtocolException(
+                    "Unexpected exception")).initCause(gse);
+        }
+        if (contentType != ContentType.HANDSHAKE.id && hsMsgOff != hsMsgLen) {
             throw new SSLProtocolException(
                     "Expected to get a handshake fragment");
         }
 
         //
-        // handshake hashing
+        // parse handshake messages
         //
-        if (contentType == ct_handshake) {
-            int pltPos = plaintext.position();
-            int pltLim = plaintext.limit();
-            int frgPos = pltPos;
-            for (int remains = plaintext.remaining(); remains > 0;) {
-                int howmuch;
-                byte handshakeType;
-                if (hsMsgOff < hsMsgLen) {
-                    // a fragment of the handshake message
-                    howmuch = Math.min((hsMsgLen - hsMsgOff), remains);
-                    handshakeType = prevType;
-
-                    hsMsgOff += howmuch;
-                    if (hsMsgOff == hsMsgLen) {
-                        // Now is a complete handshake message.
-                        hsMsgOff = 0;
-                        hsMsgLen = 0;
-                    }
-                } else {    // hsMsgOff == hsMsgLen, a new handshake message
-                    handshakeType = plaintext.get();
-                    int handshakeLen = ((plaintext.get() & 0xFF) << 16) |
-                                       ((plaintext.get() & 0xFF) << 8) |
-                                        (plaintext.get() & 0xFF);
-                    plaintext.position(frgPos);
-                    if (remains < (handshakeLen + 1)) { // 1: handshake type
-                        // This handshake message is fragmented.
-                        prevType = handshakeType;
-                        hsMsgOff = remains - 4;         // 4: handshake header
-                        hsMsgLen = handshakeLen;
-                    }
-
-                    howmuch = Math.min(handshakeLen + 4, remains);
-                }
-
-                plaintext.limit(frgPos + howmuch);
-
-                if (handshakeType == HandshakeMessage.ht_hello_request) {
-                    // omitted from handshake hash computation
-                } else if ((handshakeType != HandshakeMessage.ht_finished) &&
-                    (handshakeType != HandshakeMessage.ht_certificate_verify)) {
-
-                    if (handshakeHash == null) {
-                        // used for cache only
-                        handshakeHash = new HandshakeHash(false);
-                    }
-                    handshakeHash.update(plaintext);
+        if (contentType == ContentType.HANDSHAKE.id) {
+            ByteBuffer handshakeFrag = fragment;
+            if ((handshakeBuffer != null) &&
+                    (handshakeBuffer.remaining() != 0)) {
+                ByteBuffer bb = ByteBuffer.wrap(new byte[
+                        handshakeBuffer.remaining() + fragment.remaining()]);
+                bb.put(handshakeBuffer);
+                bb.put(fragment);
+                handshakeFrag = bb.rewind();
+                handshakeBuffer = null;
+            }
+
+            ArrayList<Plaintext> plaintexts = new ArrayList<>(5);
+            while (handshakeFrag.hasRemaining()) {
+                int remaining = handshakeFrag.remaining();
+                if (remaining < handshakeHeaderSize) {
+                    handshakeBuffer = ByteBuffer.wrap(new byte[remaining]);
+                    handshakeBuffer.put(handshakeFrag);
+                    handshakeBuffer.rewind();
+                    break;
+                }
+
+                handshakeFrag.mark();
+                // skip the first byte: handshake type
+                byte handshakeType = handshakeFrag.get();
+                int handshakeBodyLen = Record.getInt24(handshakeFrag);
+                handshakeFrag.reset();
+                int handshakeMessageLen =
+                        handshakeHeaderSize + handshakeBodyLen;
+                if (remaining < handshakeMessageLen) {
+                    handshakeBuffer = ByteBuffer.wrap(new byte[remaining]);
+                    handshakeBuffer.put(handshakeFrag);
+                    handshakeBuffer.rewind();
+                    break;
+                } if (remaining == handshakeMessageLen) {
+                    if (handshakeHash.isHashable(handshakeType)) {
+                        handshakeHash.receive(handshakeFrag);
+                    }
+
+                    plaintexts.add(
+                        new Plaintext(contentType,
+                            majorVersion, minorVersion, -1, -1L, handshakeFrag)
+                    );
+                    break;
                 } else {
-                    // Reserve until this handshake message has been processed.
-                    if (handshakeHash == null) {
-                        // used for cache only
-                        handshakeHash = new HandshakeHash(false);
-                    }
-                    handshakeHash.reserve(plaintext);
+                    int fragPos = handshakeFrag.position();
+                    int fragLim = handshakeFrag.limit();
+                    int nextPos = fragPos + handshakeMessageLen;
+                    handshakeFrag.limit(nextPos);
+
+                    if (handshakeHash.isHashable(handshakeType)) {
+                        handshakeHash.receive(handshakeFrag);
                 }
 
-                plaintext.position(frgPos + howmuch);
-                plaintext.limit(pltLim);
+                    plaintexts.add(
+                        new Plaintext(contentType, majorVersion, minorVersion,
+                            -1, -1L, handshakeFrag.slice())
+                    );
 
-                frgPos += howmuch;
-                remains -= howmuch;
+                    handshakeFrag.position(nextPos);
+                    handshakeFrag.limit(fragLim);
             }
-            plaintext.position(pltPos);
         }
 
-        return new Plaintext(contentType,
-                majorVersion, minorVersion, -1, -1L, plaintext);
-                // recordEpoch, recordSeq, plaintext);
+            return plaintexts.toArray(new Plaintext[0]);
     }
 
-    private Plaintext handleUnknownRecord(InputStream is, byte[] header,
-            ByteBuffer destination) throws IOException, BadPaddingException {
+        return new Plaintext[] {
+                new Plaintext(contentType,
+                    majorVersion, minorVersion, -1, -1L, fragment)
+                    // recordEpoch, recordSeq, plaintext);
+            };
+    }
 
+    private Plaintext[] handleUnknownRecord(
+            byte[] header) throws IOException, BadPaddingException {
         byte firstByte = header[0];
         byte thirdByte = header[2];
 
         // Does it look like a Version 2 client hello (V2ClientHello)?
         if (((firstByte & 0x80) != 0) && (thirdByte == 1)) {

@@ -345,34 +369,29 @@
                  * Looks like a V2 client hello, but not one saying
                  * "let's talk SSLv3".  So we need to send an SSLv2
                  * error message, one that's treated as fatal by
                  * clients (Otherwise we'll hang.)
                  */
-                deliverStream.write(SSLRecord.v2NoCipher);      // SSLv2Hello
+                os.write(SSLRecord.v2NoCipher);      // SSLv2Hello
 
-                if (debug != null) {
-                    if (Debug.isOn("record")) {
-                         System.out.println(Thread.currentThread().getName() +
+                if (SSLLogger.isOn) {
+                    if (SSLLogger.isOn("record")) {
+                         SSLLogger.fine(
                                 "Requested to negotiate unsupported SSLv2!");
                     }
 
-                    if (Debug.isOn("packet")) {
-                        Debug.printHex(
-                                "[Raw write]: length = " +
-                                SSLRecord.v2NoCipher.length,
-                                SSLRecord.v2NoCipher);
+                    if (SSLLogger.isOn("packet")) {
+                        SSLLogger.fine("Raw write", SSLRecord.v2NoCipher);
                     }
                 }
 
                 throw new SSLException("Unsupported SSL v2.0 ClientHello");
             }
 
             int msgLen = ((header[0] & 0x7F) << 8) | (header[1] & 0xFF);
 
-            if (destination == null) {
-                destination = ByteBuffer.allocate(headerSize + msgLen);
-            }
+            ByteBuffer destination = ByteBuffer.allocate(headerSize + msgLen);
             destination.put(temporary, 0, headerSize);
             msgLen -= 3;            // had read 3 bytes of content as header
             while (msgLen > 0) {
                 int howmuch = Math.min(temporary.length, msgLen);
                 int really = read(is, temporary, 0, howmuch);

@@ -389,27 +408,24 @@
              * If we can map this into a V3 ClientHello, read and
              * hash the rest of the V2 handshake, turn it into a
              * V3 ClientHello message, and pass it up.
              */
             destination.position(2);     // exclude the header
-
-            if (handshakeHash == null) {
-                // used for cache only
-                handshakeHash = new HandshakeHash(false);
-            }
-            handshakeHash.update(destination);
+            handshakeHash.receive(destination);
             destination.position(0);
 
             ByteBuffer converted = convertToClientHello(destination);
 
-            if (debug != null && Debug.isOn("packet")) {
-                 Debug.printHex(
+            if (SSLLogger.isOn && SSLLogger.isOn("packet")) {
+                SSLLogger.fine(
                         "[Converted] ClientHello", converted);
             }
 
-            return new Plaintext(ct_handshake,
-                majorVersion, minorVersion, -1, -1L, converted);
+            return new Plaintext[] {
+                    new Plaintext(ContentType.HANDSHAKE.id,
+                    majorVersion, minorVersion, -1, -1L, converted)
+                };
         } else {
             if (((firstByte & 0x80) != 0) && (thirdByte == 4)) {
                 throw new SSLException("SSL V2.0 servers are not supported.");
             }
 

@@ -422,17 +438,19 @@
             byte[] buffer, int offset, int len) throws IOException {
         int n = 0;
         while (n < len) {
             int readLen = is.read(buffer, offset + n, len - n);
             if (readLen < 0) {
+                if (SSLLogger.isOn && SSLLogger.isOn("packet")) {
+                    SSLLogger.fine("Raw read: EOF");
+                }
                 return -1;
             }
 
-            if (debug != null && Debug.isOn("packet")) {
-                 Debug.printHex(
-                        "[Raw read]: length = " + readLen,
-                        buffer, offset + n, readLen);
+            if (SSLLogger.isOn && SSLLogger.isOn("packet")) {
+                ByteBuffer bb = ByteBuffer.wrap(buffer, offset + n, readLen);
+                SSLLogger.fine("Raw read", bb);
             }
 
             n += readLen;
         }
 
< prev index next >