1 /* 2 * Copyright (c) 2012, Oracle and/or its affiliates. All rights reserved. 3 * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. 4 * 5 * This code is free software; you can redistribute it and/or modify it 6 * under the terms of the GNU General Public License version 2 only, as 7 * published by the Free Software Foundation. 8 * 9 * This code is distributed in the hope that it will be useful, but WITHOUT 10 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or 11 * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License 12 * version 2 for more details (a copy is included in the LICENSE file that 13 * accompanied this code). 14 * 15 * You should have received a copy of the GNU General Public License version 16 * 2 along with this work; if not, write to the Free Software Foundation, 17 * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. 18 * 19 * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA 20 * or visit www.oracle.com if you need additional information or have any 21 * questions. 22 */ 23 24 import java.io.*; 25 import java.net.*; 26 import java.util.*; 27 import java.security.*; 28 import javax.net.*; 29 import javax.net.ssl.*; 30 import java.lang.reflect.*; 31 32 import sun.security.util.KeyUtil; 33 34 public class ShortRSAKeyWithinTLS { 35 36 /* 37 * ============================================================= 38 * Set the various variables needed for the tests, then 39 * specify what tests to run on each side. 40 */ 41 42 /* 43 * Should we run the client or server in a separate thread? 44 * Both sides can throw exceptions, but do you have a preference 45 * as to which side should be the main thread. 46 */ 47 static boolean separateServerThread = false; 48 49 /* 50 * Is the server ready to serve? 51 */ 52 volatile static boolean serverReady = false; 53 54 /* 55 * Turn on SSL debugging? 56 */ 57 static boolean debug = false; 58 59 /* 60 * If the client or server is doing some kind of object creation 61 * that the other side depends on, and that thread prematurely 62 * exits, you may experience a hang. The test harness will 63 * terminate all hung threads after its timeout has expired, 64 * currently 3 minutes by default, but you might try to be 65 * smart about it.... 66 */ 67 68 /* 69 * Define the server side of the test. 70 * 71 * If the server prematurely exits, serverReady will be set to true 72 * to avoid infinite hangs. 73 */ 74 void doServerSide() throws Exception { 75 76 // load the key store 77 KeyStore ks = KeyStore.getInstance("Windows-MY", "SunMSCAPI"); 78 ks.load(null, null); 79 System.out.println("Loaded keystore: Windows-MY"); 80 81 // check key size 82 checkKeySize(ks); 83 84 // initialize the SSLContext 85 KeyManagerFactory kmf = KeyManagerFactory.getInstance("SunX509"); 86 kmf.init(ks, null); 87 88 TrustManagerFactory tmf = TrustManagerFactory.getInstance("SunX509"); 89 tmf.init(ks); 90 91 SSLContext ctx = SSLContext.getInstance("TLS"); 92 ctx.init(kmf.getKeyManagers(), tmf.getTrustManagers(), null); 93 94 ServerSocketFactory ssf = ctx.getServerSocketFactory(); 95 SSLServerSocket sslServerSocket = (SSLServerSocket) 96 ssf.createServerSocket(serverPort); 97 sslServerSocket.setNeedClientAuth(true); 98 serverPort = sslServerSocket.getLocalPort(); 99 System.out.println("serverPort = " + serverPort); 100 101 /* 102 * Signal Client, we're ready for his connect. 103 */ 104 serverReady = true; 105 106 SSLSocket sslSocket = (SSLSocket) sslServerSocket.accept(); 107 InputStream sslIS = sslSocket.getInputStream(); 108 OutputStream sslOS = sslSocket.getOutputStream(); 109 110 sslIS.read(); 111 sslOS.write(85); 112 sslOS.flush(); 113 114 sslSocket.close(); 115 } 116 117 /* 118 * Define the client side of the test. 119 * 120 * If the server prematurely exits, serverReady will be set to true 121 * to avoid infinite hangs. 122 */ 123 void doClientSide() throws Exception { 124 125 /* 126 * Wait for server to get started. 127 */ 128 while (!serverReady) { 129 Thread.sleep(50); 130 } 131 132 // load the key store 133 KeyStore ks = KeyStore.getInstance("Windows-MY", "SunMSCAPI"); 134 ks.load(null, null); 135 System.out.println("Loaded keystore: Windows-MY"); 136 137 // initialize the SSLContext 138 KeyManagerFactory kmf = KeyManagerFactory.getInstance("SunX509"); 139 kmf.init(ks, null); 140 141 TrustManagerFactory tmf = TrustManagerFactory.getInstance("SunX509"); 142 tmf.init(ks); 143 144 SSLContext ctx = SSLContext.getInstance("TLS"); 145 ctx.init(kmf.getKeyManagers(), tmf.getTrustManagers(), null); 146 147 SSLSocketFactory sslsf = ctx.getSocketFactory(); 148 SSLSocket sslSocket = (SSLSocket) 149 sslsf.createSocket("localhost", serverPort); 150 151 if (clientProtocol != null) { 152 sslSocket.setEnabledProtocols(new String[] {clientProtocol}); 153 } 154 155 if (clientCiperSuite != null) { 156 sslSocket.setEnabledCipherSuites(new String[] {clientCiperSuite}); 157 } 158 159 InputStream sslIS = sslSocket.getInputStream(); 160 OutputStream sslOS = sslSocket.getOutputStream(); 161 162 sslOS.write(280); 163 sslOS.flush(); 164 sslIS.read(); 165 166 sslSocket.close(); 167 } 168 169 private void checkKeySize(KeyStore ks) throws Exception { 170 PrivateKey privateKey = null; 171 PublicKey publicKey = null; 172 173 if (ks.containsAlias(keyAlias)) { 174 System.out.println("Loaded entry: " + keyAlias); 175 privateKey = (PrivateKey)ks.getKey(keyAlias, null); 176 publicKey = (PublicKey)ks.getCertificate(keyAlias).getPublicKey(); 177 178 int privateKeySize = KeyUtil.getKeySize(privateKey); 179 if (privateKeySize != keySize) { 180 throw new Exception("Expected key size is " + keySize + 181 ", but the private key size is " + privateKeySize); 182 } 183 184 int publicKeySize = KeyUtil.getKeySize(publicKey); 185 if (publicKeySize != keySize) { 186 throw new Exception("Expected key size is " + keySize + 187 ", but the public key size is " + publicKeySize); 188 } 189 } 190 } 191 192 /* 193 * ============================================================= 194 * The remainder is just support stuff 195 */ 196 197 // use any free port by default 198 volatile int serverPort = 0; 199 200 volatile Exception serverException = null; 201 volatile Exception clientException = null; 202 203 private static String keyAlias; 204 private static int keySize; 205 private static String clientProtocol = null; 206 private static String clientCiperSuite = null; 207 208 private static void parseArguments(String[] args) { 209 keyAlias = args[0]; 210 keySize = Integer.parseInt(args[1]); 211 212 if (args.length > 2) { 213 clientProtocol = args[2]; 214 } 215 216 if (args.length > 3) { 217 clientCiperSuite = args[3]; 218 } 219 } 220 221 public static void main(String[] args) throws Exception { 222 if (debug) { 223 System.setProperty("javax.net.debug", "all"); 224 } 225 226 // Get the customized arguments. 227 parseArguments(args); 228 229 new ShortRSAKeyWithinTLS(); 230 } 231 232 Thread clientThread = null; 233 Thread serverThread = null; 234 235 /* 236 * Primary constructor, used to drive remainder of the test. 237 * 238 * Fork off the other side, then do your work. 239 */ 240 ShortRSAKeyWithinTLS() throws Exception { 241 try { 242 if (separateServerThread) { 243 startServer(true); 244 startClient(false); 245 } else { 246 startClient(true); 247 startServer(false); 248 } 249 } catch (Exception e) { 250 // swallow for now. Show later 251 } 252 253 /* 254 * Wait for other side to close down. 255 */ 256 if (separateServerThread) { 257 serverThread.join(); 258 } else { 259 clientThread.join(); 260 } 261 262 /* 263 * When we get here, the test is pretty much over. 264 * Which side threw the error? 265 */ 266 Exception local; 267 Exception remote; 268 String whichRemote; 269 270 if (separateServerThread) { 271 remote = serverException; 272 local = clientException; 273 whichRemote = "server"; 274 } else { 275 remote = clientException; 276 local = serverException; 277 whichRemote = "client"; 278 } 279 280 /* 281 * If both failed, return the curthread's exception, but also 282 * print the remote side Exception 283 */ 284 if ((local != null) && (remote != null)) { 285 System.out.println(whichRemote + " also threw:"); 286 remote.printStackTrace(); 287 System.out.println(); 288 throw local; 289 } 290 291 if (remote != null) { 292 throw remote; 293 } 294 295 if (local != null) { 296 throw local; 297 } 298 } 299 300 void startServer(boolean newThread) throws Exception { 301 if (newThread) { 302 serverThread = new Thread() { 303 public void run() { 304 try { 305 doServerSide(); 306 } catch (Exception e) { 307 /* 308 * Our server thread just died. 309 * 310 * Release the client, if not active already... 311 */ 312 System.err.println("Server died..."); 313 serverReady = true; 314 serverException = e; 315 } 316 } 317 }; 318 serverThread.start(); 319 } else { 320 try { 321 doServerSide(); 322 } catch (Exception e) { 323 serverException = e; 324 } finally { 325 serverReady = true; 326 } 327 } 328 } 329 330 void startClient(boolean newThread) throws Exception { 331 if (newThread) { 332 clientThread = new Thread() { 333 public void run() { 334 try { 335 doClientSide(); 336 } catch (Exception e) { 337 /* 338 * Our client thread just died. 339 */ 340 System.err.println("Client died..."); 341 clientException = e; 342 } 343 } 344 }; 345 clientThread.start(); 346 } else { 347 try { 348 doClientSide(); 349 } catch (Exception e) { 350 clientException = e; 351 } 352 } 353 } 354 } 355