< prev index next >
   1 /*
   2  * Copyright (c) 2016, Oracle and/or its affiliates. All rights reserved.
   3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
   4  *
   5  * This code is free software; you can redistribute it and/or modify it
   6  * under the terms of the GNU General  License version 2 only, as
   7  * published by the Free Software Foundation.  Oracle designates this
   8  * particular file as subject to the "Classpath" exception as provided
   9  * by Oracle in the LICENSE file that accompanied this code.
  10  *
  11  * This code is distributed in the hope that it will be useful, but WITHOUT
  12  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
  13  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General  License
  14  * version 2 for more details (a copy is included in the LICENSE file that
  15  * accompanied this code).
  16  *
  17  * You should have received a copy of the GNU General  License version
  18  * 2 along with this work; if not, write to the Free Software Foundation,
  19  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
  20  *
  21  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
  22  * or visit www.oracle.com if you need additional information or have any
  23  * questions.
  24  */
  25 package java.net.http;
  26 
  27 import java.lang.reflect.Array;
  28 import java.net.http.OutgoingMessage.Binary;
  29 import java.net.http.OutgoingMessage.BinaryText;
  30 import java.net.http.OutgoingMessage.CharacterText;
  31 import java.net.http.OutgoingMessage.Close;
  32 import java.net.http.OutgoingMessage.Ping;
  33 import java.net.http.OutgoingMessage.Pong;
  34 import java.net.http.OutgoingMessage.StreamedText;
  35 import java.net.http.OutgoingMessage.Visitor;
  36 import java.net.http.WebSocketFrame.HeaderWriter;
  37 import java.net.http.WebSocketFrame.Masker;
  38 import java.net.http.WebSocketFrame.Opcode;
  39 import java.nio.ByteBuffer;
  40 import java.nio.CharBuffer;
  41 import java.nio.charset.CharacterCodingException;
  42 import java.nio.charset.CoderResult;
  43 import java.security.SecureRandom;
  44 import java.util.Iterator;
  45 import java.util.concurrent.CompletableFuture;
  46 import java.util.concurrent.Executor;
  47 import java.util.concurrent.Flow;
  48 import java.util.concurrent.Flow.Processor;
  49 import java.util.concurrent.atomic.AtomicLong;
  50 
  51 import static java.lang.System.Logger.Level.TRACE;
  52 import static java.net.http.Utils.logger;
  53 import static java.net.http.WebSocketFrame.MAX_HEADER_SIZE_BYTES;
  54 import static java.net.http.WebSocketFrame.Opcode.BINARY;
  55 import static java.net.http.WebSocketFrame.Opcode.CLOSE;
  56 import static java.net.http.WebSocketFrame.Opcode.CONTINUATION;
  57 import static java.net.http.WebSocketFrame.Opcode.PING;
  58 import static java.net.http.WebSocketFrame.Opcode.PONG;
  59 import static java.net.http.WebSocketFrame.Opcode.TEXT;
  60 import static java.util.Objects.requireNonNull;
  61 
  62 final class MessageSender implements Processor<Pair<OutgoingMessage, CompletableFuture<Void>>,
  63         Pair<Shared<ByteBuffer>[], CompletableFuture<Void>>> {
  64 
  65     private final SecureRandom random = new SecureRandom();
  66     private final CharsetToolkit.Verifier verifier = new CharsetToolkit.Verifier();
  67     private final CharsetToolkit.Encoder encoder = new CharsetToolkit.Encoder();
  68     private final Masker masker = Masker.newInstance();
  69     private final HeaderWriter headerWriter = new HeaderWriter();
  70     private final SharedPool<ByteBuffer> sharedHeaderBuffers
  71             = new SharedPool<>(() -> ByteBuffer.allocateDirect(MAX_HEADER_SIZE_BYTES), 16);
  72     private final SharedPool<ByteBuffer> sharedPayloadBuffers
  73             = new SharedPool<>(() -> ByteBuffer.allocateDirect(32768), 16); // FIXME: ensure different!
  74     private final AtomicLong demand = new AtomicLong();
  75 
  76     private Flow.Subscriber<? super Pair<Shared<ByteBuffer>[], CompletableFuture<Void>>> subscriber;
  77     private final SignalHandler handler;
  78 
  79     private final Visitor<Boolean, CompletableFuture<Void>> processingVisitor
  80             = createProcessingVisitor();
  81 
  82     private volatile boolean closed;
  83     private volatile Pair<OutgoingMessage, CompletableFuture<Void>> pair;
  84     private volatile Flow.Subscription subscription;
  85 
  86     MessageSender(Executor executor) {
  87         this.handler = new SignalHandler(requireNonNull(executor), this::react);
  88     }
  89 
  90     private void react() {
  91         synchronized (this) {
  92             while (pair != null && demand.get() > 0 && !closed) {
  93                 OutgoingMessage m = pair.first;
  94                 CompletableFuture<Void> cf = pair.second;
  95                 boolean processed = m.accept(processingVisitor, cf);
  96                 if (processed) {
  97                     pair = null;
  98                     long l = demand.decrementAndGet();
  99                     assert l >= 0 : l;
 100                     subscription.request(1);
 101                 }
 102             }
 103         }
 104     }
 105 
 106     @Override
 107     public void subscribe(Flow.Subscriber<? super Pair<Shared<ByteBuffer>[], CompletableFuture<Void>>> subscriber) {
 108         this.subscriber = requireNonNull(subscriber);
 109         subscriber.onSubscribe(
 110                 new Flow.Subscription() {
 111                     @Override
 112                     public void request(long n) {
 113                         if (n < 0) {
 114                             throw new IllegalArgumentException(String.valueOf(n));
 115                         }
 116                         demand.accumulateAndGet(n, (p, i) -> p + i < 0 ? Long.MAX_VALUE : p + i);
 117                         handler.signal();
 118                     }
 119 
 120                     @Override
 121                     public void cancel() {
 122                         closed = true;
 123                     }
 124                 });
 125     }
 126 
 127     private void send(Shared<ByteBuffer>[] buffers, CompletableFuture<Void> cf) {
 128         Pair<Shared<ByteBuffer>[], CompletableFuture<Void>> p = Pair.pair(buffers, cf);
 129         logger.log(TRACE, "Passing ''{0}'' for writing", p);
 130         subscriber.onNext(p);
 131     }
 132 
 133     private Visitor<Boolean, CompletableFuture<Void>> createProcessingVisitor() {
 134 
 135         // TODO: if we can modify buffers AND passed buffers are direct
 136         //           XOR and send (mask leak?)
 137         //       otherwise need to make copies with XORing here
 138 
 139         return new Visitor<>() {
 140 
 141             private Iterator<? extends CharSequence> streamIterator;
 142             private CharSequence previousIteratorItem;
 143 
 144             private boolean messageSequence;
 145             private boolean frameSequence;
 146 
 147             // In case of a long message that might result in more than a single
 148             // frame on the wire
 149             private boolean charactersProcessedFully = true;
 150             private CharBuffer cb;
 151 
 152             @Override
 153             public Boolean visit(CharacterText message, CompletableFuture<Void> attachment) {
 154                 return resumeCharSequence(message.characters, message.isLast, attachment);
 155             }
 156 
 157             @Override
 158             public Boolean visit(BinaryText message, CompletableFuture<Void> attachment) {
 159                 try {
 160                     int oldPos = message.bytes.position();
 161                     verifier.verify(message.bytes, message.isLast);
 162                     message.bytes.position(oldPos);
 163                 } catch (CharacterCodingException e) {
 164                     throw new IllegalArgumentException(
 165                             "Malformed UTF-8 bytes", e);
 166                 }
 167                 if (message.isLast) {
 168                     verifier.reset();
 169                 }
 170                 Shared<ByteBuffer>[] buffers = createArray(2);
 171                 buffers[0] = sharedHeaderBuffers.get();
 172                 buffers[1] = Shared.wrap(message.bytes);
 173                 final boolean fin = message.isLast;
 174                 final Opcode opcode = messageSequence ? CONTINUATION : BINARY;
 175                 final int payloadLen = message.bytes.remaining();
 176                 final int mask = nextMask();
 177 
 178 
 179                 int oldPos = message.bytes.position();
 180                 masker.mask(mask).applyMask(message.bytes, message.bytes);
 181                 message.bytes.position(oldPos);
 182 
 183                 headerWriter.fin(fin).opcode(opcode).payloadLen(payloadLen).mask(mask)
 184                         .write(buffers[0]);
 185                 buffers[0].flip();
 186                 messageSequence = !message.isLast;
 187                 return true;
 188             }
 189 
 190             @Override
 191             public Boolean visit(StreamedText streamedText, CompletableFuture<Void> cf) {
 192                 throw new IllegalArgumentException("Not yet implemented");
 193 //                if (streamIterator == null) {
 194 //                    streamIterator = streamedText.characters.iterator();
 195 //                    previousIteratorItem = "";
 196 //                }
 197 //                if (!charactersProcessedFully) {
 198 //                    boolean isLast = !streamIterator.hasNext();
 199 //                    Boolean r = resumeCharSequence(previousIteratorItem, isLast, isLast ? cf : null);
 200 //
 201 //                }
 202 //                if (!streamIterator.hasNext()) {
 203 //                    return resumeCharSequence(previousIteratorItem, true, cf);
 204 //                } else {
 205 //                    CharSequence tmp = previousIteratorItem;
 206 //                    previousIteratorItem = streamIterator.next();
 207 //                    return resumeCharSequence(tmp, false, cf);
 208 //                }
 209             }
 210 
 211             @Override
 212             public Boolean visit(Binary message, CompletableFuture<Void> cf) {
 213                 Shared<ByteBuffer>[] buffers = createArray(2);
 214                 buffers[0] = sharedHeaderBuffers.get();
 215                 buffers[1] = Shared.wrap(message.bytes);
 216                 final boolean fin = message.isLast;
 217                 final Opcode opcode = messageSequence ? CONTINUATION : BINARY;
 218                 final int payloadLen = message.bytes.remaining();
 219                 final int mask = nextMask();
 220 
 221                 int oldPos = message.bytes.position();
 222                 masker.mask(mask).applyMask(message.bytes, message.bytes);
 223                 message.bytes.position(oldPos);
 224 
 225                 headerWriter.fin(fin).opcode(opcode).payloadLen(payloadLen).mask(mask)
 226                         .write(buffers[0]);
 227                 buffers[0].flip();
 228                 messageSequence = !message.isLast;
 229                 send(buffers, cf);
 230                 return true;
 231             }
 232 
 233             @Override
 234             public Boolean visit(Ping message, CompletableFuture<Void> cf) {
 235                 Shared<ByteBuffer>[] buffers = createArray(2);
 236                 buffers[0] = sharedHeaderBuffers.get();
 237                 buffers[1] = Shared.wrap(message.bytes);
 238                 final boolean fin = true;
 239                 final Opcode opcode = PING;
 240                 final int payloadLen = message.bytes.remaining();
 241                 final int mask = nextMask();
 242 
 243                 int oldPos = message.bytes.position();
 244                 masker.mask(mask).applyMask(message.bytes, message.bytes);
 245                 message.bytes.position(oldPos);
 246 
 247                 headerWriter.fin(fin).opcode(opcode).payloadLen(payloadLen).mask(mask)
 248                         .write(buffers[0]);
 249                 buffers[0].flip();
 250                 send(buffers, cf);
 251                 return true;
 252             }
 253 
 254             @Override
 255             public Boolean visit(Pong message, CompletableFuture<Void> cf) {
 256                 Shared<ByteBuffer>[] buffers = createArray(2);
 257                 buffers[0] = sharedHeaderBuffers.get();
 258                 buffers[1] = Shared.wrap(message.bytes);
 259                 final boolean fin = true;
 260                 final Opcode opcode = PONG;
 261                 final int payloadLen = message.bytes.remaining();
 262                 final int mask = nextMask();
 263 
 264                 int oldPos = message.bytes.position();
 265                 masker.mask(mask).applyMask(message.bytes, message.bytes);
 266                 message.bytes.position(oldPos);
 267 
 268                 headerWriter.fin(fin).opcode(opcode).payloadLen(payloadLen).mask(mask)
 269                         .write(buffers[0]);
 270                 buffers[0].flip();
 271                 send(buffers, cf);
 272                 return true;
 273             }
 274 
 275             @Override
 276             public Boolean visit(Close message, CompletableFuture<Void> cf) {
 277                 Shared<ByteBuffer>[] buffers = createArray(2);
 278                 buffers[0] = sharedHeaderBuffers.get();
 279                 buffers[1] = sharedPayloadBuffers.get();
 280 
 281                 assert buffers[1].remaining() >= 125 : buffers[1].remaining();
 282 
 283                 if (message.code != null) {
 284                     buffers[1].buffer().putChar((char) message.code.getCode());
 285                     CoderResult r;
 286                     try {
 287                         r = new CharsetToolkit.Encoder().encode(CharBuffer.wrap(message.reason),
 288                                 buffers[1].buffer(), true);
 289                     } catch (CharacterCodingException e) {
 290                         // Shouldn't happen, since the message should've been
 291                         // already checked
 292                         throw new InternalError(e);
 293                     }
 294                     if (r.isOverflow()) {
 295                         // TODO: make sure payload is at least 123 bytes, or use separate BB
 296                         throw new InternalError();
 297                     }
 298                 }
 299                 buffers[1].flip();
 300                 final int payloadLen = buffers[1].remaining();
 301                 final int mask = nextMask();
 302 
 303                 int oldPos = buffers[1].position();
 304                 masker.mask(mask).applyMask(buffers[1].buffer(), buffers[1].buffer());
 305                 buffers[1].position(oldPos);
 306 
 307                 headerWriter.fin(true).opcode(CLOSE).payloadLen(payloadLen).mask(mask)
 308                         .write(buffers[0]);
 309                 buffers[0].flip();
 310                 send(buffers, cf);
 311                 return true;
 312             }
 313 
 314             private Boolean resumeCharSequence(CharSequence characters, boolean isLast,
 315                                                CompletableFuture<Void> cf) {
 316                 if (charactersProcessedFully) {
 317                     // A brand new message has arrived
 318                     cb = CharBuffer.wrap(characters);
 319                     charactersProcessedFully = false;
 320                 }
 321                 // Chop into frames in order to not decode the whole CharBuffer at once
 322                 // TODO: use them more savvy (i.e. slice)
 323                 Shared<ByteBuffer> buf = sharedPayloadBuffers.get();
 324                 CoderResult result;
 325                 try {
 326                     result = encoder.encode(cb, buf.buffer(), isLast);
 327                     if (result.isUnderflow()) {
 328                         // A payload for the last frame has just been encoded
 329                         charactersProcessedFully = true;
 330                     }
 331                     buf.flip();
 332                 } catch (CharacterCodingException e) {
 333                     throw new IllegalArgumentException("Malformed UTF-16 characters", e);
 334                 }
 335                 if (isLast && charactersProcessedFully) {
 336                     // There will be no more messages in this sequence
 337                     encoder.reset();
 338                 }
 339 
 340                 int mask = nextMask();
 341 
 342                 int oldPos = buf.position();
 343                 int oldLim = buf.limit();
 344                 masker.mask(mask).applyMask(buf.buffer(), buf.buffer());
 345                 buf.select(oldPos, oldLim);
 346 
 347                 Shared<ByteBuffer>[] buffers = createArray(2);
 348                 buffers[0] = sharedHeaderBuffers.get();
 349                 buffers[1] = buf;
 350                 // The last frame of the last message MUST have 'fin' bit set
 351                 final boolean fin = isLast && charactersProcessedFully;
 352                 final Opcode opcode = !messageSequence && !frameSequence ? TEXT : CONTINUATION;
 353                 final int payloadLen = buf.remaining();
 354                 headerWriter.fin(fin).opcode(opcode).payloadLen(payloadLen).mask(mask)
 355                         .write(buffers[0]);
 356                 buffers[0].flip();
 357 
 358                 messageSequence = !isLast;
 359 
 360                 // If 'buf' was not enough to hold all the bytes encoded from
 361                 // 'cb', there will be more frames, and we'll return here
 362                 frameSequence = result.isOverflow();
 363                 send(buffers, charactersProcessedFully ? cf : null);
 364                 return charactersProcessedFully;
 365             }
 366 
 367             private int nextMask() {
 368                 return random.nextInt();
 369             }
 370         };
 371     }
 372 
 373     @SuppressWarnings("unchecked")
 374     private static Shared<ByteBuffer>[] createArray(int size) {
 375         return (Shared<ByteBuffer>[]) Array.newInstance(Shared.class, size);
 376     }
 377 
 378     @Override
 379     public void onSubscribe(Flow.Subscription subscription) {
 380         this.subscription = requireNonNull(subscription);
 381         this.subscription.request(1);
 382     }
 383 
 384     @Override
 385     public void onNext(Pair<OutgoingMessage, CompletableFuture<Void>> p) {
 386         logger.log(TRACE, "->MessageSender.onNext(''{0}'')", p);
 387         this.pair = requireNonNull(p);
 388         handler.signal();
 389         logger.log(TRACE, "<-MessageSender.onNext", p);
 390     }
 391 
 392     @Override
 393     public void onError(Throwable throwable) { closed = true; }
 394 
 395     @Override
 396     public void onComplete() { closed = true; }
 397 }
< prev index next >