1 /*
   2  * Copyright (c) 2015, 2016, Oracle and/or its affiliates. All rights reserved.
   3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
   4  *
   5  * This code is free software; you can redistribute it and/or modify it
   6  * under the terms of the GNU General Public License version 2 only, as
   7  * published by the Free Software Foundation.  Oracle designates this
   8  * particular file as subject to the "Classpath" exception as provided
   9  * by Oracle in the LICENSE file that accompanied this code.
  10  *
  11  * This code is distributed in the hope that it will be useful, but WITHOUT
  12  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
  13  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
  14  * version 2 for more details (a copy is included in the LICENSE file that
  15  * accompanied this code).
  16  *
  17  * You should have received a copy of the GNU General Public License version
  18  * 2 along with this work; if not, write to the Free Software Foundation,
  19  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
  20  *
  21  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
  22  * or visit www.oracle.com if you need additional information or have any
  23  * questions.
  24  */
  25 
  26 package java.net.http;
  27 
  28 import java.net.http.WSFrame.Opcode;
  29 import java.net.http.WebSocket.MessagePart;
  30 import java.nio.ByteBuffer;
  31 import java.nio.CharBuffer;
  32 import java.nio.charset.CharacterCodingException;
  33 import java.util.concurrent.atomic.AtomicInteger;
  34 
  35 import static java.lang.String.format;
  36 import static java.lang.System.Logger.Level.TRACE;
  37 import static java.net.http.WSUtils.dump;
  38 import static java.net.http.WSUtils.logger;
  39 import static java.net.http.WebSocket.CloseCode.NOT_CONSISTENT;
  40 import static java.net.http.WebSocket.CloseCode.of;
  41 import static java.util.Objects.requireNonNull;
  42 
  43 /*
  44  * Consumes frame parts and notifies a message consumer, when there is
  45  * sufficient data to produce a message, or part thereof.
  46  *
  47  * Data consumed but not yet translated is accumulated until it's sufficient to
  48  * form a message.
  49  */
  50 final class WSFrameConsumer implements WSFrame.Consumer {
  51 
  52     private final AtomicInteger invocationOrder = new AtomicInteger();
  53 
  54     private final WSMessageConsumer output;
  55     private final WSCharsetToolkit.Decoder decoder = new WSCharsetToolkit.Decoder();
  56     private boolean fin;
  57     private Opcode opcode, originatingOpcode;
  58     private MessagePart part = MessagePart.WHOLE;
  59     private long payloadLen;
  60     private WSShared<ByteBuffer> binaryData;
  61 
  62     WSFrameConsumer(WSMessageConsumer output) {
  63         this.output = requireNonNull(output);
  64     }
  65 
  66     @Override
  67     public void fin(boolean value) {
  68         assert invocationOrder.compareAndSet(0, 1) : dump(invocationOrder, value);
  69         if (logger.isLoggable(TRACE)) {
  70             // Checked for being loggable because of autoboxing of 'value'
  71             logger.log(TRACE, "Reading fin: {0}", value);
  72         }
  73         fin = value;
  74     }
  75 
  76     @Override
  77     public void rsv1(boolean value) {
  78         assert invocationOrder.compareAndSet(1, 2) : dump(invocationOrder, value);
  79         if (logger.isLoggable(TRACE)) {
  80             logger.log(TRACE, "Reading rsv1: {0}", value);
  81         }
  82         if (value) {
  83             throw new WSProtocolException("5.2.", "rsv1 bit is set unexpectedly");
  84         }
  85     }
  86 
  87     @Override
  88     public void rsv2(boolean value) {
  89         assert invocationOrder.compareAndSet(2, 3) : dump(invocationOrder, value);
  90         if (logger.isLoggable(TRACE)) {
  91             logger.log(TRACE, "Reading rsv2: {0}", value);
  92         }
  93         if (value) {
  94             throw new WSProtocolException("5.2.", "rsv2 bit is set unexpectedly");
  95         }
  96     }
  97 
  98     @Override
  99     public void rsv3(boolean value) {
 100         assert invocationOrder.compareAndSet(3, 4) : dump(invocationOrder, value);
 101         if (logger.isLoggable(TRACE)) {
 102             logger.log(TRACE, "Reading rsv3: {0}", value);
 103         }
 104         if (value) {
 105             throw new WSProtocolException("5.2.", "rsv3 bit is set unexpectedly");
 106         }
 107     }
 108 
 109     @Override
 110     public void opcode(Opcode v) {
 111         assert invocationOrder.compareAndSet(4, 5) : dump(invocationOrder, v);
 112         logger.log(TRACE, "Reading opcode: {0}", v);
 113         if (v == Opcode.PING || v == Opcode.PONG || v == Opcode.CLOSE) {
 114             if (!fin) {
 115                 throw new WSProtocolException("5.5.", "A fragmented control frame " + v);
 116             }
 117             opcode = v;
 118         } else if (v == Opcode.TEXT || v == Opcode.BINARY) {
 119             if (originatingOpcode != null) {
 120                 throw new WSProtocolException
 121                         ("5.4.", format("An unexpected frame %s (fin=%s)", v, fin));
 122             }
 123             opcode = v;
 124             if (!fin) {
 125                 originatingOpcode = v;
 126             }
 127         } else if (v == Opcode.CONTINUATION) {
 128             if (originatingOpcode == null) {
 129                 throw new WSProtocolException
 130                         ("5.4.", format("An unexpected frame %s (fin=%s)", v, fin));
 131             }
 132             opcode = v;
 133         } else {
 134             throw new WSProtocolException("5.2.", "An unknown opcode " + v);
 135         }
 136     }
 137 
 138     @Override
 139     public void mask(boolean value) {
 140         assert invocationOrder.compareAndSet(5, 6) : dump(invocationOrder, value);
 141         if (logger.isLoggable(TRACE)) {
 142             logger.log(TRACE, "Reading mask: {0}", value);
 143         }
 144         if (value) {
 145             throw new WSProtocolException
 146                     ("5.1.", "Received a masked frame from the server");
 147         }
 148     }
 149 
 150     @Override
 151     public void payloadLen(long value) {
 152         assert invocationOrder.compareAndSet(6, 7) : dump(invocationOrder, value);
 153         if (logger.isLoggable(TRACE)) {
 154             logger.log(TRACE, "Reading payloadLen: {0}", value);
 155         }
 156         if (opcode.isControl()) {
 157             if (value > 125) {
 158                 throw new WSProtocolException
 159                         ("5.5.", format("A control frame %s has a payload length of %s",
 160                                 opcode, value));
 161             }
 162             assert Opcode.CLOSE.isControl();
 163             if (opcode == Opcode.CLOSE && value == 1) {
 164                 throw new WSProtocolException
 165                         ("5.5.1.", "A Close frame's status code is only 1 byte long");
 166             }
 167         }
 168         payloadLen = value;
 169     }
 170 
 171     @Override
 172     public void maskingKey(int value) {
 173         assert false : dump(invocationOrder, value);
 174     }
 175 
 176     @Override
 177     public void payloadData(WSShared<ByteBuffer> data, boolean isLast) {
 178         assert invocationOrder.compareAndSet(7, isLast ? 8 : 7)
 179                 : dump(invocationOrder, data, isLast);
 180         if (logger.isLoggable(TRACE)) {
 181             logger.log(TRACE, "Reading payloadData: data={0}, isLast={1}", data, isLast);
 182         }
 183         if (opcode.isControl()) {
 184             if (binaryData != null) {
 185                 binaryData.put(data);
 186                 data.dispose();
 187             } else if (!isLast) {
 188                 // The first chunk of the message
 189                 int remaining = data.remaining();
 190                 // It shouldn't be 125, otherwise the next chunk will be of size
 191                 // 0, which is not what Reader promises to deliver (eager
 192                 // reading)
 193                 assert remaining < 125 : dump(remaining);
 194                 WSShared<ByteBuffer> b = WSShared.wrap(ByteBuffer.allocate(125)).put(data);
 195                 data.dispose();
 196                 binaryData = b; // Will be disposed by the user
 197             } else {
 198                 // The only chunk; will be disposed by the user
 199                 binaryData = data.position(data.limit()); // FIXME: remove this hack
 200             }
 201         } else {
 202             part = determinePart(isLast);
 203             boolean text = opcode == Opcode.TEXT || originatingOpcode == Opcode.TEXT;
 204             if (!text) {
 205                 output.onBinary(part, data);
 206             } else {
 207                 boolean binaryNonEmpty = data.hasRemaining();
 208                 WSShared<CharBuffer> textData;
 209                 try {
 210                     textData = decoder.decode(data, part == MessagePart.WHOLE || part == MessagePart.LAST);
 211                 } catch (CharacterCodingException e) {
 212                     throw new WSProtocolException
 213                             ("5.6.", "Invalid UTF-8 sequence in frame " + opcode, NOT_CONSISTENT, e);
 214                 }
 215                 if (!(binaryNonEmpty && !textData.hasRemaining())) {
 216                     // If there's a binary data, that result in no text, then we
 217                     // don't deliver anything
 218                     output.onText(part, textData);
 219                 }
 220             }
 221         }
 222     }
 223 
 224     @Override
 225     public void endFrame() {
 226         assert invocationOrder.compareAndSet(8, 0) : dump(invocationOrder);
 227         if (opcode.isControl()) {
 228             binaryData.flip();
 229         }
 230         switch (opcode) {
 231             case CLOSE:
 232                 WebSocket.CloseCode cc;
 233                 String reason;
 234                 if (payloadLen == 0) {
 235                     cc = null;
 236                     reason = "";
 237                 } else {
 238                     ByteBuffer b = binaryData.buffer();
 239                     int len = b.remaining();
 240                     assert 2 <= len && len <= 125 : dump(len, payloadLen);
 241                     try {
 242                         cc = of(b.getChar());
 243                         reason = WSCharsetToolkit.decode(b).toString();
 244                     } catch (IllegalArgumentException e) {
 245                         throw new WSProtocolException
 246                                 ("5.5.1", "Incorrect status code", e);
 247                     } catch (CharacterCodingException e) {
 248                         throw new WSProtocolException
 249                                 ("5.5.1", "Close reason is a malformed UTF-8 sequence", e);
 250                     }
 251                 }
 252                 binaryData.dispose(); // Manual dispose
 253                 output.onClose(cc, reason);
 254                 break;
 255             case PING:
 256                 output.onPing(binaryData);
 257                 binaryData = null;
 258                 break;
 259             case PONG:
 260                 output.onPong(binaryData);
 261                 binaryData = null;
 262                 break;
 263             default:
 264                 assert opcode == Opcode.TEXT || opcode == Opcode.BINARY
 265                         || opcode == Opcode.CONTINUATION : dump(opcode);
 266                 if (fin) {
 267                     // It is always the last chunk:
 268                     // either TEXT(FIN=TRUE)/BINARY(FIN=TRUE) or CONT(FIN=TRUE)
 269                     originatingOpcode = null;
 270                 }
 271                 break;
 272         }
 273         payloadLen = 0;
 274         opcode = null;
 275     }
 276 
 277     private MessagePart determinePart(boolean isLast) {
 278         boolean lastChunk = fin && isLast;
 279         switch (part) {
 280             case LAST:
 281             case WHOLE:
 282                 return lastChunk ? MessagePart.WHOLE : MessagePart.FIRST;
 283             case FIRST:
 284             case PART:
 285                 return lastChunk ? MessagePart.LAST : MessagePart.PART;
 286             default:
 287                 throw new InternalError(String.valueOf(part));
 288         }
 289     }
 290 }