/* * Copyright (c) 2016, Oracle and/or its affiliates. All rights reserved. * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. * * This code is free software; you can redistribute it and/or modify it * under the terms of the GNU General License version 2 only, as * published by the Free Software Foundation. Oracle designates this * particular file as subject to the "Classpath" exception as provided * by Oracle in the LICENSE file that accompanied this code. * * This code is distributed in the hope that it will be useful, but WITHOUT * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General License * version 2 for more details (a copy is included in the LICENSE file that * accompanied this code). * * You should have received a copy of the GNU General License version * 2 along with this work; if not, write to the Free Software Foundation, * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. * * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA * or visit www.oracle.com if you need additional information or have any * questions. */ package java.net.http; import java.lang.reflect.Array; import java.net.http.OutgoingMessage.Binary; import java.net.http.OutgoingMessage.BinaryText; import java.net.http.OutgoingMessage.CharacterText; import java.net.http.OutgoingMessage.Close; import java.net.http.OutgoingMessage.Ping; import java.net.http.OutgoingMessage.Pong; import java.net.http.OutgoingMessage.StreamedText; import java.net.http.OutgoingMessage.Visitor; import java.net.http.WebSocketFrame.HeaderWriter; import java.net.http.WebSocketFrame.Masker; import java.net.http.WebSocketFrame.Opcode; import java.nio.ByteBuffer; import java.nio.CharBuffer; import java.nio.charset.CharacterCodingException; import java.nio.charset.CoderResult; import java.security.SecureRandom; import java.util.Iterator; import java.util.concurrent.CompletableFuture; import java.util.concurrent.Executor; import java.util.concurrent.Flow; import java.util.concurrent.Flow.Processor; import java.util.concurrent.atomic.AtomicLong; import static java.lang.System.Logger.Level.TRACE; import static java.net.http.Utils.logger; import static java.net.http.WebSocketFrame.MAX_HEADER_SIZE_BYTES; import static java.net.http.WebSocketFrame.Opcode.BINARY; import static java.net.http.WebSocketFrame.Opcode.CLOSE; import static java.net.http.WebSocketFrame.Opcode.CONTINUATION; import static java.net.http.WebSocketFrame.Opcode.PING; import static java.net.http.WebSocketFrame.Opcode.PONG; import static java.net.http.WebSocketFrame.Opcode.TEXT; import static java.util.Objects.requireNonNull; final class MessageSender implements Processor>, Pair[], CompletableFuture>> { private final SecureRandom random = new SecureRandom(); private final CharsetToolkit.Verifier verifier = new CharsetToolkit.Verifier(); private final CharsetToolkit.Encoder encoder = new CharsetToolkit.Encoder(); private final Masker masker = Masker.newInstance(); private final HeaderWriter headerWriter = new HeaderWriter(); private final SharedPool sharedHeaderBuffers = new SharedPool<>(() -> ByteBuffer.allocateDirect(MAX_HEADER_SIZE_BYTES), 16); private final SharedPool sharedPayloadBuffers = new SharedPool<>(() -> ByteBuffer.allocateDirect(32768), 16); // FIXME: ensure different! private final AtomicLong demand = new AtomicLong(); private Flow.Subscriber[], CompletableFuture>> subscriber; private final SignalHandler handler; private final Visitor> processingVisitor = createProcessingVisitor(); private volatile boolean closed; private volatile Pair> pair; private volatile Flow.Subscription subscription; MessageSender(Executor executor) { this.handler = new SignalHandler(requireNonNull(executor), this::react); } private void react() { synchronized (this) { while (pair != null && demand.get() > 0 && !closed) { OutgoingMessage m = pair.first; CompletableFuture cf = pair.second; boolean processed = m.accept(processingVisitor, cf); if (processed) { pair = null; long l = demand.decrementAndGet(); assert l >= 0 : l; subscription.request(1); } } } } @Override public void subscribe(Flow.Subscriber[], CompletableFuture>> subscriber) { this.subscriber = requireNonNull(subscriber); subscriber.onSubscribe( new Flow.Subscription() { @Override public void request(long n) { if (n < 0) { throw new IllegalArgumentException(String.valueOf(n)); } demand.accumulateAndGet(n, (p, i) -> p + i < 0 ? Long.MAX_VALUE : p + i); handler.signal(); } @Override public void cancel() { closed = true; } }); } private void send(Shared[] buffers, CompletableFuture cf) { Pair[], CompletableFuture> p = Pair.pair(buffers, cf); logger.log(TRACE, "Passing ''{0}'' for writing", p); subscriber.onNext(p); } private Visitor> createProcessingVisitor() { // TODO: if we can modify buffers AND passed buffers are direct // XOR and send (mask leak?) // otherwise need to make copies with XORing here return new Visitor<>() { private Iterator streamIterator; private CharSequence previousIteratorItem; private boolean messageSequence; private boolean frameSequence; // In case of a long message that might result in more than a single // frame on the wire private boolean charactersProcessedFully = true; private CharBuffer cb; @Override public Boolean visit(CharacterText message, CompletableFuture attachment) { return resumeCharSequence(message.characters, message.isLast, attachment); } @Override public Boolean visit(BinaryText message, CompletableFuture attachment) { try { int oldPos = message.bytes.position(); verifier.verify(message.bytes, message.isLast); message.bytes.position(oldPos); } catch (CharacterCodingException e) { throw new IllegalArgumentException( "Malformed UTF-8 bytes", e); } if (message.isLast) { verifier.reset(); } Shared[] buffers = createArray(2); buffers[0] = sharedHeaderBuffers.get(); buffers[1] = Shared.wrap(message.bytes); final boolean fin = message.isLast; final Opcode opcode = messageSequence ? CONTINUATION : BINARY; final int payloadLen = message.bytes.remaining(); final int mask = nextMask(); int oldPos = message.bytes.position(); masker.mask(mask).applyMask(message.bytes, message.bytes); message.bytes.position(oldPos); headerWriter.fin(fin).opcode(opcode).payloadLen(payloadLen).mask(mask) .write(buffers[0]); buffers[0].flip(); messageSequence = !message.isLast; return true; } @Override public Boolean visit(StreamedText streamedText, CompletableFuture cf) { throw new IllegalArgumentException("Not yet implemented"); // if (streamIterator == null) { // streamIterator = streamedText.characters.iterator(); // previousIteratorItem = ""; // } // if (!charactersProcessedFully) { // boolean isLast = !streamIterator.hasNext(); // Boolean r = resumeCharSequence(previousIteratorItem, isLast, isLast ? cf : null); // // } // if (!streamIterator.hasNext()) { // return resumeCharSequence(previousIteratorItem, true, cf); // } else { // CharSequence tmp = previousIteratorItem; // previousIteratorItem = streamIterator.next(); // return resumeCharSequence(tmp, false, cf); // } } @Override public Boolean visit(Binary message, CompletableFuture cf) { Shared[] buffers = createArray(2); buffers[0] = sharedHeaderBuffers.get(); buffers[1] = Shared.wrap(message.bytes); final boolean fin = message.isLast; final Opcode opcode = messageSequence ? CONTINUATION : BINARY; final int payloadLen = message.bytes.remaining(); final int mask = nextMask(); int oldPos = message.bytes.position(); masker.mask(mask).applyMask(message.bytes, message.bytes); message.bytes.position(oldPos); headerWriter.fin(fin).opcode(opcode).payloadLen(payloadLen).mask(mask) .write(buffers[0]); buffers[0].flip(); messageSequence = !message.isLast; send(buffers, cf); return true; } @Override public Boolean visit(Ping message, CompletableFuture cf) { Shared[] buffers = createArray(2); buffers[0] = sharedHeaderBuffers.get(); buffers[1] = Shared.wrap(message.bytes); final boolean fin = true; final Opcode opcode = PING; final int payloadLen = message.bytes.remaining(); final int mask = nextMask(); int oldPos = message.bytes.position(); masker.mask(mask).applyMask(message.bytes, message.bytes); message.bytes.position(oldPos); headerWriter.fin(fin).opcode(opcode).payloadLen(payloadLen).mask(mask) .write(buffers[0]); buffers[0].flip(); send(buffers, cf); return true; } @Override public Boolean visit(Pong message, CompletableFuture cf) { Shared[] buffers = createArray(2); buffers[0] = sharedHeaderBuffers.get(); buffers[1] = Shared.wrap(message.bytes); final boolean fin = true; final Opcode opcode = PONG; final int payloadLen = message.bytes.remaining(); final int mask = nextMask(); int oldPos = message.bytes.position(); masker.mask(mask).applyMask(message.bytes, message.bytes); message.bytes.position(oldPos); headerWriter.fin(fin).opcode(opcode).payloadLen(payloadLen).mask(mask) .write(buffers[0]); buffers[0].flip(); send(buffers, cf); return true; } @Override public Boolean visit(Close message, CompletableFuture cf) { Shared[] buffers = createArray(2); buffers[0] = sharedHeaderBuffers.get(); buffers[1] = sharedPayloadBuffers.get(); assert buffers[1].remaining() >= 125 : buffers[1].remaining(); if (message.code != null) { buffers[1].buffer().putChar((char) message.code.getCode()); CoderResult r; try { r = new CharsetToolkit.Encoder().encode(CharBuffer.wrap(message.reason), buffers[1].buffer(), true); } catch (CharacterCodingException e) { // Shouldn't happen, since the message should've been // already checked throw new InternalError(e); } if (r.isOverflow()) { // TODO: make sure payload is at least 123 bytes, or use separate BB throw new InternalError(); } } buffers[1].flip(); final int payloadLen = buffers[1].remaining(); final int mask = nextMask(); int oldPos = buffers[1].position(); masker.mask(mask).applyMask(buffers[1].buffer(), buffers[1].buffer()); buffers[1].position(oldPos); headerWriter.fin(true).opcode(CLOSE).payloadLen(payloadLen).mask(mask) .write(buffers[0]); buffers[0].flip(); send(buffers, cf); return true; } private Boolean resumeCharSequence(CharSequence characters, boolean isLast, CompletableFuture cf) { if (charactersProcessedFully) { // A brand new message has arrived cb = CharBuffer.wrap(characters); charactersProcessedFully = false; } // Chop into frames in order to not decode the whole CharBuffer at once // TODO: use them more savvy (i.e. slice) Shared buf = sharedPayloadBuffers.get(); CoderResult result; try { result = encoder.encode(cb, buf.buffer(), isLast); if (result.isUnderflow()) { // A payload for the last frame has just been encoded charactersProcessedFully = true; } buf.flip(); } catch (CharacterCodingException e) { throw new IllegalArgumentException("Malformed UTF-16 characters", e); } if (isLast && charactersProcessedFully) { // There will be no more messages in this sequence encoder.reset(); } int mask = nextMask(); int oldPos = buf.position(); int oldLim = buf.limit(); masker.mask(mask).applyMask(buf.buffer(), buf.buffer()); buf.select(oldPos, oldLim); Shared[] buffers = createArray(2); buffers[0] = sharedHeaderBuffers.get(); buffers[1] = buf; // The last frame of the last message MUST have 'fin' bit set final boolean fin = isLast && charactersProcessedFully; final Opcode opcode = !messageSequence && !frameSequence ? TEXT : CONTINUATION; final int payloadLen = buf.remaining(); headerWriter.fin(fin).opcode(opcode).payloadLen(payloadLen).mask(mask) .write(buffers[0]); buffers[0].flip(); messageSequence = !isLast; // If 'buf' was not enough to hold all the bytes encoded from // 'cb', there will be more frames, and we'll return here frameSequence = result.isOverflow(); send(buffers, charactersProcessedFully ? cf : null); return charactersProcessedFully; } private int nextMask() { return random.nextInt(); } }; } @SuppressWarnings("unchecked") private static Shared[] createArray(int size) { return (Shared[]) Array.newInstance(Shared.class, size); } @Override public void onSubscribe(Flow.Subscription subscription) { this.subscription = requireNonNull(subscription); this.subscription.request(1); } @Override public void onNext(Pair> p) { logger.log(TRACE, "->MessageSender.onNext(''{0}'')", p); this.pair = requireNonNull(p); handler.signal(); logger.log(TRACE, "<-MessageSender.onNext", p); } @Override public void onError(Throwable throwable) { closed = true; } @Override public void onComplete() { closed = true; } }