1 /*
   2  * Copyright (c) 2017, Oracle and/or its affiliates. All rights reserved.
   3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
   4  *
   5  * This code is free software; you can redistribute it and/or modify it
   6  * under the terms of the GNU General Public License version 2 only, as
   7  * published by the Free Software Foundation.
   8  *
   9  * This code is distributed in the hope that it will be useful, but WITHOUT
  10  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
  11  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
  12  * version 2 for more details (a copy is included in the LICENSE file that
  13  * accompanied this code).
  14  *
  15  * You should have received a copy of the GNU General Public License version
  16  * 2 along with this work; if not, write to the Free Software Foundation,
  17  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
  18  *
  19  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
  20  * or visit www.oracle.com if you need additional information or have any
  21  * questions.
  22  */
  23 
  24 package jdk.incubator.http;
  25 
  26 import java.io.BufferedOutputStream;
  27 import java.io.File;
  28 import java.io.FileInputStream;
  29 import java.io.IOException;
  30 import java.io.InputStream;
  31 import java.io.OutputStream;
  32 import java.net.Socket;
  33 import java.nio.ByteBuffer;
  34 import java.security.KeyManagementException;
  35 import java.security.KeyStore;
  36 import java.security.KeyStoreException;
  37 import java.security.NoSuchAlgorithmException;
  38 import java.security.UnrecoverableKeyException;
  39 import java.security.cert.CertificateException;
  40 import java.util.List;
  41 import java.util.Random;
  42 import java.util.StringTokenizer;
  43 import java.util.concurrent.BlockingQueue;
  44 import java.util.concurrent.CompletableFuture;
  45 import java.util.concurrent.ExecutorService;
  46 import java.util.concurrent.Executors;
  47 import java.util.concurrent.Flow;
  48 import java.util.concurrent.Flow.Subscriber;
  49 import java.util.concurrent.LinkedBlockingQueue;
  50 import java.util.concurrent.SubmissionPublisher;
  51 import java.util.concurrent.atomic.AtomicInteger;
  52 import java.util.concurrent.atomic.AtomicLong;
  53 import javax.net.ssl.KeyManagerFactory;
  54 import javax.net.ssl.*;
  55 import javax.net.ssl.TrustManagerFactory;
  56 import jdk.incubator.http.internal.common.Utils;
  57 import org.testng.annotations.Test;
  58 import jdk.incubator.http.internal.common.SSLFlowDelegate;
  59 
  60 @Test
  61 public class FlowTest {
  62 
  63     private final SubmissionPublisher<List<ByteBuffer>> srcPublisher;
  64     private final ExecutorService executor;
  65     private static final long COUNTER = 3000;
  66     private static final int LONGS_PER_BUF = 800;
  67     static final long TOTAL_LONGS = COUNTER * LONGS_PER_BUF;
  68     public static final ByteBuffer SENTINEL = ByteBuffer.allocate(0);
  69     static volatile String alpn;
  70 
  71     private final CompletableFuture<Void> completion;
  72 
  73     public FlowTest() throws IOException {
  74         executor = Executors.newCachedThreadPool();
  75         srcPublisher = new SubmissionPublisher<>(executor, 20,
  76                                                  this::handlePublisherException);
  77         SSLContext ctx = (new SimpleSSLContext()).get();
  78         SSLEngine engineClient = ctx.createSSLEngine();
  79         SSLParameters params = ctx.getSupportedSSLParameters();
  80         params.setApplicationProtocols(new String[]{"proto1", "proto2"}); // server will choose proto2
  81         params.setProtocols(new String[]{"TLSv1.2"}); // TODO: This is essential. Needs to be protocol impl
  82         engineClient.setSSLParameters(params);
  83         engineClient.setUseClientMode(true);
  84         completion = new CompletableFuture<>();
  85         SSLLoopbackSubscriber looper = new SSLLoopbackSubscriber(ctx, executor);
  86         looper.start();
  87         EndSubscriber end = new EndSubscriber(TOTAL_LONGS, completion);
  88         SSLFlowDelegate sslClient = new SSLFlowDelegate(engineClient, executor, end, looper);
  89         // going to measure how long handshake takes
  90         final long start = System.currentTimeMillis();
  91         sslClient.alpn().whenComplete((String s, Throwable t) -> {
  92             if (t != null)
  93                 t.printStackTrace();
  94             long endTime = System.currentTimeMillis();
  95             alpn = s;
  96             System.out.println("ALPN: " + alpn);
  97             long period = (endTime - start);
  98             System.out.printf("Handshake took %d ms\n", period);
  99         });
 100         Subscriber<List<ByteBuffer>> reader = sslClient.upstreamReader();
 101         Subscriber<List<ByteBuffer>> writer = sslClient.upstreamWriter();
 102         looper.setReturnSubscriber(reader);
 103         // now connect all the pieces
 104         srcPublisher.subscribe(writer);
 105         String aa = sslClient.alpn().join();
 106         System.out.println("AAALPN = " + aa);
 107     }
 108 
 109     static Random rand = new Random();
 110 
 111     static int randomRange(int lower, int upper) {
 112         if (lower > upper)
 113             throw new IllegalArgumentException("lower > upper");
 114         int diff = upper - lower;
 115         int r = lower + rand.nextInt(diff);
 116         return r - (r % 8); // round down to multiple of 8 (align for longs)
 117     }
 118 
 119     private void handlePublisherException(Object o, Throwable t) {
 120         System.out.println("Src Publisher exception");
 121         t.printStackTrace(System.out);
 122     }
 123 
 124     private static ByteBuffer getBuffer(long startingAt) {
 125         ByteBuffer buf = ByteBuffer.allocate(LONGS_PER_BUF * 8);
 126         for (int j = 0; j < LONGS_PER_BUF; j++) {
 127             buf.putLong(startingAt++);
 128         }
 129         buf.flip();
 130         return buf;
 131     }
 132 
 133     @Test
 134     public void run() {
 135         long count = 0;
 136         System.out.printf("Submitting %d buffer arrays\n", COUNTER);
 137         System.out.printf("LoopCount should be %d\n", TOTAL_LONGS);
 138         for (long i = 0; i < COUNTER; i++) {
 139             ByteBuffer b = getBuffer(count);
 140             count += LONGS_PER_BUF;
 141             srcPublisher.submit(List.of(b));
 142         }
 143         System.out.println("Finished submission. Waiting for loopback");
 144         srcPublisher.close();
 145         try {
 146             completion.join();
 147             if (!alpn.equals("proto2")) {
 148                 throw new RuntimeException("wrong alpn received");
 149             }
 150             System.out.println("OK");
 151         } finally {
 152             executor.shutdownNow();
 153         }
 154     }
 155 
 156 /*
 157     public static void main(String[]args) throws Exception {
 158         FlowTest test = new FlowTest();
 159         test.run();
 160     }
 161 */
 162 
 163     /**
 164      * This Subscriber simulates an SSL loopback network. The object itself
 165      * accepts outgoing SSL encrypted data which is looped back via two sockets
 166      * (one of which is an SSLSocket emulating a server). The method
 167      * {@link #setReturnSubscriber(java.util.concurrent.Flow.Subscriber) }
 168      * is used to provide the Subscriber which feeds the incoming side
 169      * of SSLFlowDelegate. Three threads are used to implement this behavior
 170      * and a SubmissionPublisher drives the incoming read side.
 171      * <p>
 172      * A thread reads from the buffer, writes
 173      * to the client j.n.Socket which is connected to a SSLSocket operating
 174      * in server mode. A second thread loops back data read from the SSLSocket back to the
 175      * client again. A third thread reads the client socket and pushes the data to
 176      * a SubmissionPublisher that drives the reader side of the SSLFlowDelegate
 177      */
 178     static class SSLLoopbackSubscriber implements Subscriber<List<ByteBuffer>> {
 179         private final BlockingQueue<ByteBuffer> buffer;
 180         private final Socket clientSock;
 181         private final SSLSocket serverSock;
 182         private final Thread thread1, thread2, thread3;
 183         private volatile Flow.Subscription clientSubscription;
 184         private final SubmissionPublisher<List<ByteBuffer>> publisher;
 185 
 186         SSLLoopbackSubscriber(SSLContext ctx, ExecutorService exec) throws IOException {
 187             SSLServerSocketFactory fac = ctx.getServerSocketFactory();
 188             SSLServerSocket serv = (SSLServerSocket) fac.createServerSocket(0);
 189             SSLParameters params = serv.getSSLParameters();
 190             params.setApplicationProtocols(new String[]{"proto2"});
 191             serv.setSSLParameters(params);
 192 
 193 
 194             int serverPort = serv.getLocalPort();
 195             clientSock = new Socket("127.0.0.1", serverPort);
 196             serverSock = (SSLSocket) serv.accept();
 197             this.buffer = new LinkedBlockingQueue<>();
 198             thread1 = new Thread(this::clientWriter, "clientWriter");
 199             thread2 = new Thread(this::serverLoopback, "serverLoopback");
 200             thread3 = new Thread(this::clientReader, "clientReader");
 201             publisher = new SubmissionPublisher<>(exec, Flow.defaultBufferSize(),
 202                     this::handlePublisherException);
 203             SSLFlowDelegate.Monitor.add(this::monitor);
 204         }
 205 
 206         public void start() {
 207             thread1.start();
 208             thread2.start();
 209             thread3.start();
 210         }
 211 
 212         private void handlePublisherException(Object o, Throwable t) {
 213             System.out.println("Loopback Publisher exception");
 214             t.printStackTrace(System.out);
 215         }
 216 
 217         private final AtomicInteger readCount = new AtomicInteger();
 218 
 219         // reads off the SSLSocket the data from the "server"
 220         private void clientReader() {
 221             try {
 222                 InputStream is = clientSock.getInputStream();
 223                 final int bufsize = FlowTest.randomRange(512, 16 * 1024);
 224                 System.out.println("clientReader: bufsize = " + bufsize);
 225                 while (true) {
 226                     byte[] buf = new byte[bufsize];
 227                     int n = is.read(buf);
 228                     if (n == -1) {
 229                         System.out.println("clientReader close: read "
 230                                 + readCount.get() + " bytes");
 231                         publisher.close();
 232                         sleep(2000);
 233                         Utils.close(is, clientSock);
 234                         return;
 235                     }
 236                     ByteBuffer bb = ByteBuffer.wrap(buf, 0, n);
 237                     readCount.addAndGet(n);
 238                     publisher.submit(List.of(bb));
 239                 }
 240             } catch (Throwable e) {
 241                 e.printStackTrace();
 242                 Utils.close(clientSock);
 243             }
 244         }
 245 
 246         // writes the encrypted data from SSLFLowDelegate to the j.n.Socket
 247         // which is connected to the SSLSocket emulating a server.
 248         private void clientWriter() {
 249             long nbytes = 0;
 250             try {
 251                 OutputStream os =
 252                         new BufferedOutputStream(clientSock.getOutputStream());
 253 
 254                 while (true) {
 255                     ByteBuffer buf = buffer.take();
 256                     if (buf == FlowTest.SENTINEL) {
 257                         // finished
 258                         //Utils.sleep(2000);
 259                         System.out.println("clientWriter close: " + nbytes + " written");
 260                         clientSock.shutdownOutput();
 261                         System.out.println("clientWriter close return");
 262                         return;
 263                     }
 264                     int len = buf.remaining();
 265                     int written = writeToStream(os, buf);
 266                     assert len == written;
 267                     nbytes += len;
 268                     assert !buf.hasRemaining()
 269                             : "buffer has " + buf.remaining() + " bytes left";
 270                     clientSubscription.request(1);
 271                 }
 272             } catch (Throwable e) {
 273                 e.printStackTrace();
 274             }
 275         }
 276 
 277         private int writeToStream(OutputStream os, ByteBuffer buf) throws IOException {
 278             byte[] b = buf.array();
 279             int offset = buf.arrayOffset() + buf.position();
 280             int n = buf.limit() - buf.position();
 281             os.write(b, offset, n);
 282             buf.position(buf.limit());
 283             os.flush();
 284             return n;
 285         }
 286 
 287         private final AtomicInteger loopCount = new AtomicInteger();
 288 
 289         public String monitor() {
 290             return "serverLoopback: loopcount = " + loopCount.toString()
 291                     + " clientRead: count = " + readCount.toString();
 292         }
 293 
 294         // thread2
 295         private void serverLoopback() {
 296             try {
 297                 InputStream is = serverSock.getInputStream();
 298                 OutputStream os = serverSock.getOutputStream();
 299                 final int bufsize = FlowTest.randomRange(512, 16 * 1024);
 300                 System.out.println("serverLoopback: bufsize = " + bufsize);
 301                 byte[] bb = new byte[bufsize];
 302                 while (true) {
 303                     int n = is.read(bb);
 304                     if (n == -1) {
 305                         sleep(2000);
 306                         is.close();
 307                         serverSock.close();
 308                         return;
 309                     }
 310                     os.write(bb, 0, n);
 311                     os.flush();
 312                     loopCount.addAndGet(n);
 313                 }
 314             } catch (Throwable e) {
 315                 e.printStackTrace();
 316             }
 317         }
 318 
 319 
 320         /**
 321          * This needs to be called before the chain is subscribed. It can't be
 322          * supplied in the constructor.
 323          */
 324         public void setReturnSubscriber(Subscriber<List<ByteBuffer>> returnSubscriber) {
 325             publisher.subscribe(returnSubscriber);
 326         }
 327 
 328         @Override
 329         public void onSubscribe(Flow.Subscription subscription) {
 330             clientSubscription = subscription;
 331             clientSubscription.request(5);
 332         }
 333 
 334         @Override
 335         public void onNext(List<ByteBuffer> item) {
 336             try {
 337                 for (ByteBuffer b : item)
 338                     buffer.put(b);
 339             } catch (InterruptedException e) {
 340                 e.printStackTrace();
 341                 Utils.close(clientSock);
 342             }
 343         }
 344 
 345         @Override
 346         public void onError(Throwable throwable) {
 347             throwable.printStackTrace();
 348             Utils.close(clientSock);
 349         }
 350 
 351         @Override
 352         public void onComplete() {
 353             try {
 354                 buffer.put(FlowTest.SENTINEL);
 355             } catch (InterruptedException e) {
 356                 e.printStackTrace();
 357                 Utils.close(clientSock);
 358             }
 359         }
 360     }
 361 
 362     /**
 363      * The final subscriber which receives the decrypted looped-back data.
 364      * Just needs to compare the data with what was sent. The given CF is
 365      * either completed exceptionally with an error or normally on success.
 366      */
 367     static class EndSubscriber implements Subscriber<List<ByteBuffer>> {
 368 
 369         private final long nbytes;
 370 
 371         private final AtomicLong counter;
 372         private volatile Flow.Subscription subscription;
 373         private final CompletableFuture<Void> completion;
 374 
 375         EndSubscriber(long nbytes, CompletableFuture<Void> completion) {
 376             counter = new AtomicLong(0);
 377             this.nbytes = nbytes;
 378             this.completion = completion;
 379         }
 380 
 381         @Override
 382         public void onSubscribe(Flow.Subscription subscription) {
 383             this.subscription = subscription;
 384             subscription.request(5);
 385         }
 386 
 387         public static String info(List<ByteBuffer> i) {
 388             StringBuilder sb = new StringBuilder();
 389             sb.append("size: ").append(Integer.toString(i.size()));
 390             int x = 0;
 391             for (ByteBuffer b : i)
 392                 x += b.remaining();
 393             sb.append(" bytes: " + Integer.toString(x));
 394             return sb.toString();
 395         }
 396 
 397         @Override
 398         public void onNext(List<ByteBuffer> buffers) {
 399             long currval = counter.get();
 400             //if (currval % 500 == 0) {
 401             //System.out.println("End: " + currval);
 402             //}
 403 
 404             for (ByteBuffer buf : buffers) {
 405                 while (buf.hasRemaining()) {
 406                     long n = buf.getLong();
 407                     //if (currval > (FlowTest.TOTAL_LONGS - 50)) {
 408                     //System.out.println("End: " + currval);
 409                     //}
 410                     if (n != currval++) {
 411                         System.out.println("ERROR at " + n + " != " + (currval - 1));
 412                         completion.completeExceptionally(new RuntimeException("ERROR"));
 413                         subscription.cancel();
 414                         return;
 415                     }
 416                 }
 417             }
 418 
 419             counter.set(currval);
 420             subscription.request(1);
 421         }
 422 
 423         @Override
 424         public void onError(Throwable throwable) {
 425             completion.completeExceptionally(throwable);
 426         }
 427 
 428         @Override
 429         public void onComplete() {
 430             long n = counter.get();
 431             if (n != nbytes) {
 432                 System.out.printf("nbytes=%d n=%d\n", nbytes, n);
 433                 completion.completeExceptionally(new RuntimeException("ERROR AT END"));
 434             } else {
 435                 System.out.println("DONE OK: counter = " + n);
 436                 completion.complete(null);
 437             }
 438         }
 439     }
 440 
 441     /**
 442      * Creates a simple usable SSLContext for SSLSocketFactory
 443      * or a HttpsServer using either a given keystore or a default
 444      * one in the test tree.
 445      * <p>
 446      * Using this class with a security manager requires the following
 447      * permissions to be granted:
 448      * <p>
 449      * permission "java.util.PropertyPermission" "test.src.path", "read";
 450      * permission java.io.FilePermission
 451      * "${test.src}/../../../../lib/testlibrary/jdk/testlibrary/testkeys", "read";
 452      * The exact path above depends on the location of the test.
 453      */
 454     static class SimpleSSLContext {
 455 
 456         private final SSLContext ssl;
 457 
 458         /**
 459          * Loads default keystore from SimpleSSLContext source directory
 460          */
 461         public SimpleSSLContext() throws IOException {
 462             String paths = System.getProperty("test.src.path");
 463             StringTokenizer st = new StringTokenizer(paths, File.pathSeparator);
 464             boolean securityExceptions = false;
 465             SSLContext sslContext = null;
 466             while (st.hasMoreTokens()) {
 467                 String path = st.nextToken();
 468                 try {
 469                     File f = new File(path, "../../../../lib/testlibrary/jdk/testlibrary/testkeys");
 470                     if (f.exists()) {
 471                         try (FileInputStream fis = new FileInputStream(f)) {
 472                             sslContext = init(fis);
 473                             break;
 474                         }
 475                     }
 476                 } catch (SecurityException e) {
 477                     // catch and ignore because permission only required
 478                     // for one entry on path (at most)
 479                     securityExceptions = true;
 480                 }
 481             }
 482             if (securityExceptions) {
 483                 System.out.println("SecurityExceptions thrown on loading testkeys");
 484             }
 485             ssl = sslContext;
 486         }
 487 
 488         private SSLContext init(InputStream i) throws IOException {
 489             try {
 490                 char[] passphrase = "passphrase".toCharArray();
 491                 KeyStore ks = KeyStore.getInstance("JKS");
 492                 ks.load(i, passphrase);
 493 
 494                 KeyManagerFactory kmf = KeyManagerFactory.getInstance("SunX509");
 495                 kmf.init(ks, passphrase);
 496 
 497                 TrustManagerFactory tmf = TrustManagerFactory.getInstance("SunX509");
 498                 tmf.init(ks);
 499 
 500                 SSLContext ssl = SSLContext.getInstance("TLS");
 501                 ssl.init(kmf.getKeyManagers(), tmf.getTrustManagers(), null);
 502                 return ssl;
 503             } catch (KeyManagementException | KeyStoreException |
 504                     UnrecoverableKeyException | CertificateException |
 505                     NoSuchAlgorithmException e) {
 506                 throw new RuntimeException(e.getMessage());
 507             }
 508         }
 509 
 510         public SSLContext get() {
 511             return ssl;
 512         }
 513     }
 514 
 515     private static void sleep(int millis) {
 516         try {
 517             Thread.sleep(millis);
 518         } catch (Exception e) {
 519             e.printStackTrace();
 520         }
 521     }
 522 }