/* * Copyright (c) 2016, Oracle and/or its affiliates. All rights reserved. * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. * * This code is free software; you can redistribute it and/or modify it * under the terms of the GNU General License version 2 only, as * published by the Free Software Foundation. Oracle designates this * particular file as subject to the "Classpath" exception as provided * by Oracle in the LICENSE file that accompanied this code. * * This code is distributed in the hope that it will be useful, but WITHOUT * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General License * version 2 for more details (a copy is included in the LICENSE file that * accompanied this code). * * You should have received a copy of the GNU General License version * 2 along with this work; if not, write to the Free Software Foundation, * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. * * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA * or visit www.oracle.com if you need additional information or have any * questions. */ package java.net.http; import java.nio.ByteBuffer; import static java.lang.String.format; import static java.lang.System.Logger.Level.TRACE; import static java.net.http.Utils.dump; import static java.net.http.Utils.logger; import static java.net.http.WebSocketFrame.Opcode.ofCode; // // Uninstantiable collection of tools for reading/writing/masking etc. // WebSocket frames // final class WebSocketFrame { // TODO: rename WSFrame static final int MAX_HEADER_SIZE_BYTES = 2 + 8 + 4; enum Opcode { CONTINUATION (0x0), TEXT (0x1), BINARY (0x2), NON_CONTROL_0x3(0x3), NON_CONTROL_0x4(0x4), NON_CONTROL_0x5(0x5), NON_CONTROL_0x6(0x6), NON_CONTROL_0x7(0x7), CLOSE (0x8), PING (0x9), PONG (0xA), CONTROL_0xB (0xB), CONTROL_0xC (0xC), CONTROL_0xD (0xD), CONTROL_0xE (0xE), CONTROL_0xF (0xF); private static final Opcode[] opcodes; static { Opcode[] values = values(); opcodes = new Opcode[values.length]; for (Opcode c : values) { assert opcodes[c.code] == null : Utils.dump(c, c.code, opcodes[c.code]); opcodes[c.code] = c; } } private final byte code; private final char shiftedCode; private final String description; Opcode(int code) { this.code = (byte) code; this.shiftedCode = (char) (code << 8); this.description = format("%x (%s)", code, name()); } boolean isControl() { return (code & 0x8) != 0; } static Opcode ofCode(int code) { return opcodes[code & 0xF]; } @Override public String toString() { return description; } } // Stateful tool to mask payload data incrementally abstract static class Masker { static Masker newInstance() { // TODO: fix module dependencies // return sun.misc.Unsafe.ADDRESS_SIZE == 8 ? new Masker64() : new Masker32(); return true ? new Masker64() : new Masker32(); } private Masker() { } private final int[] maskingKey = new int[4]; private int maskPos; private final ByteBuffer acc = ByteBuffer.allocate(4); Masker mask(int mask) { maskPos = 0; acc.clear().putInt(mask).flip(); for (int i = 0; i < 4; i++) { maskingKey[i] = acc.get(i); } return this; } // src.remaining() <= dst.remaining() // 'src' and 'dst' can be the same ByteBuffer Masker applyMask(ByteBuffer src, ByteBuffer dst) { if (src.remaining() > dst.remaining()) { throw new IllegalArgumentException(dump(src, dst)); } begin(src, dst); loop(src, dst); end(src, dst); return this; } // Applying the remaining of the mask (strictly not more than 3 bytes) // byte-wise private void begin(ByteBuffer src, ByteBuffer dst) { if (maskPos > 0) { for (int i = src.position(), j = dst.position(); maskPos < 4 && i + 1 <= src.limit() && j + 1 <= dst.limit(); i++, j++, maskPos++) { dst.put(j, (byte) (src.get(i) ^ maskingKey[maskPos])); dst.position(j + 1); src.position(i + 1); } maskPos &= 3; } } protected abstract void loop(ByteBuffer out, ByteBuffer in); // Applying the mask to the remaining bytes byte-wise (don't do any // assumptions on how many, hopefully not more than 7 for 64bit arch) private void end(ByteBuffer src, ByteBuffer dst) { for (int i = src.position(), j = dst.position(); i + 1 <= src.limit() && j + 1 <= dst.limit(); i++, j++, maskPos = (maskPos + 1) & 3) { // maskPos cycle through 0..3 dst.put(j, (byte) (src.get(i) ^ maskingKey[maskPos])); src.position(i + 1); dst.position(j + 1); } } Masker reset() { maskPos = 0; return this; } private static final class Masker64 extends Masker { private long mask; private final ByteBuffer acc2 = ByteBuffer.allocate(8); @Override Masker mask(int mask) { this.mask = acc2.clear().putInt(mask).putInt(mask).flip().getLong(); return super.mask(mask); } // One long at a time (yum! yum! yum!) @Override protected void loop(ByteBuffer src, ByteBuffer dst) { int i = src.position(); int j = dst.position(); for (; i + 8 <= src.limit() && j + 8 <= dst.limit(); i += 8, j += 8) { dst.putLong(j, (src.getLong(i) ^ mask)); } if (i > src.limit()) { src.position(i - 8); } else { src.position(i); } if (j > dst.limit()) { dst.position(j - 8); } else { dst.position(j); } } } private static final class Masker32 extends Masker { private int mask; @Override Masker mask(int mask) { this.mask = mask; return super.mask(mask); } // One int at a time @Override protected void loop(ByteBuffer src, ByteBuffer dst) { int i = src.position(); int j = dst.position(); for (; i + 4 <= src.limit() && j + 4 <= dst.limit(); i += 4, j += 4) { dst.putInt(j, (src.getInt(i) ^ mask)); } if (i > src.limit()) { src.position(i - 4); } else { src.position(i); } if (j > dst.limit()) { dst.position(j - 4); } else { dst.position(j); } } } } // The writer does not enforce any protocol-level rules, it simply writes a // header structure to a buffer. // The order of calls should not affect the outcome. Write, of course, // should be called last. static final class HeaderWriter { private char firstChar; private long payloadLen; private int maskingKey; private boolean mask; HeaderWriter fin(boolean value) { if (logger.isLoggable(TRACE)) { logger.log(TRACE, "Writing fin: {0}", value); } if (value) { firstChar |= 0b10000000_00000000; } else { firstChar &= ~0b10000000_00000000; } return this; } HeaderWriter rsv1(boolean value) { if (logger.isLoggable(TRACE)) { logger.log(TRACE, "Writing rsv1: {0}", value); } if (value) { firstChar |= 0b01000000_00000000; } else { firstChar &= ~0b01000000_00000000; } return this; } HeaderWriter rsv2(boolean value) { if (logger.isLoggable(TRACE)) { logger.log(TRACE, "Writing rsv2: {0}", value); } if (value) { firstChar |= 0b00100000_00000000; } else { firstChar &= ~0b00100000_00000000; } return this; } HeaderWriter rsv3(boolean value) { if (logger.isLoggable(TRACE)) { logger.log(TRACE, "Writing rsv3: {0}", value); } if (value) { firstChar |= 0b00010000_00000000; } else { firstChar &= ~0b00010000_00000000; } return this; } HeaderWriter opcode(Opcode value) { logger.log(TRACE, "Writing opcode: {0}", value); firstChar = (char) ((firstChar & 0xF0FF) | value.shiftedCode); return this; } HeaderWriter payloadLen(long value) { if (logger.isLoggable(TRACE)) { logger.log(TRACE, "Writing payloadLen: {0}", value); } payloadLen = value; firstChar &= 0b11111111_10000000; // Clear previous payload length leftovers if (payloadLen < 126) { firstChar |= payloadLen; } else if (payloadLen < 65535) { firstChar |= 126; } else { firstChar |= 127; } return this; } HeaderWriter mask(int value) { if (logger.isLoggable(TRACE)) { logger.log(TRACE, "Writing mask: 0x{0}", Integer.toHexString(value)); } firstChar |= 0b00000000_10000000; maskingKey = value; mask = true; return this; } HeaderWriter noMask() { logger.log(TRACE, "Writing noMask"); firstChar &= ~0b00000000_10000000; mask = false; return this; } // The method does not flip the buffer before returning control; the // buffer is expected to have at least MAX_HEADER_SIZE_BYTES remaining // bytes. void write(Shared buffer) { logger.log(TRACE, "Writing into buffer: ''{0}''", buffer); ByteBuffer b = buffer.buffer(); b.putChar(firstChar); if (payloadLen >= 126) { if (payloadLen < 65535) { b.putChar((char) payloadLen); } else { b.putLong(payloadLen); } } if (mask) { b.putInt(maskingKey); } } } // Guaranteed to be called in the following order (by the Reader): // // fin rsv1 rsv2 rsv3 opcode mask payloadLength maskingKey? payloadData+ endFrame interface Listener { void fin(boolean value); void rsv1(boolean value); void rsv2(boolean value); void rsv3(boolean value); void opcode(Opcode value); void mask(boolean value); void payloadLen(long value); void maskingKey(int value); // 1. May not yield a complete payload in the first invocation, i.e. // // data.remaining() < payloadLen // // However the sum of all (limit - position) passed to all invocations // of this method will be equal to 'payloadLen', reported in // // void payloadLen(long value) // // 2. No unmasking (if any) is done. void payloadData(Shared data, boolean isLast); void endFrame(); // TODO: remove (payloadData(isLast=true)) should be enough } // 1. The reader doesn't check any protocol-level rules, except for those // related to the frame structure. It simply invokes listener's methods // corresponding to structural elements it encounters. // // 2. Reads eagerly. In other words, delivers as much as possible payload in // a single invocation of payloadData. static final class Reader { private static final int AWAITING_FIRST_BYTE = 1; private static final int AWAITING_SECOND_BYTE = 2; private static final int READING_16_LENGTH = 4; private static final int READING_64_LENGTH = 8; private static final int READING_MASK = 16; private static final int READING_PAYLOAD = 32; // A private buffer used to simplify multi-byte integers reading private final ByteBuffer accumulator = ByteBuffer.allocate(8); private int state = AWAITING_FIRST_BYTE; private boolean mask; private long payloadLength; // 1. Reads frames byte-wise, at most one frame at a time. Buffer's // position is updated to reflect the progress. // // 2. WebSocketProtocolException is thrown in case there are errors in // the low-level structure of a frame. void readFrame(Shared shared, Listener listener) { ByteBuffer input = shared.buffer(); loop: while (true) { byte b; switch (state) { case AWAITING_FIRST_BYTE: if (!input.hasRemaining()) { break loop; } b = input.get(); listener.fin( (b & 0b10000000) != 0); listener.rsv1((b & 0b01000000) != 0); listener.rsv2((b & 0b00100000) != 0); listener.rsv3((b & 0b00010000) != 0); listener.opcode(ofCode(b)); state = AWAITING_SECOND_BYTE; continue loop; case AWAITING_SECOND_BYTE: if (!input.hasRemaining()) { break loop; } b = input.get(); listener.mask(mask = (b & 0b10000000) != 0); byte p1 = (byte) (b & 0b01111111); if (p1 < 126) { assert p1 >= 0 : p1; listener.payloadLen(payloadLength = p1); state = mask ? READING_MASK : READING_PAYLOAD; } else if (p1 < 127) { state = READING_16_LENGTH; } else { state = READING_64_LENGTH; } continue loop; case READING_16_LENGTH: if (!input.hasRemaining()) { break loop; } b = input.get(); if (accumulator.put(b).position() < 2) { continue loop; } payloadLength = accumulator.flip().getChar(); if (payloadLength < 126) { throw notMinimalEncoding(payloadLength, 2); } listener.payloadLen(payloadLength); accumulator.clear(); state = mask ? READING_MASK : READING_PAYLOAD; continue loop; case READING_64_LENGTH: if (!input.hasRemaining()) { break loop; } b = input.get(); if (accumulator.put(b).position() < 8) { continue loop; } payloadLength = accumulator.flip().getLong(); if (payloadLength < 0) { throw negativePayload(payloadLength); } else if (payloadLength < 65535) { throw notMinimalEncoding(payloadLength, 8); } listener.payloadLen(payloadLength); accumulator.clear(); state = mask ? READING_MASK : READING_PAYLOAD; continue loop; case READING_MASK: if (!input.hasRemaining()) { break loop; } b = input.get(); if (accumulator.put(b).position() != 4) { continue loop; } listener.maskingKey(accumulator.flip().getInt()); accumulator.clear(); state = READING_PAYLOAD; continue loop; case READING_PAYLOAD: // This state does not require any bytes to be available // in the input buffer in order to proceed boolean fullyRead; int limit; if (payloadLength <= input.remaining()) { limit = input.position() + (int) payloadLength; payloadLength = 0; fullyRead = true; } else { limit = input.limit(); payloadLength -= input.remaining(); fullyRead = false; } // FIXME: consider a case where payloadLen != 0, // but input.remaining() == 0 // // There shouldn't be an invocation of payloadData with an empty buffer, as it // would be an artifact of reading listener.payloadData(shared.share(input.position(), limit), fullyRead); // Update the position manually, since reading the // payload doesn't advance buffer's position input.position(limit); if (fullyRead) { listener.endFrame(); state = AWAITING_FIRST_BYTE; } break loop; default: throw new InternalError(String.valueOf(state)); } } } Reader reset() { state = AWAITING_FIRST_BYTE; accumulator.clear(); return this; } private static WebSocketProtocolException negativePayload(long payloadLength) { return new WebSocketProtocolException ("5.2.", format("Negative 64-bit payload length %s", payloadLength)); } private static WebSocketProtocolException notMinimalEncoding(long payloadLength, int numBytes) { return new WebSocketProtocolException ("5.2.", format("Payload length (%s) is not encoded with minimal number (%s) of bytes", payloadLength, numBytes)); } } private WebSocketFrame() { } }