< prev index next >

src/java.base/share/classes/sun/security/ssl/SSLEngineInputRecord.java

Print this page

        

@@ -1,7 +1,7 @@
 /*
- * Copyright (c) 1996, 2014, Oracle and/or its affiliates. All rights reserved.
+ * Copyright (c) 1996, 2018, Oracle and/or its affiliates. All rights reserved.
  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
  *
  * This code is free software; you can redistribute it and/or modify it
  * under the terms of the GNU General Public License version 2 only, as
  * published by the Free Software Foundation.  Oracle designates this

@@ -25,17 +25,15 @@
 
 package sun.security.ssl;
 
 import java.io.*;
 import java.nio.*;
-
+import java.security.GeneralSecurityException;
+import java.util.ArrayList;
 import javax.crypto.BadPaddingException;
-
 import javax.net.ssl.*;
-
-import sun.security.util.HexDumpEncoder;
-
+import sun.security.ssl.SSLCipher.SSLReadCipher;
 
 /**
  * {@code InputRecord} implementation for {@code SSLEngine}.
  */
 final class SSLEngineInputRecord extends InputRecord implements SSLRecord {

@@ -44,31 +42,34 @@
     private int hsMsgOff = 0;
     private int hsMsgLen = 0;
 
     private boolean formatVerified = false;     // SSLv2 ruled out?
 
-    SSLEngineInputRecord() {
-        this.readAuthenticator = MAC.TLS_NULL;
+    // Cache for incomplete handshake messages.
+    private ByteBuffer handshakeBuffer = null;
+
+    SSLEngineInputRecord(HandshakeHash handshakeHash) {
+        super(handshakeHash, SSLReadCipher.nullTlsReadCipher());
     }
 
     @Override
     int estimateFragmentSize(int packetSize) {
-        int macLen = 0;
-        if (readAuthenticator instanceof MAC) {
-            macLen = ((MAC)readAuthenticator).MAClen();
-        }
-
         if (packetSize > 0) {
-            return readCipher.estimateFragmentSize(
-                    packetSize, macLen, headerSize);
+            return readCipher.estimateFragmentSize(packetSize, headerSize);
         } else {
             return Record.maxDataSize;
         }
     }
 
     @Override
-    int bytesInCompletePacket(ByteBuffer packet) throws SSLException {
+    int bytesInCompletePacket(
+        ByteBuffer[] srcs, int srcsOffset, int srcsLength) throws IOException {
+
+        return bytesInCompletePacket(srcs[srcsOffset]);
+    }
+
+    private int bytesInCompletePacket(ByteBuffer packet) throws SSLException {
         /*
          * SSLv2 length field is in bytes 0/1
          * SSLv3/TLS length field is in bytes 3/4
          */
         if (packet.remaining() < 5) {

@@ -85,19 +86,23 @@
          * ignore the verifications steps, and jump right to the
          * determination.  Otherwise, try one last hueristic to
          * see if it's SSL/TLS.
          */
         if (formatVerified ||
-                (byteZero == ct_handshake) || (byteZero == ct_alert)) {
+                (byteZero == ContentType.HANDSHAKE.id) ||
+                (byteZero == ContentType.ALERT.id)) {
             /*
              * Last sanity check that it's not a wild record
              */
-            ProtocolVersion recordVersion = ProtocolVersion.valueOf(
-                                    packet.get(pos + 1), packet.get(pos + 2));
-
-            // check the record version
-            checkRecordVersion(recordVersion, false);
+            byte majorVersion = packet.get(pos + 1);
+            byte minorVersion = packet.get(pos + 2);
+            if (!ProtocolVersion.isNegotiable(
+                    majorVersion, minorVersion, false, false)) {
+                throw new SSLException("Unrecognized record version " +
+                        ProtocolVersion.nameOf(majorVersion, minorVersion) +
+                        " , plaintext connection?");
+            }
 
             /*
              * Reasonably sure this is a V3, disable further checks.
              * We can't do the same in the v2 check below, because
              * read still needs to parse/handle the v2 clientHello.

@@ -121,15 +126,18 @@
             boolean isShort = ((byteZero & 0x80) != 0);
 
             if (isShort &&
                     ((packet.get(pos + 2) == 1) || packet.get(pos + 2) == 4)) {
 
-                ProtocolVersion recordVersion = ProtocolVersion.valueOf(
-                                    packet.get(pos + 3), packet.get(pos + 4));
-
-                // check the record version
-                checkRecordVersion(recordVersion, true);
+                byte majorVersion = packet.get(pos + 3);
+                byte minorVersion = packet.get(pos + 4);
+                if (!ProtocolVersion.isNegotiable(
+                        majorVersion, minorVersion, false, false)) {
+                    throw new SSLException("Unrecognized record version " +
+                            ProtocolVersion.nameOf(majorVersion, minorVersion) +
+                            " , plaintext connection?");
+                }
 
                 /*
                  * Client or Server Hello
                  */
                 int mask = (isShort ? 0x7F : 0x3F);

@@ -145,41 +153,33 @@
 
         return len;
     }
 
     @Override
-    void checkRecordVersion(ProtocolVersion recordVersion,
-            boolean allowSSL20Hello) throws SSLException {
-
-        if (recordVersion.maybeDTLSProtocol()) {
-            throw new SSLException(
-                    "Unrecognized record version " + recordVersion +
-                    " , DTLS packet?");
-        }
+    Plaintext[] decode(ByteBuffer[] srcs, int srcsOffset,
+            int srcsLength) throws IOException, BadPaddingException {
+        if (srcs == null || srcs.length == 0 || srcsLength == 0) {
+            return new Plaintext[0];
+        } else if (srcsLength == 1) {
+            return decode(srcs[srcsOffset]);
+        } else {
+            ByteBuffer packet = extract(srcs,
+                    srcsOffset, srcsLength, SSLRecord.headerSize);
 
-        // Check if the record version is too old.
-        if ((recordVersion.v < ProtocolVersion.MIN.v)) {
-            // if it's not SSLv2, we're out of here.
-            if (!allowSSL20Hello ||
-                    (recordVersion.v != ProtocolVersion.SSL20Hello.v)) {
-                throw new SSLException(
-                    "Unsupported record version " + recordVersion);
-            }
+            return decode(packet);
         }
     }
 
-    @Override
-    Plaintext decode(ByteBuffer packet)
+    private Plaintext[] decode(ByteBuffer packet)
             throws IOException, BadPaddingException {
 
         if (isClosed) {
             return null;
         }
 
-        if (debug != null && Debug.isOn("packet")) {
-             Debug.printHex(
-                    "[Raw read]: length = " + packet.remaining(), packet);
+        if (SSLLogger.isOn && SSLLogger.isOn("packet")) {
+            SSLLogger.fine("Raw read", packet);
         }
 
         // The caller should have validated the record.
         if (!formatVerified) {
             formatVerified = true;

@@ -189,39 +189,37 @@
              * alert message. If it's not, it is either invalid or an
              * SSLv2 message.
              */
             int pos = packet.position();
             byte byteZero = packet.get(pos);
-            if (byteZero != ct_handshake && byteZero != ct_alert) {
+            if (byteZero != ContentType.HANDSHAKE.id &&
+                    byteZero != ContentType.ALERT.id) {
                 return handleUnknownRecord(packet);
             }
         }
 
         return decodeInputRecord(packet);
     }
 
-    private Plaintext decodeInputRecord(ByteBuffer packet)
+    private Plaintext[] decodeInputRecord(ByteBuffer packet)
             throws IOException, BadPaddingException {
-
         //
         // The packet should be a complete record, or more.
         //
-
         int srcPos = packet.position();
         int srcLim = packet.limit();
 
         byte contentType = packet.get();                   // pos: 0
         byte majorVersion = packet.get();                  // pos: 1
         byte minorVersion = packet.get();                  // pos: 2
-        int contentLen = ((packet.get() & 0xFF) << 8) +
-                          (packet.get() & 0xFF);           // pos: 3, 4
+        int contentLen = Record.getInt16(packet);          // pos: 3, 4
 
-        if (debug != null && Debug.isOn("record")) {
-             System.out.println(Thread.currentThread().getName() +
-                    ", READ: " +
-                    ProtocolVersion.valueOf(majorVersion, minorVersion) +
-                    " " + Record.contentName(contentType) + ", length = " +
+        if (SSLLogger.isOn && SSLLogger.isOn("record")) {
+            SSLLogger.fine(
+                    "READ: " +
+                    ProtocolVersion.nameOf(majorVersion, minorVersion) +
+                    " " + ContentType.nameOf(contentType) + ", length = " +
                     contentLen);
         }
 
         //
         // Check for upper bound.

@@ -233,11 +231,11 @@
         }
 
         //
         // check for handshake fragment
         //
-        if ((contentType != ct_handshake) && (hsMsgOff != hsMsgLen)) {
+        if (contentType != ContentType.HANDSHAKE.id && hsMsgOff != hsMsgLen) {
             throw new SSLProtocolException(
                     "Expected to get a handshake fragment");
         }
 
         //

@@ -245,96 +243,105 @@
         //
         int recLim = srcPos + SSLRecord.headerSize + contentLen;
         packet.limit(recLim);
         packet.position(srcPos + SSLRecord.headerSize);
 
-        ByteBuffer plaintext;
+        ByteBuffer fragment;
         try {
-            plaintext =
-                decrypt(readAuthenticator, readCipher, contentType, packet);
+            Plaintext plaintext =
+                    readCipher.decrypt(contentType, packet, null);
+            fragment = plaintext.fragment;
+            contentType = plaintext.contentType;
+        } catch (BadPaddingException bpe) {
+            throw bpe;
+        } catch (GeneralSecurityException gse) {
+            throw (SSLProtocolException)(new SSLProtocolException(
+                    "Unexpected exception")).initCause(gse);
         } finally {
             // comsume a complete record
             packet.limit(srcLim);
             packet.position(recLim);
         }
 
         //
-        // handshake hashing
+        // parse handshake messages
         //
-        if (contentType == ct_handshake) {
-            int pltPos = plaintext.position();
-            int pltLim = plaintext.limit();
-            int frgPos = pltPos;
-            for (int remains = plaintext.remaining(); remains > 0;) {
-                int howmuch;
-                byte handshakeType;
-                if (hsMsgOff < hsMsgLen) {
-                    // a fragment of the handshake message
-                    howmuch = Math.min((hsMsgLen - hsMsgOff), remains);
-                    handshakeType = prevType;
-
-                    hsMsgOff += howmuch;
-                    if (hsMsgOff == hsMsgLen) {
-                        // Now is a complete handshake message.
-                        hsMsgOff = 0;
-                        hsMsgLen = 0;
-                    }
-                } else {    // hsMsgOff == hsMsgLen, a new handshake message
-                    handshakeType = plaintext.get();
-                    int handshakeLen = ((plaintext.get() & 0xFF) << 16) |
-                                       ((plaintext.get() & 0xFF) << 8) |
-                                        (plaintext.get() & 0xFF);
-                    plaintext.position(frgPos);
-                    if (remains < (handshakeLen + 4)) { // 4: handshake header
-                        // This handshake message is fragmented.
-                        prevType = handshakeType;
-                        hsMsgOff = remains - 4;         // 4: handshake header
-                        hsMsgLen = handshakeLen;
-                    }
-
-                    howmuch = Math.min(handshakeLen + 4, remains);
-                }
-
-                plaintext.limit(frgPos + howmuch);
-
-                if (handshakeType == HandshakeMessage.ht_hello_request) {
-                    // omitted from handshake hash computation
-                } else if ((handshakeType != HandshakeMessage.ht_finished) &&
-                    (handshakeType != HandshakeMessage.ht_certificate_verify)) {
-
-                    if (handshakeHash == null) {
-                        // used for cache only
-                        handshakeHash = new HandshakeHash(false);
-                    }
-                    handshakeHash.update(plaintext);
+        if (contentType == ContentType.HANDSHAKE.id) {
+            ByteBuffer handshakeFrag = fragment;
+            if ((handshakeBuffer != null) &&
+                    (handshakeBuffer.remaining() != 0)) {
+                ByteBuffer bb = ByteBuffer.wrap(new byte[
+                        handshakeBuffer.remaining() + fragment.remaining()]);
+                bb.put(handshakeBuffer);
+                bb.put(fragment);
+                handshakeFrag = bb.rewind();
+                handshakeBuffer = null;
+            }
+
+            ArrayList<Plaintext> plaintexts = new ArrayList<>(5);
+            while (handshakeFrag.hasRemaining()) {
+                int remaining = handshakeFrag.remaining();
+                if (remaining < handshakeHeaderSize) {
+                    handshakeBuffer = ByteBuffer.wrap(new byte[remaining]);
+                    handshakeBuffer.put(handshakeFrag);
+                    handshakeBuffer.rewind();
+                    break;
+                }
+
+                handshakeFrag.mark();
+                // skip the first byte: handshake type
+                byte handshakeType = handshakeFrag.get();
+                int handshakeBodyLen = Record.getInt24(handshakeFrag);
+                handshakeFrag.reset();
+                int handshakeMessageLen =
+                        handshakeHeaderSize + handshakeBodyLen;
+                if (remaining < handshakeMessageLen) {
+                    handshakeBuffer = ByteBuffer.wrap(new byte[remaining]);
+                    handshakeBuffer.put(handshakeFrag);
+                    handshakeBuffer.rewind();
+                    break;
+                } if (remaining == handshakeMessageLen) {
+                    if (handshakeHash.isHashable(handshakeType)) {
+                        handshakeHash.receive(handshakeFrag);
+                    }
+
+                    plaintexts.add(
+                        new Plaintext(contentType,
+                            majorVersion, minorVersion, -1, -1L, handshakeFrag)
+                    );
+                    break;
                 } else {
-                    // Reserve until this handshake message has been processed.
-                    if (handshakeHash == null) {
-                        // used for cache only
-                        handshakeHash = new HandshakeHash(false);
-                    }
-                    handshakeHash.reserve(plaintext);
+                    int fragPos = handshakeFrag.position();
+                    int fragLim = handshakeFrag.limit();
+                    int nextPos = fragPos + handshakeMessageLen;
+                    handshakeFrag.limit(nextPos);
+
+                    if (handshakeHash.isHashable(handshakeType)) {
+                        handshakeHash.receive(handshakeFrag);
                 }
 
-                plaintext.position(frgPos + howmuch);
-                plaintext.limit(pltLim);
+                    plaintexts.add(
+                        new Plaintext(contentType, majorVersion, minorVersion,
+                            -1, -1L, handshakeFrag.slice())
+                    );
 
-                frgPos += howmuch;
-                remains -= howmuch;
+                    handshakeFrag.position(nextPos);
+                    handshakeFrag.limit(fragLim);
+                }
             }
 
-            plaintext.position(pltPos);
+            return plaintexts.toArray(new Plaintext[0]);
         }
 
-        return new Plaintext(contentType,
-                majorVersion, minorVersion, -1, -1L, plaintext);
-                // recordEpoch, recordSeq, plaintext);
+        return new Plaintext[] {
+            new Plaintext(contentType,
+                majorVersion, minorVersion, -1, -1L, fragment)
+        };
     }
 
-    private Plaintext handleUnknownRecord(ByteBuffer packet)
+    private Plaintext[] handleUnknownRecord(ByteBuffer packet)
             throws IOException, BadPaddingException {
-
         //
         // The packet should be a complete record.
         //
         int srcPos = packet.position();
         int srcLim = packet.limit();

@@ -361,12 +368,12 @@
                  * Looks like a V2 client hello, but not one saying
                  * "let's talk SSLv3".  So we need to send an SSLv2
                  * error message, one that's treated as fatal by
                  * clients (Otherwise we'll hang.)
                  */
-                if (debug != null && Debug.isOn("record")) {
-                     System.out.println(Thread.currentThread().getName() +
+                if (SSLLogger.isOn && SSLLogger.isOn("record")) {
+                   SSLLogger.fine(
                             "Requested to negotiate unsupported SSLv2!");
                 }
 
                 // hack code, the exception is caught in SSLEngineImpl
                 // so that SSLv2 error message can be delivered properly.

@@ -378,32 +385,28 @@
              * If we can map this into a V3 ClientHello, read and
              * hash the rest of the V2 handshake, turn it into a
              * V3 ClientHello message, and pass it up.
              */
             packet.position(srcPos + 2);        // exclude the header
-
-            if (handshakeHash == null) {
-                // used for cache only
-                handshakeHash = new HandshakeHash(false);
-            }
-            handshakeHash.update(packet);
+            handshakeHash.receive(packet);
             packet.position(srcPos);
 
             ByteBuffer converted = convertToClientHello(packet);
 
-            if (debug != null && Debug.isOn("packet")) {
-                 Debug.printHex(
+            if (SSLLogger.isOn && SSLLogger.isOn("packet")) {
+                SSLLogger.fine(
                         "[Converted] ClientHello", converted);
             }
 
-            return new Plaintext(ct_handshake,
-                majorVersion, minorVersion, -1, -1L, converted);
+            return new Plaintext[] {
+                    new Plaintext(ContentType.HANDSHAKE.id,
+                    majorVersion, minorVersion, -1, -1L, converted)
+                };
         } else {
             if (((firstByte & 0x80) != 0) && (thirdByte == 4)) {
                 throw new SSLException("SSL V2.0 servers are not supported.");
             }
 
             throw new SSLException("Unsupported or unrecognized SSL message");
         }
     }
-
 }
< prev index next >