1 /*
   2  * Copyright (c) 2015, 2016, Oracle and/or its affiliates. All rights reserved.
   3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
   4  *
   5  * This code is free software; you can redistribute it and/or modify it
   6  * under the terms of the GNU General Public License version 2 only, as
   7  * published by the Free Software Foundation.  Oracle designates this
   8  * particular file as subject to the "Classpath" exception as provided
   9  * by Oracle in the LICENSE file that accompanied this code.
  10  *
  11  * This code is distributed in the hope that it will be useful, but WITHOUT
  12  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
  13  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
  14  * version 2 for more details (a copy is included in the LICENSE file that
  15  * accompanied this code).
  16  *
  17  * You should have received a copy of the GNU General Public License version
  18  * 2 along with this work; if not, write to the Free Software Foundation,
  19  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
  20  *
  21  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
  22  * or visit www.oracle.com if you need additional information or have any
  23  * questions.
  24  */
  25 
  26 package jdk.incubator.http.internal.frame;
  27 
  28 import jdk.incubator.http.internal.common.ByteBufferReference;
  29 import jdk.incubator.http.internal.common.Log;
  30 import jdk.incubator.http.internal.common.Utils;
  31 
  32 import java.io.IOException;
  33 import java.nio.ByteBuffer;
  34 import java.util.ArrayDeque;
  35 import java.util.ArrayList;
  36 import java.util.List;
  37 
  38 /**
  39  * Frames Decoder
  40  * <p>
  41  * collect buffers until frame decoding is possible,
  42  * all decoded frames are passed to the FrameProcessor callback in order of decoding.
  43  *
  44  * It's a stateful class due to the fact that FramesDecoder stores buffers inside.
  45  * Should be allocated only the single instance per connection.
  46  */
  47 public class FramesDecoder {
  48 
  49 
  50 
  51     @FunctionalInterface
  52     public interface FrameProcessor {
  53         void processFrame(Http2Frame frame) throws IOException;
  54     }
  55 
  56     private final FrameProcessor frameProcessor;
  57     private final int maxFrameSize;
  58 
  59     private ByteBufferReference currentBuffer; // current buffer either null or hasRemaining
  60 
  61     private final java.util.Queue<ByteBufferReference> tailBuffers = new ArrayDeque<>();
  62     private int tailSize = 0;
  63 
  64     private boolean slicedToDataFrame = false;
  65 
  66     private final List<ByteBufferReference> prepareToRelease = new ArrayList<>();
  67 
  68     // if true  - Frame Header was parsed (9 bytes consumed) and subsequent fields have meaning
  69     // otherwise - stopped at frames boundary
  70     private boolean frameHeaderParsed = false;
  71     private int frameLength;
  72     private int frameType;
  73     private int frameFlags;
  74     private int frameStreamid;
  75 
  76     /**
  77      * Creates Frame Decoder
  78      *
  79      * @param frameProcessor - callback for decoded frames
  80      */
  81     public FramesDecoder(FrameProcessor frameProcessor) {
  82         this(frameProcessor, 16 * 1024);
  83     }
  84 
  85     /**
  86      * Creates Frame Decoder
  87      * @param frameProcessor - callback for decoded frames
  88      * @param maxFrameSize - maxFrameSize accepted by this decoder
  89      */
  90     public FramesDecoder(FrameProcessor frameProcessor, int maxFrameSize) {
  91         this.frameProcessor = frameProcessor;
  92         this.maxFrameSize = Math.min(Math.max(16 * 1024, maxFrameSize), 16 * 1024 * 1024 - 1);
  93     }
  94 
  95     /**
  96      * put next buffer into queue,
  97      * if frame decoding is possible - decode all buffers and invoke FrameProcessor
  98      *
  99      * @param buffer
 100      * @throws IOException
 101      */
 102     public void decode(ByteBufferReference buffer) throws IOException {
 103         int remaining = buffer.get().remaining();
 104         if (remaining > 0) {
 105             if (currentBuffer == null) {
 106                 currentBuffer = buffer;
 107             } else {
 108                 tailBuffers.add(buffer);
 109                 tailSize += remaining;
 110             }
 111         }
 112         Http2Frame frame;
 113         while ((frame = nextFrame()) != null) {
 114             frameProcessor.processFrame(frame);
 115             frameProcessed();
 116         }
 117     }
 118 
 119     private Http2Frame nextFrame() throws IOException {
 120         while (true) {
 121             if (currentBuffer == null) {
 122                 return null; // no data at all
 123             }
 124             if (!frameHeaderParsed) {
 125                 if (currentBuffer.get().remaining() + tailSize >= Http2Frame.FRAME_HEADER_SIZE) {
 126                     parseFrameHeader();
 127                     if (frameLength > maxFrameSize) {
 128                         // connection error
 129                         return new MalformedFrame(ErrorFrame.FRAME_SIZE_ERROR,
 130                                 "Frame type("+frameType+") " +"length("+frameLength+") exceeds MAX_FRAME_SIZE("+ maxFrameSize+")");
 131                     }
 132                     frameHeaderParsed = true;
 133                 } else {
 134                     return null; // no data for frame header
 135                 }
 136             }
 137             if ((frameLength == 0) ||
 138                     (currentBuffer != null && currentBuffer.get().remaining() + tailSize >= frameLength)) {
 139                 Http2Frame frame = parseFrameBody();
 140                 frameHeaderParsed = false;
 141                 // frame == null means we have to skip this frame and try parse next
 142                 if (frame != null) {
 143                     return frame;
 144                 }
 145             } else {
 146                 return null;  // no data for the whole frame header
 147             }
 148         }
 149     }
 150 
 151     private void frameProcessed() {
 152         prepareToRelease.forEach(ByteBufferReference::clear);
 153         prepareToRelease.clear();
 154     }
 155 
 156     private void parseFrameHeader() throws IOException {
 157         int x = getInt();
 158         this.frameLength = x >> 8;
 159         this.frameType = x & 0xff;
 160         this.frameFlags = getByte();
 161         this.frameStreamid = getInt() & 0x7fffffff;
 162         // R: A reserved 1-bit field.  The semantics of this bit are undefined,
 163         // MUST be ignored when receiving.
 164     }
 165 
 166     // move next buffer from tailBuffers to currentBuffer if required
 167     private void nextBuffer() {
 168         if (!currentBuffer.get().hasRemaining()) {
 169             if (!slicedToDataFrame) {
 170                 prepareToRelease.add(currentBuffer);
 171             }
 172             slicedToDataFrame = false;
 173             currentBuffer = tailBuffers.poll();
 174             if (currentBuffer != null) {
 175                 tailSize -= currentBuffer.get().remaining();
 176             }
 177         }
 178     }
 179 
 180     public int getByte() {
 181         ByteBuffer buf = currentBuffer.get();
 182         int res = buf.get() & 0xff;
 183         nextBuffer();
 184         return res;
 185     }
 186 
 187     public int getShort() {
 188         ByteBuffer buf = currentBuffer.get();
 189         if (buf.remaining() >= 2) {
 190             int res = buf.getShort() & 0xffff;
 191             nextBuffer();
 192             return res;
 193         }
 194         int val = getByte();
 195         val = (val << 8) + getByte();
 196         return val;
 197     }
 198 
 199     public int getInt() {
 200         ByteBuffer buf = currentBuffer.get();
 201         if (buf.remaining() >= 4) {
 202             int res = buf.getInt();
 203             nextBuffer();
 204             return res;
 205         }
 206         int val = getByte();
 207         val = (val << 8) + getByte();
 208         val = (val << 8) + getByte();
 209         val = (val << 8) + getByte();
 210         return val;
 211 
 212     }
 213 
 214     public byte[] getBytes(int n) {
 215         byte[] bytes = new byte[n];
 216         int offset = 0;
 217         while (n > 0) {
 218             ByteBuffer buf = currentBuffer.get();
 219             int length = Math.min(n, buf.remaining());
 220             buf.get(bytes, offset, length);
 221             offset += length;
 222             n -= length;
 223             nextBuffer();
 224         }
 225         return bytes;
 226 
 227     }
 228 
 229     private ByteBufferReference[] getBuffers(boolean isDataFrame, int bytecount) {
 230         List<ByteBufferReference> res = new ArrayList<>();
 231         while (bytecount > 0) {
 232             ByteBuffer buf = currentBuffer.get();
 233             int remaining = buf.remaining();
 234             int extract = Math.min(remaining, bytecount);
 235             ByteBuffer extractedBuf;
 236             if (isDataFrame) {
 237                 extractedBuf = Utils.slice(buf, extract);
 238                 slicedToDataFrame = true;
 239             } else {
 240                 // Header frames here
 241                 // HPACK decoding should performed under lock and immediately after frame decoding.
 242                 // in that case it is safe to release original buffer,
 243                 // because of sliced buffer has a very short life
 244                 extractedBuf = Utils.slice(buf, extract);
 245             }
 246             res.add(ByteBufferReference.of(extractedBuf));
 247             bytecount -= extract;
 248             nextBuffer();
 249         }
 250         return res.toArray(new ByteBufferReference[0]);
 251     }
 252 
 253     public void skipBytes(int bytecount) {
 254         while (bytecount > 0) {
 255             ByteBuffer buf = currentBuffer.get();
 256             int remaining = buf.remaining();
 257             int extract = Math.min(remaining, bytecount);
 258             buf.position(buf.position() + extract);
 259             bytecount -= remaining;
 260             nextBuffer();
 261         }
 262     }
 263 
 264     private Http2Frame parseFrameBody() throws IOException {
 265         assert frameHeaderParsed;
 266         switch (frameType) {
 267             case DataFrame.TYPE:
 268                 return parseDataFrame(frameLength, frameStreamid, frameFlags);
 269             case HeadersFrame.TYPE:
 270                 return parseHeadersFrame(frameLength, frameStreamid, frameFlags);
 271             case PriorityFrame.TYPE:
 272                 return parsePriorityFrame(frameLength, frameStreamid, frameFlags);
 273             case ResetFrame.TYPE:
 274                 return parseResetFrame(frameLength, frameStreamid, frameFlags);
 275             case SettingsFrame.TYPE:
 276                 return parseSettingsFrame(frameLength, frameStreamid, frameFlags);
 277             case PushPromiseFrame.TYPE:
 278                 return parsePushPromiseFrame(frameLength, frameStreamid, frameFlags);
 279             case PingFrame.TYPE:
 280                 return parsePingFrame(frameLength, frameStreamid, frameFlags);
 281             case GoAwayFrame.TYPE:
 282                 return parseGoAwayFrame(frameLength, frameStreamid, frameFlags);
 283             case WindowUpdateFrame.TYPE:
 284                 return parseWindowUpdateFrame(frameLength, frameStreamid, frameFlags);
 285             case ContinuationFrame.TYPE:
 286                 return parseContinuationFrame(frameLength, frameStreamid, frameFlags);
 287             default:
 288                 // RFC 7540 4.1
 289                 // Implementations MUST ignore and discard any frame that has a type that is unknown.
 290                 Log.logTrace("Unknown incoming frame type: {0}", frameType);
 291                 skipBytes(frameLength);
 292                 return null;
 293         }
 294     }
 295 
 296     private Http2Frame parseDataFrame(int frameLength, int streamid, int flags) {
 297         // non-zero stream
 298         if (streamid == 0) {
 299             return new MalformedFrame(ErrorFrame.PROTOCOL_ERROR, "zero streamId for DataFrame");
 300         }
 301         int padLength = 0;
 302         if ((flags & DataFrame.PADDED) != 0) {
 303             padLength = getByte();
 304             if(padLength >= frameLength) {
 305                 return new MalformedFrame(ErrorFrame.PROTOCOL_ERROR,
 306                         "the length of the padding is the length of the frame payload or greater");
 307             }
 308             frameLength--;
 309         }
 310         DataFrame df = new DataFrame(streamid, flags,
 311                 getBuffers(true, frameLength - padLength), padLength);
 312         skipBytes(padLength);
 313         return df;
 314 
 315     }
 316 
 317     private Http2Frame parseHeadersFrame(int frameLength, int streamid, int flags) {
 318         // non-zero stream
 319         if (streamid == 0) {
 320             return new MalformedFrame(ErrorFrame.PROTOCOL_ERROR, "zero streamId for HeadersFrame");
 321         }
 322         int padLength = 0;
 323         if ((flags & HeadersFrame.PADDED) != 0) {
 324             padLength = getByte();
 325             frameLength--;
 326         }
 327         boolean hasPriority = (flags & HeadersFrame.PRIORITY) != 0;
 328         boolean exclusive = false;
 329         int streamDependency = 0;
 330         int weight = 0;
 331         if (hasPriority) {
 332             int x = getInt();
 333             exclusive = (x & 0x80000000) != 0;
 334             streamDependency = x & 0x7fffffff;
 335             weight = getByte();
 336             frameLength -= 5;
 337         }
 338         if(frameLength < padLength) {
 339             return new MalformedFrame(ErrorFrame.PROTOCOL_ERROR,
 340                     "Padding exceeds the size remaining for the header block");
 341         }
 342         HeadersFrame hf = new HeadersFrame(streamid, flags,
 343                 getBuffers(false, frameLength - padLength), padLength);
 344         skipBytes(padLength);
 345         if (hasPriority) {
 346             hf.setPriority(streamDependency, exclusive, weight);
 347         }
 348         return hf;
 349     }
 350 
 351     private Http2Frame parsePriorityFrame(int frameLength, int streamid, int flags) {
 352         // non-zero stream; no flags
 353         if (streamid == 0) {
 354             return new MalformedFrame(ErrorFrame.PROTOCOL_ERROR,
 355                     "zero streamId for PriorityFrame");
 356         }
 357         if(frameLength != 5) {
 358             skipBytes(frameLength);
 359             return new MalformedFrame(ErrorFrame.FRAME_SIZE_ERROR, streamid,
 360                     "PriorityFrame length is "+ frameLength+", expected 5");
 361         }
 362         int x = getInt();
 363         int weight = getByte();
 364         return new PriorityFrame(streamid, x & 0x7fffffff, (x & 0x80000000) != 0, weight);
 365     }
 366 
 367     private Http2Frame parseResetFrame(int frameLength, int streamid, int flags) {
 368         // non-zero stream; no flags
 369         if (streamid == 0) {
 370             return new MalformedFrame(ErrorFrame.PROTOCOL_ERROR,
 371                     "zero streamId for ResetFrame");
 372         }
 373         if(frameLength != 4) {
 374             return new MalformedFrame(ErrorFrame.FRAME_SIZE_ERROR,
 375                     "ResetFrame length is "+ frameLength+", expected 4");
 376         }
 377         return new ResetFrame(streamid, getInt());
 378     }
 379 
 380     private Http2Frame parseSettingsFrame(int frameLength, int streamid, int flags) {
 381         // only zero stream
 382         if (streamid != 0) {
 383             return new MalformedFrame(ErrorFrame.PROTOCOL_ERROR,
 384                     "non-zero streamId for SettingsFrame");
 385         }
 386         if ((SettingsFrame.ACK & flags) != 0 && frameLength > 0) {
 387             // RFC 7540 6.5
 388             // Receipt of a SETTINGS frame with the ACK flag set and a length
 389             // field value other than 0 MUST be treated as a connection error
 390             return new MalformedFrame(ErrorFrame.FRAME_SIZE_ERROR,
 391                     "ACK SettingsFrame is not empty");
 392         }
 393         if (frameLength % 6 != 0) {
 394             return new MalformedFrame(ErrorFrame.FRAME_SIZE_ERROR,
 395                     "invalid SettingsFrame size: "+frameLength);
 396         }
 397         SettingsFrame sf = new SettingsFrame(flags);
 398         int n = frameLength / 6;
 399         for (int i=0; i<n; i++) {
 400             int id = getShort();
 401             int val = getInt();
 402             if (id > 0 && id <= SettingsFrame.MAX_PARAM) {
 403                 // a known parameter. Ignore otherwise
 404                 sf.setParameter(id, val); // TODO parameters validation
 405             }
 406         }
 407         return sf;
 408     }
 409 
 410     private Http2Frame parsePushPromiseFrame(int frameLength, int streamid, int flags) {
 411         // non-zero stream
 412         if (streamid == 0) {
 413             return new MalformedFrame(ErrorFrame.PROTOCOL_ERROR,
 414                     "zero streamId for PushPromiseFrame");
 415         }
 416         int padLength = 0;
 417         if ((flags & PushPromiseFrame.PADDED) != 0) {
 418             padLength = getByte();
 419             frameLength--;
 420         }
 421         int promisedStream = getInt() & 0x7fffffff;
 422         frameLength -= 4;
 423         if(frameLength < padLength) {
 424             return new MalformedFrame(ErrorFrame.PROTOCOL_ERROR,
 425                     "Padding exceeds the size remaining for the PushPromiseFrame");
 426         }
 427         PushPromiseFrame ppf = new PushPromiseFrame(streamid, flags, promisedStream,
 428                 getBuffers(false, frameLength - padLength), padLength);
 429         skipBytes(padLength);
 430         return ppf;
 431     }
 432 
 433     private Http2Frame parsePingFrame(int frameLength, int streamid, int flags) {
 434         // only zero stream
 435         if (streamid != 0) {
 436             return new MalformedFrame(ErrorFrame.PROTOCOL_ERROR,
 437                     "non-zero streamId for PingFrame");
 438         }
 439         if(frameLength != 8) {
 440             return new MalformedFrame(ErrorFrame.FRAME_SIZE_ERROR,
 441                     "PingFrame length is "+ frameLength+", expected 8");
 442         }
 443         return new PingFrame(flags, getBytes(8));
 444     }
 445 
 446     private Http2Frame parseGoAwayFrame(int frameLength, int streamid, int flags) {
 447         // only zero stream; no flags
 448         if (streamid != 0) {
 449             return new MalformedFrame(ErrorFrame.PROTOCOL_ERROR,
 450                     "non-zero streamId for GoAwayFrame");
 451         }
 452         if (frameLength < 8) {
 453             return new MalformedFrame(ErrorFrame.FRAME_SIZE_ERROR,
 454                     "Invalid GoAway frame size");
 455         }
 456         int lastStream = getInt() & 0x7fffffff;
 457         int errorCode = getInt();
 458         byte[] debugData = getBytes(frameLength - 8);
 459         if (debugData.length > 0) {
 460             Log.logError("GoAway debugData " + new String(debugData));
 461         }
 462         return new GoAwayFrame(lastStream, errorCode, debugData);
 463     }
 464 
 465     private Http2Frame parseWindowUpdateFrame(int frameLength, int streamid, int flags) {
 466         // any stream; no flags
 467         if(frameLength != 4) {
 468             return new MalformedFrame(ErrorFrame.FRAME_SIZE_ERROR,
 469                     "WindowUpdateFrame length is "+ frameLength+", expected 4");
 470         }
 471         return new WindowUpdateFrame(streamid, getInt() & 0x7fffffff);
 472     }
 473 
 474     private Http2Frame parseContinuationFrame(int frameLength, int streamid, int flags) {
 475         // non-zero stream;
 476         if (streamid == 0) {
 477             return new MalformedFrame(ErrorFrame.PROTOCOL_ERROR,
 478                     "zero streamId for ContinuationFrame");
 479         }
 480         return new ContinuationFrame(streamid, flags, getBuffers(false, frameLength));
 481     }
 482 
 483 }