1 /*
   2  * Copyright (c) 2015, 2017, Oracle and/or its affiliates. All rights reserved.
   3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
   4  *
   5  * This code is free software; you can redistribute it and/or modify it
   6  * under the terms of the GNU General Public License version 2 only, as
   7  * published by the Free Software Foundation.  Oracle designates this
   8  * particular file as subject to the "Classpath" exception as provided
   9  * by Oracle in the LICENSE file that accompanied this code.
  10  *
  11  * This code is distributed in the hope that it will be useful, but WITHOUT
  12  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
  13  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
  14  * version 2 for more details (a copy is included in the LICENSE file that
  15  * accompanied this code).
  16  *
  17  * You should have received a copy of the GNU General Public License version
  18  * 2 along with this work; if not, write to the Free Software Foundation,
  19  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
  20  *
  21  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
  22  * or visit www.oracle.com if you need additional information or have any
  23  * questions.
  24  */
  25 
  26 package jdk.incubator.http.internal.websocket;
  27 
  28 import jdk.incubator.http.WebSocket.MessagePart;
  29 import jdk.incubator.http.internal.common.Log;
  30 import jdk.incubator.http.internal.websocket.Frame.Opcode;
  31 
  32 import java.nio.ByteBuffer;
  33 import java.nio.CharBuffer;
  34 import java.nio.charset.CharacterCodingException;
  35 
  36 import static java.lang.String.format;
  37 import static java.nio.charset.StandardCharsets.UTF_8;
  38 import static java.util.Objects.requireNonNull;
  39 import static jdk.incubator.http.internal.common.Utils.dump;
  40 import static jdk.incubator.http.internal.websocket.StatusCodes.NO_STATUS_CODE;
  41 import static jdk.incubator.http.internal.websocket.StatusCodes.isLegalToReceiveFromServer;
  42 
  43 /*
  44  * Consumes frame parts and notifies a message consumer, when there is
  45  * sufficient data to produce a message, or part thereof.
  46  *
  47  * Data consumed but not yet translated is accumulated until it's sufficient to
  48  * form a message.
  49  */
  50 /* Non-final for testing purposes only */
  51 class FrameConsumer implements Frame.Consumer {
  52 
  53     private final MessageStreamConsumer output;
  54     private final UTF8AccumulatingDecoder decoder = new UTF8AccumulatingDecoder();
  55     private boolean fin;
  56     private Opcode opcode, originatingOpcode;
  57     private MessagePart part = MessagePart.WHOLE;
  58     private long payloadLen;
  59     private long unconsumedPayloadLen;
  60     private ByteBuffer binaryData;
  61 
  62     FrameConsumer(MessageStreamConsumer output) {
  63         this.output = requireNonNull(output);
  64     }
  65 
  66     /* Exposed for testing purposes only */
  67     MessageStreamConsumer getOutput() {
  68         return output;
  69     }
  70 
  71     @Override
  72     public void fin(boolean value) {
  73         Log.logTrace("Reading fin: {0}", value);
  74         fin = value;
  75     }
  76 
  77     @Override
  78     public void rsv1(boolean value) {
  79         Log.logTrace("Reading rsv1: {0}", value);
  80         if (value) {
  81             throw new FailWebSocketException("Unexpected rsv1 bit");
  82         }
  83     }
  84 
  85     @Override
  86     public void rsv2(boolean value) {
  87         Log.logTrace("Reading rsv2: {0}", value);
  88         if (value) {
  89             throw new FailWebSocketException("Unexpected rsv2 bit");
  90         }
  91     }
  92 
  93     @Override
  94     public void rsv3(boolean value) {
  95         Log.logTrace("Reading rsv3: {0}", value);
  96         if (value) {
  97             throw new FailWebSocketException("Unexpected rsv3 bit");
  98         }
  99     }
 100 
 101     @Override
 102     public void opcode(Opcode v) {
 103         Log.logTrace("Reading opcode: {0}", v);
 104         if (v == Opcode.PING || v == Opcode.PONG || v == Opcode.CLOSE) {
 105             if (!fin) {
 106                 throw new FailWebSocketException("Fragmented control frame  " + v);
 107             }
 108             opcode = v;
 109         } else if (v == Opcode.TEXT || v == Opcode.BINARY) {
 110             if (originatingOpcode != null) {
 111                 throw new FailWebSocketException(
 112                         format("Unexpected frame %s (fin=%s)", v, fin));
 113             }
 114             opcode = v;
 115             if (!fin) {
 116                 originatingOpcode = v;
 117             }
 118         } else if (v == Opcode.CONTINUATION) {
 119             if (originatingOpcode == null) {
 120                 throw new FailWebSocketException(
 121                         format("Unexpected frame %s (fin=%s)", v, fin));
 122             }
 123             opcode = v;
 124         } else {
 125             throw new FailWebSocketException("Unknown opcode " + v);
 126         }
 127     }
 128 
 129     @Override
 130     public void mask(boolean value) {
 131         Log.logTrace("Reading mask: {0}", value);
 132         if (value) {
 133             throw new FailWebSocketException("Masked frame received");
 134         }
 135     }
 136 
 137     @Override
 138     public void payloadLen(long value) {
 139         Log.logTrace("Reading payloadLen: {0}", value);
 140         if (opcode.isControl()) {
 141             if (value > 125) {
 142                 throw new FailWebSocketException(
 143                         format("%s's payload length %s", opcode, value));
 144             }
 145             assert Opcode.CLOSE.isControl();
 146             if (opcode == Opcode.CLOSE && value == 1) {
 147                 throw new FailWebSocketException("Incomplete status code");
 148             }
 149         }
 150         payloadLen = value;
 151         unconsumedPayloadLen = value;
 152     }
 153 
 154     @Override
 155     public void maskingKey(int value) {
 156         // `FrameConsumer.mask(boolean)` is where a masked frame is detected and
 157         // reported on; `FrameConsumer.mask(boolean)` MUST be invoked before
 158         // this method;
 159         // So this method (`maskingKey`) is not supposed to be invoked while
 160         // reading a frame that has came from the server. If this method is
 161         // invoked, then it's an error in implementation, thus InternalError
 162         throw new InternalError();
 163     }
 164 
 165     @Override
 166     public void payloadData(ByteBuffer data) {
 167         Log.logTrace("Reading payloadData: data={0}", data);
 168         unconsumedPayloadLen -= data.remaining();
 169         boolean isLast = unconsumedPayloadLen == 0;
 170         if (opcode.isControl()) {
 171             if (binaryData != null) { // An intermediate or the last chunk
 172                 binaryData.put(data);
 173             } else if (!isLast) { // The first chunk
 174                 int remaining = data.remaining();
 175                 // It shouldn't be 125, otherwise the next chunk will be of size
 176                 // 0, which is not what Reader promises to deliver (eager
 177                 // reading)
 178                 assert remaining < 125 : dump(remaining);
 179                 binaryData = ByteBuffer.allocate(125).put(data);
 180             } else { // The only chunk
 181                 binaryData = ByteBuffer.allocate(data.remaining()).put(data);
 182             }
 183         } else {
 184             part = determinePart(isLast);
 185             boolean text = opcode == Opcode.TEXT || originatingOpcode == Opcode.TEXT;
 186             if (!text) {
 187                 output.onBinary(part, data.slice());
 188                 data.position(data.limit()); // Consume
 189             } else {
 190                 boolean binaryNonEmpty = data.hasRemaining();
 191                 CharBuffer textData;
 192                 try {
 193                     textData = decoder.decode(data, part == MessagePart.WHOLE || part == MessagePart.LAST);
 194                 } catch (CharacterCodingException e) {
 195                     throw new FailWebSocketException(
 196                             "Invalid UTF-8 in frame " + opcode, StatusCodes.NOT_CONSISTENT)
 197                             .initCause(e);
 198                 }
 199                 if (!(binaryNonEmpty && !textData.hasRemaining())) {
 200                     // If there's a binary data, that result in no text, then we
 201                     // don't deliver anything
 202                     output.onText(part, textData);
 203                 }
 204             }
 205         }
 206     }
 207 
 208     @Override
 209     public void endFrame() {
 210         if (opcode.isControl()) {
 211             binaryData.flip();
 212         }
 213         switch (opcode) {
 214             case CLOSE:
 215                 char statusCode = NO_STATUS_CODE;
 216                 String reason = "";
 217                 if (payloadLen != 0) {
 218                     int len = binaryData.remaining();
 219                     assert 2 <= len && len <= 125 : dump(len, payloadLen);
 220                     statusCode = binaryData.getChar();
 221                     if (!isLegalToReceiveFromServer(statusCode)) {
 222                         throw new FailWebSocketException(
 223                                 "Illegal status code: " + statusCode);
 224                     }
 225                     try {
 226                         reason = UTF_8.newDecoder().decode(binaryData).toString();
 227                     } catch (CharacterCodingException e) {
 228                         throw new FailWebSocketException("Illegal close reason")
 229                                 .initCause(e);
 230                     }
 231                 }
 232                 output.onClose(statusCode, reason);
 233                 break;
 234             case PING:
 235                 output.onPing(binaryData);
 236                 binaryData = null;
 237                 break;
 238             case PONG:
 239                 output.onPong(binaryData);
 240                 binaryData = null;
 241                 break;
 242             default:
 243                 assert opcode == Opcode.TEXT || opcode == Opcode.BINARY
 244                         || opcode == Opcode.CONTINUATION : dump(opcode);
 245                 if (fin) {
 246                     // It is always the last chunk:
 247                     // either TEXT(FIN=TRUE)/BINARY(FIN=TRUE) or CONT(FIN=TRUE)
 248                     originatingOpcode = null;
 249                 }
 250                 break;
 251         }
 252         payloadLen = 0;
 253         opcode = null;
 254     }
 255 
 256     private MessagePart determinePart(boolean isLast) {
 257         boolean lastChunk = fin && isLast;
 258         switch (part) {
 259             case LAST:
 260             case WHOLE:
 261                 return lastChunk ? MessagePart.WHOLE : MessagePart.FIRST;
 262             case FIRST:
 263             case PART:
 264                 return lastChunk ? MessagePart.LAST : MessagePart.PART;
 265             default:
 266                 throw new InternalError(String.valueOf(part));
 267         }
 268     }
 269 }