< prev index next >

src/jdk.incubator.httpclient/share/classes/jdk/incubator/http/internal/websocket/WebSocketImpl.java

Print this page

        

@@ -24,280 +24,203 @@
  */
 
 package jdk.incubator.http.internal.websocket;
 
 import jdk.incubator.http.WebSocket;
+import jdk.incubator.http.internal.common.Demand;
 import jdk.incubator.http.internal.common.Log;
+import jdk.incubator.http.internal.common.MinimalFuture;
 import jdk.incubator.http.internal.common.Pair;
+import jdk.incubator.http.internal.common.SequentialScheduler;
+import jdk.incubator.http.internal.common.SequentialScheduler.DeferredCompleter;
+import jdk.incubator.http.internal.common.Utils;
 import jdk.incubator.http.internal.websocket.OpeningHandshake.Result;
 import jdk.incubator.http.internal.websocket.OutgoingMessage.Binary;
 import jdk.incubator.http.internal.websocket.OutgoingMessage.Close;
 import jdk.incubator.http.internal.websocket.OutgoingMessage.Context;
 import jdk.incubator.http.internal.websocket.OutgoingMessage.Ping;
 import jdk.incubator.http.internal.websocket.OutgoingMessage.Pong;
 import jdk.incubator.http.internal.websocket.OutgoingMessage.Text;
 
 import java.io.IOException;
+import java.lang.ref.Reference;
 import java.net.ProtocolException;
 import java.net.URI;
 import java.nio.ByteBuffer;
 import java.util.Queue;
 import java.util.concurrent.CompletableFuture;
 import java.util.concurrent.CompletionStage;
 import java.util.concurrent.ConcurrentLinkedQueue;
+import java.util.concurrent.TimeUnit;
+import java.util.concurrent.TimeoutException;
 import java.util.concurrent.atomic.AtomicBoolean;
+import java.util.concurrent.atomic.AtomicReference;
 import java.util.function.Consumer;
 import java.util.function.Function;
 
 import static java.util.Objects.requireNonNull;
-import static java.util.concurrent.CompletableFuture.failedFuture;
+import static jdk.incubator.http.internal.common.MinimalFuture.failedFuture;
 import static jdk.incubator.http.internal.common.Pair.pair;
 import static jdk.incubator.http.internal.websocket.StatusCodes.CLOSED_ABNORMALLY;
 import static jdk.incubator.http.internal.websocket.StatusCodes.NO_STATUS_CODE;
 import static jdk.incubator.http.internal.websocket.StatusCodes.isLegalToSendFromClient;
+import static jdk.incubator.http.internal.websocket.WebSocketImpl.State.BINARY;
+import static jdk.incubator.http.internal.websocket.WebSocketImpl.State.CLOSE;
+import static jdk.incubator.http.internal.websocket.WebSocketImpl.State.ERROR;
+import static jdk.incubator.http.internal.websocket.WebSocketImpl.State.IDLE;
+import static jdk.incubator.http.internal.websocket.WebSocketImpl.State.OPEN;
+import static jdk.incubator.http.internal.websocket.WebSocketImpl.State.PING;
+import static jdk.incubator.http.internal.websocket.WebSocketImpl.State.PONG;
+import static jdk.incubator.http.internal.websocket.WebSocketImpl.State.TEXT;
+import static jdk.incubator.http.internal.websocket.WebSocketImpl.State.WAITING;
 
 /*
  * A WebSocket client.
  */
