1 /*
   2  * Copyright (c) 2015, 2017, Oracle and/or its affiliates. All rights reserved.
   3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
   4  *
   5  * This code is free software; you can redistribute it and/or modify it
   6  * under the terms of the GNU General Public License version 2 only, as
   7  * published by the Free Software Foundation.  Oracle designates this
   8  * particular file as subject to the "Classpath" exception as provided
   9  * by Oracle in the LICENSE file that accompanied this code.
  10  *
  11  * This code is distributed in the hope that it will be useful, but WITHOUT
  12  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
  13  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
  14  * version 2 for more details (a copy is included in the LICENSE file that
  15  * accompanied this code).
  16  *
  17  * You should have received a copy of the GNU General Public License version
  18  * 2 along with this work; if not, write to the Free Software Foundation,
  19  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
  20  *
  21  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
  22  * or visit www.oracle.com if you need additional information or have any
  23  * questions.
  24  */
  25 
  26 package jdk.incubator.http.internal.websocket;
  27 
  28 import jdk.incubator.http.WebSocket;
  29 import jdk.incubator.http.internal.common.Log;
  30 import jdk.incubator.http.internal.common.Pair;
  31 import jdk.incubator.http.internal.websocket.OpeningHandshake.Result;
  32 import jdk.incubator.http.internal.websocket.OutgoingMessage.Binary;
  33 import jdk.incubator.http.internal.websocket.OutgoingMessage.Close;
  34 import jdk.incubator.http.internal.websocket.OutgoingMessage.Context;
  35 import jdk.incubator.http.internal.websocket.OutgoingMessage.Ping;
  36 import jdk.incubator.http.internal.websocket.OutgoingMessage.Pong;
  37 import jdk.incubator.http.internal.websocket.OutgoingMessage.Text;
  38 
  39 import java.io.IOException;
  40 import java.net.ProtocolException;
  41 import java.net.URI;
  42 import java.nio.ByteBuffer;
  43 import java.util.Queue;
  44 import java.util.concurrent.CompletableFuture;
  45 import java.util.concurrent.CompletionStage;
  46 import java.util.concurrent.ConcurrentLinkedQueue;
  47 import java.util.concurrent.atomic.AtomicBoolean;
  48 import java.util.function.Consumer;
  49 import java.util.function.Function;
  50 
  51 import static java.util.Objects.requireNonNull;
  52 import static java.util.concurrent.CompletableFuture.failedFuture;
  53 import static jdk.incubator.http.internal.common.Pair.pair;
  54 import static jdk.incubator.http.internal.websocket.StatusCodes.CLOSED_ABNORMALLY;
  55 import static jdk.incubator.http.internal.websocket.StatusCodes.NO_STATUS_CODE;
  56 import static jdk.incubator.http.internal.websocket.StatusCodes.isLegalToSendFromClient;
  57 
  58 /*
  59  * A WebSocket client.
  60  */
  61 final class WebSocketImpl implements WebSocket {
  62 
  63     private final URI uri;
  64     private final String subprotocol;
  65     private final RawChannel channel;
  66     private final Listener listener;
  67 
  68     /*
  69      * Whether or not Listener.onClose or Listener.onError has been already
  70      * invoked. We keep track of this since only one of these methods is invoked
  71      * and it is invoked at most once.
  72      */
  73     private boolean lastMethodInvoked;
  74     private final AtomicBoolean outstandingSend = new AtomicBoolean();
  75     private final CooperativeHandler sendHandler =
  76               new CooperativeHandler(this::sendFirst);
  77     private final Queue<Pair<OutgoingMessage, CompletableFuture<WebSocket>>>
  78             queue = new ConcurrentLinkedQueue<>();
  79     private final Context context = new OutgoingMessage.Context();
  80     private final Transmitter transmitter;
  81     private final Receiver receiver;
  82 
  83     /*
  84      * Whether or not the WebSocket has been closed. When a WebSocket has been
  85      * closed it means that no further messages can be sent or received.
  86      * A closure can be triggered by:
  87      *
  88      *   1. abort()
  89      *   2. "Failing the WebSocket Connection" (i.e. a fatal error)
  90      *   3. Completion of the Closing handshake
  91      */
  92     private final AtomicBoolean closed = new AtomicBoolean();
  93 
  94     /*
  95      * This lock is enforcing sequential ordering of invocations to listener's
  96      * methods. It is supposed to be uncontended. The only contention that can
  97      * happen is when onOpen, an asynchronous onError (not related to reading
  98      * from the channel, e.g. an error from automatic Pong reply) or onClose
  99      * (related to abort) happens. Since all of the above are one-shot actions,
 100      * the said contention is insignificant.
 101      */
 102     private final Object lock = new Object();
 103 
 104     private final CompletableFuture<?> closeReceived = new CompletableFuture<>();
 105     private final CompletableFuture<?> closeSent = new CompletableFuture<>();
 106 
 107     static CompletableFuture<WebSocket> newInstanceAsync(BuilderImpl b) {
 108         Function<Result, WebSocket> newWebSocket = r -> {
 109             WebSocketImpl ws = new WebSocketImpl(b.getUri(),
 110                                                  r.subprotocol,
 111                                                  r.channel,
 112                                                  b.getListener());
 113             // The order of calls might cause a subtle effects, like CF will be
 114             // returned from the buildAsync _after_ onOpen has been signalled.
 115             // This means if onOpen is lengthy, it might cause some problems.
 116             ws.signalOpen();
 117             return ws;
 118         };
 119         OpeningHandshake h;
 120         try {
 121             h = new OpeningHandshake(b);
 122         } catch (IllegalArgumentException e) {
 123             return failedFuture(e);
 124         }
 125         return h.send().thenApply(newWebSocket);
 126     }
 127 
 128     WebSocketImpl(URI uri,
 129                   String subprotocol,
 130                   RawChannel channel,
 131                   Listener listener)
 132     {
 133         this.uri = requireNonNull(uri);
 134         this.subprotocol = requireNonNull(subprotocol);
 135         this.channel = requireNonNull(channel);
 136         this.listener = requireNonNull(listener);
 137         this.transmitter = new Transmitter(channel);
 138         this.receiver = new Receiver(messageConsumerOf(listener), channel);
 139 
 140         // Set up the Closing Handshake action
 141         CompletableFuture.allOf(closeReceived, closeSent)
 142                 .whenComplete((result, error) -> {
 143                     try {
 144                         channel.close();
 145                     } catch (IOException e) {
 146                         Log.logError(e);
 147                     } finally {
 148                         closed.set(true);
 149                     }
 150                 });
 151     }
 152 
 153     /*
 154      * This initialisation is outside of the constructor for the sake of
 155      * safe publication.
 156      */
 157     private void signalOpen() {
 158         synchronized (lock) {
 159             // TODO: might hold lock longer than needed causing prolonged
 160             // contention? substitute lock for ConcurrentLinkedQueue<Runnable>?
 161             try {
 162                 listener.onOpen(this);
 163             } catch (Exception e) {
 164                 signalError(e);
 165             }
 166         }
 167     }
 168 
 169     private void signalError(Throwable error) {
 170         synchronized (lock) {
 171             if (lastMethodInvoked) {
 172                 Log.logError(error);
 173             } else {
 174                 lastMethodInvoked = true;
 175                 receiver.close();
 176                 try {
 177                     listener.onError(this, error);
 178                 } catch (Exception e) {
 179                     Log.logError(e);
 180                 }
 181             }
 182         }
 183     }
 184 
 185     /*
 186      * Processes a Close event that came from the channel. Invoked at most once.
 187      */
 188     private void processClose(int statusCode, String reason) {
 189         receiver.close();
 190         try {
 191             channel.shutdownInput();
 192         } catch (IOException e) {
 193             Log.logError(e);
 194         }
 195         boolean alreadyCompleted = !closeReceived.complete(null);
 196         if (alreadyCompleted) {
 197             // This CF is supposed to be completed only once, the first time a
 198             // Close message is received. No further messages are pulled from
 199             // the socket.
 200             throw new InternalError();
 201         }
 202         int code;
 203         if (statusCode == NO_STATUS_CODE || statusCode == CLOSED_ABNORMALLY) {
 204             code = NORMAL_CLOSURE;
 205         } else {
 206             code = statusCode;
 207         }
 208         CompletionStage<?> readyToClose = signalClose(statusCode, reason);
 209         if (readyToClose == null) {
 210             readyToClose = CompletableFuture.completedFuture(null);
 211         }
 212         readyToClose.whenComplete((r, error) -> {
 213             enqueueClose(new Close(code, ""))
 214                     .whenComplete((r1, error1) -> {
 215                         if (error1 != null) {
 216                             Log.logError(error1);
 217                         }
 218                     });
 219         });
 220     }
 221 
 222     /*
 223      * Signals a Close event (might not correspond to anything happened on the
 224      * channel, e.g. `abort()`).
 225      */
 226     private CompletionStage<?> signalClose(int statusCode, String reason) {
 227         synchronized (lock) {
 228             if (lastMethodInvoked) {
 229                 Log.logTrace("Close: {0}, ''{1}''", statusCode, reason);
 230             } else {
 231                 lastMethodInvoked = true;
 232                 receiver.close();
 233                 try {
 234                     return listener.onClose(this, statusCode, reason);
 235                 } catch (Exception e) {
 236                     Log.logError(e);
 237                 }
 238             }
 239         }
 240         return null;
 241     }
 242 
 243     @Override
 244     public CompletableFuture<WebSocket> sendText(CharSequence message,
 245                                                  boolean isLast)
 246     {
 247         return enqueueExclusively(new Text(message, isLast));
 248     }
 249 
 250     @Override
 251     public CompletableFuture<WebSocket> sendBinary(ByteBuffer message,
 252                                                    boolean isLast)
 253     {
 254         return enqueueExclusively(new Binary(message, isLast));
 255     }
 256 
 257     @Override
 258     public CompletableFuture<WebSocket> sendPing(ByteBuffer message) {
 259         return enqueueExclusively(new Ping(message));
 260     }
 261 
 262     @Override
 263     public CompletableFuture<WebSocket> sendPong(ByteBuffer message) {
 264         return enqueueExclusively(new Pong(message));
 265     }
 266 
 267     @Override
 268     public CompletableFuture<WebSocket> sendClose(int statusCode,
 269                                                   String reason) {
 270         if (!isLegalToSendFromClient(statusCode)) {
 271             return failedFuture(
 272                     new IllegalArgumentException("statusCode: " + statusCode));
 273         }
 274         Close msg;
 275         try {
 276             msg = new Close(statusCode, reason);
 277         } catch (IllegalArgumentException e) {
 278             return failedFuture(e);
 279         }
 280         return enqueueClose(msg);
 281     }
 282 
 283     /*
 284      * Sends a Close message with the given contents and then shuts down the
 285      * channel for writing since no more messages are expected to be sent after
 286      * this. Invoked at most once.
 287      */
 288     private CompletableFuture<WebSocket> enqueueClose(Close m) {
 289         return enqueue(m).whenComplete((r, error) -> {
 290             try {
 291                 channel.shutdownOutput();
 292             } catch (IOException e) {
 293                 Log.logError(e);
 294             }
 295             boolean alreadyCompleted = !closeSent.complete(null);
 296             if (alreadyCompleted) {
 297                 // Shouldn't happen as this callback must run at most once
 298                 throw new InternalError();
 299             }
 300         });
 301     }
 302 
 303     /*
 304      * Accepts the given message into the outgoing queue in a mutually-exclusive
 305      * fashion in respect to other messages accepted through this method. No
 306      * further messages will be accepted until the returned CompletableFuture
 307      * completes. This method is used to enforce "one outstanding send
 308      * operation" policy.
 309      */
 310     private CompletableFuture<WebSocket> enqueueExclusively(OutgoingMessage m)
 311     {
 312         if (closed.get()) {
 313             return failedFuture(new IllegalStateException("Closed"));
 314         }
 315         if (!outstandingSend.compareAndSet(false, true)) {
 316             return failedFuture(new IllegalStateException("Outstanding send"));
 317         }
 318         return enqueue(m).whenComplete((r, e) -> outstandingSend.set(false));
 319     }
 320 
 321     private CompletableFuture<WebSocket> enqueue(OutgoingMessage m) {
 322         CompletableFuture<WebSocket> cf = new CompletableFuture<>();
 323         boolean added = queue.add(pair(m, cf));
 324         if (!added) {
 325             // The queue is supposed to be unbounded
 326             throw new InternalError();
 327         }
 328         sendHandler.handle();
 329         return cf;
 330     }
 331 
 332     /*
 333      * This is the main sending method. It may be run in different threads,
 334      * but never concurrently.
 335      */
 336     private void sendFirst(Runnable whenSent) {
 337         Pair<OutgoingMessage, CompletableFuture<WebSocket>> p = queue.poll();
 338         if (p == null) {
 339             whenSent.run();
 340             return;
 341         }
 342         OutgoingMessage message = p.first;
 343         CompletableFuture<WebSocket> cf = p.second;
 344         try {
 345             message.contextualize(context);
 346             Consumer<Exception> h = e -> {
 347                 if (e == null) {
 348                     cf.complete(WebSocketImpl.this);
 349                 } else {
 350                     cf.completeExceptionally(e);
 351                 }
 352                 sendHandler.handle();
 353                 whenSent.run();
 354             };
 355             transmitter.send(message, h);
 356         } catch (Exception t) {
 357             cf.completeExceptionally(t);
 358         }
 359     }
 360 
 361     @Override
 362     public void request(long n) {
 363         receiver.request(n);
 364     }
 365 
 366     @Override
 367     public String getSubprotocol() {
 368         return subprotocol;
 369     }
 370 
 371     @Override
 372     public boolean isClosed() {
 373         return closed.get();
 374     }
 375 
 376     @Override
 377     public void abort() throws IOException {
 378         try {
 379             channel.close();
 380         } finally {
 381             closed.set(true);
 382             signalClose(CLOSED_ABNORMALLY, "");
 383         }
 384     }
 385 
 386     @Override
 387     public String toString() {
 388         return super.toString()
 389                 + "[" + (closed.get() ? "CLOSED" : "OPEN") + "]: " + uri
 390                 + (!subprotocol.isEmpty() ? ", subprotocol=" + subprotocol : "");
 391     }
 392 
 393     private MessageStreamConsumer messageConsumerOf(Listener listener) {
 394         // Synchronization performed here in every method is not for the sake of
 395         // ordering invocations to this consumer, after all they are naturally
 396         // ordered in the channel. The reason is to avoid an interference with
 397         // any unrelated to the channel calls to onOpen, onClose and onError.
 398         return new MessageStreamConsumer() {
 399 
 400             @Override
 401             public void onText(MessagePart part, CharSequence data) {
 402                 receiver.acknowledge();
 403                 synchronized (WebSocketImpl.this.lock) {
 404                     try {
 405                         listener.onText(WebSocketImpl.this, data, part);
 406                     } catch (Exception e) {
 407                         signalError(e);
 408                     }
 409                 }
 410             }
 411 
 412             @Override
 413             public void onBinary(MessagePart part, ByteBuffer data) {
 414                 receiver.acknowledge();
 415                 synchronized (WebSocketImpl.this.lock) {
 416                     try {
 417                         listener.onBinary(WebSocketImpl.this, data.slice(), part);
 418                     } catch (Exception e) {
 419                         signalError(e);
 420                     }
 421                 }
 422             }
 423 
 424             @Override
 425             public void onPing(ByteBuffer data) {
 426                 receiver.acknowledge();
 427                 // Let's make a full copy of this tiny data. What we want here
 428                 // is to rule out a possibility the shared data we send might be
 429                 // corrupted the by processing in the listener.
 430                 ByteBuffer slice = data.slice();
 431                 ByteBuffer copy = ByteBuffer.allocate(data.remaining())
 432                         .put(data)
 433                         .flip();
 434                 // Non-exclusive send;
 435                 CompletableFuture<WebSocket> pongSent = enqueue(new Pong(copy));
 436                 pongSent.whenComplete(
 437                         (r, error) -> {
 438                             if (error != null) {
 439                                 WebSocketImpl.this.signalError(error);
 440                             }
 441                         }
 442                 );
 443                 synchronized (WebSocketImpl.this.lock) {
 444                     try {
 445                         listener.onPing(WebSocketImpl.this, slice);
 446                     } catch (Exception e) {
 447                         signalError(e);
 448                     }
 449                 }
 450             }
 451 
 452             @Override
 453             public void onPong(ByteBuffer data) {
 454                 receiver.acknowledge();
 455                 synchronized (WebSocketImpl.this.lock) {
 456                     try {
 457                         listener.onPong(WebSocketImpl.this, data.slice());
 458                     } catch (Exception e) {
 459                         signalError(e);
 460                     }
 461                 }
 462             }
 463 
 464             @Override
 465             public void onClose(int statusCode, CharSequence reason) {
 466                 receiver.acknowledge();
 467                 processClose(statusCode, reason.toString());
 468             }
 469 
 470             @Override
 471             public void onError(Exception error) {
 472                 // An signalError doesn't necessarily mean we must signalClose
 473                 // the WebSocket. However, if it's something the WebSocket
 474                 // Specification recognizes as a reason for "Failing the
 475                 // WebSocket Connection", then we must do so, but BEFORE
 476                 // notifying the Listener.
 477                 if (!(error instanceof FailWebSocketException)) {
 478                     signalError(error);
 479                 } else {
 480                     Exception ex = (Exception) new ProtocolException().initCause(error);
 481                     int code = ((FailWebSocketException) error).getStatusCode();
 482                     enqueueClose(new Close(code, ""))
 483                             .whenComplete((r, e) -> {
 484                                 if (e != null) {
 485                                     ex.addSuppressed(e);
 486                                 }
 487                                 try {
 488                                     channel.close();
 489                                 } catch (IOException e1) {
 490                                     ex.addSuppressed(e1);
 491                                 } finally {
 492                                     closed.set(true);
 493                                 }
 494                                 signalError(ex);
 495                             });
 496                 }
 497             }
 498 
 499             @Override
 500             public void onComplete() {
 501                 processClose(CLOSED_ABNORMALLY, "");
 502             }
 503         };
 504     }
 505 }