1 /*
   2  * Copyright (c) 2016, Oracle and/or its affiliates. All rights reserved.
   3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
   4  *
   5  * This code is free software; you can redistribute it and/or modify it
   6  * under the terms of the GNU General Public License version 2 only, as
   7  * published by the Free Software Foundation.  Oracle designates this
   8  * particular file as subject to the "Classpath" exception as provided
   9  * by Oracle in the LICENSE file that accompanied this code.
  10  *
  11  * This code is distributed in the hope that it will be useful, but WITHOUT
  12  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
  13  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
  14  * version 2 for more details (a copy is included in the LICENSE file that
  15  * accompanied this code).
  16  *
  17  * You should have received a copy of the GNU General Public License version
  18  * 2 along with this work; if not, write to the Free Software Foundation,
  19  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
  20  *
  21  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
  22  * or visit www.oracle.com if you need additional information or have any
  23  * questions.
  24  */
  25 package java.net.http;
  26 
  27 import java.io.IOException;
  28 import java.io.UncheckedIOException;
  29 import java.net.ProtocolException;
  30 import java.net.http.WebSocket.Listener;
  31 import java.nio.ByteBuffer;
  32 import java.nio.CharBuffer;
  33 import java.nio.channels.SelectionKey;
  34 import java.util.Optional;
  35 import java.util.concurrent.CompletionStage;
  36 import java.util.concurrent.Executor;
  37 import java.util.concurrent.atomic.AtomicBoolean;
  38 import java.util.concurrent.atomic.AtomicLong;
  39 import java.util.function.Supplier;
  40 
  41 import static java.lang.System.Logger.Level.ERROR;
  42 import static java.net.http.WSUtils.EMPTY_BYTE_BUFFER;
  43 import static java.net.http.WSUtils.logger;
  44 
  45 /*
  46  * Receives incoming data from the channel and converts it into a sequence of
  47  * messages, which are then passed to the listener.
  48  */
  49 final class WSReceiver {
  50 
  51     private final Listener listener;
  52     private final WebSocket webSocket;
  53     private final Supplier<WSShared<ByteBuffer>> buffersSupplier =
  54             new WSSharedPool<>(() -> ByteBuffer.allocateDirect(32768), 2);
  55     private final RawChannel channel;
  56     private final RawChannel.NonBlockingEvent channelEvent;
  57     private final WSSignalHandler handler;
  58     private final AtomicLong demand = new AtomicLong();
  59     private final AtomicBoolean readable = new AtomicBoolean();
  60     private boolean started;
  61     private volatile boolean closed;
  62     private final WSFrame.Reader reader = new WSFrame.Reader();
  63     private final WSFrameConsumer frameConsumer;
  64     private WSShared<ByteBuffer> buf = WSShared.wrap(EMPTY_BYTE_BUFFER);
  65     private WSShared<ByteBuffer> data; // TODO: initialize with leftovers from the RawChannel
  66 
  67     WSReceiver(Listener listener, WebSocket webSocket, Executor executor,
  68                RawChannel channel) {
  69         this.listener = listener;
  70         this.webSocket = webSocket;
  71         this.channel = channel;
  72         handler = new WSSignalHandler(executor, this::react);
  73         channelEvent = createChannelEvent();
  74         this.frameConsumer = new WSFrameConsumer(new MessageConsumer());
  75     }
  76 
  77     private void react() {
  78         synchronized (this) {
  79             while (demand.get() > 0 && !closed) {
  80                 try {
  81                     if (data == null) {
  82                         if (!getData()) {
  83                             break;
  84                         }
  85                     }
  86                     reader.readFrame(data, frameConsumer);
  87                     if (!data.hasRemaining()) {
  88                         data.dispose();
  89                         data = null;
  90                     }
  91                     // In case of exception we don't need to clean any state,
  92                     // since it's the terminal condition anyway. Nothing will be
  93                     // retried.
  94                 } catch (WSProtocolException e) {
  95                     // Translate into ProtocolException
  96                     closeExceptionally(new ProtocolException().initCause(e));
  97                 } catch (Exception e) {
  98                     closeExceptionally(e);
  99                 }
 100             }
 101         }
 102     }
 103 
 104     void request(long n) {
 105         long newDemand = demand.accumulateAndGet(n, (p, i) -> p + i < 0 ? Long.MAX_VALUE : p + i);
 106         handler.signal();
 107         assert newDemand >= 0 : newDemand;
 108     }
 109 
 110     private boolean getData() throws IOException {
 111         if (!readable.get()) {
 112             return false;
 113         }
 114         if (!buf.hasRemaining()) {
 115             buf.dispose();
 116             buf = buffersSupplier.get();
 117             assert buf.hasRemaining() : buf;
 118         }
 119         int oldPosition = buf.position();
 120         int oldLimit = buf.limit();
 121         int numRead = channel.read(buf.buffer());
 122         if (numRead > 0) {
 123             data = buf.share(oldPosition, oldPosition + numRead);
 124             buf.select(buf.limit(), oldLimit); // Move window to the free region
 125             return true;
 126         } else if (numRead == 0) {
 127             readable.set(false);
 128             channel.registerEvent(channelEvent);
 129             return false;
 130         } else {
 131             assert numRead < 0 : numRead;
 132             throw new WSProtocolException
 133                     ("7.2.1.", "Stream ended before a Close frame has been received");
 134         }
 135     }
 136 
 137     void start() {
 138         synchronized (this) {
 139             if (started) {
 140                 throw new IllegalStateException("Already started");
 141             }
 142             started = true;
 143             try {
 144                 channel.registerEvent(channelEvent);
 145             } catch (IOException e) {
 146                 throw new UncheckedIOException(e);
 147             }
 148             try {
 149                 listener.onOpen(webSocket);
 150             } catch (Exception e) {
 151                 closeExceptionally(new RuntimeException("onOpen threw an exception", e));
 152             }
 153         }
 154     }
 155 
 156     private void close() { // TODO: move to WS.java
 157         closed = true;
 158     }
 159 
 160     private void closeExceptionally(Throwable error) {  // TODO: move to WS.java
 161         close();
 162         try {
 163             listener.onError(webSocket, error);
 164         } catch (Exception e) {
 165             logger.log(ERROR, "onError threw an exception", e);
 166         }
 167     }
 168 
 169     private final class MessageConsumer implements WSMessageConsumer {
 170 
 171         @Override
 172         public void onText(WebSocket.MessagePart part, WSShared<CharBuffer> data) {
 173             decrementDemand();
 174             CompletionStage<?> cs;
 175             try {
 176                 cs = listener.onText(webSocket, data.buffer(), part);
 177             } catch (Exception e) {
 178                 closeExceptionally(new RuntimeException("onText threw an exception", e));
 179                 return;
 180             }
 181             follow(cs, data, "onText");
 182         }
 183 
 184         @Override
 185         public void onBinary(WebSocket.MessagePart part, WSShared<ByteBuffer> data) {
 186             decrementDemand();
 187             CompletionStage<?> cs;
 188             try {
 189                 cs = listener.onBinary(webSocket, data.buffer(), part);
 190             } catch (Exception e) {
 191                 closeExceptionally(new RuntimeException("onBinary threw an exception", e));
 192                 return;
 193             }
 194             follow(cs, data, "onBinary");
 195         }
 196 
 197         @Override
 198         public void onPing(WSShared<ByteBuffer> data) {
 199             decrementDemand();
 200             CompletionStage<?> cs;
 201             try {
 202                 cs = listener.onPing(webSocket, data.buffer());
 203             } catch (Exception e) {
 204                 closeExceptionally(new RuntimeException("onPing threw an exception", e));
 205                 return;
 206             }
 207             follow(cs, data, "onPing");
 208         }
 209 
 210         @Override
 211         public void onPong(WSShared<ByteBuffer> data) {
 212             decrementDemand();
 213             CompletionStage<?> cs;
 214             try {
 215                 cs = listener.onPong(webSocket, data.buffer());
 216             } catch (Exception e) {
 217                 closeExceptionally(new RuntimeException("onPong threw an exception", e));
 218                 return;
 219             }
 220             follow(cs, data, "onPong");
 221         }
 222 
 223         @Override
 224         public void onClose(WebSocket.CloseCode code, CharSequence reason) {
 225             decrementDemand();
 226             try {
 227                 close();
 228                 listener.onClose(webSocket, Optional.ofNullable(code), reason.toString());
 229             } catch (Exception e) {
 230                 logger.log(ERROR, "onClose threw an exception", e);
 231             }
 232         }
 233     }
 234 
 235     private void follow(CompletionStage<?> cs, WSDisposable d, String source) {
 236         if (cs == null) {
 237             d.dispose();
 238         } else {
 239             cs.whenComplete((whatever, error) -> {
 240                 if (error != null) {
 241                     String m = "CompletionStage returned by " + source + " completed exceptionally";
 242                     closeExceptionally(new RuntimeException(m, error));
 243                 }
 244                 d.dispose();
 245             });
 246         }
 247     }
 248 
 249     private void decrementDemand() {
 250         long newDemand = demand.decrementAndGet();
 251         assert newDemand >= 0 : newDemand;
 252     }
 253 
 254     private RawChannel.NonBlockingEvent createChannelEvent() {
 255         return new RawChannel.NonBlockingEvent() {
 256 
 257             @Override
 258             public int interestOps() {
 259                 return SelectionKey.OP_READ;
 260             }
 261 
 262             @Override
 263             public void handle() {
 264                 boolean wasNotReadable = readable.compareAndSet(false, true);
 265                 assert wasNotReadable;
 266                 handler.signal();
 267             }
 268 
 269             @Override
 270             public String toString() {
 271                 return "Read readiness event [" + channel + "]";
 272             }
 273         };
 274     }
 275 }