1 /*
   2  * Copyright (c) 1996, 2018, Oracle and/or its affiliates. All rights reserved.
   3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
   4  *
   5  * This code is free software; you can redistribute it and/or modify it
   6  * under the terms of the GNU General Public License version 2 only, as
   7  * published by the Free Software Foundation.  Oracle designates this
   8  * particular file as subject to the "Classpath" exception as provided
   9  * by Oracle in the LICENSE file that accompanied this code.
  10  *
  11  * This code is distributed in the hope that it will be useful, but WITHOUT
  12  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
  13  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
  14  * version 2 for more details (a copy is included in the LICENSE file that
  15  * accompanied this code).
  16  *
  17  * You should have received a copy of the GNU General Public License version
  18  * 2 along with this work; if not, write to the Free Software Foundation,
  19  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
  20  *
  21  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
  22  * or visit www.oracle.com if you need additional information or have any
  23  * questions.
  24  */
  25 
  26 package sun.security.ssl;
  27 
  28 import java.io.*;
  29 import java.nio.*;
  30 import java.security.GeneralSecurityException;
  31 import java.util.ArrayList;
  32 import javax.crypto.BadPaddingException;
  33 import javax.net.ssl.*;
  34 import sun.security.ssl.SSLCipher.SSLReadCipher;
  35 
  36 /**
  37  * {@code InputRecord} implementation for {@code SSLEngine}.
  38  */
  39 final class SSLEngineInputRecord extends InputRecord implements SSLRecord {
  40     // used by handshake hash computation for handshake fragment
  41     private byte prevType = -1;
  42     private int hsMsgOff = 0;
  43     private int hsMsgLen = 0;
  44 
  45     private boolean formatVerified = false;     // SSLv2 ruled out?
  46 
  47     // Cache for incomplete handshake messages.
  48     private ByteBuffer handshakeBuffer = null;
  49 
  50     SSLEngineInputRecord(HandshakeHash handshakeHash) {
  51         super(handshakeHash, SSLReadCipher.nullTlsReadCipher());
  52     }
  53 
  54     @Override
  55     int estimateFragmentSize(int packetSize) {
  56         if (packetSize > 0) {
  57             return readCipher.estimateFragmentSize(packetSize, headerSize);
  58         } else {
  59             return Record.maxDataSize;
  60         }
  61     }
  62 
  63     @Override
  64     int bytesInCompletePacket(
  65         ByteBuffer[] srcs, int srcsOffset, int srcsLength) throws IOException {
  66 
  67         return bytesInCompletePacket(srcs[srcsOffset]);
  68     }
  69 
  70     private int bytesInCompletePacket(ByteBuffer packet) throws SSLException {
  71         /*
  72          * SSLv2 length field is in bytes 0/1
  73          * SSLv3/TLS length field is in bytes 3/4
  74          */
  75         if (packet.remaining() < 5) {
  76             return -1;
  77         }
  78 
  79         int pos = packet.position();
  80         byte byteZero = packet.get(pos);
  81 
  82         int len = 0;
  83 
  84         /*
  85          * If we have already verified previous packets, we can
  86          * ignore the verifications steps, and jump right to the
  87          * determination.  Otherwise, try one last hueristic to
  88          * see if it's SSL/TLS.
  89          */
  90         if (formatVerified ||
  91                 (byteZero == ContentType.HANDSHAKE.id) ||
  92                 (byteZero == ContentType.ALERT.id)) {
  93             /*
  94              * Last sanity check that it's not a wild record
  95              */
  96             byte majorVersion = packet.get(pos + 1);
  97             byte minorVersion = packet.get(pos + 2);
  98             if (!ProtocolVersion.isNegotiable(
  99                     majorVersion, minorVersion, false, false)) {
 100                 throw new SSLException("Unrecognized record version " +
 101                         ProtocolVersion.nameOf(majorVersion, minorVersion) +
 102                         " , plaintext connection?");
 103             }
 104 
 105             /*
 106              * Reasonably sure this is a V3, disable further checks.
 107              * We can't do the same in the v2 check below, because
 108              * read still needs to parse/handle the v2 clientHello.
 109              */
 110             formatVerified = true;
 111 
 112             /*
 113              * One of the SSLv3/TLS message types.
 114              */
 115             len = ((packet.get(pos + 3) & 0xFF) << 8) +
 116                    (packet.get(pos + 4) & 0xFF) + headerSize;
 117 
 118         } else {
 119             /*
 120              * Must be SSLv2 or something unknown.
 121              * Check if it's short (2 bytes) or
 122              * long (3) header.
 123              *
 124              * Internals can warn about unsupported SSLv2
 125              */
 126             boolean isShort = ((byteZero & 0x80) != 0);
 127 
 128             if (isShort &&
 129                     ((packet.get(pos + 2) == 1) || packet.get(pos + 2) == 4)) {
 130 
 131                 byte majorVersion = packet.get(pos + 3);
 132                 byte minorVersion = packet.get(pos + 4);
 133                 if (!ProtocolVersion.isNegotiable(
 134                         majorVersion, minorVersion, false, false)) {
 135                     throw new SSLException("Unrecognized record version " +
 136                             ProtocolVersion.nameOf(majorVersion, minorVersion) +
 137                             " , plaintext connection?");
 138                 }
 139 
 140                 /*
 141                  * Client or Server Hello
 142                  */
 143                 int mask = (isShort ? 0x7F : 0x3F);
 144                 len = ((byteZero & mask) << 8) +
 145                         (packet.get(pos + 1) & 0xFF) + (isShort ? 2 : 3);
 146 
 147             } else {
 148                 // Gobblygook!
 149                 throw new SSLException(
 150                         "Unrecognized SSL message, plaintext connection?");
 151             }
 152         }
 153 
 154         return len;
 155     }
 156 
 157     @Override
 158     Plaintext[] decode(ByteBuffer[] srcs, int srcsOffset,
 159             int srcsLength) throws IOException, BadPaddingException {
 160         if (srcs == null || srcs.length == 0 || srcsLength == 0) {
 161             return new Plaintext[0];
 162         } else if (srcsLength == 1) {
 163             return decode(srcs[srcsOffset]);
 164         } else {
 165             ByteBuffer packet = extract(srcs,
 166                     srcsOffset, srcsLength, SSLRecord.headerSize);
 167 
 168             return decode(packet);
 169         }
 170     }
 171 
 172     private Plaintext[] decode(ByteBuffer packet)
 173             throws IOException, BadPaddingException {
 174 
 175         if (isClosed) {
 176             return null;
 177         }
 178 
 179         if (SSLLogger.isOn && SSLLogger.isOn("packet")) {
 180             SSLLogger.fine("Raw read", packet);
 181         }
 182 
 183         // The caller should have validated the record.
 184         if (!formatVerified) {
 185             formatVerified = true;
 186 
 187             /*
 188              * The first record must either be a handshake record or an
 189              * alert message. If it's not, it is either invalid or an
 190              * SSLv2 message.
 191              */
 192             int pos = packet.position();
 193             byte byteZero = packet.get(pos);
 194             if (byteZero != ContentType.HANDSHAKE.id &&
 195                     byteZero != ContentType.ALERT.id) {
 196                 return handleUnknownRecord(packet);
 197             }
 198         }
 199 
 200         return decodeInputRecord(packet);
 201     }
 202 
 203     private Plaintext[] decodeInputRecord(ByteBuffer packet)
 204             throws IOException, BadPaddingException {
 205         //
 206         // The packet should be a complete record, or more.
 207         //
 208         int srcPos = packet.position();
 209         int srcLim = packet.limit();
 210 
 211         byte contentType = packet.get();                   // pos: 0
 212         byte majorVersion = packet.get();                  // pos: 1
 213         byte minorVersion = packet.get();                  // pos: 2
 214         int contentLen = Record.getInt16(packet);          // pos: 3, 4
 215 
 216         if (SSLLogger.isOn && SSLLogger.isOn("record")) {
 217             SSLLogger.fine(
 218                     "READ: " +
 219                     ProtocolVersion.nameOf(majorVersion, minorVersion) +
 220                     " " + ContentType.nameOf(contentType) + ", length = " +
 221                     contentLen);
 222         }
 223 
 224         //
 225         // Check for upper bound.
 226         //
 227         // Note: May check packetSize limit in the future.
 228         if (contentLen < 0 || contentLen > maxLargeRecordSize - headerSize) {
 229             throw new SSLProtocolException(
 230                 "Bad input record size, TLSCiphertext.length = " + contentLen);
 231         }
 232 
 233         //
 234         // check for handshake fragment
 235         //
 236         if (contentType != ContentType.HANDSHAKE.id && hsMsgOff != hsMsgLen) {
 237             throw new SSLProtocolException(
 238                     "Expected to get a handshake fragment");
 239         }
 240 
 241         //
 242         // Decrypt the fragment
 243         //
 244         int recLim = srcPos + SSLRecord.headerSize + contentLen;
 245         packet.limit(recLim);
 246         packet.position(srcPos + SSLRecord.headerSize);
 247 
 248         ByteBuffer fragment;
 249         try {
 250             Plaintext plaintext =
 251                     readCipher.decrypt(contentType, packet, null);
 252             fragment = plaintext.fragment;
 253             contentType = plaintext.contentType;
 254         } catch (BadPaddingException bpe) {
 255             throw bpe;
 256         } catch (GeneralSecurityException gse) {
 257             throw (SSLProtocolException)(new SSLProtocolException(
 258                     "Unexpected exception")).initCause(gse);
 259         } finally {
 260             // comsume a complete record
 261             packet.limit(srcLim);
 262             packet.position(recLim);
 263         }
 264 
 265         //
 266         // parse handshake messages
 267         //
 268         if (contentType == ContentType.HANDSHAKE.id) {
 269             ByteBuffer handshakeFrag = fragment;
 270             if ((handshakeBuffer != null) &&
 271                     (handshakeBuffer.remaining() != 0)) {
 272                 ByteBuffer bb = ByteBuffer.wrap(new byte[
 273                         handshakeBuffer.remaining() + fragment.remaining()]);
 274                 bb.put(handshakeBuffer);
 275                 bb.put(fragment);
 276                 handshakeFrag = bb.rewind();
 277                 handshakeBuffer = null;
 278             }
 279 
 280             ArrayList<Plaintext> plaintexts = new ArrayList<>(5);
 281             while (handshakeFrag.hasRemaining()) {
 282                 int remaining = handshakeFrag.remaining();
 283                 if (remaining < handshakeHeaderSize) {
 284                     handshakeBuffer = ByteBuffer.wrap(new byte[remaining]);
 285                     handshakeBuffer.put(handshakeFrag);
 286                     handshakeBuffer.rewind();
 287                     break;
 288                 }
 289 
 290                 handshakeFrag.mark();
 291                 // skip the first byte: handshake type
 292                 byte handshakeType = handshakeFrag.get();
 293                 int handshakeBodyLen = Record.getInt24(handshakeFrag);
 294                 handshakeFrag.reset();
 295                 int handshakeMessageLen =
 296                         handshakeHeaderSize + handshakeBodyLen;
 297                 if (remaining < handshakeMessageLen) {
 298                     handshakeBuffer = ByteBuffer.wrap(new byte[remaining]);
 299                     handshakeBuffer.put(handshakeFrag);
 300                     handshakeBuffer.rewind();
 301                     break;
 302                 } if (remaining == handshakeMessageLen) {
 303                     if (handshakeHash.isHashable(handshakeType)) {
 304                         handshakeHash.receive(handshakeFrag);
 305                     }
 306 
 307                     plaintexts.add(
 308                         new Plaintext(contentType,
 309                             majorVersion, minorVersion, -1, -1L, handshakeFrag)
 310                     );
 311                     break;
 312                 } else {
 313                     int fragPos = handshakeFrag.position();
 314                     int fragLim = handshakeFrag.limit();
 315                     int nextPos = fragPos + handshakeMessageLen;
 316                     handshakeFrag.limit(nextPos);
 317 
 318                     if (handshakeHash.isHashable(handshakeType)) {
 319                         handshakeHash.receive(handshakeFrag);
 320                     }
 321 
 322                     plaintexts.add(
 323                         new Plaintext(contentType, majorVersion, minorVersion,
 324                             -1, -1L, handshakeFrag.slice())
 325                     );
 326 
 327                     handshakeFrag.position(nextPos);
 328                     handshakeFrag.limit(fragLim);
 329                 }
 330             }
 331 
 332             return plaintexts.toArray(new Plaintext[0]);
 333         }
 334 
 335         return new Plaintext[] {
 336             new Plaintext(contentType,
 337                 majorVersion, minorVersion, -1, -1L, fragment)
 338         };
 339     }
 340 
 341     private Plaintext[] handleUnknownRecord(ByteBuffer packet)
 342             throws IOException, BadPaddingException {
 343         //
 344         // The packet should be a complete record.
 345         //
 346         int srcPos = packet.position();
 347         int srcLim = packet.limit();
 348 
 349         byte firstByte = packet.get(srcPos);
 350         byte thirdByte = packet.get(srcPos + 2);
 351 
 352         // Does it look like a Version 2 client hello (V2ClientHello)?
 353         if (((firstByte & 0x80) != 0) && (thirdByte == 1)) {
 354             /*
 355              * If SSLv2Hello is not enabled, throw an exception.
 356              */
 357             if (helloVersion != ProtocolVersion.SSL20Hello) {
 358                 throw new SSLHandshakeException("SSLv2Hello is not enabled");
 359             }
 360 
 361             byte majorVersion = packet.get(srcPos + 3);
 362             byte minorVersion = packet.get(srcPos + 4);
 363 
 364             if ((majorVersion == ProtocolVersion.SSL20Hello.major) &&
 365                 (minorVersion == ProtocolVersion.SSL20Hello.minor)) {
 366 
 367                 /*
 368                  * Looks like a V2 client hello, but not one saying
 369                  * "let's talk SSLv3".  So we need to send an SSLv2
 370                  * error message, one that's treated as fatal by
 371                  * clients (Otherwise we'll hang.)
 372                  */
 373                 if (SSLLogger.isOn && SSLLogger.isOn("record")) {
 374                    SSLLogger.fine(
 375                             "Requested to negotiate unsupported SSLv2!");
 376                 }
 377 
 378                 // hack code, the exception is caught in SSLEngineImpl
 379                 // so that SSLv2 error message can be delivered properly.
 380                 throw new UnsupportedOperationException(        // SSLv2Hello
 381                         "Unsupported SSL v2.0 ClientHello");
 382             }
 383 
 384             /*
 385              * If we can map this into a V3 ClientHello, read and
 386              * hash the rest of the V2 handshake, turn it into a
 387              * V3 ClientHello message, and pass it up.
 388              */
 389             packet.position(srcPos + 2);        // exclude the header
 390             handshakeHash.receive(packet);
 391             packet.position(srcPos);
 392 
 393             ByteBuffer converted = convertToClientHello(packet);
 394 
 395             if (SSLLogger.isOn && SSLLogger.isOn("packet")) {
 396                 SSLLogger.fine(
 397                         "[Converted] ClientHello", converted);
 398             }
 399 
 400             return new Plaintext[] {
 401                     new Plaintext(ContentType.HANDSHAKE.id,
 402                     majorVersion, minorVersion, -1, -1L, converted)
 403                 };
 404         } else {
 405             if (((firstByte & 0x80) != 0) && (thirdByte == 4)) {
 406                 throw new SSLException("SSL V2.0 servers are not supported.");
 407             }
 408 
 409             throw new SSLException("Unsupported or unrecognized SSL message");
 410         }
 411     }
 412 }