1 /*
   2  * Copyright (c) 2001, 2014, Oracle and/or its affiliates. All rights reserved.
   3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
   4  *
   5  * This code is free software; you can redistribute it and/or modify it
   6  * under the terms of the GNU General Public License version 2 only, as
   7  * published by the Free Software Foundation.
   8  *
   9  * This code is distributed in the hope that it will be useful, but WITHOUT
  10  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
  11  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
  12  * version 2 for more details (a copy is included in the LICENSE file that
  13  * accompanied this code).
  14  *
  15  * You should have received a copy of the GNU General Public License version
  16  * 2 along with this work; if not, write to the Free Software Foundation,
  17  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
  18  *
  19  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
  20  * or visit www.oracle.com if you need additional information or have any
  21  * questions.
  22  */
  23 
  24 // SunJSSE does not support dynamic system properties, no way to re-use
  25 // system properties in samevm/agentvm mode.
  26 
  27 /*
  28  * @test
  29  * @bug 4416068 4478803 4479736
  30  * @summary 4273544 JSSE request for function forceV3ClientHello()
  31  *          4479736 setEnabledProtocols API does not work correctly
  32  *          4478803 Need APIs to determine the protocol versions used in an SSL
  33  *                  session
  34  *          4701722 protocol mismatch exceptions should be consistent between
  35  *                  SSLv3 and TLSv1
  36  * @run main/othervm TestEnabledProtocols
  37  * @author Ram Marti
  38  */
  39 
  40 import java.io.*;
  41 import java.net.*;
  42 import java.util.*;
  43 import java.security.*;
  44 import javax.net.ssl.*;
  45 import java.security.cert.*;
  46 
  47 public class TestEnabledProtocols {
  48 
  49     /*
  50      * For each of the valid protocols combinations, start a server thread
  51      * that sets up an SSLServerSocket supporting that protocol. Then run
  52      * a client thread that attemps to open a connection with all
  53      * possible protocol combinataion.  Verify that we get handshake
  54      * exceptions correctly. Whenever the connection is established
  55      * successfully, verify that the negotiated protocol was correct.
  56      * See results file in this directory for complete results.
  57      */
  58 
  59     static final String[][] protocolStrings = {
  60                                 {"TLSv1"},
  61                                 {"TLSv1", "SSLv2Hello"},
  62                                 {"TLSv1", "SSLv3"},
  63                                 {"SSLv3", "SSLv2Hello"},
  64                                 {"SSLv3"},
  65                                 {"TLSv1", "SSLv3", "SSLv2Hello"}
  66                                 };
  67 
  68     static final boolean [][] eXceptionArray = {
  69         // Do we expect exception?       Protocols supported by the server
  70         { false, true,  false, true,  true,  true }, // TLSv1
  71         { false, false, false, true,  true,  false}, // TLSv1,SSLv2Hello
  72         { false, true,  false, true,  false, true }, // TLSv1,SSLv3
  73         { true,  true,  false, false, false, false}, // SSLv3, SSLv2Hello
  74         { true,  true,  false, true,  false, true }, // SSLv3
  75         { false, false, false, false, false, false } // TLSv1,SSLv3,SSLv2Hello
  76         };
  77 
  78     static final String[][] protocolSelected = {
  79         // TLSv1
  80         { "TLSv1",  null,   "TLSv1",  null,   null,     null },
  81 
  82         // TLSv1,SSLv2Hello
  83         { "TLSv1", "TLSv1", "TLSv1",  null,   null,    "TLSv1"},
  84 
  85         // TLSv1,SSLv3
  86         { "TLSv1",  null,   "TLSv1",  null,   "SSLv3",  null },
  87 
  88         // SSLv3, SSLv2Hello
  89         {  null,    null,   "SSLv3", "SSLv3", "SSLv3",  "SSLv3"},
  90 
  91         // SSLv3
  92         {  null,    null,   "SSLv3",  null,   "SSLv3",  null },
  93 
  94         // TLSv1,SSLv3,SSLv2Hello
  95         { "TLSv1", "TLSv1", "TLSv1", "SSLv3", "SSLv3", "TLSv1" }
  96 
  97     };
  98 
  99     /*
 100      * Where do we find the keystores?
 101      */
 102     final static String pathToStores = "../etc";
 103     static String passwd = "passphrase";
 104     static String keyStoreFile = "keystore";
 105     static String trustStoreFile = "truststore";
 106 
 107     /*
 108      * Is the server ready to serve?
 109      */
 110     volatile static boolean serverReady = false;
 111 
 112     /*
 113      * Turn on SSL debugging?
 114      */
 115     final static boolean debug = false;
 116 
 117     // use any free port by default
 118     volatile int serverPort = 0;
 119 
 120     volatile Exception clientException = null;
 121 
 122     public static void main(String[] args) throws Exception {
 123         // reset the security property to make sure that the algorithms
 124         // and keys used in this test are not disabled.
 125         Security.setProperty("jdk.tls.disabledAlgorithms", "");
 126 
 127         String keyFilename =
 128             System.getProperty("test.src", "./") + "/" + pathToStores +
 129                 "/" + keyStoreFile;
 130         String trustFilename =
 131             System.getProperty("test.src", "./") + "/" + pathToStores +
 132                 "/" + trustStoreFile;
 133 
 134         System.setProperty("javax.net.ssl.keyStore", keyFilename);
 135         System.setProperty("javax.net.ssl.keyStorePassword", passwd);
 136         System.setProperty("javax.net.ssl.trustStore", trustFilename);
 137         System.setProperty("javax.net.ssl.trustStorePassword", passwd);
 138 
 139         if (debug)
 140             System.setProperty("javax.net.debug", "all");
 141 
 142         new TestEnabledProtocols();
 143     }
 144 
 145     TestEnabledProtocols() throws Exception  {
 146         /*
 147          * Start the tests.
 148          */
 149         SSLServerSocketFactory sslssf =
 150             (SSLServerSocketFactory) SSLServerSocketFactory.getDefault();
 151         SSLServerSocket sslServerSocket =
 152             (SSLServerSocket) sslssf.createServerSocket(serverPort);
 153         serverPort = sslServerSocket.getLocalPort();
 154         // sslServerSocket.setNeedClientAuth(true);
 155 
 156         for (int i = 0; i < protocolStrings.length; i++) {
 157             String [] serverProtocols = protocolStrings[i];
 158             startServer ss = new startServer(serverProtocols,
 159                 sslServerSocket, protocolStrings.length);
 160             ss.setDaemon(true);
 161             ss.start();
 162             for (int j = 0; j < protocolStrings.length; j++) {
 163                 String [] clientProtocols = protocolStrings[j];
 164                 startClient sc = new startClient(
 165                     clientProtocols, serverProtocols,
 166                     eXceptionArray[i][j], protocolSelected[i][j]);
 167                 sc.start();
 168                 sc.join();
 169                 if (clientException != null) {
 170                     ss.requestStop();
 171                     throw clientException;
 172                 }
 173             }
 174             ss.requestStop();
 175             System.out.println("Waiting for the server to complete");
 176             ss.join();
 177         }
 178     }
 179 
 180     class startServer extends Thread  {
 181         private String[] enabledP = null;
 182         SSLServerSocket sslServerSocket = null;
 183         int numExpConns;
 184         volatile boolean stopRequested = false;
 185 
 186         public startServer(String[] enabledProtocols,
 187                             SSLServerSocket sslServerSocket,
 188                             int numExpConns) {
 189             super("Server Thread");
 190             serverReady = false;
 191             enabledP = enabledProtocols;
 192             this.sslServerSocket = sslServerSocket;
 193             sslServerSocket.setEnabledProtocols(enabledP);
 194             this.numExpConns = numExpConns;
 195         }
 196 
 197         public void requestStop() {
 198             stopRequested = true;
 199         }
 200 
 201         public void run() {
 202             int conns = 0;
 203             while (!stopRequested) {
 204                 SSLSocket socket = null;
 205                 try {
 206                     serverReady = true;
 207                     socket = (SSLSocket)sslServerSocket.accept();
 208                     conns++;
 209 
 210                     // set ready to false. this is just to make the
 211                     // client wait and synchronise exception messages
 212                     serverReady = false;
 213                     socket.startHandshake();
 214                     SSLSession session = socket.getSession();
 215                     session.invalidate();
 216 
 217                     InputStream in = socket.getInputStream();
 218                     OutputStream out = socket.getOutputStream();
 219                     out.write(280);
 220                     in.read();
 221 
 222                     socket.close();
 223                     // sleep for a while so that the server thread can be
 224                     // stopped
 225                     Thread.sleep(30);
 226                 } catch (SSLHandshakeException se) {
 227                     // ignore it; this is part of the testing
 228                     // log it for debugging
 229                     System.out.println("Server SSLHandshakeException:");
 230                     se.printStackTrace(System.out);
 231                 } catch (java.io.InterruptedIOException ioe) {
 232                     // must have been interrupted, no harm
 233                     break;
 234                 } catch (java.lang.InterruptedException ie) {
 235                     // must have been interrupted, no harm
 236                     break;
 237                 } catch (Exception e) {
 238                     System.out.println("Server exception:");
 239                     e.printStackTrace(System.out);
 240                     throw new RuntimeException(e);
 241                 } finally {
 242                     try {
 243                         if (socket != null) {
 244                             socket.close();
 245                         }
 246                     } catch (IOException e) {
 247                         // ignore
 248                     }
 249                 }
 250                 if (conns >= numExpConns) {
 251                     break;
 252                 }
 253             }
 254         }
 255     }
 256 
 257     private static void showProtocols(String name, String[] protocols) {
 258         System.out.println("Enabled protocols on the " + name + " are: " + Arrays.asList(protocols));
 259     }
 260 
 261     class startClient extends Thread {
 262         boolean hsCompleted = false;
 263         boolean exceptionExpected = false;
 264         private String[] enabledP = null;
 265         private String[] serverP = null; // used to print the result
 266         private String protocolToUse = null;
 267 
 268         startClient(String[] enabledProtocol,
 269                     String[] serverP,
 270                     boolean eXception,
 271                     String protocol) throws Exception {
 272             super("Client Thread");
 273             this.enabledP = enabledProtocol;
 274             this.serverP = serverP;
 275             this.exceptionExpected = eXception;
 276             this.protocolToUse = protocol;
 277         }
 278 
 279         public void run() {
 280             SSLSocket sslSocket = null;
 281             try {
 282                 while (!serverReady) {
 283                     Thread.sleep(50);
 284                 }
 285                 System.out.flush();
 286                 System.out.println("=== Starting new test run ===");
 287                 showProtocols("server", serverP);
 288                 showProtocols("client", enabledP);
 289 
 290                 SSLSocketFactory sslsf =
 291                     (SSLSocketFactory)SSLSocketFactory.getDefault();
 292                 sslSocket = (SSLSocket)
 293                     sslsf.createSocket("localhost", serverPort);
 294                 sslSocket.setEnabledProtocols(enabledP);
 295                 sslSocket.startHandshake();
 296 
 297                 SSLSession session = sslSocket.getSession();
 298                 session.invalidate();
 299                 String protocolName = session.getProtocol();
 300                 System.out.println("Protocol name after getSession is " +
 301                     protocolName);
 302 
 303                 if (protocolName.equals(protocolToUse)) {
 304                     System.out.println("** Success **");
 305                 } else {
 306                     System.out.println("** FAILURE ** ");
 307                     throw new RuntimeException
 308                         ("expected protocol " + protocolToUse +
 309                          " but using " + protocolName);
 310                 }
 311 
 312                 InputStream in = sslSocket.getInputStream();
 313                 OutputStream out = sslSocket.getOutputStream();
 314                 in.read();
 315                 out.write(280);
 316 
 317                 sslSocket.close();
 318 
 319             } catch (SSLHandshakeException e) {
 320                 if (!exceptionExpected) {
 321                     System.out.println("Client got UNEXPECTED SSLHandshakeException:");
 322                     e.printStackTrace(System.out);
 323                     System.out.println("** FAILURE **");
 324                     clientException = e;
 325                 } else {
 326                     System.out.println("Client got expected SSLHandshakeException:");
 327                     e.printStackTrace(System.out);
 328                     System.out.println("** Success **");
 329                 }
 330             } catch (RuntimeException e) {
 331                 clientException = e;
 332             } catch (Exception e) {
 333                 System.out.println("Client got UNEXPECTED Exception:");
 334                 e.printStackTrace(System.out);
 335                 System.out.println("** FAILURE **");
 336                 clientException = e;
 337             }
 338         }
 339     }
 340 
 341 }