1 /*
   2  * Copyright (c) 2003, 2012, Oracle and/or its affiliates. All rights reserved.
   3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
   4  *
   5  * This code is free software; you can redistribute it and/or modify it
   6  * under the terms of the GNU General Public License version 2 only, as
   7  * published by the Free Software Foundation.
   8  *
   9  * This code is distributed in the hope that it will be useful, but WITHOUT
  10  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
  11  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
  12  * version 2 for more details (a copy is included in the LICENSE file that
  13  * accompanied this code).
  14  *
  15  * You should have received a copy of the GNU General Public License version
  16  * 2 along with this work; if not, write to the Free Software Foundation,
  17  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
  18  *
  19  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
  20  * or visit www.oracle.com if you need additional information or have any
  21  * questions.
  22  */
  23 
  24 import java.io.*;
  25 import java.net.*;
  26 import java.util.*;
  27 import java.security.*;
  28 import javax.net.*;
  29 import javax.net.ssl.*;
  30 import java.lang.reflect.*;
  31 
  32 public class ClientAuth extends PKCS11Test {
  33 
  34     /*
  35      * =============================================================
  36      * Set the various variables needed for the tests, then
  37      * specify what tests to run on each side.
  38      */
  39 
  40     private static Provider provider;
  41     private static final String NSS_PWD = "test12";
  42     private static final String JKS_PWD = "passphrase";
  43     private static final String SERVER_KS = "server.keystore";
  44     private static final String TS = "truststore";
  45     private static String p11config;
  46 
  47     private static String DIR = System.getProperty("DIR");
  48 
  49     /*
  50      * Should we run the client or server in a separate thread?
  51      * Both sides can throw exceptions, but do you have a preference
  52      * as to which side should be the main thread.
  53      */
  54     static boolean separateServerThread = false;
  55 
  56     /*
  57      * Is the server ready to serve?
  58      */
  59     volatile static boolean serverReady = false;
  60 
  61     /*
  62      * Turn on SSL debugging?
  63      */
  64     static boolean debug = false;
  65 
  66     /*
  67      * If the client or server is doing some kind of object creation
  68      * that the other side depends on, and that thread prematurely
  69      * exits, you may experience a hang.  The test harness will
  70      * terminate all hung threads after its timeout has expired,
  71      * currently 3 minutes by default, but you might try to be
  72      * smart about it....
  73      */
  74 
  75     /*
  76      * Define the server side of the test.
  77      *
  78      * If the server prematurely exits, serverReady will be set to true
  79      * to avoid infinite hangs.
  80      */
  81     void doServerSide() throws Exception {
  82 
  83         SSLContext ctx = SSLContext.getInstance("TLS");
  84         char[] passphrase = JKS_PWD.toCharArray();
  85 
  86         // server gets KeyStore from JKS keystore
  87         KeyStore ks = KeyStore.getInstance("JKS");
  88         ks.load(new FileInputStream(new File(DIR, SERVER_KS)), passphrase);
  89         KeyManagerFactory kmf = KeyManagerFactory.getInstance("SunX509");
  90         kmf.init(ks, passphrase);
  91 
  92         // server gets TrustStore from PKCS#11 token
  93 /*
  94         passphrase = NSS_PWD.toCharArray();
  95         KeyStore ts = KeyStore.getInstance("PKCS11", "SunPKCS11-nss");
  96         ts.load(null, passphrase);
  97         TrustManagerFactory tmf = TrustManagerFactory.getInstance("SunX509");
  98         tmf.init(ts);
  99 */
 100 
 101         //ctx.init(kmf.getKeyManagers(), tmf.getTrustManagers(), null);
 102         ctx.init(kmf.getKeyManagers(), null, null);
 103         ServerSocketFactory ssf = ctx.getServerSocketFactory();
 104         SSLServerSocket sslServerSocket = (SSLServerSocket)
 105                                 ssf.createServerSocket(serverPort);
 106         sslServerSocket.setNeedClientAuth(true);
 107         serverPort = sslServerSocket.getLocalPort();
 108         System.out.println("serverPort = " + serverPort);
 109 
 110         /*
 111          * Signal Client, we're ready for his connect.
 112          */
 113         serverReady = true;
 114 
 115         SSLSocket sslSocket = (SSLSocket) sslServerSocket.accept();
 116         InputStream sslIS = sslSocket.getInputStream();
 117         OutputStream sslOS = sslSocket.getOutputStream();
 118 
 119         sslIS.read();
 120         sslOS.write(85);
 121         sslOS.flush();
 122 
 123         sslSocket.close();
 124     }
 125 
 126     /*
 127      * Define the client side of the test.
 128      *
 129      * If the server prematurely exits, serverReady will be set to true
 130      * to avoid infinite hangs.
 131      */
 132     void doClientSide() throws Exception {
 133 
 134         /*
 135          * Wait for server to get started.
 136          */
 137         while (!serverReady) {
 138             Thread.sleep(50);
 139         }
 140 
 141         SSLContext ctx = SSLContext.getInstance("TLS");
 142         KeyManagerFactory kmf = KeyManagerFactory.getInstance("SunX509");
 143 
 144         // client gets KeyStore from PKCS#11 token,
 145         // and gets TrustStore from JKS KeyStore (using system properties)
 146         char[] passphrase = NSS_PWD.toCharArray();
 147         KeyStore ks = KeyStore.getInstance("PKCS11", "SunPKCS11-nss");
 148         ks.load(null, passphrase);
 149 
 150         kmf = KeyManagerFactory.getInstance("SunX509");
 151         kmf.init(ks, passphrase);
 152         ctx.init(kmf.getKeyManagers(), null, null);
 153 
 154         SSLSocketFactory sslsf = ctx.getSocketFactory();
 155         SSLSocket sslSocket = (SSLSocket)
 156             sslsf.createSocket("localhost", serverPort);
 157 
 158         if (clientProtocol != null) {
 159             sslSocket.setEnabledProtocols(new String[] {clientProtocol});
 160         }
 161 
 162         if (clientCiperSuite != null) {
 163             sslSocket.setEnabledCipherSuites(new String[] {clientCiperSuite});
 164         }
 165 
 166         InputStream sslIS = sslSocket.getInputStream();
 167         OutputStream sslOS = sslSocket.getOutputStream();
 168 
 169         sslOS.write(280);
 170         sslOS.flush();
 171         sslIS.read();
 172 
 173         sslSocket.close();
 174     }
 175 
 176     /*
 177      * =============================================================
 178      * The remainder is just support stuff
 179      */
 180 
 181     // use any free port by default
 182     volatile int serverPort = 0;
 183 
 184     volatile Exception serverException = null;
 185     volatile Exception clientException = null;
 186 
 187     private static String clientProtocol = null;
 188     private static String clientCiperSuite = null;
 189 
 190     private static void parseArguments(String[] args) {
 191         if (args.length > 0) {
 192             clientProtocol = args[0];
 193         }
 194 
 195         if (args.length > 1) {
 196             clientCiperSuite = args[1];
 197         }
 198     }
 199 
 200     public static void main(String[] args) throws Exception {
 201         // Get the customized arguments.
 202         parseArguments(args);
 203         main(new ClientAuth());
 204     }
 205 
 206     public void main(Provider p) throws Exception {
 207         // SSL RSA client auth currently needs an RSA cipher
 208         // (cf. NONEwithRSA hack), which is currently not available in
 209         // open builds.
 210         try {
 211             javax.crypto.Cipher.getInstance("RSA/ECB/PKCS1Padding", p);
 212         } catch (GeneralSecurityException e) {
 213             System.out.println("Not supported by provider, skipping");
 214             return;
 215         }
 216 
 217         this.provider = p;
 218 
 219         System.setProperty("javax.net.ssl.trustStore",
 220                                         new File(DIR, TS).toString());
 221         System.setProperty("javax.net.ssl.trustStoreType", "JKS");
 222         System.setProperty("javax.net.ssl.trustStoreProvider", "SUN");
 223         System.setProperty("javax.net.ssl.trustStorePassword", JKS_PWD);
 224 
 225         // perform Security.addProvider of P11 provider
 226         ProviderLoader.go(System.getProperty("CUSTOM_P11_CONFIG"));
 227 
 228         if (debug) {
 229             System.setProperty("javax.net.debug", "all");
 230         }
 231 
 232         /*
 233          * Start the tests.
 234          */
 235         go();
 236     }
 237 
 238     Thread clientThread = null;
 239     Thread serverThread = null;
 240 
 241     /*
 242      * Fork off the other side, then do your work.
 243      */
 244     private void go() throws Exception {
 245         try {
 246             if (separateServerThread) {
 247                 startServer(true);
 248                 startClient(false);
 249             } else {
 250                 startClient(true);
 251                 startServer(false);
 252             }
 253         } catch (Exception e) {
 254             //swallow for now.  Show later
 255         }
 256 
 257         /*
 258          * Wait for other side to close down.
 259          */
 260         if (separateServerThread) {
 261             serverThread.join();
 262         } else {
 263             clientThread.join();
 264         }
 265 
 266         /*
 267          * When we get here, the test is pretty much over.
 268          * Which side threw the error?
 269          */
 270         Exception local;
 271         Exception remote;
 272         String whichRemote;
 273 
 274         if (separateServerThread) {
 275             remote = serverException;
 276             local = clientException;
 277             whichRemote = "server";
 278         } else {
 279             remote = clientException;
 280             local = serverException;
 281             whichRemote = "client";
 282         }
 283 
 284         /*
 285          * If both failed, return the curthread's exception, but also
 286          * print the remote side Exception
 287          */
 288         if ((local != null) && (remote != null)) {
 289             System.out.println(whichRemote + " also threw:");
 290             remote.printStackTrace();
 291             System.out.println();
 292             throw local;
 293         }
 294 
 295         if (remote != null) {
 296             throw remote;
 297         }
 298 
 299         if (local != null) {
 300             throw local;
 301         }
 302     }
 303 
 304     void startServer(boolean newThread) throws Exception {
 305         if (newThread) {
 306             serverThread = new Thread() {
 307                 public void run() {
 308                     try {
 309                         doServerSide();
 310                     } catch (Exception e) {
 311                         /*
 312                          * Our server thread just died.
 313                          *
 314                          * Release the client, if not active already...
 315                          */
 316                         System.err.println("Server died...");
 317                         serverReady = true;
 318                         serverException = e;
 319                     }
 320                 }
 321             };
 322             serverThread.start();
 323         } else {
 324             try {
 325                 doServerSide();
 326             } catch (Exception e) {
 327                 serverException = e;
 328             } finally {
 329                 serverReady = true;
 330             }
 331         }
 332     }
 333 
 334     void startClient(boolean newThread) throws Exception {
 335         if (newThread) {
 336             clientThread = new Thread() {
 337                 public void run() {
 338                     try {
 339                         doClientSide();
 340                     } catch (Exception e) {
 341                         /*
 342                          * Our client thread just died.
 343                          */
 344                         System.err.println("Client died...");
 345                         clientException = e;
 346                     }
 347                 }
 348             };
 349             clientThread.start();
 350         } else {
 351             try {
 352                 doClientSide();
 353             } catch (Exception e) {
 354                 clientException = e;
 355             }
 356         }
 357     }
 358 }