1 /*
   2  * Copyright (c) 2012, 2015, Oracle and/or its affiliates. All rights reserved.
   3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
   4  *
   5  * This code is free software; you can redistribute it and/or modify it
   6  * under the terms of the GNU General Public License version 2 only, as
   7  * published by the Free Software Foundation.
   8  *
   9  * This code is distributed in the hope that it will be useful, but WITHOUT
  10  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
  11  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
  12  * version 2 for more details (a copy is included in the LICENSE file that
  13  * accompanied this code).
  14  *
  15  * You should have received a copy of the GNU General Public License version
  16  * 2 along with this work; if not, write to the Free Software Foundation,
  17  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
  18  *
  19  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
  20  * or visit www.oracle.com if you need additional information or have any
  21  * questions.
  22  */
  23 
  24 import java.io.*;
  25 import java.security.*;
  26 import javax.net.*;
  27 import javax.net.ssl.*;
  28 
  29 import sun.security.util.KeyUtil;
  30 
  31 public class ShortRSAKeyWithinTLS {
  32 
  33     /*
  34      * =============================================================
  35      * Set the various variables needed for the tests, then
  36      * specify what tests to run on each side.
  37      */
  38 
  39     /*
  40      * Should we run the client or server in a separate thread?
  41      * Both sides can throw exceptions, but do you have a preference
  42      * as to which side should be the main thread.
  43      */
  44     static boolean separateServerThread = false;
  45 
  46     /*
  47      * Is the server ready to serve?
  48      */
  49     volatile static boolean serverReady = false;
  50 
  51     /*
  52      * Turn on SSL debugging?
  53      */
  54     static boolean debug = false;
  55 
  56     /*
  57      * If the client or server is doing some kind of object creation
  58      * that the other side depends on, and that thread prematurely
  59      * exits, you may experience a hang.  The test harness will
  60      * terminate all hung threads after its timeout has expired,
  61      * currently 3 minutes by default, but you might try to be
  62      * smart about it....
  63      */
  64 
  65     /*
  66      * Define the server side of the test.
  67      *
  68      * If the server prematurely exits, serverReady will be set to true
  69      * to avoid infinite hangs.
  70      */
  71     void doServerSide() throws Exception {
  72 
  73         // load the key store
  74         KeyStore ks = KeyStore.getInstance("Windows-MY", "SunMSCAPI");
  75         ks.load(null, null);
  76         System.out.println("Loaded keystore: Windows-MY");
  77 
  78         // check key size
  79         checkKeySize(ks);
  80 
  81         // initialize the SSLContext
  82         KeyManagerFactory kmf = KeyManagerFactory.getInstance("SunX509");
  83         kmf.init(ks, null);
  84 
  85         TrustManagerFactory tmf = TrustManagerFactory.getInstance("SunX509");
  86         tmf.init(ks);
  87 
  88         SSLContext ctx = SSLContext.getInstance("TLS");
  89         ctx.init(kmf.getKeyManagers(), tmf.getTrustManagers(), null);
  90 
  91         ServerSocketFactory ssf = ctx.getServerSocketFactory();
  92         SSLServerSocket sslServerSocket = (SSLServerSocket)
  93                                 ssf.createServerSocket(serverPort);
  94         sslServerSocket.setNeedClientAuth(true);
  95         serverPort = sslServerSocket.getLocalPort();
  96         System.out.println("serverPort = " + serverPort);
  97 
  98         /*
  99          * Signal Client, we're ready for his connect.
 100          */
 101         serverReady = true;
 102 
 103         SSLSocket sslSocket = (SSLSocket) sslServerSocket.accept();
 104         InputStream sslIS = sslSocket.getInputStream();
 105         OutputStream sslOS = sslSocket.getOutputStream();
 106 
 107         sslIS.read();
 108         sslOS.write(85);
 109         sslOS.flush();
 110 
 111         sslSocket.close();
 112     }
 113 
 114     /*
 115      * Define the client side of the test.
 116      *
 117      * If the server prematurely exits, serverReady will be set to true
 118      * to avoid infinite hangs.
 119      */
 120     void doClientSide() throws Exception {
 121 
 122         /*
 123          * Wait for server to get started.
 124          */
 125         while (!serverReady) {
 126             Thread.sleep(50);
 127         }
 128 
 129         // load the key store
 130         KeyStore ks = KeyStore.getInstance("Windows-MY", "SunMSCAPI");
 131         ks.load(null, null);
 132         System.out.println("Loaded keystore: Windows-MY");
 133 
 134         // initialize the SSLContext
 135         KeyManagerFactory kmf = KeyManagerFactory.getInstance("SunX509");
 136         kmf.init(ks, null);
 137 
 138         TrustManagerFactory tmf = TrustManagerFactory.getInstance("SunX509");
 139         tmf.init(ks);
 140 
 141         SSLContext ctx = SSLContext.getInstance("TLS");
 142         ctx.init(kmf.getKeyManagers(), tmf.getTrustManagers(), null);
 143 
 144         SSLSocketFactory sslsf = ctx.getSocketFactory();
 145         SSLSocket sslSocket = (SSLSocket)
 146             sslsf.createSocket("localhost", serverPort);
 147 
 148         if (clientProtocol != null) {
 149             sslSocket.setEnabledProtocols(new String[] {clientProtocol});
 150         }
 151 
 152         if (clientCiperSuite != null) {
 153             sslSocket.setEnabledCipherSuites(new String[] {clientCiperSuite});
 154         }
 155 
 156         InputStream sslIS = sslSocket.getInputStream();
 157         OutputStream sslOS = sslSocket.getOutputStream();
 158 
 159         sslOS.write(280);
 160         sslOS.flush();
 161         sslIS.read();
 162 
 163         sslSocket.close();
 164     }
 165 
 166     private void checkKeySize(KeyStore ks) throws Exception {
 167         PrivateKey privateKey = null;
 168         PublicKey publicKey = null;
 169 
 170         if (ks.containsAlias(keyAlias)) {
 171             System.out.println("Loaded entry: " + keyAlias);
 172             privateKey = (PrivateKey)ks.getKey(keyAlias, null);
 173             publicKey = (PublicKey)ks.getCertificate(keyAlias).getPublicKey();
 174 
 175             int privateKeySize = KeyUtil.getKeySize(privateKey);
 176             if (privateKeySize != keySize) {
 177                 throw new Exception("Expected key size is " + keySize +
 178                         ", but the private key size is " + privateKeySize);
 179             }
 180 
 181             int publicKeySize = KeyUtil.getKeySize(publicKey);
 182             if (publicKeySize != keySize) {
 183                 throw new Exception("Expected key size is " + keySize +
 184                         ", but the public key size is " + publicKeySize);
 185             }
 186         }
 187     }
 188 
 189     /*
 190      * =============================================================
 191      * The remainder is just support stuff
 192      */
 193 
 194     // use any free port by default
 195     volatile int serverPort = 0;
 196 
 197     volatile Exception serverException = null;
 198     volatile Exception clientException = null;
 199 
 200     private static String keyAlias;
 201     private static int keySize;
 202     private static String clientProtocol = null;
 203     private static String clientCiperSuite = null;
 204 
 205     private static void parseArguments(String[] args) {
 206         keyAlias = args[0];
 207         keySize = Integer.parseInt(args[1]);
 208 
 209         if (args.length > 2) {
 210             clientProtocol = args[2];
 211         }
 212 
 213         if (args.length > 3) {
 214             clientCiperSuite = args[3];
 215         }
 216     }
 217 
 218     public static void main(String[] args) throws Exception {
 219         if (debug) {
 220             System.setProperty("javax.net.debug", "all");
 221         }
 222 
 223         // Get the customized arguments.
 224         parseArguments(args);
 225 
 226         new ShortRSAKeyWithinTLS();
 227     }
 228 
 229     Thread clientThread = null;
 230     Thread serverThread = null;
 231 
 232     /*
 233      * Primary constructor, used to drive remainder of the test.
 234      *
 235      * Fork off the other side, then do your work.
 236      */
 237     ShortRSAKeyWithinTLS() throws Exception {
 238         try {
 239             if (separateServerThread) {
 240                 startServer(true);
 241                 startClient(false);
 242             } else {
 243                 startClient(true);
 244                 startServer(false);
 245             }
 246         } catch (Exception e) {
 247             // swallow for now.  Show later
 248         }
 249 
 250         /*
 251          * Wait for other side to close down.
 252          */
 253         if (separateServerThread) {
 254             serverThread.join();
 255         } else {
 256             clientThread.join();
 257         }
 258 
 259         /*
 260          * When we get here, the test is pretty much over.
 261          * Which side threw the error?
 262          */
 263         Exception local;
 264         Exception remote;
 265         String whichRemote;
 266 
 267         if (separateServerThread) {
 268             remote = serverException;
 269             local = clientException;
 270             whichRemote = "server";
 271         } else {
 272             remote = clientException;
 273             local = serverException;
 274             whichRemote = "client";
 275         }
 276 
 277         /*
 278          * If both failed, return the curthread's exception, but also
 279          * print the remote side Exception
 280          */
 281         if ((local != null) && (remote != null)) {
 282             System.out.println(whichRemote + " also threw:");
 283             remote.printStackTrace();
 284             System.out.println();
 285             throw local;
 286         }
 287 
 288         if (remote != null) {
 289             throw remote;
 290         }
 291 
 292         if (local != null) {
 293             throw local;
 294         }
 295     }
 296 
 297     void startServer(boolean newThread) throws Exception {
 298         if (newThread) {
 299             serverThread = new Thread() {
 300                 public void run() {
 301                     try {
 302                         doServerSide();
 303                     } catch (Exception e) {
 304                         /*
 305                          * Our server thread just died.
 306                          *
 307                          * Release the client, if not active already...
 308                          */
 309                         System.err.println("Server died...");
 310                         serverReady = true;
 311                         serverException = e;
 312                     }
 313                 }
 314             };
 315             serverThread.start();
 316         } else {
 317             try {
 318                 doServerSide();
 319             } catch (Exception e) {
 320                 serverException = e;
 321             } finally {
 322                 serverReady = true;
 323             }
 324         }
 325     }
 326 
 327     void startClient(boolean newThread) throws Exception {
 328         if (newThread) {
 329             clientThread = new Thread() {
 330                 public void run() {
 331                     try {
 332                         doClientSide();
 333                     } catch (Exception e) {
 334                         /*
 335                          * Our client thread just died.
 336                          */
 337                         System.err.println("Client died...");
 338                         clientException = e;
 339                     }
 340                 }
 341             };
 342             clientThread.start();
 343         } else {
 344             try {
 345                 doClientSide();
 346             } catch (Exception e) {
 347                 clientException = e;
 348             }
 349         }
 350     }
 351 }
 352