1 /*
   2  * Copyright (c) 2017, Oracle and/or its affiliates. All rights reserved.
   3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
   4  *
   5  * This code is free software; you can redistribute it and/or modify it
   6  * under the terms of the GNU General Public License version 2 only, as
   7  * published by the Free Software Foundation.
   8  *
   9  * This code is distributed in the hope that it will be useful, but WITHOUT
  10  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
  11  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
  12  * version 2 for more details (a copy is included in the LICENSE file that
  13  * accompanied this code).
  14  *
  15  * You should have received a copy of the GNU General Public License version
  16  * 2 along with this work; if not, write to the Free Software Foundation,
  17  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
  18  *
  19  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
  20  * or visit www.oracle.com if you need additional information or have any
  21  * questions.
  22  */
  23 
  24 package jdk.incubator.http;
  25 
  26 import jdk.incubator.http.internal.websocket.RawChannel;
  27 import org.testng.annotations.Test;
  28 
  29 import java.io.IOException;
  30 import java.io.InputStream;
  31 import java.io.OutputStream;
  32 import java.io.UncheckedIOException;
  33 import java.net.ServerSocket;
  34 import java.net.Socket;
  35 import java.net.URI;
  36 import java.nio.ByteBuffer;
  37 import java.nio.channels.SelectionKey;
  38 import java.util.Random;
  39 import java.util.concurrent.CountDownLatch;
  40 import java.util.concurrent.TimeUnit;
  41 import java.util.concurrent.atomic.AtomicInteger;
  42 import java.util.concurrent.atomic.AtomicLong;
  43 
  44 import static jdk.incubator.http.HttpResponse.BodyHandler.discard;
  45 import static org.testng.Assert.assertEquals;
  46 
  47 /*
  48  * This test exercises mechanics of _independent_ reads and writes on the
  49  * RawChannel. It verifies that the underlying implementation can manage more
  50  * than a single type of notifications at the same time.
  51  */
  52 public class RawChannelTest {
  53 
  54     private final AtomicLong clientWritten = new AtomicLong();
  55     private final AtomicLong serverWritten = new AtomicLong();
  56     private final AtomicLong clientRead = new AtomicLong();
  57     private final AtomicLong serverRead = new AtomicLong();
  58 
  59     /*
  60      * Since at this level we don't have any control over the low level socket
  61      * parameters, this latch ensures a write to the channel will stall at least
  62      * once (socket's send buffer filled up).
  63      */
  64     private final CountDownLatch writeStall = new CountDownLatch(1);
  65     private final CountDownLatch initialWriteStall = new CountDownLatch(1);
  66 
  67     /*
  68      * This one works similarly by providing means to ensure a read from the
  69      * channel will stall at least once (no more data available on the socket).
  70      */
  71     private final CountDownLatch readStall = new CountDownLatch(1);
  72     private final CountDownLatch initialReadStall = new CountDownLatch(1);
  73 
  74     private final AtomicInteger writeHandles = new AtomicInteger();
  75     private final AtomicInteger readHandles = new AtomicInteger();
  76 
  77     private final CountDownLatch exit = new CountDownLatch(1);
  78 
  79     @Test
  80     public void test() throws Exception {
  81         try (ServerSocket server = new ServerSocket(0)) {
  82             int port = server.getLocalPort();
  83             new TestServer(server).start();
  84 
  85             final RawChannel chan = channelOf(port);
  86             initialWriteStall.await();
  87 
  88             // It's very important not to forget the initial bytes, possibly
  89             // left from the HTTP thingy
  90             int initialBytes = chan.initialByteBuffer().remaining();
  91             print("RawChannel has %s initial bytes", initialBytes);
  92             clientRead.addAndGet(initialBytes);
  93 
  94             // tell the server we have read the initial bytes, so
  95             // that it makes sure there is something for us to
  96             // read next in case the initialBytes have already drained the
  97             // channel dry.
  98             initialReadStall.countDown();
  99 
 100             chan.registerEvent(new RawChannel.RawEvent() {
 101 
 102                 private final ByteBuffer reusableBuffer = ByteBuffer.allocate(32768);
 103 
 104                 @Override
 105                 public int interestOps() {
 106                     return SelectionKey.OP_WRITE;
 107                 }
 108 
 109                 @Override
 110                 public void handle() {
 111                     int i = writeHandles.incrementAndGet();
 112                     print("OP_WRITE #%s", i);
 113                     if (i > 3) { // Fill up the send buffer not more than 3 times
 114                         try {
 115                             chan.shutdownOutput();
 116                         } catch (IOException e) {
 117                             e.printStackTrace();
 118                         }
 119                         return;
 120                     }
 121                     long total = 0;
 122                     try {
 123                         long n;
 124                         do {
 125                             ByteBuffer[] array = {reusableBuffer.slice()};
 126                             n = chan.write(array, 0, 1);
 127                             total += n;
 128                         } while (n > 0);
 129                         print("OP_WRITE clogged SNDBUF with %s bytes", total);
 130                         clientWritten.addAndGet(total);
 131                         chan.registerEvent(this);
 132                         writeStall.countDown(); // signal send buffer is full
 133                     } catch (IOException e) {
 134                         throw new UncheckedIOException(e);
 135                     }
 136                 }
 137             });
 138 
 139             chan.registerEvent(new RawChannel.RawEvent() {
 140 
 141                 @Override
 142                 public int interestOps() {
 143                     return SelectionKey.OP_READ;
 144                 }
 145 
 146                 @Override
 147                 public void handle() {
 148                     int i = readHandles.incrementAndGet();
 149                     print("OP_READ #%s", i);
 150                     ByteBuffer read = null;
 151                     long total = 0;
 152                     while (true) {
 153                         try {
 154                             read = chan.read();
 155                         } catch (IOException e) {
 156                             e.printStackTrace();
 157                         }
 158                         if (read == null) {
 159                             print("OP_READ EOF");
 160                             break;
 161                         } else if (!read.hasRemaining()) {
 162                             print("OP_READ stall");
 163                             try {
 164                                 chan.registerEvent(this);
 165                             } catch (IOException e) {
 166                                 e.printStackTrace();
 167                             }
 168                             readStall.countDown();
 169                             break;
 170                         }
 171                         int r = read.remaining();
 172                         total += r;
 173                         clientRead.addAndGet(r);
 174                     }
 175                     print("OP_READ read %s bytes (%s total)", total, clientRead.get());
 176                 }
 177             });
 178             exit.await(); // All done, we need to compare results:
 179             assertEquals(clientRead.get(), serverWritten.get());
 180             assertEquals(serverRead.get(), clientWritten.get());
 181         }
 182     }
 183 
 184     private static RawChannel channelOf(int port) throws Exception {
 185         URI uri = URI.create("http://127.0.0.1:" + port + "/");
 186         print("raw channel to %s", uri.toString());
 187         HttpRequest req = HttpRequest.newBuilder(uri).build();
 188         HttpResponse<?> r = HttpClient.newHttpClient().send(req, discard(null));
 189         r.body();
 190         return ((HttpResponseImpl) r).rawChannel();
 191     }
 192 
 193     private class TestServer extends Thread { // Powered by Slowpokes
 194 
 195         private final ServerSocket server;
 196 
 197         TestServer(ServerSocket server) throws IOException {
 198             this.server = server;
 199         }
 200 
 201         @Override
 202         public void run() {
 203             try (Socket s = server.accept()) {
 204                 InputStream is = s.getInputStream();
 205                 OutputStream os = s.getOutputStream();
 206 
 207                 processHttp(is, os);
 208 
 209                 Thread reader = new Thread(() -> {
 210                     try {
 211                         long n = readSlowly(is);
 212                         print("Server read %s bytes", n);
 213                         serverRead.addAndGet(n);
 214                         s.shutdownInput();
 215                     } catch (Exception e) {
 216                         e.printStackTrace();
 217                     }
 218                 });
 219 
 220                 Thread writer = new Thread(() -> {
 221                     try {
 222                         long n = writeSlowly(os);
 223                         print("Server written %s bytes", n);
 224                         serverWritten.addAndGet(n);
 225                         s.shutdownOutput();
 226                     } catch (Exception e) {
 227                         e.printStackTrace();
 228                     }
 229                 });
 230 
 231                 reader.start();
 232                 writer.start();
 233 
 234                 reader.join();
 235                 writer.join();
 236             } catch (Exception e) {
 237                 e.printStackTrace();
 238             } finally {
 239                 exit.countDown();
 240             }
 241         }
 242 
 243         private void processHttp(InputStream is, OutputStream os)
 244                 throws IOException
 245         {
 246             os.write("HTTP/1.1 200 OK\r\nContent-length: 0\r\n\r\n".getBytes());
 247 
 248             // write some initial bytes
 249             byte[] initial = byteArrayOfSize(1024);
 250             os.write(initial);
 251             os.flush();
 252             serverWritten.addAndGet(initial.length);
 253             initialWriteStall.countDown();
 254 
 255             byte[] buf = new byte[1024];
 256             String s = "";
 257             while (true) {
 258                 int n = is.read(buf);
 259                 if (n <= 0) {
 260                     throw new RuntimeException("Unexpected end of request");
 261                 }
 262                 s = s + new String(buf, 0, n);
 263                 if (s.contains("\r\n\r\n")) {
 264                     break;
 265                 }
 266             }
 267         }
 268 
 269         private long writeSlowly(OutputStream os) throws Exception {
 270             byte[] first = byteArrayOfSize(1024);
 271             long total = first.length;
 272             os.write(first);
 273             os.flush();
 274 
 275             // wait until initial bytes were read
 276             initialReadStall.await();
 277 
 278             // make sure there is something to read, otherwise readStall
 279             // will never be counted down.
 280             first = byteArrayOfSize(1024);
 281             os.write(first);
 282             os.flush();
 283             total += first.length;
 284 
 285             // Let's wait for the signal from the raw channel that its read has
 286             // stalled, and then continue sending a bit more stuff
 287             readStall.await();
 288             for (int i = 0; i < 32; i++) {
 289                 byte[] b = byteArrayOfSize(1024);
 290                 os.write(b);
 291                 os.flush();
 292                 total += b.length;
 293                 TimeUnit.MILLISECONDS.sleep(1);
 294             }
 295             return total;
 296         }
 297 
 298         private long readSlowly(InputStream is) throws Exception {
 299             // Wait for the raw channel to fill up its send buffer
 300             writeStall.await();
 301             long overall = 0;
 302             byte[] array = new byte[1024];
 303             for (int n = 0; n != -1; n = is.read(array)) {
 304                 TimeUnit.MILLISECONDS.sleep(1);
 305                 overall += n;
 306             }
 307             return overall;
 308         }
 309     }
 310 
 311     private static void print(String format, Object... args) {
 312         System.out.println(Thread.currentThread() + ": " + String.format(format, args));
 313     }
 314 
 315     private static byte[] byteArrayOfSize(int bound) {
 316         return new byte[new Random().nextInt(1 + bound)];
 317     }
 318 }