1 /*
   2  * Copyright (c) 2015, 2017, Oracle and/or its affiliates. All rights reserved.
   3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
   4  *
   5  * This code is free software; you can redistribute it and/or modify it
   6  * under the terms of the GNU General Public License version 2 only, as
   7  * published by the Free Software Foundation.  Oracle designates this
   8  * particular file as subject to the "Classpath" exception as provided
   9  * by Oracle in the LICENSE file that accompanied this code.
  10  *
  11  * This code is distributed in the hope that it will be useful, but WITHOUT
  12  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
  13  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
  14  * version 2 for more details (a copy is included in the LICENSE file that
  15  * accompanied this code).
  16  *
  17  * You should have received a copy of the GNU General Public License version
  18  * 2 along with this work; if not, write to the Free Software Foundation,
  19  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
  20  *
  21  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
  22  * or visit www.oracle.com if you need additional information or have any
  23  * questions.
  24  */
  25 
  26 package jdk.incubator.http.internal.frame;
  27 
  28 import jdk.incubator.http.internal.common.Log;
  29 import jdk.incubator.http.internal.common.Utils;
  30 
  31 import java.io.IOException;
  32 import java.lang.System.Logger.Level;
  33 import java.nio.ByteBuffer;
  34 import java.util.ArrayDeque;
  35 import java.util.ArrayList;
  36 import java.util.List;
  37 import java.util.Queue;
  38 
  39 /**
  40  * Frames Decoder
  41  * <p>
  42  * collect buffers until frame decoding is possible,
  43  * all decoded frames are passed to the FrameProcessor callback in order of decoding.
  44  *
  45  * It's a stateful class due to the fact that FramesDecoder stores buffers inside.
  46  * Should be allocated only the single instance per connection.
  47  */
  48 public class FramesDecoder {
  49 
  50     static final boolean DEBUG = Utils.DEBUG; // Revisit: temporary dev flag.
  51     static final System.Logger DEBUG_LOGGER =
  52             Utils.getDebugLogger("FramesDecoder"::toString, DEBUG);
  53 
  54     @FunctionalInterface
  55     public interface FrameProcessor {
  56         void processFrame(Http2Frame frame) throws IOException;
  57     }
  58 
  59     private final FrameProcessor frameProcessor;
  60     private final int maxFrameSize;
  61 
  62     private ByteBuffer currentBuffer; // current buffer either null or hasRemaining
  63 
  64     private final Queue<ByteBuffer> tailBuffers = new ArrayDeque<>();
  65     private int tailSize = 0;
  66 
  67     private boolean slicedToDataFrame = false;
  68 
  69     private final List<ByteBuffer> prepareToRelease = new ArrayList<>();
  70 
  71     // if true  - Frame Header was parsed (9 bytes consumed) and subsequent fields have meaning
  72     // otherwise - stopped at frames boundary
  73     private boolean frameHeaderParsed = false;
  74     private int frameLength;
  75     private int frameType;
  76     private int frameFlags;
  77     private int frameStreamid;
  78 
  79     /**
  80      * Creates Frame Decoder
  81      *
  82      * @param frameProcessor - callback for decoded frames
  83      */
  84     public FramesDecoder(FrameProcessor frameProcessor) {
  85         this(frameProcessor, 16 * 1024);
  86     }
  87 
  88     /**
  89      * Creates Frame Decoder
  90      * @param frameProcessor - callback for decoded frames
  91      * @param maxFrameSize - maxFrameSize accepted by this decoder
  92      */
  93     public FramesDecoder(FrameProcessor frameProcessor, int maxFrameSize) {
  94         this.frameProcessor = frameProcessor;
  95         this.maxFrameSize = Math.min(Math.max(16 * 1024, maxFrameSize), 16 * 1024 * 1024 - 1);
  96     }
  97 
  98     /** Threshold beyond which data is no longer copied into the current buffer,
  99      * if that buffer has enough unused space. */
 100     private static final int COPY_THRESHOLD = 8192;
 101 
 102     /**
 103      * Adds the data from the given buffer, and performs frame decoding if
 104      * possible.   Either 1) appends the data from the given buffer to the
 105      * current buffer ( if there is enough unused space ), or 2) adds it to the
 106      * next buffer in the queue.
 107      *
 108      * If there is enough data to perform frame decoding then, all buffers are
 109      * decoded and the FrameProcessor is invoked.
 110      */
 111     public void decode(ByteBuffer buffer) throws IOException {
 112         int remaining = buffer.remaining();
 113         DEBUG_LOGGER.log(Level.DEBUG, "decodes: %d", remaining);
 114         if (remaining > 0) {
 115             if (currentBuffer == null) {
 116                 currentBuffer = buffer;
 117             } else {
 118                 int limit = currentBuffer.limit();
 119                 int freeSpace = currentBuffer.capacity() - limit;
 120                 if (remaining <= COPY_THRESHOLD && freeSpace >= remaining) {
 121                     // append the new data to the unused space in the current buffer
 122                     ByteBuffer b = buffer;
 123                     int position = currentBuffer.position();
 124                     currentBuffer.position(limit);
 125                     currentBuffer.limit(limit + b.limit());
 126                     currentBuffer.put(b);
 127                     currentBuffer.position(position);
 128                     DEBUG_LOGGER.log(Level.DEBUG, "copied: %d", remaining);
 129                 } else {
 130                     DEBUG_LOGGER.log(Level.DEBUG, "added: %d", remaining);
 131                     tailBuffers.add(buffer);
 132                     tailSize += remaining;
 133                 }
 134             }
 135         }
 136         DEBUG_LOGGER.log(Level.DEBUG, "Tail size is now: %d, current=",
 137                 tailSize,
 138                 (currentBuffer == null ? 0 :
 139                    currentBuffer.remaining()));
 140         Http2Frame frame;
 141         while ((frame = nextFrame()) != null) {
 142             DEBUG_LOGGER.log(Level.DEBUG, "Got frame: %s", frame);
 143             frameProcessor.processFrame(frame);
 144             frameProcessed();
 145         }
 146     }
 147 
 148     private Http2Frame nextFrame() throws IOException {
 149         while (true) {
 150             if (currentBuffer == null) {
 151                 return null; // no data at all
 152             }
 153             long available = currentBuffer.remaining() + tailSize;
 154             if (!frameHeaderParsed) {
 155                 if (available >= Http2Frame.FRAME_HEADER_SIZE) {
 156                     parseFrameHeader();
 157                     if (frameLength > maxFrameSize) {
 158                         // connection error
 159                         return new MalformedFrame(ErrorFrame.FRAME_SIZE_ERROR,
 160                                 "Frame type("+frameType+") "
 161                                 +"length("+frameLength
 162                                 +") exceeds MAX_FRAME_SIZE("
 163                                 + maxFrameSize+")");
 164                     }
 165                     frameHeaderParsed = true;
 166                 } else {
 167                     DEBUG_LOGGER.log(Level.DEBUG,
 168                             "Not enough data to parse header, needs: %d, has: %d",
 169                             Http2Frame.FRAME_HEADER_SIZE, available);
 170                 }
 171             }
 172             available = currentBuffer == null ? 0 : currentBuffer.remaining() + tailSize;
 173             if ((frameLength == 0) ||
 174                     (currentBuffer != null && available >= frameLength)) {
 175                 Http2Frame frame = parseFrameBody();
 176                 frameHeaderParsed = false;
 177                 // frame == null means we have to skip this frame and try parse next
 178                 if (frame != null) {
 179                     return frame;
 180                 }
 181             } else {
 182                 DEBUG_LOGGER.log(Level.DEBUG,
 183                         "Not enough data to parse frame body, needs: %d,  has: %d",
 184                         frameLength, available);
 185                 return null;  // no data for the whole frame header
 186             }
 187         }
 188     }
 189 
 190     private void frameProcessed() {
 191         prepareToRelease.clear();
 192     }
 193 
 194     private void parseFrameHeader() throws IOException {
 195         int x = getInt();
 196         this.frameLength = x >> 8;
 197         this.frameType = x & 0xff;
 198         this.frameFlags = getByte();
 199         this.frameStreamid = getInt() & 0x7fffffff;
 200         // R: A reserved 1-bit field.  The semantics of this bit are undefined,
 201         // MUST be ignored when receiving.
 202     }
 203 
 204     // move next buffer from tailBuffers to currentBuffer if required
 205     private void nextBuffer() {
 206         if (!currentBuffer.hasRemaining()) {
 207             if (!slicedToDataFrame) {
 208                 prepareToRelease.add(currentBuffer);
 209             }
 210             slicedToDataFrame = false;
 211             currentBuffer = tailBuffers.poll();
 212             if (currentBuffer != null) {
 213                 tailSize -= currentBuffer.remaining();
 214             }
 215         }
 216     }
 217 
 218     public int getByte() {
 219         int res = currentBuffer.get() & 0xff;
 220         nextBuffer();
 221         return res;
 222     }
 223 
 224     public int getShort() {
 225         if (currentBuffer.remaining() >= 2) {
 226             int res = currentBuffer.getShort() & 0xffff;
 227             nextBuffer();
 228             return res;
 229         }
 230         int val = getByte();
 231         val = (val << 8) + getByte();
 232         return val;
 233     }
 234 
 235     public int getInt() {
 236         if (currentBuffer.remaining() >= 4) {
 237             int res = currentBuffer.getInt();
 238             nextBuffer();
 239             return res;
 240         }
 241         int val = getByte();
 242         val = (val << 8) + getByte();
 243         val = (val << 8) + getByte();
 244         val = (val << 8) + getByte();
 245         return val;
 246 
 247     }
 248 
 249     public byte[] getBytes(int n) {
 250         byte[] bytes = new byte[n];
 251         int offset = 0;
 252         while (n > 0) {
 253             int length = Math.min(n, currentBuffer.remaining());
 254             currentBuffer.get(bytes, offset, length);
 255             offset += length;
 256             n -= length;
 257             nextBuffer();
 258         }
 259         return bytes;
 260 
 261     }
 262 
 263     private List<ByteBuffer> getBuffers(boolean isDataFrame, int bytecount) {
 264         List<ByteBuffer> res = new ArrayList<>();
 265         while (bytecount > 0) {
 266             int remaining = currentBuffer.remaining();
 267             int extract = Math.min(remaining, bytecount);
 268             ByteBuffer extractedBuf;
 269             if (isDataFrame) {
 270                 extractedBuf = Utils.slice(currentBuffer, extract);
 271                 slicedToDataFrame = true;
 272             } else {
 273                 // Header frames here
 274                 // HPACK decoding should performed under lock and immediately after frame decoding.
 275                 // in that case it is safe to release original buffer,
 276                 // because of sliced buffer has a very short life
 277                 extractedBuf = Utils.slice(currentBuffer, extract);
 278             }
 279             res.add(extractedBuf);
 280             bytecount -= extract;
 281             nextBuffer();
 282         }
 283         return res;
 284     }
 285 
 286     public void skipBytes(int bytecount) {
 287         while (bytecount > 0) {
 288             int remaining = currentBuffer.remaining();
 289             int extract = Math.min(remaining, bytecount);
 290             currentBuffer.position(currentBuffer.position() + extract);
 291             bytecount -= remaining;
 292             nextBuffer();
 293         }
 294     }
 295 
 296     private Http2Frame parseFrameBody() throws IOException {
 297         assert frameHeaderParsed;
 298         switch (frameType) {
 299             case DataFrame.TYPE:
 300                 return parseDataFrame(frameLength, frameStreamid, frameFlags);
 301             case HeadersFrame.TYPE:
 302                 return parseHeadersFrame(frameLength, frameStreamid, frameFlags);
 303             case PriorityFrame.TYPE:
 304                 return parsePriorityFrame(frameLength, frameStreamid, frameFlags);
 305             case ResetFrame.TYPE:
 306                 return parseResetFrame(frameLength, frameStreamid, frameFlags);
 307             case SettingsFrame.TYPE:
 308                 return parseSettingsFrame(frameLength, frameStreamid, frameFlags);
 309             case PushPromiseFrame.TYPE:
 310                 return parsePushPromiseFrame(frameLength, frameStreamid, frameFlags);
 311             case PingFrame.TYPE:
 312                 return parsePingFrame(frameLength, frameStreamid, frameFlags);
 313             case GoAwayFrame.TYPE:
 314                 return parseGoAwayFrame(frameLength, frameStreamid, frameFlags);
 315             case WindowUpdateFrame.TYPE:
 316                 return parseWindowUpdateFrame(frameLength, frameStreamid, frameFlags);
 317             case ContinuationFrame.TYPE:
 318                 return parseContinuationFrame(frameLength, frameStreamid, frameFlags);
 319             default:
 320                 // RFC 7540 4.1
 321                 // Implementations MUST ignore and discard any frame that has a type that is unknown.
 322                 Log.logTrace("Unknown incoming frame type: {0}", frameType);
 323                 skipBytes(frameLength);
 324                 return null;
 325         }
 326     }
 327 
 328     private Http2Frame parseDataFrame(int frameLength, int streamid, int flags) {
 329         // non-zero stream
 330         if (streamid == 0) {
 331             return new MalformedFrame(ErrorFrame.PROTOCOL_ERROR,
 332                                       "zero streamId for DataFrame");
 333         }
 334         int padLength = 0;
 335         if ((flags & DataFrame.PADDED) != 0) {
 336             padLength = getByte();
 337             if (padLength >= frameLength) {
 338                 return new MalformedFrame(ErrorFrame.PROTOCOL_ERROR,
 339                         "the length of the padding is the length of the frame payload or greater");
 340             }
 341             frameLength--;
 342         }
 343         DataFrame df = new DataFrame(streamid, flags,
 344                 getBuffers(true, frameLength - padLength), padLength);
 345         skipBytes(padLength);
 346         return df;
 347 
 348     }
 349 
 350     private Http2Frame parseHeadersFrame(int frameLength, int streamid, int flags) {
 351         // non-zero stream
 352         if (streamid == 0) {
 353             return new MalformedFrame(ErrorFrame.PROTOCOL_ERROR,
 354                                       "zero streamId for HeadersFrame");
 355         }
 356         int padLength = 0;
 357         if ((flags & HeadersFrame.PADDED) != 0) {
 358             padLength = getByte();
 359             frameLength--;
 360         }
 361         boolean hasPriority = (flags & HeadersFrame.PRIORITY) != 0;
 362         boolean exclusive = false;
 363         int streamDependency = 0;
 364         int weight = 0;
 365         if (hasPriority) {
 366             int x = getInt();
 367             exclusive = (x & 0x80000000) != 0;
 368             streamDependency = x & 0x7fffffff;
 369             weight = getByte();
 370             frameLength -= 5;
 371         }
 372         if(frameLength < padLength) {
 373             return new MalformedFrame(ErrorFrame.PROTOCOL_ERROR,
 374                     "Padding exceeds the size remaining for the header block");
 375         }
 376         HeadersFrame hf = new HeadersFrame(streamid, flags,
 377                 getBuffers(false, frameLength - padLength), padLength);
 378         skipBytes(padLength);
 379         if (hasPriority) {
 380             hf.setPriority(streamDependency, exclusive, weight);
 381         }
 382         return hf;
 383     }
 384 
 385     private Http2Frame parsePriorityFrame(int frameLength, int streamid, int flags) {
 386         // non-zero stream; no flags
 387         if (streamid == 0) {
 388             return new MalformedFrame(ErrorFrame.PROTOCOL_ERROR,
 389                     "zero streamId for PriorityFrame");
 390         }
 391         if(frameLength != 5) {
 392             skipBytes(frameLength);
 393             return new MalformedFrame(ErrorFrame.FRAME_SIZE_ERROR, streamid,
 394                     "PriorityFrame length is "+ frameLength+", expected 5");
 395         }
 396         int x = getInt();
 397         int weight = getByte();
 398         return new PriorityFrame(streamid, x & 0x7fffffff, (x & 0x80000000) != 0, weight);
 399     }
 400 
 401     private Http2Frame parseResetFrame(int frameLength, int streamid, int flags) {
 402         // non-zero stream; no flags
 403         if (streamid == 0) {
 404             return new MalformedFrame(ErrorFrame.PROTOCOL_ERROR,
 405                     "zero streamId for ResetFrame");
 406         }
 407         if(frameLength != 4) {
 408             return new MalformedFrame(ErrorFrame.FRAME_SIZE_ERROR,
 409                     "ResetFrame length is "+ frameLength+", expected 4");
 410         }
 411         return new ResetFrame(streamid, getInt());
 412     }
 413 
 414     private Http2Frame parseSettingsFrame(int frameLength, int streamid, int flags) {
 415         // only zero stream
 416         if (streamid != 0) {
 417             return new MalformedFrame(ErrorFrame.PROTOCOL_ERROR,
 418                     "non-zero streamId for SettingsFrame");
 419         }
 420         if ((SettingsFrame.ACK & flags) != 0 && frameLength > 0) {
 421             // RFC 7540 6.5
 422             // Receipt of a SETTINGS frame with the ACK flag set and a length
 423             // field value other than 0 MUST be treated as a connection error
 424             return new MalformedFrame(ErrorFrame.FRAME_SIZE_ERROR,
 425                     "ACK SettingsFrame is not empty");
 426         }
 427         if (frameLength % 6 != 0) {
 428             return new MalformedFrame(ErrorFrame.FRAME_SIZE_ERROR,
 429                     "invalid SettingsFrame size: "+frameLength);
 430         }
 431         SettingsFrame sf = new SettingsFrame(flags);
 432         int n = frameLength / 6;
 433         for (int i=0; i<n; i++) {
 434             int id = getShort();
 435             int val = getInt();
 436             if (id > 0 && id <= SettingsFrame.MAX_PARAM) {
 437                 // a known parameter. Ignore otherwise
 438                 sf.setParameter(id, val); // TODO parameters validation
 439             }
 440         }
 441         return sf;
 442     }
 443 
 444     private Http2Frame parsePushPromiseFrame(int frameLength, int streamid, int flags) {
 445         // non-zero stream
 446         if (streamid == 0) {
 447             return new MalformedFrame(ErrorFrame.PROTOCOL_ERROR,
 448                     "zero streamId for PushPromiseFrame");
 449         }
 450         int padLength = 0;
 451         if ((flags & PushPromiseFrame.PADDED) != 0) {
 452             padLength = getByte();
 453             frameLength--;
 454         }
 455         int promisedStream = getInt() & 0x7fffffff;
 456         frameLength -= 4;
 457         if(frameLength < padLength) {
 458             return new MalformedFrame(ErrorFrame.PROTOCOL_ERROR,
 459                     "Padding exceeds the size remaining for the PushPromiseFrame");
 460         }
 461         PushPromiseFrame ppf = new PushPromiseFrame(streamid, flags, promisedStream,
 462                 getBuffers(false, frameLength - padLength), padLength);
 463         skipBytes(padLength);
 464         return ppf;
 465     }
 466 
 467     private Http2Frame parsePingFrame(int frameLength, int streamid, int flags) {
 468         // only zero stream
 469         if (streamid != 0) {
 470             return new MalformedFrame(ErrorFrame.PROTOCOL_ERROR,
 471                     "non-zero streamId for PingFrame");
 472         }
 473         if(frameLength != 8) {
 474             return new MalformedFrame(ErrorFrame.FRAME_SIZE_ERROR,
 475                     "PingFrame length is "+ frameLength+", expected 8");
 476         }
 477         return new PingFrame(flags, getBytes(8));
 478     }
 479 
 480     private Http2Frame parseGoAwayFrame(int frameLength, int streamid, int flags) {
 481         // only zero stream; no flags
 482         if (streamid != 0) {
 483             return new MalformedFrame(ErrorFrame.PROTOCOL_ERROR,
 484                     "non-zero streamId for GoAwayFrame");
 485         }
 486         if (frameLength < 8) {
 487             return new MalformedFrame(ErrorFrame.FRAME_SIZE_ERROR,
 488                     "Invalid GoAway frame size");
 489         }
 490         int lastStream = getInt() & 0x7fffffff;
 491         int errorCode = getInt();
 492         byte[] debugData = getBytes(frameLength - 8);
 493         if (debugData.length > 0) {
 494             Log.logError("GoAway debugData " + new String(debugData));
 495         }
 496         return new GoAwayFrame(lastStream, errorCode, debugData);
 497     }
 498 
 499     private Http2Frame parseWindowUpdateFrame(int frameLength, int streamid, int flags) {
 500         // any stream; no flags
 501         if(frameLength != 4) {
 502             return new MalformedFrame(ErrorFrame.FRAME_SIZE_ERROR,
 503                     "WindowUpdateFrame length is "+ frameLength+", expected 4");
 504         }
 505         return new WindowUpdateFrame(streamid, getInt() & 0x7fffffff);
 506     }
 507 
 508     private Http2Frame parseContinuationFrame(int frameLength, int streamid, int flags) {
 509         // non-zero stream;
 510         if (streamid == 0) {
 511             return new MalformedFrame(ErrorFrame.PROTOCOL_ERROR,
 512                     "zero streamId for ContinuationFrame");
 513         }
 514         return new ContinuationFrame(streamid, flags, getBuffers(false, frameLength));
 515     }
 516 
 517 }