--- /dev/null 2016-04-22 22:42:07.067228853 +0100 +++ new/src/java.httpclient/share/classes/java/net/http/MessageSender.java 2016-04-25 23:10:56.261374316 +0100 @@ -0,0 +1,397 @@ +/* + * Copyright (c) 2016, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General License version 2 only, as + * published by the Free Software Foundation. Oracle designates this + * particular file as subject to the "Classpath" exception as provided + * by Oracle in the LICENSE file that accompanied this code. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ +package java.net.http; + +import java.lang.reflect.Array; +import java.net.http.OutgoingMessage.Binary; +import java.net.http.OutgoingMessage.BinaryText; +import java.net.http.OutgoingMessage.CharacterText; +import java.net.http.OutgoingMessage.Close; +import java.net.http.OutgoingMessage.Ping; +import java.net.http.OutgoingMessage.Pong; +import java.net.http.OutgoingMessage.StreamedText; +import java.net.http.OutgoingMessage.Visitor; +import java.net.http.WebSocketFrame.HeaderWriter; +import java.net.http.WebSocketFrame.Masker; +import java.net.http.WebSocketFrame.Opcode; +import java.nio.ByteBuffer; +import java.nio.CharBuffer; +import java.nio.charset.CharacterCodingException; +import java.nio.charset.CoderResult; +import java.security.SecureRandom; +import java.util.Iterator; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.Executor; +import java.util.concurrent.Flow; +import java.util.concurrent.Flow.Processor; +import java.util.concurrent.atomic.AtomicLong; + +import static java.lang.System.Logger.Level.TRACE; +import static java.net.http.Utils.logger; +import static java.net.http.WebSocketFrame.MAX_HEADER_SIZE_BYTES; +import static java.net.http.WebSocketFrame.Opcode.BINARY; +import static java.net.http.WebSocketFrame.Opcode.CLOSE; +import static java.net.http.WebSocketFrame.Opcode.CONTINUATION; +import static java.net.http.WebSocketFrame.Opcode.PING; +import static java.net.http.WebSocketFrame.Opcode.PONG; +import static java.net.http.WebSocketFrame.Opcode.TEXT; +import static java.util.Objects.requireNonNull; + +final class MessageSender implements Processor>, + Pair[], CompletableFuture>> { + + private final SecureRandom random = new SecureRandom(); + private final CharsetToolkit.Verifier verifier = new CharsetToolkit.Verifier(); + private final CharsetToolkit.Encoder encoder = new CharsetToolkit.Encoder(); + private final Masker masker = Masker.newInstance(); + private final HeaderWriter headerWriter = new HeaderWriter(); + private final SharedPool sharedHeaderBuffers + = new SharedPool<>(() -> ByteBuffer.allocateDirect(MAX_HEADER_SIZE_BYTES), 16); + private final SharedPool sharedPayloadBuffers + = new SharedPool<>(() -> ByteBuffer.allocateDirect(32768), 16); // FIXME: ensure different! + private final AtomicLong demand = new AtomicLong(); + + private Flow.Subscriber[], CompletableFuture>> subscriber; + private final SignalHandler handler; + + private final Visitor> processingVisitor + = createProcessingVisitor(); + + private volatile boolean closed; + private volatile Pair> pair; + private volatile Flow.Subscription subscription; + + MessageSender(Executor executor) { + this.handler = new SignalHandler(requireNonNull(executor), this::react); + } + + private void react() { + synchronized (this) { + while (pair != null && demand.get() > 0 && !closed) { + OutgoingMessage m = pair.first; + CompletableFuture cf = pair.second; + boolean processed = m.accept(processingVisitor, cf); + if (processed) { + pair = null; + long l = demand.decrementAndGet(); + assert l >= 0 : l; + subscription.request(1); + } + } + } + } + + @Override + public void subscribe(Flow.Subscriber[], CompletableFuture>> subscriber) { + this.subscriber = requireNonNull(subscriber); + subscriber.onSubscribe( + new Flow.Subscription() { + @Override + public void request(long n) { + if (n < 0) { + throw new IllegalArgumentException(String.valueOf(n)); + } + demand.accumulateAndGet(n, (p, i) -> p + i < 0 ? Long.MAX_VALUE : p + i); + handler.signal(); + } + + @Override + public void cancel() { + closed = true; + } + }); + } + + private void send(Shared[] buffers, CompletableFuture cf) { + Pair[], CompletableFuture> p = Pair.pair(buffers, cf); + logger.log(TRACE, "Passing ''{0}'' for writing", p); + subscriber.onNext(p); + } + + private Visitor> createProcessingVisitor() { + + // TODO: if we can modify buffers AND passed buffers are direct + // XOR and send (mask leak?) + // otherwise need to make copies with XORing here + + return new Visitor<>() { + + private Iterator streamIterator; + private CharSequence previousIteratorItem; + + private boolean messageSequence; + private boolean frameSequence; + + // In case of a long message that might result in more than a single + // frame on the wire + private boolean charactersProcessedFully = true; + private CharBuffer cb; + + @Override + public Boolean visit(CharacterText message, CompletableFuture attachment) { + return resumeCharSequence(message.characters, message.isLast, attachment); + } + + @Override + public Boolean visit(BinaryText message, CompletableFuture attachment) { + try { + int oldPos = message.bytes.position(); + verifier.verify(message.bytes, message.isLast); + message.bytes.position(oldPos); + } catch (CharacterCodingException e) { + throw new IllegalArgumentException( + "Malformed UTF-8 bytes", e); + } + if (message.isLast) { + verifier.reset(); + } + Shared[] buffers = createArray(2); + buffers[0] = sharedHeaderBuffers.get(); + buffers[1] = Shared.wrap(message.bytes); + final boolean fin = message.isLast; + final Opcode opcode = messageSequence ? CONTINUATION : BINARY; + final int payloadLen = message.bytes.remaining(); + final int mask = nextMask(); + + + int oldPos = message.bytes.position(); + masker.mask(mask).applyMask(message.bytes, message.bytes); + message.bytes.position(oldPos); + + headerWriter.fin(fin).opcode(opcode).payloadLen(payloadLen).mask(mask) + .write(buffers[0]); + buffers[0].flip(); + messageSequence = !message.isLast; + return true; + } + + @Override + public Boolean visit(StreamedText streamedText, CompletableFuture cf) { + throw new IllegalArgumentException("Not yet implemented"); +// if (streamIterator == null) { +// streamIterator = streamedText.characters.iterator(); +// previousIteratorItem = ""; +// } +// if (!charactersProcessedFully) { +// boolean isLast = !streamIterator.hasNext(); +// Boolean r = resumeCharSequence(previousIteratorItem, isLast, isLast ? cf : null); +// +// } +// if (!streamIterator.hasNext()) { +// return resumeCharSequence(previousIteratorItem, true, cf); +// } else { +// CharSequence tmp = previousIteratorItem; +// previousIteratorItem = streamIterator.next(); +// return resumeCharSequence(tmp, false, cf); +// } + } + + @Override + public Boolean visit(Binary message, CompletableFuture cf) { + Shared[] buffers = createArray(2); + buffers[0] = sharedHeaderBuffers.get(); + buffers[1] = Shared.wrap(message.bytes); + final boolean fin = message.isLast; + final Opcode opcode = messageSequence ? CONTINUATION : BINARY; + final int payloadLen = message.bytes.remaining(); + final int mask = nextMask(); + + int oldPos = message.bytes.position(); + masker.mask(mask).applyMask(message.bytes, message.bytes); + message.bytes.position(oldPos); + + headerWriter.fin(fin).opcode(opcode).payloadLen(payloadLen).mask(mask) + .write(buffers[0]); + buffers[0].flip(); + messageSequence = !message.isLast; + send(buffers, cf); + return true; + } + + @Override + public Boolean visit(Ping message, CompletableFuture cf) { + Shared[] buffers = createArray(2); + buffers[0] = sharedHeaderBuffers.get(); + buffers[1] = Shared.wrap(message.bytes); + final boolean fin = true; + final Opcode opcode = PING; + final int payloadLen = message.bytes.remaining(); + final int mask = nextMask(); + + int oldPos = message.bytes.position(); + masker.mask(mask).applyMask(message.bytes, message.bytes); + message.bytes.position(oldPos); + + headerWriter.fin(fin).opcode(opcode).payloadLen(payloadLen).mask(mask) + .write(buffers[0]); + buffers[0].flip(); + send(buffers, cf); + return true; + } + + @Override + public Boolean visit(Pong message, CompletableFuture cf) { + Shared[] buffers = createArray(2); + buffers[0] = sharedHeaderBuffers.get(); + buffers[1] = Shared.wrap(message.bytes); + final boolean fin = true; + final Opcode opcode = PONG; + final int payloadLen = message.bytes.remaining(); + final int mask = nextMask(); + + int oldPos = message.bytes.position(); + masker.mask(mask).applyMask(message.bytes, message.bytes); + message.bytes.position(oldPos); + + headerWriter.fin(fin).opcode(opcode).payloadLen(payloadLen).mask(mask) + .write(buffers[0]); + buffers[0].flip(); + send(buffers, cf); + return true; + } + + @Override + public Boolean visit(Close message, CompletableFuture cf) { + Shared[] buffers = createArray(2); + buffers[0] = sharedHeaderBuffers.get(); + buffers[1] = sharedPayloadBuffers.get(); + + assert buffers[1].remaining() >= 125 : buffers[1].remaining(); + + if (message.code != null) { + buffers[1].buffer().putChar((char) message.code.getCode()); + CoderResult r; + try { + r = new CharsetToolkit.Encoder().encode(CharBuffer.wrap(message.reason), + buffers[1].buffer(), true); + } catch (CharacterCodingException e) { + // Shouldn't happen, since the message should've been + // already checked + throw new InternalError(e); + } + if (r.isOverflow()) { + // TODO: make sure payload is at least 123 bytes, or use separate BB + throw new InternalError(); + } + } + buffers[1].flip(); + final int payloadLen = buffers[1].remaining(); + final int mask = nextMask(); + + int oldPos = buffers[1].position(); + masker.mask(mask).applyMask(buffers[1].buffer(), buffers[1].buffer()); + buffers[1].position(oldPos); + + headerWriter.fin(true).opcode(CLOSE).payloadLen(payloadLen).mask(mask) + .write(buffers[0]); + buffers[0].flip(); + send(buffers, cf); + return true; + } + + private Boolean resumeCharSequence(CharSequence characters, boolean isLast, + CompletableFuture cf) { + if (charactersProcessedFully) { + // A brand new message has arrived + cb = CharBuffer.wrap(characters); + charactersProcessedFully = false; + } + // Chop into frames in order to not decode the whole CharBuffer at once + // TODO: use them more savvy (i.e. slice) + Shared buf = sharedPayloadBuffers.get(); + CoderResult result; + try { + result = encoder.encode(cb, buf.buffer(), isLast); + if (result.isUnderflow()) { + // A payload for the last frame has just been encoded + charactersProcessedFully = true; + } + buf.flip(); + } catch (CharacterCodingException e) { + throw new IllegalArgumentException("Malformed UTF-16 characters", e); + } + if (isLast && charactersProcessedFully) { + // There will be no more messages in this sequence + encoder.reset(); + } + + int mask = nextMask(); + + int oldPos = buf.position(); + int oldLim = buf.limit(); + masker.mask(mask).applyMask(buf.buffer(), buf.buffer()); + buf.select(oldPos, oldLim); + + Shared[] buffers = createArray(2); + buffers[0] = sharedHeaderBuffers.get(); + buffers[1] = buf; + // The last frame of the last message MUST have 'fin' bit set + final boolean fin = isLast && charactersProcessedFully; + final Opcode opcode = !messageSequence && !frameSequence ? TEXT : CONTINUATION; + final int payloadLen = buf.remaining(); + headerWriter.fin(fin).opcode(opcode).payloadLen(payloadLen).mask(mask) + .write(buffers[0]); + buffers[0].flip(); + + messageSequence = !isLast; + + // If 'buf' was not enough to hold all the bytes encoded from + // 'cb', there will be more frames, and we'll return here + frameSequence = result.isOverflow(); + send(buffers, charactersProcessedFully ? cf : null); + return charactersProcessedFully; + } + + private int nextMask() { + return random.nextInt(); + } + }; + } + + @SuppressWarnings("unchecked") + private static Shared[] createArray(int size) { + return (Shared[]) Array.newInstance(Shared.class, size); + } + + @Override + public void onSubscribe(Flow.Subscription subscription) { + this.subscription = requireNonNull(subscription); + this.subscription.request(1); + } + + @Override + public void onNext(Pair> p) { + logger.log(TRACE, "->MessageSender.onNext(''{0}'')", p); + this.pair = requireNonNull(p); + handler.signal(); + logger.log(TRACE, "<-MessageSender.onNext", p); + } + + @Override + public void onError(Throwable throwable) { closed = true; } + + @Override + public void onComplete() { closed = true; } +}