1 /*
   2  * Copyright (c) 1996, 2014, Oracle and/or its affiliates. All rights reserved.
   3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
   4  *
   5  * This code is free software; you can redistribute it and/or modify it
   6  * under the terms of the GNU General Public License version 2 only, as
   7  * published by the Free Software Foundation.  Oracle designates this
   8  * particular file as subject to the "Classpath" exception as provided
   9  * by Oracle in the LICENSE file that accompanied this code.
  10  *
  11  * This code is distributed in the hope that it will be useful, but WITHOUT
  12  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
  13  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
  14  * version 2 for more details (a copy is included in the LICENSE file that
  15  * accompanied this code).
  16  *
  17  * You should have received a copy of the GNU General Public License version
  18  * 2 along with this work; if not, write to the Free Software Foundation,
  19  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
  20  *
  21  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
  22  * or visit www.oracle.com if you need additional information or have any
  23  * questions.
  24  */
  25 
  26 package sun.security.ssl;
  27 
  28 import java.io.*;
  29 import java.nio.*;
  30 import java.util.*;
  31 
  32 import javax.crypto.BadPaddingException;
  33 
  34 import javax.net.ssl.*;
  35 
  36 import sun.security.util.HexDumpEncoder;
  37 import static sun.security.ssl.Ciphertext.RecordType;
  38 
  39 /**
  40  * DTLS {@code OutputRecord} implementation for {@code SSLEngine}.
  41  */
  42 final class DTLSOutputRecord extends OutputRecord implements DTLSRecord {
  43 
  44     private DTLSFragmenter fragmenter = null;
  45 
  46     int                 writeEpoch;
  47 
  48     int                 prevWriteEpoch;
  49     Authenticator       prevWriteAuthenticator;
  50     CipherBox           prevWriteCipher;
  51 
  52     private LinkedList<RecordMemo> alertMemos = new LinkedList<>();
  53 
  54     DTLSOutputRecord() {
  55         this.writeAuthenticator = new MAC(true);
  56 
  57         this.writeEpoch = 0;
  58         this.prevWriteEpoch = 0;
  59         this.prevWriteCipher = CipherBox.NULL;
  60         this.prevWriteAuthenticator = new MAC(true);
  61 
  62         this.packetSize = DTLSRecord.maxRecordSize;
  63         this.protocolVersion = ProtocolVersion.DEFAULT_DTLS;
  64     }
  65 
  66     @Override
  67     void changeWriteCiphers(Authenticator writeAuthenticator,
  68             CipherBox writeCipher) throws IOException {
  69 
  70         encodeChangeCipherSpec();
  71 
  72         prevWriteCipher.dispose();
  73 
  74         this.prevWriteAuthenticator = this.writeAuthenticator;
  75         this.prevWriteCipher = this.writeCipher;
  76         this.prevWriteEpoch = this.writeEpoch;
  77 
  78         this.writeAuthenticator = writeAuthenticator;
  79         this.writeCipher = writeCipher;
  80         this.writeEpoch++;
  81 
  82         this.isFirstAppOutputRecord = true;
  83 
  84         // set the epoch number
  85         this.writeAuthenticator.setEpochNumber(this.writeEpoch);
  86     }
  87 
  88     @Override
  89     void encodeAlert(byte level, byte description) throws IOException {
  90         RecordMemo memo = new RecordMemo();
  91 
  92         memo.contentType = Record.ct_alert;
  93         memo.majorVersion = protocolVersion.major;
  94         memo.minorVersion = protocolVersion.minor;
  95         memo.encodeEpoch = writeEpoch;
  96         memo.encodeCipher = writeCipher;
  97         memo.encodeAuthenticator = writeAuthenticator;
  98 
  99         memo.fragment = new byte[2];
 100         memo.fragment[0] = level;
 101         memo.fragment[1] = description;
 102 
 103         alertMemos.add(memo);
 104     }
 105 
 106     @Override
 107     void encodeChangeCipherSpec() throws IOException {
 108         if (fragmenter == null) {
 109            fragmenter = new DTLSFragmenter();
 110         }
 111         fragmenter.queueUpChangeCipherSpec();
 112     }
 113 
 114     @Override
 115     void encodeHandshake(byte[] source,
 116             int offset, int length) throws IOException {
 117 
 118         if (firstMessage) {
 119             firstMessage = false;
 120         }
 121 
 122         if (fragmenter == null) {
 123            fragmenter = new DTLSFragmenter();
 124         }
 125 
 126         fragmenter.queueUpHandshake(source, offset, length);
 127     }
 128 
 129     @Override
 130     Ciphertext encode(ByteBuffer[] sources, int offset, int length,
 131             ByteBuffer destination) throws IOException {
 132 
 133         if (writeAuthenticator.seqNumOverflow()) {
 134             if (debug != null && Debug.isOn("ssl")) {
 135                 System.out.println(Thread.currentThread().getName() +
 136                     ", sequence number extremely close to overflow " +
 137                     "(2^64-1 packets). Closing connection.");
 138             }
 139 
 140             throw new SSLHandshakeException("sequence number overflow");
 141         }
 142 
 143         // not apply to handshake message
 144         int macLen = 0;
 145         if (writeAuthenticator instanceof MAC) {
 146             macLen = ((MAC)writeAuthenticator).MAClen();
 147         }
 148 
 149         int fragLen;
 150         if (packetSize > 0) {
 151             fragLen = Math.min(maxRecordSize, packetSize);
 152             fragLen = writeCipher.calculateFragmentSize(
 153                     fragLen, macLen, headerSize);
 154 
 155             fragLen = Math.min(fragLen, Record.maxDataSize);
 156         } else {
 157             fragLen = Record.maxDataSize;
 158         }
 159 
 160         if (fragmentSize > 0) {
 161             fragLen = Math.min(fragLen, fragmentSize);
 162         }
 163 
 164         int dstPos = destination.position();
 165         int dstLim = destination.limit();
 166         int dstContent = dstPos + headerSize +
 167                                 writeCipher.getExplicitNonceSize();
 168         destination.position(dstContent);
 169 
 170         int remains = Math.min(fragLen, destination.remaining());
 171         fragLen = 0;
 172         int srcsLen = offset + length;
 173         for (int i = offset; (i < srcsLen) && (remains > 0); i++) {
 174             int amount = Math.min(sources[i].remaining(), remains);
 175             int srcLimit = sources[i].limit();
 176             sources[i].limit(sources[i].position() + amount);
 177             destination.put(sources[i]);
 178             sources[i].limit(srcLimit);         // restore the limit
 179             remains -= amount;
 180             fragLen += amount;
 181         }
 182 
 183         destination.limit(destination.position());
 184         destination.position(dstContent);
 185 
 186         if ((debug != null) && Debug.isOn("record")) {
 187             System.out.println(Thread.currentThread().getName() +
 188                     ", WRITE: " + protocolVersion + " " +
 189                     Record.contentName(Record.ct_application_data) +
 190                     ", length = " + destination.remaining());
 191         }
 192 
 193         // Encrypt the fragment and wrap up a record.
 194         long recordSN = encrypt(writeAuthenticator, writeCipher,
 195                 Record.ct_application_data, destination,
 196                 dstPos, dstLim, headerSize,
 197                 protocolVersion, true);
 198 
 199         if ((debug != null) && Debug.isOn("packet")) {
 200             ByteBuffer temporary = destination.duplicate();
 201             temporary.limit(temporary.position());
 202             temporary.position(dstPos);
 203             Debug.printHex(
 204                     "[Raw write]: length = " + temporary.remaining(),
 205                     temporary);
 206         }
 207 
 208         // remain the limit unchanged
 209         destination.limit(dstLim);
 210 
 211         return new Ciphertext(RecordType.RECORD_APPLICATION_DATA, recordSN);
 212     }
 213 
 214     @Override
 215     Ciphertext acquireCiphertext(ByteBuffer destination) throws IOException {
 216         if (alertMemos != null && !alertMemos.isEmpty()) {
 217             RecordMemo memo = alertMemos.pop();
 218 
 219             int macLen = 0;
 220             if (memo.encodeAuthenticator instanceof MAC) {
 221                 macLen = ((MAC)memo.encodeAuthenticator).MAClen();
 222             }
 223 
 224             int dstPos = destination.position();
 225             int dstLim = destination.limit();
 226             int dstContent = dstPos + headerSize +
 227                                 writeCipher.getExplicitNonceSize();
 228             destination.position(dstContent);
 229 
 230             destination.put(memo.fragment);
 231 
 232             destination.limit(destination.position());
 233             destination.position(dstContent);
 234 
 235             if ((debug != null) && Debug.isOn("record")) {
 236                 System.out.println(Thread.currentThread().getName() +
 237                         ", WRITE: " + protocolVersion + " " +
 238                         Record.contentName(Record.ct_alert) +
 239                         ", length = " + destination.remaining());
 240             }
 241 
 242             // Encrypt the fragment and wrap up a record.
 243             long recordSN = encrypt(memo.encodeAuthenticator, memo.encodeCipher,
 244                     Record.ct_alert, destination, dstPos, dstLim, headerSize,
 245                     ProtocolVersion.valueOf(memo.majorVersion,
 246                             memo.minorVersion), true);
 247 
 248             if ((debug != null) && Debug.isOn("packet")) {
 249                 ByteBuffer temporary = destination.duplicate();
 250                 temporary.limit(temporary.position());
 251                 temporary.position(dstPos);
 252                 Debug.printHex(
 253                         "[Raw write]: length = " + temporary.remaining(),
 254                         temporary);
 255             }
 256 
 257             // remain the limit unchanged
 258             destination.limit(dstLim);
 259 
 260             return new Ciphertext(RecordType.RECORD_ALERT, recordSN);
 261         }
 262 
 263         if (fragmenter != null) {
 264             return fragmenter.acquireCiphertext(destination);
 265         }
 266 
 267         return null;
 268     }
 269 
 270     @Override
 271     boolean isEmpty() {
 272         return ((fragmenter == null) || fragmenter.isEmpty()) &&
 273                ((alertMemos == null) || alertMemos.isEmpty());
 274     }
 275 
 276     @Override
 277     void initHandshaker() {
 278         // clean up
 279         fragmenter = null;
 280     }
 281 
 282     // buffered record fragment
 283     private static class RecordMemo {
 284         byte            contentType;
 285         byte            majorVersion;
 286         byte            minorVersion;
 287         int             encodeEpoch;
 288         CipherBox       encodeCipher;
 289         Authenticator   encodeAuthenticator;
 290 
 291         byte[]          fragment;
 292     }
 293 
 294     private static class HandshakeMemo extends RecordMemo {
 295         byte            handshakeType;
 296         int             messageSequence;
 297         int             acquireOffset;
 298     }
 299 
 300     private final class DTLSFragmenter {
 301         private LinkedList<RecordMemo> handshakeMemos = new LinkedList<>();
 302         private int acquireIndex = 0;
 303         private int messageSequence = 0;
 304         private boolean flightIsReady = false;
 305 
 306         // Per section 4.1.1, RFC 6347:
 307         //
 308         // If repeated retransmissions do not result in a response, and the
 309         // PMTU is unknown, subsequent retransmissions SHOULD back off to a
 310         // smaller record size, fragmenting the handshake message as
 311         // appropriate.
 312         //
 313         // In this implementation, two times of retransmits would be attempted
 314         // before backing off.  The back off is supported only if the packet
 315         // size is bigger than 256 bytes.
 316         private int retransmits = 2;            // attemps of retransmits
 317 
 318         void queueUpChangeCipherSpec() {
 319 
 320             // Cleanup if a new flight starts.
 321             if (flightIsReady) {
 322                 handshakeMemos.clear();
 323                 acquireIndex = 0;
 324                 flightIsReady = false;
 325             }
 326 
 327             RecordMemo memo = new RecordMemo();
 328 
 329             memo.contentType = Record.ct_change_cipher_spec;
 330             memo.majorVersion = protocolVersion.major;
 331             memo.minorVersion = protocolVersion.minor;
 332             memo.encodeEpoch = writeEpoch;
 333             memo.encodeCipher = writeCipher;
 334             memo.encodeAuthenticator = writeAuthenticator;
 335 
 336             memo.fragment = new byte[1];
 337             memo.fragment[0] = 1;
 338 
 339             handshakeMemos.add(memo);
 340         }
 341 
 342         void queueUpHandshake(byte[] buf,
 343                 int offset, int length) throws IOException {
 344 
 345             // Cleanup if a new flight starts.
 346             if (flightIsReady) {
 347                 handshakeMemos.clear();
 348                 acquireIndex = 0;
 349                 flightIsReady = false;
 350             }
 351 
 352             HandshakeMemo memo = new HandshakeMemo();
 353 
 354             memo.contentType = Record.ct_handshake;
 355             memo.majorVersion = protocolVersion.major;
 356             memo.minorVersion = protocolVersion.minor;
 357             memo.encodeEpoch = writeEpoch;
 358             memo.encodeCipher = writeCipher;
 359             memo.encodeAuthenticator = writeAuthenticator;
 360 
 361             memo.handshakeType = buf[offset];
 362             memo.messageSequence = messageSequence++;
 363             memo.acquireOffset = 0;
 364             memo.fragment = new byte[length - 4];       // 4: header size
 365                                                         //    1: HandshakeType
 366                                                         //    3: message length
 367             System.arraycopy(buf, offset + 4, memo.fragment, 0, length - 4);
 368 
 369             handshakeHashing(memo, memo.fragment);
 370             handshakeMemos.add(memo);
 371 
 372             if ((memo.handshakeType == HandshakeMessage.ht_client_hello) ||
 373                 (memo.handshakeType == HandshakeMessage.ht_hello_request) ||
 374                 (memo.handshakeType ==
 375                         HandshakeMessage.ht_hello_verify_request) ||
 376                 (memo.handshakeType == HandshakeMessage.ht_server_hello_done) ||
 377                 (memo.handshakeType == HandshakeMessage.ht_finished)) {
 378 
 379                 flightIsReady = true;
 380             }
 381         }
 382 
 383         Ciphertext acquireCiphertext(ByteBuffer dstBuf) throws IOException {
 384             if (isEmpty()) {
 385                 if (isRetransmittable()) {
 386                     setRetransmission();    // configure for retransmission
 387                 } else {
 388                     return null;
 389                 }
 390             }
 391 
 392             RecordMemo memo = handshakeMemos.get(acquireIndex);
 393             HandshakeMemo hsMemo = null;
 394             if (memo.contentType == Record.ct_handshake) {
 395                 hsMemo = (HandshakeMemo)memo;
 396             }
 397 
 398             int macLen = 0;
 399             if (memo.encodeAuthenticator instanceof MAC) {
 400                 macLen = ((MAC)memo.encodeAuthenticator).MAClen();
 401             }
 402 
 403             // ChangeCipherSpec message is pretty small.  Don't worry about
 404             // the fragmentation of ChangeCipherSpec record.
 405             int fragLen;
 406             if (packetSize > 0) {
 407                 fragLen = Math.min(maxRecordSize, packetSize);
 408                 fragLen = memo.encodeCipher.calculateFragmentSize(
 409                         fragLen, macLen, 25);   // 25: header size
 410                                                 //   13: DTLS record
 411                                                 //   12: DTLS handshake message
 412                 fragLen = Math.min(fragLen, Record.maxDataSize);
 413             } else {
 414                 fragLen = Record.maxDataSize;
 415             }
 416 
 417             if (fragmentSize > 0) {
 418                 fragLen = Math.min(fragLen, fragmentSize);
 419             }
 420 
 421             int dstPos = dstBuf.position();
 422             int dstLim = dstBuf.limit();
 423             int dstContent = dstPos + headerSize +
 424                                     memo.encodeCipher.getExplicitNonceSize();
 425             dstBuf.position(dstContent);
 426 
 427             if (hsMemo != null) {
 428                 fragLen = Math.min(fragLen,
 429                         (hsMemo.fragment.length - hsMemo.acquireOffset));
 430 
 431                 dstBuf.put(hsMemo.handshakeType);
 432                 dstBuf.put((byte)((hsMemo.fragment.length >> 16) & 0xFF));
 433                 dstBuf.put((byte)((hsMemo.fragment.length >> 8) & 0xFF));
 434                 dstBuf.put((byte)(hsMemo.fragment.length & 0xFF));
 435                 dstBuf.put((byte)((hsMemo.messageSequence >> 8) & 0xFF));
 436                 dstBuf.put((byte)(hsMemo.messageSequence & 0xFF));
 437                 dstBuf.put((byte)((hsMemo.acquireOffset >> 16) & 0xFF));
 438                 dstBuf.put((byte)((hsMemo.acquireOffset >> 8) & 0xFF));
 439                 dstBuf.put((byte)(hsMemo.acquireOffset & 0xFF));
 440                 dstBuf.put((byte)((fragLen >> 16) & 0xFF));
 441                 dstBuf.put((byte)((fragLen >> 8) & 0xFF));
 442                 dstBuf.put((byte)(fragLen & 0xFF));
 443                 dstBuf.put(hsMemo.fragment, hsMemo.acquireOffset, fragLen);
 444             } else {
 445                 fragLen = Math.min(fragLen, memo.fragment.length);
 446                 dstBuf.put(memo.fragment, 0, fragLen);
 447             }
 448 
 449             dstBuf.limit(dstBuf.position());
 450             dstBuf.position(dstContent);
 451 
 452             if ((debug != null) && Debug.isOn("record")) {
 453                 System.out.println(Thread.currentThread().getName() +
 454                         ", WRITE: " + protocolVersion + " " +
 455                         Record.contentName(memo.contentType) +
 456                         ", length = " + dstBuf.remaining());
 457             }
 458 
 459             // Encrypt the fragment and wrap up a record.
 460             long recordSN = encrypt(memo.encodeAuthenticator, memo.encodeCipher,
 461                     memo.contentType, dstBuf,
 462                     dstPos, dstLim, headerSize,
 463                     ProtocolVersion.valueOf(memo.majorVersion,
 464                             memo.minorVersion), true);
 465 
 466             if ((debug != null) && Debug.isOn("packet")) {
 467                 ByteBuffer temporary = dstBuf.duplicate();
 468                 temporary.limit(temporary.position());
 469                 temporary.position(dstPos);
 470                 Debug.printHex(
 471                         "[Raw write]: length = " + temporary.remaining(),
 472                         temporary);
 473             }
 474 
 475             // remain the limit unchanged
 476             dstBuf.limit(dstLim);
 477 
 478             // Reset the fragmentation offset.
 479             if (hsMemo != null) {
 480                 hsMemo.acquireOffset += fragLen;
 481                 if (hsMemo.acquireOffset == hsMemo.fragment.length) {
 482                     acquireIndex++;
 483                 }
 484 
 485                 return new Ciphertext(RecordType.valueOf(
 486                         hsMemo.contentType, hsMemo.handshakeType), recordSN);
 487             } else {
 488                 acquireIndex++;
 489                 return new Ciphertext(
 490                         RecordType.RECORD_CHANGE_CIPHER_SPEC, recordSN);
 491             }
 492         }
 493 
 494         private void handshakeHashing(HandshakeMemo hsFrag, byte[] hsBody) {
 495 
 496             byte hsType = hsFrag.handshakeType;
 497             if ((hsType == HandshakeMessage.ht_hello_request) ||
 498                 (hsType == HandshakeMessage.ht_hello_verify_request)) {
 499 
 500                 // omitted from handshake hash computation
 501                 return;
 502             }
 503 
 504             if ((hsFrag.messageSequence == 0) &&
 505                 (hsType == HandshakeMessage.ht_client_hello)) {
 506 
 507                 // omit initial ClientHello message
 508                 //
 509                 //  2: ClientHello.client_version
 510                 // 32: ClientHello.random
 511                 int sidLen = hsBody[34];
 512 
 513                 if (sidLen == 0) {      // empty session_id, initial handshake
 514                     return;
 515                 }
 516             }
 517 
 518             // calculate the DTLS header
 519             byte[] temporary = new byte[12];    // 12: handshake header size
 520 
 521             // Handshake.msg_type
 522             temporary[0] = hsFrag.handshakeType;
 523 
 524             // Handshake.length
 525             temporary[1] = (byte)((hsBody.length >> 16) & 0xFF);
 526             temporary[2] = (byte)((hsBody.length >> 8) & 0xFF);
 527             temporary[3] = (byte)(hsBody.length & 0xFF);
 528 
 529             // Handshake.message_seq
 530             temporary[4] = (byte)((hsFrag.messageSequence >> 8) & 0xFF);
 531             temporary[5] = (byte)(hsFrag.messageSequence & 0xFF);
 532 
 533             // Handshake.fragment_offset
 534             temporary[6] = 0;
 535             temporary[7] = 0;
 536             temporary[8] = 0;
 537 
 538             // Handshake.fragment_length
 539             temporary[9] = temporary[1];
 540             temporary[10] = temporary[2];
 541             temporary[11] = temporary[3];
 542 
 543             if ((hsType != HandshakeMessage.ht_finished) &&
 544                 (hsType != HandshakeMessage.ht_certificate_verify)) {
 545 
 546                 handshakeHash.update(temporary, 0, 12);
 547                 handshakeHash.update(hsBody, 0, hsBody.length);
 548             } else {
 549                 // Reserve until this handshake message has been processed.
 550                 handshakeHash.reserve(temporary, 0, 12);
 551                 handshakeHash.reserve(hsBody, 0, hsBody.length);
 552             }
 553 
 554         }
 555 
 556         boolean isEmpty() {
 557             if (!flightIsReady || handshakeMemos.isEmpty() ||
 558                     acquireIndex >= handshakeMemos.size()) {
 559                 return true;
 560             }
 561 
 562             return false;
 563         }
 564 
 565         boolean isRetransmittable() {
 566             return (flightIsReady && !handshakeMemos.isEmpty() &&
 567                                 (acquireIndex >= handshakeMemos.size()));
 568         }
 569 
 570         private void setRetransmission() {
 571             acquireIndex = 0;
 572             for (RecordMemo memo : handshakeMemos) {
 573                 if (memo instanceof HandshakeMemo) {
 574                     HandshakeMemo hmemo = (HandshakeMemo)memo;
 575                     hmemo.acquireOffset = 0;
 576                 }
 577             }
 578 
 579             // Shrink packet size if:
 580             // 1. maximum fragment size is allowed, in which case the packet
 581             //    size is configured bigger than maxRecordSize;
 582             // 2. maximum packet is bigger than 256 bytes;
 583             // 3. two times of retransmits have been attempted.
 584             if ((packetSize <= maxRecordSize) &&
 585                     (packetSize > 256) && ((retransmits--) <= 0)) {
 586 
 587                 // shrink packet size
 588                 shrinkPacketSize();
 589                 retransmits = 2;        // attemps of retransmits
 590             }
 591         }
 592 
 593         private void shrinkPacketSize() {
 594             packetSize = Math.max(256, packetSize / 2);
 595         }
 596     }
 597 }