1 /*
   2  * Copyright (c) 2012, Oracle and/or its affiliates. All rights reserved.
   3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
   4  *
   5  * This code is free software; you can redistribute it and/or modify it
   6  * under the terms of the GNU General Public License version 2 only, as
   7  * published by the Free Software Foundation.
   8  *
   9  * This code is distributed in the hope that it will be useful, but WITHOUT
  10  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
  11  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
  12  * version 2 for more details (a copy is included in the LICENSE file that
  13  * accompanied this code).
  14  *
  15  * You should have received a copy of the GNU General Public License version
  16  * 2 along with this work; if not, write to the Free Software Foundation,
  17  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
  18  *
  19  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
  20  * or visit www.oracle.com if you need additional information or have any
  21  * questions.
  22  */
  23 
  24 import java.io.*;
  25 import java.net.*;
  26 import java.util.*;
  27 import java.security.*;
  28 import javax.net.*;
  29 import javax.net.ssl.*;
  30 import java.lang.reflect.*;
  31 
  32 import sun.security.util.KeyUtil;
  33 
  34 public class ShortRSAKeyWithinTLS {
  35 
  36     /*
  37      * =============================================================
  38      * Set the various variables needed for the tests, then
  39      * specify what tests to run on each side.
  40      */
  41 
  42     /*
  43      * Should we run the client or server in a separate thread?
  44      * Both sides can throw exceptions, but do you have a preference
  45      * as to which side should be the main thread.
  46      */
  47     static boolean separateServerThread = false;
  48 
  49     /*
  50      * Is the server ready to serve?
  51      */
  52     volatile static boolean serverReady = false;
  53 
  54     /*
  55      * Turn on SSL debugging?
  56      */
  57     static boolean debug = false;
  58 
  59     /*
  60      * If the client or server is doing some kind of object creation
  61      * that the other side depends on, and that thread prematurely
  62      * exits, you may experience a hang.  The test harness will
  63      * terminate all hung threads after its timeout has expired,
  64      * currently 3 minutes by default, but you might try to be
  65      * smart about it....
  66      */
  67 
  68     /*
  69      * Define the server side of the test.
  70      *
  71      * If the server prematurely exits, serverReady will be set to true
  72      * to avoid infinite hangs.
  73      */
  74     void doServerSide() throws Exception {
  75 
  76         // load the key store
  77         KeyStore ks = KeyStore.getInstance("Windows-MY", "SunMSCAPI");
  78         ks.load(null, null);
  79         System.out.println("Loaded keystore: Windows-MY");
  80 
  81         // check key size
  82         checkKeySize(ks);
  83 
  84         // initialize the SSLContext
  85         KeyManagerFactory kmf = KeyManagerFactory.getInstance("SunX509");
  86         kmf.init(ks, null);
  87 
  88         TrustManagerFactory tmf = TrustManagerFactory.getInstance("SunX509");
  89         tmf.init(ks);
  90 
  91         SSLContext ctx = SSLContext.getInstance("TLS");
  92         ctx.init(kmf.getKeyManagers(), tmf.getTrustManagers(), null);
  93 
  94         ServerSocketFactory ssf = ctx.getServerSocketFactory();
  95         SSLServerSocket sslServerSocket = (SSLServerSocket)
  96                                 ssf.createServerSocket(serverPort);
  97         sslServerSocket.setNeedClientAuth(true);
  98         serverPort = sslServerSocket.getLocalPort();
  99         System.out.println("serverPort = " + serverPort);
 100 
 101         /*
 102          * Signal Client, we're ready for his connect.
 103          */
 104         serverReady = true;
 105 
 106         SSLSocket sslSocket = (SSLSocket) sslServerSocket.accept();
 107         InputStream sslIS = sslSocket.getInputStream();
 108         OutputStream sslOS = sslSocket.getOutputStream();
 109 
 110         sslIS.read();
 111         sslOS.write(85);
 112         sslOS.flush();
 113 
 114         sslSocket.close();
 115     }
 116 
 117     /*
 118      * Define the client side of the test.
 119      *
 120      * If the server prematurely exits, serverReady will be set to true
 121      * to avoid infinite hangs.
 122      */
 123     void doClientSide() throws Exception {
 124 
 125         /*
 126          * Wait for server to get started.
 127          */
 128         while (!serverReady) {
 129             Thread.sleep(50);
 130         }
 131 
 132         // load the key store
 133         KeyStore ks = KeyStore.getInstance("Windows-MY", "SunMSCAPI");
 134         ks.load(null, null);
 135         System.out.println("Loaded keystore: Windows-MY");
 136 
 137         // initialize the SSLContext
 138         KeyManagerFactory kmf = KeyManagerFactory.getInstance("SunX509");
 139         kmf.init(ks, null);
 140 
 141         TrustManagerFactory tmf = TrustManagerFactory.getInstance("SunX509");
 142         tmf.init(ks);
 143 
 144         SSLContext ctx = SSLContext.getInstance("TLS");
 145         ctx.init(kmf.getKeyManagers(), tmf.getTrustManagers(), null);
 146 
 147         SSLSocketFactory sslsf = ctx.getSocketFactory();
 148         SSLSocket sslSocket = (SSLSocket)
 149             sslsf.createSocket("localhost", serverPort);
 150 
 151         if (clientProtocol != null) {
 152             sslSocket.setEnabledProtocols(new String[] {clientProtocol});
 153         }
 154 
 155         if (clientCiperSuite != null) {
 156             sslSocket.setEnabledCipherSuites(new String[] {clientCiperSuite});
 157         }
 158 
 159         InputStream sslIS = sslSocket.getInputStream();
 160         OutputStream sslOS = sslSocket.getOutputStream();
 161 
 162         sslOS.write(280);
 163         sslOS.flush();
 164         sslIS.read();
 165 
 166         sslSocket.close();
 167     }
 168 
 169     private void checkKeySize(KeyStore ks) throws Exception {
 170         PrivateKey privateKey = null;
 171         PublicKey publicKey = null;
 172 
 173         if (ks.containsAlias(keyAlias)) {
 174             System.out.println("Loaded entry: " + keyAlias);
 175             privateKey = (PrivateKey)ks.getKey(keyAlias, null);
 176             publicKey = (PublicKey)ks.getCertificate(keyAlias).getPublicKey();
 177 
 178             int privateKeySize = KeyUtil.getKeySize(privateKey);
 179             if (privateKeySize != keySize) {
 180                 throw new Exception("Expected key size is " + keySize +
 181                         ", but the private key size is " + privateKeySize);
 182             }
 183 
 184             int publicKeySize = KeyUtil.getKeySize(publicKey);
 185             if (publicKeySize != keySize) {
 186                 throw new Exception("Expected key size is " + keySize +
 187                         ", but the public key size is " + publicKeySize);
 188             }
 189         }
 190     }
 191 
 192     /*
 193      * =============================================================
 194      * The remainder is just support stuff
 195      */
 196 
 197     // use any free port by default
 198     volatile int serverPort = 0;
 199 
 200     volatile Exception serverException = null;
 201     volatile Exception clientException = null;
 202 
 203     private static String keyAlias;
 204     private static int keySize;
 205     private static String clientProtocol = null;
 206     private static String clientCiperSuite = null;
 207 
 208     private static void parseArguments(String[] args) {
 209         keyAlias = args[0];
 210         keySize = Integer.parseInt(args[1]);
 211 
 212         if (args.length > 2) {
 213             clientProtocol = args[2];
 214         }
 215 
 216         if (args.length > 3) {
 217             clientCiperSuite = args[3];
 218         }
 219     }
 220 
 221     public static void main(String[] args) throws Exception {
 222         if (debug) {
 223             System.setProperty("javax.net.debug", "all");
 224         }
 225 
 226         // Get the customized arguments.
 227         parseArguments(args);
 228 
 229         new ShortRSAKeyWithinTLS();
 230     }
 231 
 232     Thread clientThread = null;
 233     Thread serverThread = null;
 234 
 235     /*
 236      * Primary constructor, used to drive remainder of the test.
 237      *
 238      * Fork off the other side, then do your work.
 239      */
 240     ShortRSAKeyWithinTLS() throws Exception {
 241         try {
 242             if (separateServerThread) {
 243                 startServer(true);
 244                 startClient(false);
 245             } else {
 246                 startClient(true);
 247                 startServer(false);
 248             }
 249         } catch (Exception e) {
 250             // swallow for now.  Show later
 251         }
 252 
 253         /*
 254          * Wait for other side to close down.
 255          */
 256         if (separateServerThread) {
 257             serverThread.join();
 258         } else {
 259             clientThread.join();
 260         }
 261 
 262         /*
 263          * When we get here, the test is pretty much over.
 264          * Which side threw the error?
 265          */
 266         Exception local;
 267         Exception remote;
 268         String whichRemote;
 269 
 270         if (separateServerThread) {
 271             remote = serverException;
 272             local = clientException;
 273             whichRemote = "server";
 274         } else {
 275             remote = clientException;
 276             local = serverException;
 277             whichRemote = "client";
 278         }
 279 
 280         /*
 281          * If both failed, return the curthread's exception, but also
 282          * print the remote side Exception
 283          */
 284         if ((local != null) && (remote != null)) {
 285             System.out.println(whichRemote + " also threw:");
 286             remote.printStackTrace();
 287             System.out.println();
 288             throw local;
 289         }
 290 
 291         if (remote != null) {
 292             throw remote;
 293         }
 294 
 295         if (local != null) {
 296             throw local;
 297         }
 298     }
 299 
 300     void startServer(boolean newThread) throws Exception {
 301         if (newThread) {
 302             serverThread = new Thread() {
 303                 public void run() {
 304                     try {
 305                         doServerSide();
 306                     } catch (Exception e) {
 307                         /*
 308                          * Our server thread just died.
 309                          *
 310                          * Release the client, if not active already...
 311                          */
 312                         System.err.println("Server died...");
 313                         serverReady = true;
 314                         serverException = e;
 315                     }
 316                 }
 317             };
 318             serverThread.start();
 319         } else {
 320             try {
 321                 doServerSide();
 322             } catch (Exception e) {
 323                 serverException = e;
 324             } finally {
 325                 serverReady = true;
 326             }
 327         }
 328     }
 329 
 330     void startClient(boolean newThread) throws Exception {
 331         if (newThread) {
 332             clientThread = new Thread() {
 333                 public void run() {
 334                     try {
 335                         doClientSide();
 336                     } catch (Exception e) {
 337                         /*
 338                          * Our client thread just died.
 339                          */
 340                         System.err.println("Client died...");
 341                         clientException = e;
 342                     }
 343                 }
 344             };
 345             clientThread.start();
 346         } else {
 347             try {
 348                 doClientSide();
 349             } catch (Exception e) {
 350                 clientException = e;
 351             }
 352         }
 353     }
 354 }
 355