1 /* 2 * Copyright (c) 2017, Oracle and/or its affiliates. All rights reserved. 3 * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. 4 * 5 * This code is free software; you can redistribute it and/or modify it 6 * under the terms of the GNU General Public License version 2 only, as 7 * published by the Free Software Foundation. 8 * 9 * This code is distributed in the hope that it will be useful, but WITHOUT 10 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or 11 * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License 12 * version 2 for more details (a copy is included in the LICENSE file that 13 * accompanied this code). 14 * 15 * You should have received a copy of the GNU General Public License version 16 * 2 along with this work; if not, write to the Free Software Foundation, 17 * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. 18 * 19 * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA 20 * or visit www.oracle.com if you need additional information or have any 21 * questions. 22 */ 23 24 package jdk.incubator.http; 25 26 import jdk.incubator.http.internal.websocket.RawChannel; 27 import org.testng.annotations.Test; 28 29 import java.io.IOException; 30 import java.io.InputStream; 31 import java.io.OutputStream; 32 import java.io.UncheckedIOException; 33 import java.net.ServerSocket; 34 import java.net.Socket; 35 import java.net.URI; 36 import java.nio.ByteBuffer; 37 import java.nio.channels.SelectionKey; 38 import java.util.Random; 39 import java.util.concurrent.CountDownLatch; 40 import java.util.concurrent.TimeUnit; 41 import java.util.concurrent.atomic.AtomicInteger; 42 import java.util.concurrent.atomic.AtomicLong; 43 44 import static jdk.incubator.http.HttpResponse.BodyHandler.discard; 45 import static org.testng.Assert.assertEquals; 46 47 /* 48 * This test exercises mechanics of _independent_ reads and writes on the 49 * RawChannel. It verifies that the underlying implementation can manage more 50 * than a single type of notifications at the same time. 51 */ 52 public class RawChannelTest { 53 54 private final AtomicLong clientWritten = new AtomicLong(); 55 private final AtomicLong serverWritten = new AtomicLong(); 56 private final AtomicLong clientRead = new AtomicLong(); 57 private final AtomicLong serverRead = new AtomicLong(); 58 59 /* 60 * Since at this level we don't have any control over the low level socket 61 * parameters, this latch ensures a write to the channel will stall at least 62 * once (socket's send buffer filled up). 63 */ 64 private final CountDownLatch writeStall = new CountDownLatch(1); 65 private final CountDownLatch initialWriteStall = new CountDownLatch(1); 66 67 /* 68 * This one works similarly by providing means to ensure a read from the 69 * channel will stall at least once (no more data available on the socket). 70 */ 71 private final CountDownLatch readStall = new CountDownLatch(1); 72 private final CountDownLatch initialReadStall = new CountDownLatch(1); 73 74 private final AtomicInteger writeHandles = new AtomicInteger(); 75 private final AtomicInteger readHandles = new AtomicInteger(); 76 77 private final CountDownLatch exit = new CountDownLatch(1); 78 79 @Test 80 public void test() throws Exception { 81 try (ServerSocket server = new ServerSocket(0)) { 82 int port = server.getLocalPort(); 83 new TestServer(server).start(); 84 85 final RawChannel chan = channelOf(port); 86 initialWriteStall.await(); 87 88 // It's very important not to forget the initial bytes, possibly 89 // left from the HTTP thingy 90 int initialBytes = chan.initialByteBuffer().remaining(); 91 print("RawChannel has %s initial bytes", initialBytes); 92 clientRead.addAndGet(initialBytes); 93 94 // tell the server we have read the initial bytes, so 95 // that it makes sure there is something for us to 96 // read next in case the initialBytes have already drained the 97 // channel dry. 98 initialReadStall.countDown(); 99 100 chan.registerEvent(new RawChannel.RawEvent() { 101 102 private final ByteBuffer reusableBuffer = ByteBuffer.allocate(32768); 103 104 @Override 105 public int interestOps() { 106 return SelectionKey.OP_WRITE; 107 } 108 109 @Override 110 public void handle() { 111 int i = writeHandles.incrementAndGet(); 112 print("OP_WRITE #%s", i); 113 if (i > 3) { // Fill up the send buffer not more than 3 times 114 try { 115 chan.shutdownOutput(); 116 } catch (IOException e) { 117 e.printStackTrace(); 118 } 119 return; 120 } 121 long total = 0; 122 try { 123 long n; 124 do { 125 ByteBuffer[] array = {reusableBuffer.slice()}; 126 n = chan.write(array, 0, 1); 127 total += n; 128 } while (n > 0); 129 print("OP_WRITE clogged SNDBUF with %s bytes", total); 130 clientWritten.addAndGet(total); 131 chan.registerEvent(this); 132 writeStall.countDown(); // signal send buffer is full 133 } catch (IOException e) { 134 throw new UncheckedIOException(e); 135 } 136 } 137 }); 138 139 chan.registerEvent(new RawChannel.RawEvent() { 140 141 @Override 142 public int interestOps() { 143 return SelectionKey.OP_READ; 144 } 145 146 @Override 147 public void handle() { 148 int i = readHandles.incrementAndGet(); 149 print("OP_READ #%s", i); 150 ByteBuffer read = null; 151 long total = 0; 152 while (true) { 153 try { 154 read = chan.read(); 155 } catch (IOException e) { 156 e.printStackTrace(); 157 } 158 if (read == null) { 159 print("OP_READ EOF"); 160 break; 161 } else if (!read.hasRemaining()) { 162 print("OP_READ stall"); 163 try { 164 chan.registerEvent(this); 165 } catch (IOException e) { 166 e.printStackTrace(); 167 } 168 readStall.countDown(); 169 break; 170 } 171 int r = read.remaining(); 172 total += r; 173 clientRead.addAndGet(r); 174 } 175 print("OP_READ read %s bytes (%s total)", total, clientRead.get()); 176 } 177 }); 178 exit.await(); // All done, we need to compare results: 179 assertEquals(clientRead.get(), serverWritten.get()); 180 assertEquals(serverRead.get(), clientWritten.get()); 181 } 182 } 183 184 private static RawChannel channelOf(int port) throws Exception { 185 URI uri = URI.create("http://127.0.0.1:" + port + "/"); 186 print("raw channel to %s", uri.toString()); 187 HttpRequest req = HttpRequest.newBuilder(uri).build(); 188 HttpResponse<?> r = HttpClient.newHttpClient().send(req, discard(null)); 189 r.body(); 190 return ((HttpResponseImpl) r).rawChannel(); 191 } 192 193 private class TestServer extends Thread { // Powered by Slowpokes 194 195 private final ServerSocket server; 196 197 TestServer(ServerSocket server) throws IOException { 198 this.server = server; 199 } 200 201 @Override 202 public void run() { 203 try (Socket s = server.accept()) { 204 InputStream is = s.getInputStream(); 205 OutputStream os = s.getOutputStream(); 206 207 processHttp(is, os); 208 209 Thread reader = new Thread(() -> { 210 try { 211 long n = readSlowly(is); 212 print("Server read %s bytes", n); 213 serverRead.addAndGet(n); 214 s.shutdownInput(); 215 } catch (Exception e) { 216 e.printStackTrace(); 217 } 218 }); 219 220 Thread writer = new Thread(() -> { 221 try { 222 long n = writeSlowly(os); 223 print("Server written %s bytes", n); 224 serverWritten.addAndGet(n); 225 s.shutdownOutput(); 226 } catch (Exception e) { 227 e.printStackTrace(); 228 } 229 }); 230 231 reader.start(); 232 writer.start(); 233 234 reader.join(); 235 writer.join(); 236 } catch (Exception e) { 237 e.printStackTrace(); 238 } finally { 239 exit.countDown(); 240 } 241 } 242 243 private void processHttp(InputStream is, OutputStream os) 244 throws IOException 245 { 246 os.write("HTTP/1.1 200 OK\r\nContent-length: 0\r\n\r\n".getBytes()); 247 248 // write some initial bytes 249 byte[] initial = byteArrayOfSize(1024); 250 os.write(initial); 251 os.flush(); 252 serverWritten.addAndGet(initial.length); 253 initialWriteStall.countDown(); 254 255 byte[] buf = new byte[1024]; 256 String s = ""; 257 while (true) { 258 int n = is.read(buf); 259 if (n <= 0) { 260 throw new RuntimeException("Unexpected end of request"); 261 } 262 s = s + new String(buf, 0, n); 263 if (s.contains("\r\n\r\n")) { 264 break; 265 } 266 } 267 } 268 269 private long writeSlowly(OutputStream os) throws Exception { 270 byte[] first = byteArrayOfSize(1024); 271 long total = first.length; 272 os.write(first); 273 os.flush(); 274 275 // wait until initial bytes were read 276 initialReadStall.await(); 277 278 // make sure there is something to read, otherwise readStall 279 // will never be counted down. 280 first = byteArrayOfSize(1024); 281 os.write(first); 282 os.flush(); 283 total += first.length; 284 285 // Let's wait for the signal from the raw channel that its read has 286 // stalled, and then continue sending a bit more stuff 287 readStall.await(); 288 for (int i = 0; i < 32; i++) { 289 byte[] b = byteArrayOfSize(1024); 290 os.write(b); 291 os.flush(); 292 total += b.length; 293 TimeUnit.MILLISECONDS.sleep(1); 294 } 295 return total; 296 } 297 298 private long readSlowly(InputStream is) throws Exception { 299 // Wait for the raw channel to fill up its send buffer 300 writeStall.await(); 301 long overall = 0; 302 byte[] array = new byte[1024]; 303 for (int n = 0; n != -1; n = is.read(array)) { 304 TimeUnit.MILLISECONDS.sleep(1); 305 overall += n; 306 } 307 return overall; 308 } 309 } 310 311 private static void print(String format, Object... args) { 312 System.out.println(Thread.currentThread() + ": " + String.format(format, args)); 313 } 314 315 private static byte[] byteArrayOfSize(int bound) { 316 return new byte[new Random().nextInt(1 + bound)]; 317 } 318 }