-final class WebSocketImpl implements WebSocket {
+public final class WebSocketImpl implements WebSocket {
+
+    enum State {
+        OPEN,
+        IDLE,
+        WAITING,
+        TEXT,
+        BINARY,
+        PING,
+        PONG,
+        CLOSE,
+        ERROR;
+    }
+
+    private volatile boolean inputClosed;
+    private volatile boolean outputClosed;
+
+    private final AtomicReference<State> state = new AtomicReference<>(OPEN);
+
+    /* Components of calls to Listener's methods */
+    private MessagePart part;
+    private ByteBuffer binaryData;
+    private CharSequence text;
+    private int statusCode;
+    private String reason;
+    private final AtomicReference<Throwable> error = new AtomicReference<>();
 
     private final URI uri;
     private final String subprotocol;
-    private final RawChannel channel;
     private final Listener listener;
 
-    /*
-     * Whether or not Listener.onClose or Listener.onError has been already
-     * invoked. We keep track of this since only one of these methods is invoked
-     * and it is invoked at most once.
-     */
-    private boolean lastMethodInvoked;
     private final AtomicBoolean outstandingSend = new AtomicBoolean();
-    private final CooperativeHandler sendHandler =
-              new CooperativeHandler(this::sendFirst);
+    private final SequentialScheduler sendScheduler = new SequentialScheduler(new SendTask());
     private final Queue<Pair<OutgoingMessage, CompletableFuture<WebSocket>>>
             queue = new ConcurrentLinkedQueue<>();
     private final Context context = new OutgoingMessage.Context();
     private final Transmitter transmitter;
     private final Receiver receiver;
+    private final SequentialScheduler receiveScheduler = new SequentialScheduler(new ReceiveTask());
+    private final Demand demand = new Demand();
 
-    /*
-     * Whether or not the WebSocket has been closed. When a WebSocket has been
-     * closed it means that no further messages can be sent or received.
-     * A closure can be triggered by:
-     *
-     *   1. abort()
-     *   2. "Failing the WebSocket Connection" (i.e. a fatal error)
-     *   3. Completion of the Closing handshake
-     */
-    private final AtomicBoolean closed = new AtomicBoolean();
-
-    /*
-     * This lock is enforcing sequential ordering of invocations to listener's
-     * methods. It is supposed to be uncontended. The only contention that can
-     * happen is when onOpen, an asynchronous onError (not related to reading
-     * from the channel, e.g. an error from automatic Pong reply) or onClose
-     * (related to abort) happens. Since all of the above are one-shot actions,
-     * the said contention is insignificant.
-     */
-    private final Object lock = new Object();
-
-    private final CompletableFuture<?> closeReceived = new CompletableFuture<>();
-    private final CompletableFuture<?> closeSent = new CompletableFuture<>();
-
-    static CompletableFuture<WebSocket> newInstanceAsync(BuilderImpl b) {
+    public static CompletableFuture<WebSocket> newInstanceAsync(BuilderImpl b) {
         Function<Result, WebSocket> newWebSocket = r -> {
-            WebSocketImpl ws = new WebSocketImpl(b.getUri(),
+            WebSocket ws = newInstance(b.getUri(),
                                                  r.subprotocol,
-                                                 r.channel,
-                                                 b.getListener());
-            // The order of calls might cause a subtle effects, like CF will be
-            // returned from the buildAsync _after_ onOpen has been signalled.
-            // This means if onOpen is lengthy, it might cause some problems.
-            ws.signalOpen();
+                                       b.getListener(),
+                                       r.transport);
+            // Make sure we don't release the builder until this lambda
+            // has been executed. The builder has a strong reference to
+            // the HttpClientFacade, and we want to keep that live until
+            // after the raw channel is created and passed to WebSocketImpl.
+            Reference.reachabilityFence(b);
             return ws;
         };
         OpeningHandshake h;
         try {
             h = new OpeningHandshake(b);
-        } catch (IllegalArgumentException e) {
+        } catch (Throwable e) {
             return failedFuture(e);
         }
         return h.send().thenApply(newWebSocket);
     }
 
-    WebSocketImpl(URI uri,
+    /* Exposed for testing purposes */
+    static WebSocket newInstance(URI uri,
                   String subprotocol,
-                  RawChannel channel,
-                  Listener listener)
-    {
+                                 Listener listener,
+                                 TransportSupplier transport) {
+        WebSocketImpl ws = new WebSocketImpl(uri, subprotocol, listener, transport);
+        // This initialisation is outside of the constructor for the sake of
+        // safe publication of WebSocketImpl.this
+        ws.signalOpen();
+        return ws;
+    }
+
+    private WebSocketImpl(URI uri,
+                          String subprotocol,
+                          Listener listener,
+                          TransportSupplier transport) {
         this.uri = requireNonNull(uri);
         this.subprotocol = requireNonNull(subprotocol);
-        this.channel = requireNonNull(channel);
         this.listener = requireNonNull(listener);
-        this.transmitter = new Transmitter(channel);
-        this.receiver = new Receiver(messageConsumerOf(listener), channel);
-
-        // Set up the Closing Handshake action
-        CompletableFuture.allOf(closeReceived, closeSent)
-                .whenComplete((result, error) -> {
-                    try {
-                        channel.close();
-                    } catch (IOException e) {
-                        Log.logError(e);
-                    } finally {
-                        closed.set(true);
-                    }
-                });
-    }
-
-    /*
-     * This initialisation is outside of the constructor for the sake of
-     * safe publication.
-     */
-    private void signalOpen() {
-        synchronized (lock) {
-            // TODO: might hold lock longer than needed causing prolonged
-            // contention? substitute lock for ConcurrentLinkedQueue<Runnable>?
-            try {
-                listener.onOpen(this);
-            } catch (Exception e) {
-                signalError(e);
-            }
-        }
-    }
-
-    private void signalError(Throwable error) {
-        synchronized (lock) {
-            if (lastMethodInvoked) {
-                Log.logError(error);
-            } else {
-                lastMethodInvoked = true;
-                receiver.close();
-                try {
-                    listener.onError(this, error);
-                } catch (Exception e) {
-                    Log.logError(e);
-                }
-            }
-        }
-    }
-
-    /*
-     * Processes a Close event that came from the channel. Invoked at most once.
-     */
-    private void processClose(int statusCode, String reason) {
-        receiver.close();
-        try {
-            channel.shutdownInput();
-        } catch (IOException e) {
-            Log.logError(e);
-        }
-        boolean alreadyCompleted = !closeReceived.complete(null);
-        if (alreadyCompleted) {
-            // This CF is supposed to be completed only once, the first time a
-            // Close message is received. No further messages are pulled from
-            // the socket.
-            throw new InternalError();
-        }
-        int code;
-        if (statusCode == NO_STATUS_CODE || statusCode == CLOSED_ABNORMALLY) {
-            code = NORMAL_CLOSURE;
-        } else {
-            code = statusCode;
-        }
-        CompletionStage<?> readyToClose = signalClose(statusCode, reason);
-        if (readyToClose == null) {
-            readyToClose = CompletableFuture.completedFuture(null);
-        }
-        readyToClose.whenComplete((r, error) -> {
-            enqueueClose(new Close(code, ""))
-                    .whenComplete((r1, error1) -> {
-                        if (error1 != null) {
-                            Log.logError(error1);
-                        }
-                    });
-        });
-    }
-
-    /*
-     * Signals a Close event (might not correspond to anything happened on the
-     * channel, e.g. `abort()`).
-     */
-    private CompletionStage<?> signalClose(int statusCode, String reason) {
-        synchronized (lock) {
-            if (lastMethodInvoked) {
-                Log.logTrace("Close: {0}, ''{1}''", statusCode, reason);
-            } else {
-                lastMethodInvoked = true;
-                receiver.close();
-                try {
-                    return listener.onClose(this, statusCode, reason);
-                } catch (Exception e) {
-                    Log.logError(e);
-                }
-            }
-        }
-        return null;
+        this.transmitter = transport.transmitter();
+        this.receiver = transport.receiver(new SignallingMessageConsumer());
     }
 
     @Override
-    public CompletableFuture<WebSocket> sendText(CharSequence message,
-                                                 boolean isLast)
-    {
+    public CompletableFuture<WebSocket> sendText(CharSequence message, boolean isLast) {
         return enqueueExclusively(new Text(message, isLast));
     }
 
     @Override
-    public CompletableFuture<WebSocket> sendBinary(ByteBuffer message,
-                                                   boolean isLast)
-    {
+    public CompletableFuture<WebSocket> sendBinary(ByteBuffer message, boolean isLast) {
         return enqueueExclusively(new Binary(message, isLast));
     }
 
     @Override
     public CompletableFuture<WebSocket> sendPing(ByteBuffer message) {
-        return enqueueExclusively(new Ping(message));
+        return enqueue(new Ping(message));
     }
 
     @Override
     public CompletableFuture<WebSocket> sendPong(ByteBuffer message) {
-        return enqueueExclusively(new Pong(message));
+        return enqueue(new Pong(message));
     }
 
     @Override
-    public CompletableFuture<WebSocket> sendClose(int statusCode,
-                                                  String reason) {
+    public CompletableFuture<WebSocket> sendClose(int statusCode, String reason) {
         if (!isLegalToSendFromClient(statusCode)) {
             return failedFuture(
                     new IllegalArgumentException("statusCode: " + statusCode));
         }
         Close msg;
         try {
             msg = new Close(statusCode, reason);
         } catch (IllegalArgumentException e) {
             return failedFuture(e);
         }
+        outputClosed = true;
         return enqueueClose(msg);
     }
 
     /*
-     * Sends a Close message with the given contents and then shuts down the
-     * channel for writing since no more messages are expected to be sent after
-     * this. Invoked at most once.
+     * Sends a Close message, then shuts down the transmitter since no more
+     * messages are expected to be sent after this.
      */
     private CompletableFuture<WebSocket> enqueueClose(Close m) {
-        return enqueue(m).whenComplete((r, error) -> {
+        // TODO: MUST be a CF created once and shared across sendClose, otherwise
+        // a second sendClose may prematurely close the channel
+        return enqueue(m)
+                .orTimeout(60, TimeUnit.SECONDS)
+                .whenComplete((r, error) -> {
             try {
-                channel.shutdownOutput();
+                        transmitter.close();
+                    } catch (IOException e) {
+                        Log.logError(e);
+                    }
+                    if (error instanceof TimeoutException) {
+                        try {
+                            receiver.close();
             } catch (IOException e) {
                 Log.logError(e);
             }
-            boolean alreadyCompleted = !closeSent.complete(null);
-            if (alreadyCompleted) {
-                // Shouldn't happen as this callback must run at most once
-                throw new InternalError();
             }
         });
     }
 
     /*

@@ -305,201 +228,378 @@
      * fashion in respect to other messages accepted through this method. No
      * further messages will be accepted until the returned CompletableFuture
      * completes. This method is used to enforce "one outstanding send
      * operation" policy.
      */
-    private CompletableFuture<WebSocket> enqueueExclusively(OutgoingMessage m)
-    {
-        if (closed.get()) {
-            return failedFuture(new IllegalStateException("Closed"));
-        }
+    private CompletableFuture<WebSocket> enqueueExclusively(OutgoingMessage m) {
         if (!outstandingSend.compareAndSet(false, true)) {
-            return failedFuture(new IllegalStateException("Outstanding send"));
+            return failedFuture(new IllegalStateException("Send pending"));
         }
         return enqueue(m).whenComplete((r, e) -> outstandingSend.set(false));
     }
 
     private CompletableFuture<WebSocket> enqueue(OutgoingMessage m) {
-        CompletableFuture<WebSocket> cf = new CompletableFuture<>();
+        CompletableFuture<WebSocket> cf = new MinimalFuture<>();
         boolean added = queue.add(pair(m, cf));
         if (!added) {
             // The queue is supposed to be unbounded
             throw new InternalError();
         }
-        sendHandler.handle();
+        sendScheduler.runOrSchedule();
         return cf;
     }
 
     /*
-     * This is the main sending method. It may be run in different threads,
-     * but never concurrently.
+     * This is a message sending task. It pulls messages from the queue one by
+     * one and sends them. It may be run in different threads, but never
+     * concurrently.
      */
-    private void sendFirst(Runnable whenSent) {
+    private class SendTask implements SequentialScheduler.RestartableTask {
+
+        @Override
+        public void run(DeferredCompleter taskCompleter) {
         Pair<OutgoingMessage, CompletableFuture<WebSocket>> p = queue.poll();
         if (p == null) {
-            whenSent.run();
+                taskCompleter.complete();
             return;
         }
         OutgoingMessage message = p.first;
         CompletableFuture<WebSocket> cf = p.second;
         try {
-            message.contextualize(context);
+                if (!message.contextualize(context)) { // Do not send the message
+                    cf.complete(null);
+                    repeat(taskCompleter);
+                    return;
+                }
             Consumer<Exception> h = e -> {
                 if (e == null) {
                     cf.complete(WebSocketImpl.this);
                 } else {
                     cf.completeExceptionally(e);
                 }
-                sendHandler.handle();
-                whenSent.run();
+                    repeat(taskCompleter);
             };
             transmitter.send(message, h);
-        } catch (Exception t) {
+            } catch (Throwable t) {
             cf.completeExceptionally(t);
+                repeat(taskCompleter);
+            }
+        }
+
+        private void repeat(DeferredCompleter taskCompleter) {
+            taskCompleter.complete();
+            // More than a single message may have been enqueued while
+            // the task has been busy with the current message, but
+            // there is only a single signal recorded
+            sendScheduler.runOrSchedule();
         }
     }
 
     @Override
     public void request(long n) {
-        receiver.request(n);
+        if (demand.increase(n)) {
+            receiveScheduler.runOrSchedule();
+        }
     }
 
     @Override
     public String getSubprotocol() {
         return subprotocol;
     }
 
     @Override
-    public boolean isClosed() {
-        return closed.get();
+    public boolean isOutputClosed() {
+        return outputClosed;
     }
 
     @Override
-    public void abort() throws IOException {
-        try {
-            channel.close();
-        } finally {
-            closed.set(true);
-            signalClose(CLOSED_ABNORMALLY, "");
+    public boolean isInputClosed() {
+        return inputClosed;
         }
+
+    @Override
+    public void abort() {
+        inputClosed = true;
+        outputClosed = true;
+        receiveScheduler.stop();
+        close();
     }
 
     @Override
     public String toString() {
         return super.toString()
-                + "[" + (closed.get() ? "CLOSED" : "OPEN") + "]: " + uri
-                + (!subprotocol.isEmpty() ? ", subprotocol=" + subprotocol : "");
+                + "[uri=" + uri
+                + (!subprotocol.isEmpty() ? ", subprotocol=" + subprotocol : "")
+                + "]";
     }
 
-    private MessageStreamConsumer messageConsumerOf(Listener listener) {
-        // Synchronization performed here in every method is not for the sake of
-        // ordering invocations to this consumer, after all they are naturally
-        // ordered in the channel. The reason is to avoid an interference with
-        // any unrelated to the channel calls to onOpen, onClose and onError.
-        return new MessageStreamConsumer() {
+    /*
+     * The assumptions about order is as follows:
+     *
+     *     - state is never changed more than twice inside the `run` method:
+     *       x --(1)--> IDLE --(2)--> y (otherwise we're loosing events, or
+     *       overwriting parts of messages creating a mess since there's no
+     *       queueing)
+     *     - OPEN is always the first state
+     *     - no messages are requested/delivered before onOpen is called (this
+     *       is implemented by making WebSocket instance accessible first in
+     *       onOpen)
+     *     - after the state has been observed as CLOSE/ERROR, the scheduler
+     *       is stopped
+     */
+    private class ReceiveTask extends SequentialScheduler.CompleteRestartableTask {
+
+        // Receiver only asked here and nowhere else because we must make sure
+        // onOpen is invoked first and no messages become pending before onOpen
+        // finishes
+
+        @Override
+        public void run() {
+            while (true) {
+                State s = state.get();
+                try {
+                    switch (s) {
+                        case OPEN:
+                            processOpen();
+                            tryChangeState(OPEN, IDLE);
+                            break;
+                        case TEXT:
+                            processText();
+                            tryChangeState(TEXT, IDLE);
+                            break;
+                        case BINARY:
+                            processBinary();
+                            tryChangeState(BINARY, IDLE);
+                            break;
+                        case PING:
+                            processPing();
+                            tryChangeState(PING, IDLE);
+                            break;
+                        case PONG:
+                            processPong();
+                            tryChangeState(PONG, IDLE);
+                            break;
+                        case CLOSE:
+                            processClose();
+                            return;
+                        case ERROR:
+                            processError();
+                            return;
+                        case IDLE:
+                            if (demand.tryDecrement()
+                                    && tryChangeState(IDLE, WAITING)) {
+                                receiver.request(1);
+                            }
+                            return;
+                        case WAITING:
+                            // For debugging spurious signalling: when there was a
+                            // signal, but apparently nothing has changed
+                            return;
+                        default:
+                            throw new InternalError(String.valueOf(s));
+                    }
+                } catch (Throwable t) {
+                    signalError(t);
+                }
+            }
+        }
 
-            @Override
-            public void onText(MessagePart part, CharSequence data) {
-                receiver.acknowledge();
-                synchronized (WebSocketImpl.this.lock) {
-                    try {
-                        listener.onText(WebSocketImpl.this, data, part);
-                    } catch (Exception e) {
-                        signalError(e);
+        private void processError() throws IOException {
+            receiver.close();
+            receiveScheduler.stop();
+            Throwable err = error.get();
+            if (err instanceof FailWebSocketException) {
+                int code1 = ((FailWebSocketException) err).getStatusCode();
+                err = new ProtocolException().initCause(err);
+                enqueueClose(new Close(code1, ""))
+                        .whenComplete(
+                                (r, e) -> {
+                                    if (e != null) {
+                                        Log.logError(e);
                     }
+                                });
                 }
+            listener.onError(WebSocketImpl.this, err);
             }
 
-            @Override
-            public void onBinary(MessagePart part, ByteBuffer data) {
-                receiver.acknowledge();
-                synchronized (WebSocketImpl.this.lock) {
-                    try {
-                        listener.onBinary(WebSocketImpl.this, data.slice(), part);
-                    } catch (Exception e) {
-                        signalError(e);
+        private void processClose() throws IOException {
+            receiver.close();
+            receiveScheduler.stop();
+            CompletionStage<?> readyToClose;
+            readyToClose = listener.onClose(WebSocketImpl.this, statusCode, reason);
+            if (readyToClose == null) {
+                readyToClose = MinimalFuture.completedFuture(null);
                     }
+            int code;
+            if (statusCode == NO_STATUS_CODE || statusCode == CLOSED_ABNORMALLY) {
+                code = NORMAL_CLOSURE;
+            } else {
+                code = statusCode;
                 }
+            readyToClose.whenComplete((r, e) -> {
+                enqueueClose(new Close(code, ""))
+                        .whenComplete((r1, e1) -> {
+                            if (e1 != null) {
+                                Log.logError(e1);
+                            }
+                        });
+            });
             }
 
-            @Override
-            public void onPing(ByteBuffer data) {
-                receiver.acknowledge();
+        private void processPong() {
+            listener.onPong(WebSocketImpl.this, binaryData);
+        }
+
+        private void processPing() {
                 // Let's make a full copy of this tiny data. What we want here
                 // is to rule out a possibility the shared data we send might be
-                // corrupted the by processing in the listener.
-                ByteBuffer slice = data.slice();
-                ByteBuffer copy = ByteBuffer.allocate(data.remaining())
-                        .put(data)
+            // corrupted by processing in the listener.
+            ByteBuffer slice = binaryData.slice();
+            ByteBuffer copy = ByteBuffer.allocate(binaryData.remaining())
+                    .put(binaryData)
                         .flip();
                 // Non-exclusive send;
                 CompletableFuture<WebSocket> pongSent = enqueue(new Pong(copy));
                 pongSent.whenComplete(
-                        (r, error) -> {
-                            if (error != null) {
-                                WebSocketImpl.this.signalError(error);
+                    (r, e) -> {
+                        if (e != null) {
+                            signalError(Utils.getCompletionCause(e));
                             }
                         }
                 );
-                synchronized (WebSocketImpl.this.lock) {
-                    try {
                         listener.onPing(WebSocketImpl.this, slice);
-                    } catch (Exception e) {
-                        signalError(e);
+        }
+
+        private void processBinary() {
+            listener.onBinary(WebSocketImpl.this, binaryData, part);
+        }
+
+        private void processText() {
+            listener.onText(WebSocketImpl.this, text, part);
+        }
+
+        private void processOpen() {
+            listener.onOpen(WebSocketImpl.this);
+        }
+    }
+
+    private void signalOpen() {
+        receiveScheduler.runOrSchedule();
+    }
+
+    private void signalError(Throwable error) {
+        inputClosed = true;
+        outputClosed = true;
+        if (!this.error.compareAndSet(null, error) || !trySetState(ERROR)) {
+            Log.logError(error);
+        } else {
+            close();
                     }
                 }
+
+    private void close() {
+        try {
+            try {
+                receiver.close();
+            } finally {
+                transmitter.close();
+            }
+        } catch (Throwable t) {
+            Log.logError(t);
+        }
+    }
+
+    /*
+     * Signals a Close event (might not correspond to anything happened on the
+     * channel, i.e. might be synthetic).
+     */
+    private void signalClose(int statusCode, String reason) {
+        inputClosed = true;
+        this.statusCode = statusCode;
+        this.reason = reason;
+        if (!trySetState(CLOSE)) {
+            Log.logTrace("Close: {0}, ''{1}''", statusCode, reason);
+        } else {
+            try {
+                receiver.close();
+            } catch (Throwable t) {
+                Log.logError(t);
+            }
             }
+    }
+
+    private class SignallingMessageConsumer implements MessageStreamConsumer {
 
             @Override
-            public void onPong(ByteBuffer data) {
+        public void onText(CharSequence data, MessagePart part) {
                 receiver.acknowledge();
-                synchronized (WebSocketImpl.this.lock) {
-                    try {
-                        listener.onPong(WebSocketImpl.this, data.slice());
-                    } catch (Exception e) {
-                        signalError(e);
+            text = data;
+            WebSocketImpl.this.part = part;
+            tryChangeState(WAITING, TEXT);
                     }
+
+        @Override
+        public void onBinary(ByteBuffer data, MessagePart part) {
+            receiver.acknowledge();
+            binaryData = data;
+            WebSocketImpl.this.part = part;
+            tryChangeState(WAITING, BINARY);
                 }
+
+        @Override
+        public void onPing(ByteBuffer data) {
+            receiver.acknowledge();
+            binaryData = data;
+            tryChangeState(WAITING, PING);
+        }
+
+        @Override
+        public void onPong(ByteBuffer data) {
+            receiver.acknowledge();
+            binaryData = data;
+            tryChangeState(WAITING, PONG);
             }
 
             @Override
             public void onClose(int statusCode, CharSequence reason) {
                 receiver.acknowledge();
-                processClose(statusCode, reason.toString());
+            signalClose(statusCode, reason.toString());
+        }
+
+        @Override
+        public void onComplete() {
+            receiver.acknowledge();
+            signalClose(CLOSED_ABNORMALLY, "");
             }
 
             @Override
-            public void onError(Exception error) {
-                // An signalError doesn't necessarily mean we must signalClose
-                // the WebSocket. However, if it's something the WebSocket
-                // Specification recognizes as a reason for "Failing the
-                // WebSocket Connection", then we must do so, but BEFORE
-                // notifying the Listener.
-                if (!(error instanceof FailWebSocketException)) {
+        public void onError(Throwable error) {
                     signalError(error);
-                } else {
-                    Exception ex = (Exception) new ProtocolException().initCause(error);
-                    int code = ((FailWebSocketException) error).getStatusCode();
-                    enqueueClose(new Close(code, ""))
-                            .whenComplete((r, e) -> {
-                                if (e != null) {
-                                    ex.addSuppressed(e);
                                 }
-                                try {
-                                    channel.close();
-                                } catch (IOException e1) {
-                                    ex.addSuppressed(e1);
-                                } finally {
-                                    closed.set(true);
                                 }
-                                signalError(ex);
-                            });
+
+    private boolean trySetState(State newState) {
+        while (true) {
+            State currentState = state.get();
+            if (currentState == ERROR || currentState == CLOSE) {
+                return false;
+            } else if (state.compareAndSet(currentState, newState)) {
+                receiveScheduler.runOrSchedule();
+                return true;
+            }
                 }
             }
 
-            @Override
-            public void onComplete() {
-                processClose(CLOSED_ABNORMALLY, "");
+    private boolean tryChangeState(State expectedState, State newState) {
+        State witness = state.compareAndExchange(expectedState, newState);
+        if (witness == expectedState) {
+            receiveScheduler.runOrSchedule();
+            return true;
+        }
+        // This should be the only reason for inability to change the state from
+        // IDLE to WAITING: the state has changed to terminal
+        if (witness != ERROR && witness != CLOSE) {
+            throw new InternalError();
             }
-        };
+        return false;
     }
 }
< prev index next >