1 /* 2 * Copyright (c) 2017, Oracle and/or its affiliates. All rights reserved. 3 * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. 4 * 5 * This code is free software; you can redistribute it and/or modify it 6 * under the terms of the GNU General Public License version 2 only, as 7 * published by the Free Software Foundation. 8 * 9 * This code is distributed in the hope that it will be useful, but WITHOUT 10 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or 11 * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License 12 * version 2 for more details (a copy is included in the LICENSE file that 13 * accompanied this code). 14 * 15 * You should have received a copy of the GNU General Public License version 16 * 2 along with this work; if not, write to the Free Software Foundation, 17 * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. 18 * 19 * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA 20 * or visit www.oracle.com if you need additional information or have any 21 * questions. 22 */ 23 24 package jdk.incubator.http; 25 26 import java.io.BufferedOutputStream; 27 import java.io.File; 28 import java.io.FileInputStream; 29 import java.io.IOException; 30 import java.io.InputStream; 31 import java.io.OutputStream; 32 import java.net.Socket; 33 import java.nio.ByteBuffer; 34 import java.security.KeyManagementException; 35 import java.security.KeyStore; 36 import java.security.KeyStoreException; 37 import java.security.NoSuchAlgorithmException; 38 import java.security.UnrecoverableKeyException; 39 import java.security.cert.CertificateException; 40 import java.util.List; 41 import java.util.Random; 42 import java.util.StringTokenizer; 43 import java.util.concurrent.BlockingQueue; 44 import java.util.concurrent.CompletableFuture; 45 import java.util.concurrent.ExecutorService; 46 import java.util.concurrent.Executors; 47 import java.util.concurrent.Flow; 48 import java.util.concurrent.Flow.Subscriber; 49 import java.util.concurrent.LinkedBlockingQueue; 50 import java.util.concurrent.SubmissionPublisher; 51 import java.util.concurrent.atomic.AtomicInteger; 52 import java.util.concurrent.atomic.AtomicLong; 53 import javax.net.ssl.KeyManagerFactory; 54 import javax.net.ssl.*; 55 import javax.net.ssl.TrustManagerFactory; 56 import jdk.incubator.http.internal.common.Utils; 57 import org.testng.annotations.Test; 58 import jdk.incubator.http.internal.common.SSLFlowDelegate; 59 60 @Test 61 public class FlowTest { 62 63 private final SubmissionPublisher<List<ByteBuffer>> srcPublisher; 64 private final ExecutorService executor; 65 private static final long COUNTER = 3000; 66 private static final int LONGS_PER_BUF = 800; 67 static final long TOTAL_LONGS = COUNTER * LONGS_PER_BUF; 68 public static final ByteBuffer SENTINEL = ByteBuffer.allocate(0); 69 static volatile String alpn; 70 71 private final CompletableFuture<Void> completion; 72 73 public FlowTest() throws IOException { 74 executor = Executors.newCachedThreadPool(); 75 srcPublisher = new SubmissionPublisher<>(executor, 20, 76 this::handlePublisherException); 77 SSLContext ctx = (new SimpleSSLContext()).get(); 78 SSLEngine engineClient = ctx.createSSLEngine(); 79 SSLParameters params = ctx.getSupportedSSLParameters(); 80 params.setApplicationProtocols(new String[]{"proto1", "proto2"}); // server will choose proto2 81 params.setProtocols(new String[]{"TLSv1.2"}); // TODO: This is essential. Needs to be protocol impl 82 engineClient.setSSLParameters(params); 83 engineClient.setUseClientMode(true); 84 completion = new CompletableFuture<>(); 85 SSLLoopbackSubscriber looper = new SSLLoopbackSubscriber(ctx, executor); 86 looper.start(); 87 EndSubscriber end = new EndSubscriber(TOTAL_LONGS, completion); 88 SSLFlowDelegate sslClient = new SSLFlowDelegate(engineClient, executor, end, looper); 89 // going to measure how long handshake takes 90 final long start = System.currentTimeMillis(); 91 sslClient.alpn().whenComplete((String s, Throwable t) -> { 92 if (t != null) 93 t.printStackTrace(); 94 long endTime = System.currentTimeMillis(); 95 alpn = s; 96 System.out.println("ALPN: " + alpn); 97 long period = (endTime - start); 98 System.out.printf("Handshake took %d ms\n", period); 99 }); 100 Subscriber<List<ByteBuffer>> reader = sslClient.upstreamReader(); 101 Subscriber<List<ByteBuffer>> writer = sslClient.upstreamWriter(); 102 looper.setReturnSubscriber(reader); 103 // now connect all the pieces 104 srcPublisher.subscribe(writer); 105 String aa = sslClient.alpn().join(); 106 System.out.println("AAALPN = " + aa); 107 } 108 109 static Random rand = new Random(); 110 111 static int randomRange(int lower, int upper) { 112 if (lower > upper) 113 throw new IllegalArgumentException("lower > upper"); 114 int diff = upper - lower; 115 int r = lower + rand.nextInt(diff); 116 return r - (r % 8); // round down to multiple of 8 (align for longs) 117 } 118 119 private void handlePublisherException(Object o, Throwable t) { 120 System.out.println("Src Publisher exception"); 121 t.printStackTrace(System.out); 122 } 123 124 private static ByteBuffer getBuffer(long startingAt) { 125 ByteBuffer buf = ByteBuffer.allocate(LONGS_PER_BUF * 8); 126 for (int j = 0; j < LONGS_PER_BUF; j++) { 127 buf.putLong(startingAt++); 128 } 129 buf.flip(); 130 return buf; 131 } 132 133 @Test 134 public void run() { 135 long count = 0; 136 System.out.printf("Submitting %d buffer arrays\n", COUNTER); 137 System.out.printf("LoopCount should be %d\n", TOTAL_LONGS); 138 for (long i = 0; i < COUNTER; i++) { 139 ByteBuffer b = getBuffer(count); 140 count += LONGS_PER_BUF; 141 srcPublisher.submit(List.of(b)); 142 } 143 System.out.println("Finished submission. Waiting for loopback"); 144 srcPublisher.close(); 145 try { 146 completion.join(); 147 if (!alpn.equals("proto2")) { 148 throw new RuntimeException("wrong alpn received"); 149 } 150 System.out.println("OK"); 151 } finally { 152 executor.shutdownNow(); 153 } 154 } 155 156 /* 157 public static void main(String[]args) throws Exception { 158 FlowTest test = new FlowTest(); 159 test.run(); 160 } 161 */ 162 163 /** 164 * This Subscriber simulates an SSL loopback network. The object itself 165 * accepts outgoing SSL encrypted data which is looped back via two sockets 166 * (one of which is an SSLSocket emulating a server). The method 167 * {@link #setReturnSubscriber(java.util.concurrent.Flow.Subscriber) } 168 * is used to provide the Subscriber which feeds the incoming side 169 * of SSLFlowDelegate. Three threads are used to implement this behavior 170 * and a SubmissionPublisher drives the incoming read side. 171 * <p> 172 * A thread reads from the buffer, writes 173 * to the client j.n.Socket which is connected to a SSLSocket operating 174 * in server mode. A second thread loops back data read from the SSLSocket back to the 175 * client again. A third thread reads the client socket and pushes the data to 176 * a SubmissionPublisher that drives the reader side of the SSLFlowDelegate 177 */ 178 static class SSLLoopbackSubscriber implements Subscriber<List<ByteBuffer>> { 179 private final BlockingQueue<ByteBuffer> buffer; 180 private final Socket clientSock; 181 private final SSLSocket serverSock; 182 private final Thread thread1, thread2, thread3; 183 private volatile Flow.Subscription clientSubscription; 184 private final SubmissionPublisher<List<ByteBuffer>> publisher; 185 186 SSLLoopbackSubscriber(SSLContext ctx, ExecutorService exec) throws IOException { 187 SSLServerSocketFactory fac = ctx.getServerSocketFactory(); 188 SSLServerSocket serv = (SSLServerSocket) fac.createServerSocket(0); 189 SSLParameters params = serv.getSSLParameters(); 190 params.setApplicationProtocols(new String[]{"proto2"}); 191 serv.setSSLParameters(params); 192 193 194 int serverPort = serv.getLocalPort(); 195 clientSock = new Socket("127.0.0.1", serverPort); 196 serverSock = (SSLSocket) serv.accept(); 197 this.buffer = new LinkedBlockingQueue<>(); 198 thread1 = new Thread(this::clientWriter, "clientWriter"); 199 thread2 = new Thread(this::serverLoopback, "serverLoopback"); 200 thread3 = new Thread(this::clientReader, "clientReader"); 201 publisher = new SubmissionPublisher<>(exec, Flow.defaultBufferSize(), 202 this::handlePublisherException); 203 SSLFlowDelegate.Monitor.add(this::monitor); 204 } 205 206 public void start() { 207 thread1.start(); 208 thread2.start(); 209 thread3.start(); 210 } 211 212 private void handlePublisherException(Object o, Throwable t) { 213 System.out.println("Loopback Publisher exception"); 214 t.printStackTrace(System.out); 215 } 216 217 private final AtomicInteger readCount = new AtomicInteger(); 218 219 // reads off the SSLSocket the data from the "server" 220 private void clientReader() { 221 try { 222 InputStream is = clientSock.getInputStream(); 223 final int bufsize = FlowTest.randomRange(512, 16 * 1024); 224 System.out.println("clientReader: bufsize = " + bufsize); 225 while (true) { 226 byte[] buf = new byte[bufsize]; 227 int n = is.read(buf); 228 if (n == -1) { 229 System.out.println("clientReader close: read " 230 + readCount.get() + " bytes"); 231 publisher.close(); 232 sleep(2000); 233 Utils.close(is, clientSock); 234 return; 235 } 236 ByteBuffer bb = ByteBuffer.wrap(buf, 0, n); 237 readCount.addAndGet(n); 238 publisher.submit(List.of(bb)); 239 } 240 } catch (Throwable e) { 241 e.printStackTrace(); 242 Utils.close(clientSock); 243 } 244 } 245 246 // writes the encrypted data from SSLFLowDelegate to the j.n.Socket 247 // which is connected to the SSLSocket emulating a server. 248 private void clientWriter() { 249 long nbytes = 0; 250 try { 251 OutputStream os = 252 new BufferedOutputStream(clientSock.getOutputStream()); 253 254 while (true) { 255 ByteBuffer buf = buffer.take(); 256 if (buf == FlowTest.SENTINEL) { 257 // finished 258 //Utils.sleep(2000); 259 System.out.println("clientWriter close: " + nbytes + " written"); 260 clientSock.shutdownOutput(); 261 System.out.println("clientWriter close return"); 262 return; 263 } 264 int len = buf.remaining(); 265 int written = writeToStream(os, buf); 266 assert len == written; 267 nbytes += len; 268 assert !buf.hasRemaining() 269 : "buffer has " + buf.remaining() + " bytes left"; 270 clientSubscription.request(1); 271 } 272 } catch (Throwable e) { 273 e.printStackTrace(); 274 } 275 } 276 277 private int writeToStream(OutputStream os, ByteBuffer buf) throws IOException { 278 byte[] b = buf.array(); 279 int offset = buf.arrayOffset() + buf.position(); 280 int n = buf.limit() - buf.position(); 281 os.write(b, offset, n); 282 buf.position(buf.limit()); 283 os.flush(); 284 return n; 285 } 286 287 private final AtomicInteger loopCount = new AtomicInteger(); 288 289 public String monitor() { 290 return "serverLoopback: loopcount = " + loopCount.toString() 291 + " clientRead: count = " + readCount.toString(); 292 } 293 294 // thread2 295 private void serverLoopback() { 296 try { 297 InputStream is = serverSock.getInputStream(); 298 OutputStream os = serverSock.getOutputStream(); 299 final int bufsize = FlowTest.randomRange(512, 16 * 1024); 300 System.out.println("serverLoopback: bufsize = " + bufsize); 301 byte[] bb = new byte[bufsize]; 302 while (true) { 303 int n = is.read(bb); 304 if (n == -1) { 305 sleep(2000); 306 is.close(); 307 serverSock.close(); 308 return; 309 } 310 os.write(bb, 0, n); 311 os.flush(); 312 loopCount.addAndGet(n); 313 } 314 } catch (Throwable e) { 315 e.printStackTrace(); 316 } 317 } 318 319 320 /** 321 * This needs to be called before the chain is subscribed. It can't be 322 * supplied in the constructor. 323 */ 324 public void setReturnSubscriber(Subscriber<List<ByteBuffer>> returnSubscriber) { 325 publisher.subscribe(returnSubscriber); 326 } 327 328 @Override 329 public void onSubscribe(Flow.Subscription subscription) { 330 clientSubscription = subscription; 331 clientSubscription.request(5); 332 } 333 334 @Override 335 public void onNext(List<ByteBuffer> item) { 336 try { 337 for (ByteBuffer b : item) 338 buffer.put(b); 339 } catch (InterruptedException e) { 340 e.printStackTrace(); 341 Utils.close(clientSock); 342 } 343 } 344 345 @Override 346 public void onError(Throwable throwable) { 347 throwable.printStackTrace(); 348 Utils.close(clientSock); 349 } 350 351 @Override 352 public void onComplete() { 353 try { 354 buffer.put(FlowTest.SENTINEL); 355 } catch (InterruptedException e) { 356 e.printStackTrace(); 357 Utils.close(clientSock); 358 } 359 } 360 } 361 362 /** 363 * The final subscriber which receives the decrypted looped-back data. 364 * Just needs to compare the data with what was sent. The given CF is 365 * either completed exceptionally with an error or normally on success. 366 */ 367 static class EndSubscriber implements Subscriber<List<ByteBuffer>> { 368 369 private final long nbytes; 370 371 private final AtomicLong counter; 372 private volatile Flow.Subscription subscription; 373 private final CompletableFuture<Void> completion; 374 375 EndSubscriber(long nbytes, CompletableFuture<Void> completion) { 376 counter = new AtomicLong(0); 377 this.nbytes = nbytes; 378 this.completion = completion; 379 } 380 381 @Override 382 public void onSubscribe(Flow.Subscription subscription) { 383 this.subscription = subscription; 384 subscription.request(5); 385 } 386 387 public static String info(List<ByteBuffer> i) { 388 StringBuilder sb = new StringBuilder(); 389 sb.append("size: ").append(Integer.toString(i.size())); 390 int x = 0; 391 for (ByteBuffer b : i) 392 x += b.remaining(); 393 sb.append(" bytes: " + Integer.toString(x)); 394 return sb.toString(); 395 } 396 397 @Override 398 public void onNext(List<ByteBuffer> buffers) { 399 long currval = counter.get(); 400 //if (currval % 500 == 0) { 401 //System.out.println("End: " + currval); 402 //} 403 404 for (ByteBuffer buf : buffers) { 405 while (buf.hasRemaining()) { 406 long n = buf.getLong(); 407 //if (currval > (FlowTest.TOTAL_LONGS - 50)) { 408 //System.out.println("End: " + currval); 409 //} 410 if (n != currval++) { 411 System.out.println("ERROR at " + n + " != " + (currval - 1)); 412 completion.completeExceptionally(new RuntimeException("ERROR")); 413 subscription.cancel(); 414 return; 415 } 416 } 417 } 418 419 counter.set(currval); 420 subscription.request(1); 421 } 422 423 @Override 424 public void onError(Throwable throwable) { 425 completion.completeExceptionally(throwable); 426 } 427 428 @Override 429 public void onComplete() { 430 long n = counter.get(); 431 if (n != nbytes) { 432 System.out.printf("nbytes=%d n=%d\n", nbytes, n); 433 completion.completeExceptionally(new RuntimeException("ERROR AT END")); 434 } else { 435 System.out.println("DONE OK: counter = " + n); 436 completion.complete(null); 437 } 438 } 439 } 440 441 /** 442 * Creates a simple usable SSLContext for SSLSocketFactory 443 * or a HttpsServer using either a given keystore or a default 444 * one in the test tree. 445 * <p> 446 * Using this class with a security manager requires the following 447 * permissions to be granted: 448 * <p> 449 * permission "java.util.PropertyPermission" "test.src.path", "read"; 450 * permission java.io.FilePermission 451 * "${test.src}/../../../../lib/testlibrary/jdk/testlibrary/testkeys", "read"; 452 * The exact path above depends on the location of the test. 453 */ 454 static class SimpleSSLContext { 455 456 private final SSLContext ssl; 457 458 /** 459 * Loads default keystore from SimpleSSLContext source directory 460 */ 461 public SimpleSSLContext() throws IOException { 462 String paths = System.getProperty("test.src.path"); 463 StringTokenizer st = new StringTokenizer(paths, File.pathSeparator); 464 boolean securityExceptions = false; 465 SSLContext sslContext = null; 466 while (st.hasMoreTokens()) { 467 String path = st.nextToken(); 468 try { 469 File f = new File(path, "../../../../lib/testlibrary/jdk/testlibrary/testkeys"); 470 if (f.exists()) { 471 try (FileInputStream fis = new FileInputStream(f)) { 472 sslContext = init(fis); 473 break; 474 } 475 } 476 } catch (SecurityException e) { 477 // catch and ignore because permission only required 478 // for one entry on path (at most) 479 securityExceptions = true; 480 } 481 } 482 if (securityExceptions) { 483 System.out.println("SecurityExceptions thrown on loading testkeys"); 484 } 485 ssl = sslContext; 486 } 487 488 private SSLContext init(InputStream i) throws IOException { 489 try { 490 char[] passphrase = "passphrase".toCharArray(); 491 KeyStore ks = KeyStore.getInstance("JKS"); 492 ks.load(i, passphrase); 493 494 KeyManagerFactory kmf = KeyManagerFactory.getInstance("SunX509"); 495 kmf.init(ks, passphrase); 496 497 TrustManagerFactory tmf = TrustManagerFactory.getInstance("SunX509"); 498 tmf.init(ks); 499 500 SSLContext ssl = SSLContext.getInstance("TLS"); 501 ssl.init(kmf.getKeyManagers(), tmf.getTrustManagers(), null); 502 return ssl; 503 } catch (KeyManagementException | KeyStoreException | 504 UnrecoverableKeyException | CertificateException | 505 NoSuchAlgorithmException e) { 506 throw new RuntimeException(e.getMessage()); 507 } 508 } 509 510 public SSLContext get() { 511 return ssl; 512 } 513 } 514 515 private static void sleep(int millis) { 516 try { 517 Thread.sleep(millis); 518 } catch (Exception e) { 519 e.printStackTrace(); 520 } 521 } 522 }