1 /*
   2  * Copyright (c) 2016, Oracle and/or its affiliates. All rights reserved.
   3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
   4  *
   5  * This code is free software; you can redistribute it and/or modify it
   6  * under the terms of the GNU General  License version 2 only, as
   7  * published by the Free Software Foundation.  Oracle designates this
   8  * particular file as subject to the "Classpath" exception as provided
   9  * by Oracle in the LICENSE file that accompanied this code.
  10  *
  11  * This code is distributed in the hope that it will be useful, but WITHOUT
  12  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
  13  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General  License
  14  * version 2 for more details (a copy is included in the LICENSE file that
  15  * accompanied this code).
  16  *
  17  * You should have received a copy of the GNU General  License version
  18  * 2 along with this work; if not, write to the Free Software Foundation,
  19  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
  20  *
  21  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
  22  * or visit www.oracle.com if you need additional information or have any
  23  * questions.
  24  */
  25 package java.net.http;
  26 
  27 import java.nio.ByteBuffer;
  28 
  29 import static java.lang.String.format;
  30 import static java.net.http.WSFrame.Opcode.ofCode;
  31 import static java.net.http.WSUtils.dump;
  32 
  33 /*
  34  * A collection of utilities for reading, writing, and masking frames.
  35  */
  36 final class WSFrame {
  37 
  38     private WSFrame() { }
  39 
  40     static final int MAX_HEADER_SIZE_BYTES = 2 + 8 + 4;
  41 
  42     enum Opcode {
  43 
  44         CONTINUATION   (0x0),
  45         TEXT           (0x1),
  46         BINARY         (0x2),
  47         NON_CONTROL_0x3(0x3),
  48         NON_CONTROL_0x4(0x4),
  49         NON_CONTROL_0x5(0x5),
  50         NON_CONTROL_0x6(0x6),
  51         NON_CONTROL_0x7(0x7),
  52         CLOSE          (0x8),
  53         PING           (0x9),
  54         PONG           (0xA),
  55         CONTROL_0xB    (0xB),
  56         CONTROL_0xC    (0xC),
  57         CONTROL_0xD    (0xD),
  58         CONTROL_0xE    (0xE),
  59         CONTROL_0xF    (0xF);
  60 
  61         private static final Opcode[] opcodes;
  62 
  63         static {
  64             Opcode[] values = values();
  65             opcodes = new Opcode[values.length];
  66             for (Opcode c : values) {
  67                 assert opcodes[c.code] == null
  68                         : WSUtils.dump(c, c.code, opcodes[c.code]);
  69                 opcodes[c.code] = c;
  70             }
  71         }
  72 
  73         private final byte code;
  74         private final char shiftedCode;
  75         private final String description;
  76 
  77         Opcode(int code) {
  78             this.code = (byte) code;
  79             this.shiftedCode = (char) (code << 8);
  80             this.description = format("%x (%s)", code, name());
  81         }
  82 
  83         boolean isControl() {
  84             return (code & 0x8) != 0;
  85         }
  86 
  87         static Opcode ofCode(int code) {
  88             return opcodes[code & 0xF];
  89         }
  90 
  91         @Override
  92         public String toString() {
  93             return description;
  94         }
  95     }
  96 
  97     /*
  98      * A utility to mask payload data.
  99      */
 100     static final class Masker {
 101 
 102         private final ByteBuffer acc = ByteBuffer.allocate(8);
 103         private final int[] maskBytes = new int[4];
 104         private int offset;
 105         private long maskLong;
 106 
 107         /*
 108          * Sets up the mask.
 109          */
 110         Masker mask(int value) {
 111             acc.clear().putInt(value).putInt(value).flip();
 112             for (int i = 0; i < maskBytes.length; i++) {
 113                 maskBytes[i] = acc.get(i);
 114             }
 115             offset = 0;
 116             maskLong = acc.getLong(0);
 117             return this;
 118         }
 119 
 120         /*
 121          * Reads as many bytes as possible from the given input buffer, writing
 122          * the resulting masked bytes to the given output buffer.
 123          *
 124          * src.remaining() <= dst.remaining() // TODO: do we need this restriction?
 125          * 'src' and 'dst' can be the same ByteBuffer
 126          */
 127         Masker applyMask(ByteBuffer src, ByteBuffer dst) {
 128             if (src.remaining() > dst.remaining()) {
 129                 throw new IllegalArgumentException(dump(src, dst));
 130             }
 131             begin(src, dst);
 132             loop(src, dst);
 133             end(src, dst);
 134             return this;
 135         }
 136 
 137         // Applying the remaining of the mask (strictly not more than 3 bytes)
 138         // byte-wise
 139         private void begin(ByteBuffer src, ByteBuffer dst) {
 140             if (offset > 0) {
 141                 for (int i = src.position(), j = dst.position();
 142                      offset < 4 && i <= src.limit() - 1 && j <= dst.limit() - 1;
 143                      i++, j++, offset++) {
 144                     dst.put(j, (byte) (src.get(i) ^ maskBytes[offset]));
 145                     dst.position(j + 1);
 146                     src.position(i + 1);
 147                 }
 148                 offset &= 3;
 149             }
 150         }
 151 
 152         private void loop(ByteBuffer src, ByteBuffer dst) {
 153             int i = src.position();
 154             int j = dst.position();
 155             final int srcLim = src.limit() - 8;
 156             final int dstLim = dst.limit() - 8;
 157             for (; i <= srcLim && j <= dstLim; i += 8, j += 8) {
 158                 dst.putLong(j, (src.getLong(i) ^ maskLong));
 159             }
 160             if (i > src.limit()) {
 161                 src.position(i - 8);
 162             } else {
 163                 src.position(i);
 164             }
 165             if (j > dst.limit()) {
 166                 dst.position(j - 8);
 167             } else {
 168                 dst.position(j);
 169             }
 170         }
 171 
 172         // Applying the mask to the remaining bytes byte-wise (don't make any
 173         // assumptions on how many, hopefully not more than 7 for 64bit arch)
 174         private void end(ByteBuffer src, ByteBuffer dst) {
 175             for (int i = src.position(), j = dst.position();
 176                  i <= src.limit() - 1 && j <= dst.limit() - 1;
 177                  i++, j++, offset = (offset + 1) & 3) { // offset cycle through 0..3
 178                 dst.put(j, (byte) (src.get(i) ^ maskBytes[offset]));
 179                 src.position(i + 1);
 180                 dst.position(j + 1);
 181             }
 182         }
 183     }
 184 
 185     /*
 186      * A builder of frame headers, capable of writing to a given buffer.
 187      *
 188      * The builder does not enforce any protocol-level rules, it simply writes
 189      * a header structure to the buffer. The order of calls to intermediate
 190      * methods is not significant.
 191      */
 192     static final class HeaderBuilder {
 193 
 194         private char firstChar;
 195         private long payloadLen;
 196         private int maskingKey;
 197         private boolean mask;
 198 
 199         HeaderBuilder fin(boolean value) {
 200             if (value) {
 201                 firstChar |=  0b10000000_00000000;
 202             } else {
 203                 firstChar &= ~0b10000000_00000000;
 204             }
 205             return this;
 206         }
 207 
 208         HeaderBuilder rsv1(boolean value) {
 209             if (value) {
 210                 firstChar |=  0b01000000_00000000;
 211             } else {
 212                 firstChar &= ~0b01000000_00000000;
 213             }
 214             return this;
 215         }
 216 
 217         HeaderBuilder rsv2(boolean value) {
 218             if (value) {
 219                 firstChar |=  0b00100000_00000000;
 220             } else {
 221                 firstChar &= ~0b00100000_00000000;
 222             }
 223             return this;
 224         }
 225 
 226         HeaderBuilder rsv3(boolean value) {
 227             if (value) {
 228                 firstChar |=  0b00010000_00000000;
 229             } else {
 230                 firstChar &= ~0b00010000_00000000;
 231             }
 232             return this;
 233         }
 234 
 235         HeaderBuilder opcode(Opcode value) {
 236             firstChar = (char) ((firstChar & 0xF0FF) | value.shiftedCode);
 237             return this;
 238         }
 239 
 240         HeaderBuilder payloadLen(long value) {
 241             payloadLen = value;
 242             firstChar &= 0b11111111_10000000; // Clear previous payload length leftovers
 243             if (payloadLen < 126) {
 244                 firstChar |= payloadLen;
 245             } else if (payloadLen < 65535) {
 246                 firstChar |= 126;
 247             } else {
 248                 firstChar |= 127;
 249             }
 250             return this;
 251         }
 252 
 253         HeaderBuilder mask(int value) {
 254             firstChar |= 0b00000000_10000000;
 255             maskingKey = value;
 256             mask = true;
 257             return this;
 258         }
 259 
 260         HeaderBuilder noMask() {
 261             firstChar &= ~0b00000000_10000000;
 262             mask = false;
 263             return this;
 264         }
 265 
 266         /*
 267          * Writes the header to the given buffer.
 268          *
 269          * The buffer must have at least MAX_HEADER_SIZE_BYTES remaining. The
 270          * buffer's position is incremented by the number of bytes written.
 271          */
 272         void build(ByteBuffer buffer) {
 273             buffer.putChar(firstChar);
 274             if (payloadLen >= 126) {
 275                 if (payloadLen < 65535) {
 276                     buffer.putChar((char) payloadLen);
 277                 } else {
 278                     buffer.putLong(payloadLen);
 279                 }
 280             }
 281             if (mask) {
 282                 buffer.putInt(maskingKey);
 283             }
 284         }
 285     }
 286 
 287     /*
 288      * A consumer of frame parts.
 289      *
 290      * Guaranteed to be called in the following order by the Frame.Reader:
 291      *
 292      *     fin rsv1 rsv2 rsv3 opcode mask payloadLength maskingKey? payloadData+ endFrame
 293      */
 294     interface Consumer {
 295 
 296         void fin(boolean value);
 297 
 298         void rsv1(boolean value);
 299 
 300         void rsv2(boolean value);
 301 
 302         void rsv3(boolean value);
 303 
 304         void opcode(Opcode value);
 305 
 306         void mask(boolean value);
 307 
 308         void payloadLen(long value);
 309 
 310         void maskingKey(int value);
 311 
 312         /*
 313          * Called when a part of the payload is ready to be consumed.
 314          *
 315          * Though may not yield a complete payload in a single invocation, i.e.
 316          *
 317          *     data.remaining() < payloadLen
 318          *
 319          * the sum of `data.remaining()` passed to all invocations of this
 320          * method will be equal to 'payloadLen', reported in
 321          * `void payloadLen(long value)`
 322          *
 323          * No unmasking is done.
 324          */
 325         void payloadData(WSShared<ByteBuffer> data, boolean isLast);
 326 
 327         void endFrame(); // TODO: remove (payloadData(isLast=true)) should be enough
 328     }
 329 
 330     /*
 331      * A Reader of Frames.
 332      *
 333      * No protocol-level rules are enforced, only frame structure.
 334      */
 335     static final class Reader {
 336 
 337         private static final int AWAITING_FIRST_BYTE  =  1;
 338         private static final int AWAITING_SECOND_BYTE =  2;
 339         private static final int READING_16_LENGTH    =  4;
 340         private static final int READING_64_LENGTH    =  8;
 341         private static final int READING_MASK         = 16;
 342         private static final int READING_PAYLOAD      = 32;
 343 
 344         // A private buffer used to simplify multi-byte integers reading
 345         private final ByteBuffer accumulator = ByteBuffer.allocate(8);
 346         private int state = AWAITING_FIRST_BYTE;
 347         private boolean mask;
 348         private long payloadLength;
 349 
 350         /*
 351          * Reads at most one frame from the given buffer invoking the consumer's
 352          * methods corresponding to the frame elements found.
 353          *
 354          * As much of the frame's payload, if any, is read. The buffers position
 355          * is updated to reflect the number of bytes read.
 356          *
 357          * Throws WSProtocolException if the frame is malformed.
 358          */
 359         void readFrame(WSShared<ByteBuffer> shared, Consumer consumer) {
 360             ByteBuffer input = shared.buffer();
 361             loop:
 362             while (true) {
 363                 byte b;
 364                 switch (state) {
 365                     case AWAITING_FIRST_BYTE:
 366                         if (!input.hasRemaining()) {
 367                             break loop;
 368                         }
 369                         b = input.get();
 370                         consumer.fin( (b & 0b10000000) != 0);
 371                         consumer.rsv1((b & 0b01000000) != 0);
 372                         consumer.rsv2((b & 0b00100000) != 0);
 373                         consumer.rsv3((b & 0b00010000) != 0);
 374                         consumer.opcode(ofCode(b));
 375                         state = AWAITING_SECOND_BYTE;
 376                         continue loop;
 377                     case AWAITING_SECOND_BYTE:
 378                         if (!input.hasRemaining()) {
 379                             break loop;
 380                         }
 381                         b = input.get();
 382                         consumer.mask(mask = (b & 0b10000000) != 0);
 383                         byte p1 = (byte) (b & 0b01111111);
 384                         if (p1 < 126) {
 385                             assert p1 >= 0 : p1;
 386                             consumer.payloadLen(payloadLength = p1);
 387                             state = mask ? READING_MASK : READING_PAYLOAD;
 388                         } else if (p1 < 127) {
 389                             state = READING_16_LENGTH;
 390                         } else {
 391                             state = READING_64_LENGTH;
 392                         }
 393                         continue loop;
 394                     case READING_16_LENGTH:
 395                         if (!input.hasRemaining()) {
 396                             break loop;
 397                         }
 398                         b = input.get();
 399                         if (accumulator.put(b).position() < 2) {
 400                             continue loop;
 401                         }
 402                         payloadLength = accumulator.flip().getChar();
 403                         if (payloadLength < 126) {
 404                             throw notMinimalEncoding(payloadLength, 2);
 405                         }
 406                         consumer.payloadLen(payloadLength);
 407                         accumulator.clear();
 408                         state = mask ? READING_MASK : READING_PAYLOAD;
 409                         continue loop;
 410                     case READING_64_LENGTH:
 411                         if (!input.hasRemaining()) {
 412                             break loop;
 413                         }
 414                         b = input.get();
 415                         if (accumulator.put(b).position() < 8) {
 416                             continue loop;
 417                         }
 418                         payloadLength = accumulator.flip().getLong();
 419                         if (payloadLength < 0) {
 420                             throw negativePayload(payloadLength);
 421                         } else if (payloadLength < 65535) {
 422                             throw notMinimalEncoding(payloadLength, 8);
 423                         }
 424                         consumer.payloadLen(payloadLength);
 425                         accumulator.clear();
 426                         state = mask ? READING_MASK : READING_PAYLOAD;
 427                         continue loop;
 428                     case READING_MASK:
 429                         if (!input.hasRemaining()) {
 430                             break loop;
 431                         }
 432                         b = input.get();
 433                         if (accumulator.put(b).position() != 4) {
 434                             continue loop;
 435                         }
 436                         consumer.maskingKey(accumulator.flip().getInt());
 437                         accumulator.clear();
 438                         state = READING_PAYLOAD;
 439                         continue loop;
 440                     case READING_PAYLOAD:
 441                         // This state does not require any bytes to be available
 442                         // in the input buffer in order to proceed
 443                         boolean fullyRead;
 444                         int limit;
 445                         if (payloadLength <= input.remaining()) {
 446                             limit = input.position() + (int) payloadLength;
 447                             payloadLength = 0;
 448                             fullyRead = true;
 449                         } else {
 450                             limit = input.limit();
 451                             payloadLength -= input.remaining();
 452                             fullyRead = false;
 453                         }
 454                         // FIXME: consider a case where payloadLen != 0,
 455                         // but input.remaining() == 0
 456                         //
 457                         // There shouldn't be an invocation of payloadData with
 458                         // an empty buffer, as it would be an artifact of
 459                         // reading
 460                         consumer.payloadData(shared.share(input.position(), limit), fullyRead);
 461                         // Update the position manually, since reading the
 462                         // payload doesn't advance buffer's position
 463                         input.position(limit);
 464                         if (fullyRead) {
 465                             consumer.endFrame();
 466                             state = AWAITING_FIRST_BYTE;
 467                         }
 468                         break loop;
 469                     default:
 470                         throw new InternalError(String.valueOf(state));
 471                 }
 472             }
 473         }
 474 
 475         private static WSProtocolException negativePayload(long payloadLength) {
 476             return new WSProtocolException
 477                     ("5.2.", format("Negative 64-bit payload length %s", payloadLength));
 478         }
 479 
 480         private static WSProtocolException notMinimalEncoding(long payloadLength, int numBytes) {
 481             return new WSProtocolException
 482                     ("5.2.", format("Payload length (%s) is not encoded with minimal number (%s) of bytes",
 483                             payloadLength, numBytes));
 484         }
 485     }
 486 }