1 /*
   2  * Copyright (c) 1996, 2016, Oracle and/or its affiliates. All rights reserved.
   3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
   4  *
   5  * This code is free software; you can redistribute it and/or modify it
   6  * under the terms of the GNU General Public License version 2 only, as
   7  * published by the Free Software Foundation.  Oracle designates this
   8  * particular file as subject to the "Classpath" exception as provided
   9  * by Oracle in the LICENSE file that accompanied this code.
  10  *
  11  * This code is distributed in the hope that it will be useful, but WITHOUT
  12  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
  13  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
  14  * version 2 for more details (a copy is included in the LICENSE file that
  15  * accompanied this code).
  16  *
  17  * You should have received a copy of the GNU General Public License version
  18  * 2 along with this work; if not, write to the Free Software Foundation,
  19  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
  20  *
  21  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
  22  * or visit www.oracle.com if you need additional information or have any
  23  * questions.
  24  */
  25 
  26 package sun.security.ssl;
  27 
  28 import java.io.*;
  29 import java.nio.*;
  30 import java.util.*;
  31 
  32 import javax.crypto.BadPaddingException;
  33 
  34 import javax.net.ssl.*;
  35 
  36 import sun.security.util.HexDumpEncoder;
  37 import static sun.security.ssl.Ciphertext.RecordType;
  38 
  39 /**
  40  * DTLS {@code OutputRecord} implementation for {@code SSLEngine}.
  41  */
  42 final class DTLSOutputRecord extends OutputRecord implements DTLSRecord {
  43 
  44     private DTLSFragmenter fragmenter = null;
  45 
  46     int                 writeEpoch;
  47 
  48     int                 prevWriteEpoch;
  49     Authenticator       prevWriteAuthenticator;
  50     CipherBox           prevWriteCipher;
  51 
  52     private LinkedList<RecordMemo> alertMemos = new LinkedList<>();
  53 
  54     DTLSOutputRecord() {
  55         this.writeAuthenticator = new MAC(true);
  56 
  57         this.writeEpoch = 0;
  58         this.prevWriteEpoch = 0;
  59         this.prevWriteCipher = CipherBox.NULL;
  60         this.prevWriteAuthenticator = new MAC(true);
  61 
  62         this.packetSize = DTLSRecord.maxRecordSize;
  63         this.protocolVersion = ProtocolVersion.DEFAULT_DTLS;
  64     }
  65 
  66     @Override
  67     void changeWriteCiphers(Authenticator writeAuthenticator,
  68             CipherBox writeCipher) throws IOException {
  69 
  70         encodeChangeCipherSpec();
  71 
  72         prevWriteCipher.dispose();
  73 
  74         this.prevWriteAuthenticator = this.writeAuthenticator;
  75         this.prevWriteCipher = this.writeCipher;
  76         this.prevWriteEpoch = this.writeEpoch;
  77 
  78         this.writeAuthenticator = writeAuthenticator;
  79         this.writeCipher = writeCipher;
  80         this.writeEpoch++;
  81 
  82         this.isFirstAppOutputRecord = true;
  83 
  84         // set the epoch number
  85         this.writeAuthenticator.setEpochNumber(this.writeEpoch);
  86     }
  87 
  88     @Override
  89     void encodeAlert(byte level, byte description) throws IOException {
  90         RecordMemo memo = new RecordMemo();
  91 
  92         memo.contentType = Record.ct_alert;
  93         memo.majorVersion = protocolVersion.major;
  94         memo.minorVersion = protocolVersion.minor;
  95         memo.encodeEpoch = writeEpoch;
  96         memo.encodeCipher = writeCipher;
  97         memo.encodeAuthenticator = writeAuthenticator;
  98 
  99         memo.fragment = new byte[2];
 100         memo.fragment[0] = level;
 101         memo.fragment[1] = description;
 102 
 103         alertMemos.add(memo);
 104     }
 105 
 106     @Override
 107     void encodeChangeCipherSpec() throws IOException {
 108         if (fragmenter == null) {
 109            fragmenter = new DTLSFragmenter();
 110         }
 111         fragmenter.queueUpChangeCipherSpec();
 112     }
 113 
 114     @Override
 115     void encodeHandshake(byte[] source,
 116             int offset, int length) throws IOException {
 117 
 118         if (firstMessage) {
 119             firstMessage = false;
 120         }
 121 
 122         if (fragmenter == null) {
 123            fragmenter = new DTLSFragmenter();
 124         }
 125 
 126         fragmenter.queueUpHandshake(source, offset, length);
 127     }
 128 
 129     @Override
 130     Ciphertext encode(ByteBuffer[] sources, int offset, int length,
 131             ByteBuffer destination) throws IOException {
 132 
 133         if (writeAuthenticator.seqNumOverflow()) {
 134             if (debug != null && Debug.isOn("ssl")) {
 135                 System.out.println(Thread.currentThread().getName() +
 136                     ", sequence number extremely close to overflow " +
 137                     "(2^64-1 packets). Closing connection.");
 138             }
 139 
 140             throw new SSLHandshakeException("sequence number overflow");
 141         }
 142 
 143         // not apply to handshake message
 144         int macLen = 0;
 145         if (writeAuthenticator instanceof MAC) {
 146             macLen = ((MAC)writeAuthenticator).MAClen();
 147         }
 148 
 149         int fragLen;
 150         if (packetSize > 0) {
 151             fragLen = Math.min(maxRecordSize, packetSize);
 152             fragLen = writeCipher.calculateFragmentSize(
 153                     fragLen, macLen, headerSize);
 154 
 155             fragLen = Math.min(fragLen, Record.maxDataSize);
 156         } else {
 157             fragLen = Record.maxDataSize;
 158         }
 159 
 160         if (fragmentSize > 0) {
 161             fragLen = Math.min(fragLen, fragmentSize);
 162         }
 163 
 164         int dstPos = destination.position();
 165         int dstLim = destination.limit();
 166         int dstContent = dstPos + headerSize +
 167                                 writeCipher.getExplicitNonceSize();
 168         destination.position(dstContent);
 169 
 170         int remains = Math.min(fragLen, destination.remaining());
 171         fragLen = 0;
 172         int srcsLen = offset + length;
 173         for (int i = offset; (i < srcsLen) && (remains > 0); i++) {
 174             int amount = Math.min(sources[i].remaining(), remains);
 175             int srcLimit = sources[i].limit();
 176             sources[i].limit(sources[i].position() + amount);
 177             destination.put(sources[i]);
 178             sources[i].limit(srcLimit);         // restore the limit
 179             remains -= amount;
 180             fragLen += amount;
 181         }
 182 
 183         destination.limit(destination.position());
 184         destination.position(dstContent);
 185 
 186         if ((debug != null) && Debug.isOn("record")) {
 187             System.out.println(Thread.currentThread().getName() +
 188                     ", WRITE: " + protocolVersion + " " +
 189                     Record.contentName(Record.ct_application_data) +
 190                     ", length = " + destination.remaining());
 191         }
 192 
 193         // Encrypt the fragment and wrap up a record.
 194         long recordSN = encrypt(writeAuthenticator, writeCipher,
 195                 Record.ct_application_data, destination,
 196                 dstPos, dstLim, headerSize,
 197                 protocolVersion, true);
 198 
 199         if ((debug != null) && Debug.isOn("packet")) {
 200             ByteBuffer temporary = destination.duplicate();
 201             temporary.limit(temporary.position());
 202             temporary.position(dstPos);
 203             Debug.printHex(
 204                     "[Raw write]: length = " + temporary.remaining(),
 205                     temporary);
 206         }
 207 
 208         // remain the limit unchanged
 209         destination.limit(dstLim);
 210 
 211         return new Ciphertext(RecordType.RECORD_APPLICATION_DATA, recordSN);
 212     }
 213 
 214     @Override
 215     Ciphertext acquireCiphertext(ByteBuffer destination) throws IOException {
 216         if (alertMemos != null && !alertMemos.isEmpty()) {
 217             RecordMemo memo = alertMemos.pop();
 218 
 219             int macLen = 0;
 220             if (memo.encodeAuthenticator instanceof MAC) {
 221                 macLen = ((MAC)memo.encodeAuthenticator).MAClen();
 222             }
 223 
 224             int dstPos = destination.position();
 225             int dstLim = destination.limit();
 226             int dstContent = dstPos + headerSize +
 227                                 writeCipher.getExplicitNonceSize();
 228             destination.position(dstContent);
 229 
 230             destination.put(memo.fragment);
 231 
 232             destination.limit(destination.position());
 233             destination.position(dstContent);
 234 
 235             if ((debug != null) && Debug.isOn("record")) {
 236                 System.out.println(Thread.currentThread().getName() +
 237                         ", WRITE: " + protocolVersion + " " +
 238                         Record.contentName(Record.ct_alert) +
 239                         ", length = " + destination.remaining());
 240             }
 241 
 242             // Encrypt the fragment and wrap up a record.
 243             long recordSN = encrypt(memo.encodeAuthenticator, memo.encodeCipher,
 244                     Record.ct_alert, destination, dstPos, dstLim, headerSize,
 245                     ProtocolVersion.valueOf(memo.majorVersion,
 246                             memo.minorVersion), true);
 247 
 248             if ((debug != null) && Debug.isOn("packet")) {
 249                 ByteBuffer temporary = destination.duplicate();
 250                 temporary.limit(temporary.position());
 251                 temporary.position(dstPos);
 252                 Debug.printHex(
 253                         "[Raw write]: length = " + temporary.remaining(),
 254                         temporary);
 255             }
 256 
 257             // remain the limit unchanged
 258             destination.limit(dstLim);
 259 
 260             return new Ciphertext(RecordType.RECORD_ALERT, recordSN);
 261         }
 262 
 263         if (fragmenter != null) {
 264             return fragmenter.acquireCiphertext(destination);
 265         }
 266 
 267         return null;
 268     }
 269 
 270     @Override
 271     boolean isEmpty() {
 272         return ((fragmenter == null) || fragmenter.isEmpty()) &&
 273                ((alertMemos == null) || alertMemos.isEmpty());
 274     }
 275 
 276     @Override
 277     void initHandshaker() {
 278         // clean up
 279         fragmenter = null;
 280     }
 281 
 282     @Override
 283     void launchRetransmission() {
 284         // Note: Please don't retransmit if there are handshake messages
 285         // or alerts waiting in the queue.
 286         if (((alertMemos == null) || alertMemos.isEmpty()) &&
 287                 (fragmenter != null) && fragmenter.isRetransmittable()) {
 288             fragmenter.setRetransmission();
 289         }
 290     }
 291 
 292     // buffered record fragment
 293     private static class RecordMemo {
 294         byte            contentType;
 295         byte            majorVersion;
 296         byte            minorVersion;
 297         int             encodeEpoch;
 298         CipherBox       encodeCipher;
 299         Authenticator   encodeAuthenticator;
 300 
 301         byte[]          fragment;
 302     }
 303 
 304     private static class HandshakeMemo extends RecordMemo {
 305         byte            handshakeType;
 306         int             messageSequence;
 307         int             acquireOffset;
 308     }
 309 
 310     private final class DTLSFragmenter {
 311         private LinkedList<RecordMemo> handshakeMemos = new LinkedList<>();
 312         private int acquireIndex = 0;
 313         private int messageSequence = 0;
 314         private boolean flightIsReady = false;
 315 
 316         // Per section 4.1.1, RFC 6347:
 317         //
 318         // If repeated retransmissions do not result in a response, and the
 319         // PMTU is unknown, subsequent retransmissions SHOULD back off to a
 320         // smaller record size, fragmenting the handshake message as
 321         // appropriate.
 322         //
 323         // In this implementation, two times of retransmits would be attempted
 324         // before backing off.  The back off is supported only if the packet
 325         // size is bigger than 256 bytes.
 326         private int retransmits = 2;            // attemps of retransmits
 327 
 328         void queueUpChangeCipherSpec() {
 329 
 330             // Cleanup if a new flight starts.
 331             if (flightIsReady) {
 332                 handshakeMemos.clear();
 333                 acquireIndex = 0;
 334                 flightIsReady = false;
 335             }
 336 
 337             RecordMemo memo = new RecordMemo();
 338 
 339             memo.contentType = Record.ct_change_cipher_spec;
 340             memo.majorVersion = protocolVersion.major;
 341             memo.minorVersion = protocolVersion.minor;
 342             memo.encodeEpoch = writeEpoch;
 343             memo.encodeCipher = writeCipher;
 344             memo.encodeAuthenticator = writeAuthenticator;
 345 
 346             memo.fragment = new byte[1];
 347             memo.fragment[0] = 1;
 348 
 349             handshakeMemos.add(memo);
 350         }
 351 
 352         void queueUpHandshake(byte[] buf,
 353                 int offset, int length) throws IOException {
 354 
 355             // Cleanup if a new flight starts.
 356             if (flightIsReady) {
 357                 handshakeMemos.clear();
 358                 acquireIndex = 0;
 359                 flightIsReady = false;
 360             }
 361 
 362             HandshakeMemo memo = new HandshakeMemo();
 363 
 364             memo.contentType = Record.ct_handshake;
 365             memo.majorVersion = protocolVersion.major;
 366             memo.minorVersion = protocolVersion.minor;
 367             memo.encodeEpoch = writeEpoch;
 368             memo.encodeCipher = writeCipher;
 369             memo.encodeAuthenticator = writeAuthenticator;
 370 
 371             memo.handshakeType = buf[offset];
 372             memo.messageSequence = messageSequence++;
 373             memo.acquireOffset = 0;
 374             memo.fragment = new byte[length - 4];       // 4: header size
 375                                                         //    1: HandshakeType
 376                                                         //    3: message length
 377             System.arraycopy(buf, offset + 4, memo.fragment, 0, length - 4);
 378 
 379             handshakeHashing(memo, memo.fragment);
 380             handshakeMemos.add(memo);
 381 
 382             if ((memo.handshakeType == HandshakeMessage.ht_client_hello) ||
 383                 (memo.handshakeType == HandshakeMessage.ht_hello_request) ||
 384                 (memo.handshakeType ==
 385                         HandshakeMessage.ht_hello_verify_request) ||
 386                 (memo.handshakeType == HandshakeMessage.ht_server_hello_done) ||
 387                 (memo.handshakeType == HandshakeMessage.ht_finished)) {
 388 
 389                 flightIsReady = true;
 390             }
 391         }
 392 
 393         Ciphertext acquireCiphertext(ByteBuffer dstBuf) throws IOException {
 394             if (isEmpty()) {
 395                 if (isRetransmittable()) {
 396                     setRetransmission();    // configure for retransmission
 397                 } else {
 398                     return null;
 399                 }
 400             }
 401 
 402             RecordMemo memo = handshakeMemos.get(acquireIndex);
 403             HandshakeMemo hsMemo = null;
 404             if (memo.contentType == Record.ct_handshake) {
 405                 hsMemo = (HandshakeMemo)memo;
 406             }
 407 
 408             int macLen = 0;
 409             if (memo.encodeAuthenticator instanceof MAC) {
 410                 macLen = ((MAC)memo.encodeAuthenticator).MAClen();
 411             }
 412 
 413             // ChangeCipherSpec message is pretty small.  Don't worry about
 414             // the fragmentation of ChangeCipherSpec record.
 415             int fragLen;
 416             if (packetSize > 0) {
 417                 fragLen = Math.min(maxRecordSize, packetSize);
 418                 fragLen = memo.encodeCipher.calculateFragmentSize(
 419                         fragLen, macLen, 25);   // 25: header size
 420                                                 //   13: DTLS record
 421                                                 //   12: DTLS handshake message
 422                 fragLen = Math.min(fragLen, Record.maxDataSize);
 423             } else {
 424                 fragLen = Record.maxDataSize;
 425             }
 426 
 427             if (fragmentSize > 0) {
 428                 fragLen = Math.min(fragLen, fragmentSize);
 429             }
 430 
 431             int dstPos = dstBuf.position();
 432             int dstLim = dstBuf.limit();
 433             int dstContent = dstPos + headerSize +
 434                                     memo.encodeCipher.getExplicitNonceSize();
 435             dstBuf.position(dstContent);
 436 
 437             if (hsMemo != null) {
 438                 fragLen = Math.min(fragLen,
 439                         (hsMemo.fragment.length - hsMemo.acquireOffset));
 440 
 441                 dstBuf.put(hsMemo.handshakeType);
 442                 dstBuf.put((byte)((hsMemo.fragment.length >> 16) & 0xFF));
 443                 dstBuf.put((byte)((hsMemo.fragment.length >> 8) & 0xFF));
 444                 dstBuf.put((byte)(hsMemo.fragment.length & 0xFF));
 445                 dstBuf.put((byte)((hsMemo.messageSequence >> 8) & 0xFF));
 446                 dstBuf.put((byte)(hsMemo.messageSequence & 0xFF));
 447                 dstBuf.put((byte)((hsMemo.acquireOffset >> 16) & 0xFF));
 448                 dstBuf.put((byte)((hsMemo.acquireOffset >> 8) & 0xFF));
 449                 dstBuf.put((byte)(hsMemo.acquireOffset & 0xFF));
 450                 dstBuf.put((byte)((fragLen >> 16) & 0xFF));
 451                 dstBuf.put((byte)((fragLen >> 8) & 0xFF));
 452                 dstBuf.put((byte)(fragLen & 0xFF));
 453                 dstBuf.put(hsMemo.fragment, hsMemo.acquireOffset, fragLen);
 454             } else {
 455                 fragLen = Math.min(fragLen, memo.fragment.length);
 456                 dstBuf.put(memo.fragment, 0, fragLen);
 457             }
 458 
 459             dstBuf.limit(dstBuf.position());
 460             dstBuf.position(dstContent);
 461 
 462             if ((debug != null) && Debug.isOn("record")) {
 463                 System.out.println(Thread.currentThread().getName() +
 464                         ", WRITE: " + protocolVersion + " " +
 465                         Record.contentName(memo.contentType) +
 466                         ", length = " + dstBuf.remaining());
 467             }
 468 
 469             // Encrypt the fragment and wrap up a record.
 470             long recordSN = encrypt(memo.encodeAuthenticator, memo.encodeCipher,
 471                     memo.contentType, dstBuf,
 472                     dstPos, dstLim, headerSize,
 473                     ProtocolVersion.valueOf(memo.majorVersion,
 474                             memo.minorVersion), true);
 475 
 476             if ((debug != null) && Debug.isOn("packet")) {
 477                 ByteBuffer temporary = dstBuf.duplicate();
 478                 temporary.limit(temporary.position());
 479                 temporary.position(dstPos);
 480                 Debug.printHex(
 481                         "[Raw write]: length = " + temporary.remaining(),
 482                         temporary);
 483             }
 484 
 485             // remain the limit unchanged
 486             dstBuf.limit(dstLim);
 487 
 488             // Reset the fragmentation offset.
 489             if (hsMemo != null) {
 490                 hsMemo.acquireOffset += fragLen;
 491                 if (hsMemo.acquireOffset == hsMemo.fragment.length) {
 492                     acquireIndex++;
 493                 }
 494 
 495                 return new Ciphertext(RecordType.valueOf(
 496                         hsMemo.contentType, hsMemo.handshakeType), recordSN);
 497             } else {
 498                 acquireIndex++;
 499                 return new Ciphertext(
 500                         RecordType.RECORD_CHANGE_CIPHER_SPEC, recordSN);
 501             }
 502         }
 503 
 504         private void handshakeHashing(HandshakeMemo hsFrag, byte[] hsBody) {
 505 
 506             byte hsType = hsFrag.handshakeType;
 507             if ((hsType == HandshakeMessage.ht_hello_request) ||
 508                 (hsType == HandshakeMessage.ht_hello_verify_request)) {
 509 
 510                 // omitted from handshake hash computation
 511                 return;
 512             }
 513 
 514             if ((hsFrag.messageSequence == 0) &&
 515                 (hsType == HandshakeMessage.ht_client_hello)) {
 516 
 517                 // omit initial ClientHello message
 518                 //
 519                 //  2: ClientHello.client_version
 520                 // 32: ClientHello.random
 521                 int sidLen = hsBody[34];
 522 
 523                 if (sidLen == 0) {      // empty session_id, initial handshake
 524                     return;
 525                 }
 526             }
 527 
 528             // calculate the DTLS header
 529             byte[] temporary = new byte[12];    // 12: handshake header size
 530 
 531             // Handshake.msg_type
 532             temporary[0] = hsFrag.handshakeType;
 533 
 534             // Handshake.length
 535             temporary[1] = (byte)((hsBody.length >> 16) & 0xFF);
 536             temporary[2] = (byte)((hsBody.length >> 8) & 0xFF);
 537             temporary[3] = (byte)(hsBody.length & 0xFF);
 538 
 539             // Handshake.message_seq
 540             temporary[4] = (byte)((hsFrag.messageSequence >> 8) & 0xFF);
 541             temporary[5] = (byte)(hsFrag.messageSequence & 0xFF);
 542 
 543             // Handshake.fragment_offset
 544             temporary[6] = 0;
 545             temporary[7] = 0;
 546             temporary[8] = 0;
 547 
 548             // Handshake.fragment_length
 549             temporary[9] = temporary[1];
 550             temporary[10] = temporary[2];
 551             temporary[11] = temporary[3];
 552 
 553             if ((hsType != HandshakeMessage.ht_finished) &&
 554                 (hsType != HandshakeMessage.ht_certificate_verify)) {
 555 
 556                 handshakeHash.update(temporary, 0, 12);
 557                 handshakeHash.update(hsBody, 0, hsBody.length);
 558             } else {
 559                 // Reserve until this handshake message has been processed.
 560                 handshakeHash.reserve(temporary, 0, 12);
 561                 handshakeHash.reserve(hsBody, 0, hsBody.length);
 562             }
 563 
 564         }
 565 
 566         boolean isEmpty() {
 567             if (!flightIsReady || handshakeMemos.isEmpty() ||
 568                     acquireIndex >= handshakeMemos.size()) {
 569                 return true;
 570             }
 571 
 572             return false;
 573         }
 574 
 575         boolean isRetransmittable() {
 576             return (flightIsReady && !handshakeMemos.isEmpty() &&
 577                                 (acquireIndex >= handshakeMemos.size()));
 578         }
 579 
 580         private void setRetransmission() {
 581             acquireIndex = 0;
 582             for (RecordMemo memo : handshakeMemos) {
 583                 if (memo instanceof HandshakeMemo) {
 584                     HandshakeMemo hmemo = (HandshakeMemo)memo;
 585                     hmemo.acquireOffset = 0;
 586                 }
 587             }
 588 
 589             // Shrink packet size if:
 590             // 1. maximum fragment size is allowed, in which case the packet
 591             //    size is configured bigger than maxRecordSize;
 592             // 2. maximum packet is bigger than 256 bytes;
 593             // 3. two times of retransmits have been attempted.
 594             if ((packetSize <= maxRecordSize) &&
 595                     (packetSize > 256) && ((retransmits--) <= 0)) {
 596 
 597                 // shrink packet size
 598                 shrinkPacketSize();
 599                 retransmits = 2;        // attemps of retransmits
 600             }
 601         }
 602 
 603         private void shrinkPacketSize() {
 604             packetSize = Math.max(256, packetSize / 2);
 605         }
 606     }
 607 }