1 /*
   2  * Copyright (c) 2012, Oracle and/or its affiliates. All rights reserved.
   3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
   4  *
   5  * This code is free software; you can redistribute it and/or modify it
   6  * under the terms of the GNU General Public License version 2 only, as
   7  * published by the Free Software Foundation.
   8  *
   9  * This code is distributed in the hope that it will be useful, but WITHOUT
  10  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
  11  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
  12  * version 2 for more details (a copy is included in the LICENSE file that
  13  * accompanied this code).
  14  *
  15  * You should have received a copy of the GNU General Public License version
  16  * 2 along with this work; if not, write to the Free Software Foundation,
  17  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
  18  *
  19  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
  20  * or visit www.oracle.com if you need additional information or have any
  21  * questions.
  22  */
  23 
  24 //
  25 // SunJSSE does not support dynamic system properties, no way to re-use
  26 // system properties in samevm/agentvm mode.
  27 //
  28 
  29 /**
  30  * @test
  31  * @bug 7068321
  32  * @summary Support TLS Server Name Indication (SNI) Extension in JSSE Server
  33  * @library ../templates
  34  * @modules java.base/sun.misc
  35  * @build SSLCapabilities SSLExplorer
  36  * @run main/othervm SSLSocketExplorerMatchedSNI www.example.com
  37  *     www\.example\.com
  38  * @run main/othervm SSLSocketExplorerMatchedSNI www.example.com
  39  *     www\.example\.(com|org)
  40  * @run main/othervm SSLSocketExplorerMatchedSNI example.com
  41  *     (.*\.)*example\.(com|org)
  42  * @run main/othervm SSLSocketExplorerMatchedSNI www.example.com
  43  *     (.*\.)*example\.(com|org)
  44  * @run main/othervm SSLSocketExplorerMatchedSNI www.us.example.com
  45  *     (.*\.)*example\.(com|org)
  46  */
  47 
  48 import java.io.*;
  49 import java.nio.*;
  50 import java.nio.channels.*;
  51 import java.util.*;
  52 import java.net.*;
  53 import javax.net.ssl.*;
  54 
  55 public class SSLSocketExplorerMatchedSNI {
  56 
  57     /*
  58      * =============================================================
  59      * Set the various variables needed for the tests, then
  60      * specify what tests to run on each side.
  61      */
  62 
  63     /*
  64      * Should we run the client or server in a separate thread?
  65      * Both sides can throw exceptions, but do you have a preference
  66      * as to which side should be the main thread.
  67      */
  68     static boolean separateServerThread = false;
  69 
  70     /*
  71      * Where do we find the keystores?
  72      */
  73     static String pathToStores = "../etc";
  74     static String keyStoreFile = "keystore";
  75     static String trustStoreFile = "truststore";
  76     static String passwd = "passphrase";
  77 
  78     /*
  79      * Is the server ready to serve?
  80      */
  81     volatile static boolean serverReady = false;
  82 
  83     /*
  84      * Turn on SSL debugging?
  85      */
  86     static boolean debug = false;
  87 
  88     /*
  89      * If the client or server is doing some kind of object creation
  90      * that the other side depends on, and that thread prematurely
  91      * exits, you may experience a hang.  The test harness will
  92      * terminate all hung threads after its timeout has expired,
  93      * currently 3 minutes by default, but you might try to be
  94      * smart about it....
  95      */
  96 
  97     /*
  98      * Define the server side of the test.
  99      *
 100      * If the server prematurely exits, serverReady will be set to true
 101      * to avoid infinite hangs.
 102      */
 103     void doServerSide() throws Exception {
 104 
 105         ServerSocket serverSocket = new ServerSocket(serverPort);
 106 
 107         // Signal Client, we're ready for his connect.
 108         serverPort = serverSocket.getLocalPort();
 109         serverReady = true;
 110 
 111         Socket socket = serverSocket.accept();
 112         InputStream ins = socket.getInputStream();
 113 
 114         byte[] buffer = new byte[0xFF];
 115         int position = 0;
 116         SSLCapabilities capabilities = null;
 117 
 118         // Read the header of TLS record
 119         while (position < SSLExplorer.RECORD_HEADER_SIZE) {
 120             int count = SSLExplorer.RECORD_HEADER_SIZE - position;
 121             int n = ins.read(buffer, position, count);
 122             if (n < 0) {
 123                 throw new Exception("unexpected end of stream!");
 124             }
 125             position += n;
 126         }
 127 
 128         int recordLength = SSLExplorer.getRequiredSize(buffer, 0, position);
 129         if (buffer.length < recordLength) {
 130             buffer = Arrays.copyOf(buffer, recordLength);
 131         }
 132 
 133         while (position < recordLength) {
 134             int count = recordLength - position;
 135             int n = ins.read(buffer, position, count);
 136             if (n < 0) {
 137                 throw new Exception("unexpected end of stream!");
 138             }
 139             position += n;
 140         }
 141 
 142         capabilities = SSLExplorer.explore(buffer, 0, recordLength);;
 143         if (capabilities != null) {
 144             System.out.println("Record version: " +
 145                     capabilities.getRecordVersion());
 146             System.out.println("Hello version: " +
 147                     capabilities.getHelloVersion());
 148         }
 149 
 150         SSLSocketFactory sslsf =
 151             (SSLSocketFactory) SSLSocketFactory.getDefault();
 152         ByteArrayInputStream bais =
 153             new ByteArrayInputStream(buffer, 0, position);
 154         SSLSocket sslSocket = (SSLSocket)sslsf.createSocket(socket, bais, true);
 155 
 156         SNIMatcher matcher = SNIHostName.createSNIMatcher(
 157                                                 serverAcceptableHostname);
 158         Collection<SNIMatcher> matchers = new ArrayList<>(1);
 159         matchers.add(matcher);
 160         SSLParameters params = sslSocket.getSSLParameters();
 161         params.setSNIMatchers(matchers);
 162         sslSocket.setSSLParameters(params);
 163 
 164         InputStream sslIS = sslSocket.getInputStream();
 165         OutputStream sslOS = sslSocket.getOutputStream();
 166 
 167         sslIS.read();
 168         sslOS.write(85);
 169         sslOS.flush();
 170 
 171         ExtendedSSLSession session = (ExtendedSSLSession)sslSocket.getSession();
 172         checkCapabilities(capabilities, session);
 173 
 174         sslSocket.close();
 175         serverSocket.close();
 176     }
 177 
 178 
 179     /*
 180      * Define the client side of the test.
 181      *
 182      * If the server prematurely exits, serverReady will be set to true
 183      * to avoid infinite hangs.
 184      */
 185     void doClientSide() throws Exception {
 186 
 187         /*
 188          * Wait for server to get started.
 189          */
 190         while (!serverReady) {
 191             Thread.sleep(50);
 192         }
 193 
 194         SSLSocketFactory sslsf =
 195             (SSLSocketFactory) SSLSocketFactory.getDefault();
 196         SSLSocket sslSocket = (SSLSocket)
 197             sslsf.createSocket("localhost", serverPort);
 198 
 199         SNIHostName serverName = new SNIHostName(clientRequestedHostname);
 200         List<SNIServerName> serverNames = new ArrayList<>(1);
 201         serverNames.add(serverName);
 202         SSLParameters params = sslSocket.getSSLParameters();
 203         params.setServerNames(serverNames);
 204         sslSocket.setSSLParameters(params);
 205 
 206         InputStream sslIS = sslSocket.getInputStream();
 207         OutputStream sslOS = sslSocket.getOutputStream();
 208 
 209         sslOS.write(280);
 210         sslOS.flush();
 211         sslIS.read();
 212 
 213         ExtendedSSLSession session = (ExtendedSSLSession)sslSocket.getSession();
 214         checkSNIInSession(session);
 215 
 216         sslSocket.close();
 217     }
 218 
 219 
 220     void checkCapabilities(SSLCapabilities capabilities,
 221             ExtendedSSLSession session) throws Exception {
 222 
 223         List<SNIServerName> sessionSNI = session.getRequestedServerNames();
 224         if (!sessionSNI.equals(capabilities.getServerNames())) {
 225             for (SNIServerName sni : sessionSNI) {
 226                 System.out.println("SNI in session is " + sni);
 227             }
 228 
 229             List<SNIServerName> capaSNI = capabilities.getServerNames();
 230             for (SNIServerName sni : capaSNI) {
 231                 System.out.println("SNI in session is " + sni);
 232             }
 233 
 234             throw new Exception(
 235                     "server name indication does not match capabilities");
 236         }
 237 
 238         checkSNIInSession(session);
 239     }
 240 
 241     void checkSNIInSession(ExtendedSSLSession session) throws Exception {
 242         List<SNIServerName> sessionSNI = session.getRequestedServerNames();
 243         if (sessionSNI.isEmpty()) {
 244             throw new Exception(
 245                     "unexpected empty request server name indication");
 246         }
 247 
 248         if (sessionSNI.size() != 1) {
 249             throw new Exception(
 250                     "unexpected request server name indication");
 251         }
 252 
 253         SNIServerName serverName = sessionSNI.get(0);
 254         if (!(serverName instanceof SNIHostName)) {
 255             throw new Exception(
 256                     "unexpected instance of request server name indication");
 257         }
 258 
 259         String hostname = ((SNIHostName)serverName).getAsciiName();
 260         if (!clientRequestedHostname.equalsIgnoreCase(hostname)) {
 261             throw new Exception(
 262                     "unexpected request server name indication value");
 263         }
 264     }
 265 
 266     private static String clientRequestedHostname;
 267     private static String serverAcceptableHostname;
 268 
 269     private static void parseArguments(String[] args) {
 270         clientRequestedHostname = args[0];
 271         serverAcceptableHostname = args[1];
 272     }
 273 
 274 
 275     /*
 276      * =============================================================
 277      * The remainder is just support stuff
 278      */
 279 
 280     // use any free port by default
 281     volatile int serverPort = 0;
 282 
 283     volatile Exception serverException = null;
 284     volatile Exception clientException = null;
 285 
 286 
 287     public static void main(String[] args) throws Exception {
 288         String keyFilename =
 289             System.getProperty("test.src", ".") + "/" + pathToStores +
 290                 "/" + keyStoreFile;
 291         String trustFilename =
 292             System.getProperty("test.src", ".") + "/" + pathToStores +
 293                 "/" + trustStoreFile;
 294 
 295         System.setProperty("javax.net.ssl.keyStore", keyFilename);
 296         System.setProperty("javax.net.ssl.keyStorePassword", passwd);
 297         System.setProperty("javax.net.ssl.trustStore", trustFilename);
 298         System.setProperty("javax.net.ssl.trustStorePassword", passwd);
 299 
 300         if (debug)
 301             System.setProperty("javax.net.debug", "all");
 302 
 303         /*
 304          * Get the customized arguments.
 305          */
 306         parseArguments(args);
 307 
 308         /*
 309          * Start the tests.
 310          */
 311         new SSLSocketExplorerMatchedSNI();
 312     }
 313 
 314     Thread clientThread = null;
 315     Thread serverThread = null;
 316 
 317     /*
 318      * Primary constructor, used to drive remainder of the test.
 319      *
 320      * Fork off the other side, then do your work.
 321      */
 322     SSLSocketExplorerMatchedSNI() throws Exception {
 323         try {
 324             if (separateServerThread) {
 325                 startServer(true);
 326                 startClient(false);
 327             } else {
 328                 startClient(true);
 329                 startServer(false);
 330             }
 331         } catch (Exception e) {
 332             // swallow for now.  Show later
 333         }
 334 
 335         /*
 336          * Wait for other side to close down.
 337          */
 338         if (separateServerThread) {
 339             serverThread.join();
 340         } else {
 341             clientThread.join();
 342         }
 343 
 344         /*
 345          * When we get here, the test is pretty much over.
 346          * Which side threw the error?
 347          */
 348         Exception local;
 349         Exception remote;
 350         String whichRemote;
 351 
 352         if (separateServerThread) {
 353             remote = serverException;
 354             local = clientException;
 355             whichRemote = "server";
 356         } else {
 357             remote = clientException;
 358             local = serverException;
 359             whichRemote = "client";
 360         }
 361 
 362         /*
 363          * If both failed, return the curthread's exception, but also
 364          * print the remote side Exception
 365          */
 366         if ((local != null) && (remote != null)) {
 367             System.out.println(whichRemote + " also threw:");
 368             remote.printStackTrace();
 369             System.out.println();
 370             throw local;
 371         }
 372 
 373         if (remote != null) {
 374             throw remote;
 375         }
 376 
 377         if (local != null) {
 378             throw local;
 379         }
 380     }
 381 
 382     void startServer(boolean newThread) throws Exception {
 383         if (newThread) {
 384             serverThread = new Thread() {
 385                 public void run() {
 386                     try {
 387                         doServerSide();
 388                     } catch (Exception e) {
 389                         /*
 390                          * Our server thread just died.
 391                          *
 392                          * Release the client, if not active already...
 393                          */
 394                         System.err.println("Server died...");
 395                         serverReady = true;
 396                         serverException = e;
 397                     }
 398                 }
 399             };
 400             serverThread.start();
 401         } else {
 402             try {
 403                 doServerSide();
 404             } catch (Exception e) {
 405                 serverException = e;
 406             } finally {
 407                 serverReady = true;
 408             }
 409         }
 410     }
 411 
 412     void startClient(boolean newThread) throws Exception {
 413         if (newThread) {
 414             clientThread = new Thread() {
 415                 public void run() {
 416                     try {
 417                         doClientSide();
 418                     } catch (Exception e) {
 419                         /*
 420                          * Our client thread just died.
 421                          */
 422                         System.err.println("Client died...");
 423                         clientException = e;
 424                     }
 425                 }
 426             };
 427             clientThread.start();
 428         } else {
 429             try {
 430                 doClientSide();
 431             } catch (Exception e) {
 432                 clientException = e;
 433             }
 434         }
 435     }
 436 }