/* * Copyright (c) 2016, Oracle and/or its affiliates. All rights reserved. * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. * * This code is free software; you can redistribute it and/or modify it * under the terms of the GNU General Public License version 2 only, as * published by the Free Software Foundation. Oracle designates this * particular file as subject to the "Classpath" exception as provided * by Oracle in the LICENSE file that accompanied this code. * * This code is distributed in the hope that it will be useful, but WITHOUT * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License * version 2 for more details (a copy is included in the LICENSE file that * accompanied this code). * * You should have received a copy of the GNU General Public License version * 2 along with this work; if not, write to the Free Software Foundation, * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. * * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA * or visit www.oracle.com if you need additional information or have any * questions. */ package java.net.http; import java.io.IOException; import java.io.UncheckedIOException; import java.net.ProtocolException; import java.net.http.WebSocket.Listener; import java.nio.ByteBuffer; import java.nio.channels.SelectionKey; import java.util.Optional; import java.util.concurrent.CompletionStage; import java.util.concurrent.Executor; import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicLong; import java.util.function.Supplier; import static java.lang.System.Logger.Level.ERROR; import static java.lang.System.Logger.Level.TRACE; import static java.net.http.Utils.EMPTY_BYTE_BUFFER; import static java.net.http.Utils.logger; import static java.util.Objects.requireNonNull; // // Serves the whole incoming sequence from reading bytes from the channel to // delivering message to the user // final class WSReceiver { private final Listener listener; private final WebSocket webSocket; private final Supplier> buffersSupplier; private final RawChannel channel; private final RawChannel.NonBlockingEvent channelEvent; private final SignalHandler handler; private final AtomicLong demand = new AtomicLong(); private final AtomicBoolean readable = new AtomicBoolean(); private boolean started; private volatile boolean closed; private final WebSocketFrame.Reader reader = new WebSocketFrame.Reader(); private final WSTranslator translator; private Shared buf = Shared.wrap(EMPTY_BYTE_BUFFER); private Shared data; WSReceiver(Listener listener, WebSocket webSocket, Executor executor, RawChannel channel, Supplier> buffersSupplier) { this.listener = requireNonNull(listener, "listener"); this.webSocket = requireNonNull(webSocket, "webSocket"); requireNonNull(executor, "executor"); this.channel = requireNonNull(channel, "channel"); this.buffersSupplier = requireNonNull(buffersSupplier, "buffersSupplier"); handler = new SignalHandler(executor, this::react); channelEvent = createChannelEvent(); this.translator = new WSTranslator(createIncomingMessageListener()); } private void react() { synchronized (this) { while (demand.get() > 0 && !closed) { try { if (data == null) { if (!getData()) { break; } } reader.readFrame(data, translator); if (!data.hasRemaining()) { data.dispose(); data = null; } // In case of exception we don't need to clean any state, since // it's the terminal condition anyway. Nothing will be retried. } catch (WebSocketProtocolException e) { // Translate into ProtocolException closeExceptionally(new ProtocolException().initCause(e)); } catch (Exception e) { closeExceptionally(e); } } } } long request(long n) { // TODO: protect from stack bloating when using same thread executor (& test it) long newDemand = demand.accumulateAndGet(n, (p, i) -> p + i < 0 ? Long.MAX_VALUE : p + i); handler.signal(); assert newDemand >= 0 : newDemand; return newDemand; } private boolean getData() throws IOException { if (!readable.get()) { return false; } if (!buf.hasRemaining()) { buf.dispose(); buf = buffersSupplier.get(); assert buf.hasRemaining() : buf; } int oldPosition = buf.position(); int oldLimit = buf.limit(); int numRead = channel.read(buf.buffer()); if (numRead > 0) { data = buf.share(oldPosition, oldPosition + numRead); buf.select(buf.limit(), oldLimit); // Move window to the free region return true; } else if (numRead == 0) { readable.set(false); channel.registerEvent(channelEvent); return false; } else { assert numRead < 0 : numRead; throw new WebSocketProtocolException ("7.2.1.", "Stream ended before a Close frame has been received"); } } void start() { synchronized (this) { if (started) { throw new IllegalStateException("Already started"); } started = true; try { channel.registerEvent(channelEvent); } catch (IOException e) { throw new UncheckedIOException(e); } try { listener.onOpen(webSocket); } catch (Exception e) { closeExceptionally(new RuntimeException("onOpen threw an exception", e)); } } } private void close() { closed = true; } private void closeExceptionally(Throwable error) { close(); try { listener.onError(webSocket, error); } catch (Exception e) { logger.log(ERROR, "onError threw an exception", e); } } private WSReceivedMessages createIncomingMessageListener() { return new WSReceivedMessages() { @Override public void onText(WebSocket.MessagePart part, DisposableText data) { decrementDemand(); CompletionStage cs; try { cs = listener.onText(webSocket, data, part); } catch (Exception e) { closeExceptionally(new RuntimeException("onText threw an exception", e)); return; } follow(cs, data, "onText"); } @Override public void onBinary(WebSocket.MessagePart part, Shared data) { decrementDemand(); CompletionStage cs; try { cs = listener.onBinary(webSocket, data.buffer(), part); } catch (Exception e) { closeExceptionally(new RuntimeException("onBinary threw an exception", e)); return; } follow(cs, data, "onBinary"); } @Override public void onPing(Shared data) { decrementDemand(); CompletionStage cs; try { cs = listener.onPing(webSocket, data.buffer()); } catch (Exception e) { closeExceptionally(new RuntimeException("onPing threw an exception", e)); return; } follow(cs, data, "onPing"); } @Override public void onPong(Shared data) { decrementDemand(); CompletionStage cs; try { cs = listener.onPong(webSocket, data.buffer()); } catch (Exception e) { closeExceptionally(new RuntimeException("onPong threw an exception", e)); return; } follow(cs, data, "onPong"); } @Override public void onClose(WebSocket.CloseCode code, CharSequence reason) { decrementDemand(); try { close(); listener.onClose(webSocket, Optional.ofNullable(code), reason.toString()); } catch (Exception e) { logger.log(ERROR, "onClose threw an exception", e); } } }; } private void follow(CompletionStage cs, Disposable d, String source) { if (cs == null) { logger.log(TRACE, "CompletionStage is null, disposing {0}", d); d.dispose(); } else { logger.log(TRACE, "When {0} completes, the {1} will be disposed", cs, d); cs.whenComplete((whatever, error) -> { logger.log(TRACE, "{0} has completed, error={1}; disposing {2}", cs, error, d); if (error != null) { String m = "CompletionStage returned by " + source + " completed exceptionally"; closeExceptionally(new RuntimeException(m, error)); } d.dispose(); }); } } private void decrementDemand() { long newDemand = demand.decrementAndGet(); assert newDemand >= 0 : newDemand; } private RawChannel.NonBlockingEvent createChannelEvent() { return new RawChannel.NonBlockingEvent() { @Override public int interestOps() { return SelectionKey.OP_READ; } @Override public void handle() { logger.log(TRACE, "The channel {0} can be read", channel); boolean wasNotReadable = readable.compareAndSet(false, true); assert wasNotReadable; handler.signal(); } @Override public String toString() { return "Read readiness event [" + channel + "]"; } }; } }