1 /*
   2  * Copyright (c) 2016, Oracle and/or its affiliates. All rights reserved.
   3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
   4  *
   5  * This code is free software; you can redistribute it and/or modify it
   6  * under the terms of the GNU General Public License version 2 only, as
   7  * published by the Free Software Foundation.
   8  *
   9  * This code is distributed in the hope that it will be useful, but WITHOUT
  10  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
  11  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
  12  * version 2 for more details (a copy is included in the LICENSE file that
  13  * accompanied this code).
  14  *
  15  * You should have received a copy of the GNU General Public License version
  16  * 2 along with this work; if not, write to the Free Software Foundation,
  17  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
  18  *
  19  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
  20  * or visit www.oracle.com if you need additional information or have any
  21  * questions.
  22  */
  23 
  24 package jdk.incubator.http.internal.websocket;
  25 
  26 import org.testng.annotations.Test;
  27 
  28 import java.nio.ByteBuffer;
  29 import java.security.SecureRandom;
  30 import java.util.stream.IntStream;
  31 
  32 import static org.testng.Assert.assertEquals;
  33 import static jdk.incubator.http.internal.websocket.Frame.Masker.transferMasking;
  34 import static jdk.incubator.http.internal.websocket.TestSupport.forEachBufferPartition;
  35 import static jdk.incubator.http.internal.websocket.TestSupport.fullCopy;
  36 
  37 public class MaskerTest {
  38 
  39     private static final SecureRandom random = new SecureRandom();
  40 
  41     @Test
  42     public void stateless() {
  43         IntStream.iterate(0, i -> i + 1).limit(125).boxed()
  44                 .forEach(r -> {
  45                     int m = random.nextInt();
  46                     ByteBuffer src = createSourceBuffer(r);
  47                     ByteBuffer dst = createDestinationBuffer(r);
  48                     verify(src, dst, maskArray(m), 0,
  49                             () -> transferMasking(src, dst, m));
  50                 });
  51     }
  52 
  53     /*
  54      * Stateful masker to make sure setting a mask resets the state as if a new
  55      * Masker instance is created each time
  56      */
  57     private final Frame.Masker masker = new Frame.Masker();
  58 
  59     @Test
  60     public void stateful0() {
  61         // This size (17 = 8 + 8 + 1) should test all the stages
  62         // (galloping/slow) of masking good enough
  63         int N = 17;
  64         ByteBuffer src = createSourceBuffer(N);
  65         ByteBuffer dst = createDestinationBuffer(N);
  66         int mask = random.nextInt();
  67         forEachBufferPartition(src,
  68                 buffers -> {
  69                     int offset = 0;
  70                     masker.mask(mask);
  71                     int[] maskBytes = maskArray(mask);
  72                     for (ByteBuffer s : buffers) {
  73                         offset = verify(s, dst, maskBytes, offset,
  74                                 () -> masker.transferMasking(s, dst));
  75                     }
  76                 });
  77     }
  78 
  79     @Test
  80     public void stateful1() {
  81         int m = random.nextInt();
  82         masker.mask(m);
  83         ByteBuffer src = ByteBuffer.allocate(0);
  84         ByteBuffer dst = ByteBuffer.allocate(16);
  85         verify(src, dst, maskArray(m), 0,
  86                 () -> masker.transferMasking(src, dst));
  87     }
  88 
  89     private static int verify(ByteBuffer src,
  90                               ByteBuffer dst,
  91                               int[] maskBytes,
  92                               int offset,
  93                               Runnable masking) {
  94         ByteBuffer srcCopy = fullCopy(src);
  95         ByteBuffer dstCopy = fullCopy(dst);
  96         masking.run();
  97         int srcRemaining = srcCopy.remaining();
  98         int dstRemaining = dstCopy.remaining();
  99         int masked = Math.min(srcRemaining, dstRemaining);
 100         // 1. position check
 101         assertEquals(src.position(), srcCopy.position() + masked);
 102         assertEquals(dst.position(), dstCopy.position() + masked);
 103         // 2. masking check
 104         src.position(srcCopy.position());
 105         dst.position(dstCopy.position());
 106         for (; src.hasRemaining() && dst.hasRemaining();
 107              offset = (offset + 1) & 3) {
 108             assertEquals(dst.get(), src.get() ^ maskBytes[offset]);
 109         }
 110         // 3. corruption check
 111         // 3.1 src contents haven't changed
 112         int srcPosition = src.position();
 113         int srcLimit = src.limit();
 114         src.clear();
 115         srcCopy.clear();
 116         assertEquals(src, srcCopy);
 117         src.limit(srcLimit).position(srcPosition); // restore src
 118         // 3.2 dst leading and trailing regions' contents haven't changed
 119         int dstPosition = dst.position();
 120         int dstInitialPosition = dstCopy.position();
 121         int dstLimit = dst.limit();
 122         // leading
 123         dst.position(0).limit(dstInitialPosition);
 124         dstCopy.position(0).limit(dstInitialPosition);
 125         assertEquals(dst, dstCopy);
 126         // trailing
 127         dst.limit(dst.capacity()).position(dstLimit);
 128         dstCopy.limit(dst.capacity()).position(dstLimit);
 129         assertEquals(dst, dstCopy);
 130         // restore dst
 131         dst.position(dstPosition).limit(dstLimit);
 132         return offset;
 133     }
 134 
 135     private static ByteBuffer createSourceBuffer(int remaining) {
 136         int leading = random.nextInt(4);
 137         int trailing = random.nextInt(4);
 138         byte[] bytes = new byte[leading + remaining + trailing];
 139         random.nextBytes(bytes);
 140         return ByteBuffer.wrap(bytes).position(leading).limit(leading + remaining);
 141     }
 142 
 143     private static ByteBuffer createDestinationBuffer(int remaining) {
 144         int leading = random.nextInt(4);
 145         int trailing = random.nextInt(4);
 146         return ByteBuffer.allocate(leading + remaining + trailing)
 147                 .position(leading).limit(leading + remaining);
 148     }
 149 
 150     private static int[] maskArray(int mask) {
 151         return new int[]{
 152                 (byte) (mask >>> 24),
 153                 (byte) (mask >>> 16),
 154                 (byte) (mask >>>  8),
 155                 (byte) (mask >>>  0)
 156         };
 157     }
 158 }