1 /*
   2  * Copyright (c) 2017, Oracle and/or its affiliates. All rights reserved.
   3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
   4  *
   5  * This code is free software; you can redistribute it and/or modify it
   6  * under the terms of the GNU General Public License version 2 only, as
   7  * published by the Free Software Foundation.
   8  *
   9  * This code is distributed in the hope that it will be useful, but WITHOUT
  10  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
  11  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
  12  * version 2 for more details (a copy is included in the LICENSE file that
  13  * accompanied this code).
  14  *
  15  * You should have received a copy of the GNU General Public License version
  16  * 2 along with this work; if not, write to the Free Software Foundation,
  17  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
  18  *
  19  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
  20  * or visit www.oracle.com if you need additional information or have any
  21  * questions.
  22  */
  23 
  24 import javax.net.ServerSocketFactory;
  25 import javax.net.ssl.SSLContext;
  26 import javax.net.ssl.SSLHandshakeException;
  27 import javax.net.ssl.SSLSocket;
  28 import java.io.DataInputStream;
  29 import java.io.IOException;
  30 import java.io.UncheckedIOException;
  31 import java.net.ServerSocket;
  32 import java.net.Socket;
  33 import java.net.URI;
  34 import java.util.List;
  35 import java.util.concurrent.CompletableFuture;
  36 import java.util.concurrent.CompletionException;
  37 import jdk.incubator.http.HttpClient;
  38 import jdk.incubator.http.HttpClient.Version;
  39 import jdk.incubator.http.HttpResponse;
  40 import jdk.incubator.http.HttpRequest;
  41 import static java.lang.System.out;
  42 import static jdk.incubator.http.HttpResponse.BodyHandler.discard;
  43 
  44 /**
  45  * @test
  46  * @run main/othervm HandshakeFailureTest
  47  * @summary Verify SSLHandshakeException is received when the handshake fails,
  48  * either because the server closes ( EOF ) the connection during handshaking
  49  * or no cipher suite ( or similar ) can be negotiated.
  50  */
  51 // To switch on debugging use:
  52 // @run main/othervm -Djdk.internal.httpclient.debug=true HandshakeFailureTest
  53 public class HandshakeFailureTest {
  54 
  55     // The number of iterations each testXXXClient performs. Can be increased
  56     // when running standalone testing.
  57     static final int TIMES = 10;
  58 
  59     public static void main(String[] args) throws Exception {
  60         HandshakeFailureTest test = new HandshakeFailureTest();
  61         List<AbstractServer> servers = List.of( new PlainServer(), new SSLServer());
  62 
  63         for (AbstractServer server : servers) {
  64             try (server) {
  65                 out.format("%n%n------ Testing with server:%s ------%n", server);
  66                 URI uri = new URI("https://127.0.0.1:" + server.getPort() + "/");
  67 
  68                 test.testSyncSameClient(uri, Version.HTTP_1_1);
  69                 test.testSyncSameClient(uri, Version.HTTP_2);
  70                 test.testSyncDiffClient(uri, Version.HTTP_1_1);
  71                 test.testSyncDiffClient(uri, Version.HTTP_2);
  72 
  73                 test.testAsyncSameClient(uri, Version.HTTP_1_1);
  74                 test.testAsyncSameClient(uri, Version.HTTP_2);
  75                 test.testAsyncDiffClient(uri, Version.HTTP_1_1);
  76                 test.testAsyncDiffClient(uri, Version.HTTP_2);
  77             }
  78         }
  79     }
  80 
  81     void testSyncSameClient(URI uri, Version version) throws Exception {
  82         out.printf("%n--- testSyncSameClient %s ---%n", version);
  83         HttpClient client = HttpClient.newHttpClient();
  84         for (int i = 0; i < TIMES; i++) {
  85             out.printf("iteration %d%n", i);
  86             HttpRequest request = HttpRequest.newBuilder(uri)
  87                                              .version(version)
  88                                              .build();
  89             try {
  90                 HttpResponse<Void> response = client.send(request, discard(null));
  91                 String msg = String.format("UNEXPECTED response=%s%n", response);
  92                 throw new RuntimeException(msg);
  93             } catch (SSLHandshakeException expected) {
  94                 out.printf("Client: caught expected exception: %s%n", expected);
  95             }
  96         }
  97     }
  98 
  99     void testSyncDiffClient(URI uri, Version version) throws Exception {
 100         out.printf("%n--- testSyncDiffClient %s ---%n", version);
 101         for (int i = 0; i < TIMES; i++) {
 102             out.printf("iteration %d%n", i);
 103             // a new client each time
 104             HttpClient client = HttpClient.newHttpClient();
 105             HttpRequest request = HttpRequest.newBuilder(uri)
 106                                              .version(version)
 107                                              .build();
 108             try {
 109                 HttpResponse<Void> response = client.send(request, discard(null));
 110                 String msg = String.format("UNEXPECTED response=%s%n", response);
 111                 throw new RuntimeException(msg);
 112             } catch (SSLHandshakeException expected) {
 113                 out.printf("Client: caught expected exception: %s%n", expected);
 114             }
 115         }
 116     }
 117 
 118     void testAsyncSameClient(URI uri, Version version) throws Exception {
 119         out.printf("%n--- testAsyncSameClient %s ---%n", version);
 120         HttpClient client = HttpClient.newHttpClient();
 121         for (int i = 0; i < TIMES; i++) {
 122             out.printf("iteration %d%n", i);
 123             HttpRequest request = HttpRequest.newBuilder(uri)
 124                                              .version(version)
 125                                              .build();
 126             CompletableFuture<HttpResponse<Void>> response =
 127                         client.sendAsync(request, discard(null));
 128             try {
 129                 response.join();
 130                 String msg = String.format("UNEXPECTED response=%s%n", response);
 131                 throw new RuntimeException(msg);
 132             } catch (CompletionException ce) {
 133                 if (ce.getCause() instanceof SSLHandshakeException) {
 134                     out.printf("Client: caught expected exception: %s%n", ce.getCause());
 135                 } else {
 136                     out.printf("Client: caught UNEXPECTED exception: %s%n", ce.getCause());
 137                     throw ce;
 138                 }
 139             }
 140         }
 141     }
 142 
 143     void testAsyncDiffClient(URI uri, Version version) throws Exception {
 144         out.printf("%n--- testAsyncDiffClient %s ---%n", version);
 145         for (int i = 0; i < TIMES; i++) {
 146             out.printf("iteration %d%n", i);
 147             // a new client each time
 148             HttpClient client = HttpClient.newHttpClient();
 149             HttpRequest request = HttpRequest.newBuilder(uri)
 150                                              .version(version)
 151                                              .build();
 152             CompletableFuture<HttpResponse<Void>> response =
 153                     client.sendAsync(request, discard(null));
 154             try {
 155                 response.join();
 156                 String msg = String.format("UNEXPECTED response=%s%n", response);
 157                 throw new RuntimeException(msg);
 158             } catch (CompletionException ce) {
 159                 if (ce.getCause() instanceof SSLHandshakeException) {
 160                     out.printf("Client: caught expected exception: %s%n", ce.getCause());
 161                 } else {
 162                     out.printf("Client: caught UNEXPECTED exception: %s%n", ce.getCause());
 163                     throw ce;
 164                 }
 165             }
 166         }
 167     }
 168 
 169     /** Common supertype for PlainServer and SSLServer. */
 170     static abstract class AbstractServer extends Thread implements AutoCloseable {
 171         protected final ServerSocket ss;
 172         protected volatile boolean closed;
 173 
 174         AbstractServer(String name, ServerSocket ss) throws IOException {
 175             super(name);
 176             this.ss = ss;
 177             this.start();
 178         }
 179 
 180         int getPort() { return ss.getLocalPort(); }
 181 
 182         @Override
 183         public void close() {
 184             if (closed)
 185                 return;
 186             closed = true;
 187             try {
 188                 ss.close();
 189             } catch (IOException e) {
 190                 throw new UncheckedIOException("Unexpected", e);
 191             }
 192         }
 193     }
 194 
 195     /** Emulates a server-side, using plain cleartext Sockets, that just closes
 196      * the connection, after a small variable delay. */
 197     static class PlainServer extends AbstractServer {
 198         private volatile int count;
 199 
 200         PlainServer() throws IOException {
 201             super("PlainServer", new ServerSocket(0));
 202         }
 203 
 204         @Override
 205         public void run() {
 206             while (!closed) {
 207                 try (Socket s = ss.accept()) {
 208                     count++;
 209 
 210                     /*   SSL record layer - contains the client hello
 211                     struct {
 212                         uint8 major, minor;
 213                     } ProtocolVersion;
 214 
 215                     enum {
 216                         change_cipher_spec(20), alert(21), handshake(22),
 217                         application_data(23), (255)
 218                     } ContentType;
 219 
 220                     struct {
 221                         ContentType type;
 222                         ProtocolVersion version;
 223                         uint16 length;
 224                         opaque fragment[SSLPlaintext.length];
 225                     } SSLPlaintext;   */
 226                     DataInputStream din =  new DataInputStream(s.getInputStream());
 227                     int contentType = din.read();
 228                     out.println("ContentType:" + contentType);
 229                     int majorVersion = din.read();
 230                     out.println("Major:" + majorVersion);
 231                     int minorVersion = din.read();
 232                     out.println("Minor:" + minorVersion);
 233                     int length = din.readShort();
 234                     out.println("length:" + length);
 235                     byte[] ba = new byte[length];
 236                     din.readFully(ba);
 237 
 238                     // simulate various delays in response
 239                     Thread.sleep(10 * (count % 10));
 240                     s.close(); // close without giving any reply
 241                 } catch (IOException e) {
 242                     if (!closed)
 243                         out.println("Unexpected" + e);
 244                 } catch (InterruptedException e) {
 245                     throw new RuntimeException(e);
 246                 }
 247             }
 248         }
 249     }
 250 
 251     /** Emulates a server-side, using SSL Sockets, that will fail during
 252      * handshaking, as there are no cipher suites in common. */
 253     static class SSLServer extends AbstractServer {
 254         static final SSLContext sslContext = createUntrustingContext();
 255         static final ServerSocketFactory factory = sslContext.getServerSocketFactory();
 256 
 257         static SSLContext createUntrustingContext() {
 258             try {
 259                 SSLContext sslContext = SSLContext.getInstance("TLSv1.2");
 260                 sslContext.init(null, null, null);
 261                 return sslContext;
 262             } catch (Throwable t) {
 263                 throw new AssertionError(t);
 264             }
 265         }
 266 
 267         SSLServer() throws IOException {
 268             super("SSLServer", factory.createServerSocket(0));
 269         }
 270 
 271         @Override
 272         public void run() {
 273             while (!closed) {
 274                 try (SSLSocket s = (SSLSocket)ss.accept()) {
 275                     s.getInputStream().read();  // will throw SHE here
 276 
 277                     throw new AssertionError("Should not reach here");
 278                 } catch (SSLHandshakeException expected) {
 279                     // Expected: SSLHandshakeException: no cipher suites in common
 280                     out.printf("Server: caught expected exception: %s%n", expected);
 281                 } catch (IOException e) {
 282                     if (!closed)
 283                         out.printf("UNEXPECTED %s", e);
 284                 }
 285             }
 286         }
 287     }
 288 }