1 /* 2 * Copyright (c) 2012, 2015, Oracle and/or its affiliates. All rights reserved. 3 * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. 4 * 5 * This code is free software; you can redistribute it and/or modify it 6 * under the terms of the GNU General Public License version 2 only, as 7 * published by the Free Software Foundation. 8 * 9 * This code is distributed in the hope that it will be useful, but WITHOUT 10 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or 11 * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License 12 * version 2 for more details (a copy is included in the LICENSE file that 13 * accompanied this code). 14 * 15 * You should have received a copy of the GNU General Public License version 16 * 2 along with this work; if not, write to the Free Software Foundation, 17 * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. 18 * 19 * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA 20 * or visit www.oracle.com if you need additional information or have any 21 * questions. 22 */ 23 24 import java.io.*; 25 import java.security.*; 26 import javax.net.*; 27 import javax.net.ssl.*; 28 29 import sun.security.util.KeyUtil; 30 31 public class ShortRSAKeyWithinTLS { 32 33 /* 34 * ============================================================= 35 * Set the various variables needed for the tests, then 36 * specify what tests to run on each side. 37 */ 38 39 /* 40 * Should we run the client or server in a separate thread? 41 * Both sides can throw exceptions, but do you have a preference 42 * as to which side should be the main thread. 43 */ 44 static boolean separateServerThread = false; 45 46 /* 47 * Is the server ready to serve? 48 */ 49 volatile static boolean serverReady = false; 50 51 /* 52 * Turn on SSL debugging? 53 */ 54 static boolean debug = false; 55 56 /* 57 * If the client or server is doing some kind of object creation 58 * that the other side depends on, and that thread prematurely 59 * exits, you may experience a hang. The test harness will 60 * terminate all hung threads after its timeout has expired, 61 * currently 3 minutes by default, but you might try to be 62 * smart about it.... 63 */ 64 65 /* 66 * Define the server side of the test. 67 * 68 * If the server prematurely exits, serverReady will be set to true 69 * to avoid infinite hangs. 70 */ 71 void doServerSide() throws Exception { 72 73 // load the key store 74 KeyStore ks = KeyStore.getInstance("Windows-MY", "SunMSCAPI"); 75 ks.load(null, null); 76 System.out.println("Loaded keystore: Windows-MY"); 77 78 // check key size 79 checkKeySize(ks); 80 81 // initialize the SSLContext 82 KeyManagerFactory kmf = KeyManagerFactory.getInstance("SunX509"); 83 kmf.init(ks, null); 84 85 TrustManagerFactory tmf = TrustManagerFactory.getInstance("SunX509"); 86 tmf.init(ks); 87 88 SSLContext ctx = SSLContext.getInstance("TLS"); 89 ctx.init(kmf.getKeyManagers(), tmf.getTrustManagers(), null); 90 91 ServerSocketFactory ssf = ctx.getServerSocketFactory(); 92 SSLServerSocket sslServerSocket = (SSLServerSocket) 93 ssf.createServerSocket(serverPort); 94 sslServerSocket.setNeedClientAuth(true); 95 serverPort = sslServerSocket.getLocalPort(); 96 System.out.println("serverPort = " + serverPort); 97 98 /* 99 * Signal Client, we're ready for his connect. 100 */ 101 serverReady = true; 102 103 SSLSocket sslSocket = (SSLSocket) sslServerSocket.accept(); 104 InputStream sslIS = sslSocket.getInputStream(); 105 OutputStream sslOS = sslSocket.getOutputStream(); 106 107 sslIS.read(); 108 sslOS.write(85); 109 sslOS.flush(); 110 111 sslSocket.close(); 112 } 113 114 /* 115 * Define the client side of the test. 116 * 117 * If the server prematurely exits, serverReady will be set to true 118 * to avoid infinite hangs. 119 */ 120 void doClientSide() throws Exception { 121 122 /* 123 * Wait for server to get started. 124 */ 125 while (!serverReady) { 126 Thread.sleep(50); 127 } 128 129 // load the key store 130 KeyStore ks = KeyStore.getInstance("Windows-MY", "SunMSCAPI"); 131 ks.load(null, null); 132 System.out.println("Loaded keystore: Windows-MY"); 133 134 // initialize the SSLContext 135 KeyManagerFactory kmf = KeyManagerFactory.getInstance("SunX509"); 136 kmf.init(ks, null); 137 138 TrustManagerFactory tmf = TrustManagerFactory.getInstance("SunX509"); 139 tmf.init(ks); 140 141 SSLContext ctx = SSLContext.getInstance("TLS"); 142 ctx.init(kmf.getKeyManagers(), tmf.getTrustManagers(), null); 143 144 SSLSocketFactory sslsf = ctx.getSocketFactory(); 145 SSLSocket sslSocket = (SSLSocket) 146 sslsf.createSocket("localhost", serverPort); 147 148 if (clientProtocol != null) { 149 sslSocket.setEnabledProtocols(new String[] {clientProtocol}); 150 } 151 152 if (clientCiperSuite != null) { 153 sslSocket.setEnabledCipherSuites(new String[] {clientCiperSuite}); 154 } 155 156 InputStream sslIS = sslSocket.getInputStream(); 157 OutputStream sslOS = sslSocket.getOutputStream(); 158 159 sslOS.write(280); 160 sslOS.flush(); 161 sslIS.read(); 162 163 sslSocket.close(); 164 } 165 166 private void checkKeySize(KeyStore ks) throws Exception { 167 PrivateKey privateKey = null; 168 PublicKey publicKey = null; 169 170 if (ks.containsAlias(keyAlias)) { 171 System.out.println("Loaded entry: " + keyAlias); 172 privateKey = (PrivateKey)ks.getKey(keyAlias, null); 173 publicKey = (PublicKey)ks.getCertificate(keyAlias).getPublicKey(); 174 175 int privateKeySize = KeyUtil.getKeySize(privateKey); 176 if (privateKeySize != keySize) { 177 throw new Exception("Expected key size is " + keySize + 178 ", but the private key size is " + privateKeySize); 179 } 180 181 int publicKeySize = KeyUtil.getKeySize(publicKey); 182 if (publicKeySize != keySize) { 183 throw new Exception("Expected key size is " + keySize + 184 ", but the public key size is " + publicKeySize); 185 } 186 } 187 } 188 189 /* 190 * ============================================================= 191 * The remainder is just support stuff 192 */ 193 194 // use any free port by default 195 volatile int serverPort = 0; 196 197 volatile Exception serverException = null; 198 volatile Exception clientException = null; 199 200 private static String keyAlias; 201 private static int keySize; 202 private static String clientProtocol = null; 203 private static String clientCiperSuite = null; 204 205 private static void parseArguments(String[] args) { 206 keyAlias = args[0]; 207 keySize = Integer.parseInt(args[1]); 208 209 if (args.length > 2) { 210 clientProtocol = args[2]; 211 } 212 213 if (args.length > 3) { 214 clientCiperSuite = args[3]; 215 } 216 } 217 218 public static void main(String[] args) throws Exception { 219 if (debug) { 220 System.setProperty("javax.net.debug", "all"); 221 } 222 223 // Get the customized arguments. 224 parseArguments(args); 225 226 new ShortRSAKeyWithinTLS(); 227 } 228 229 Thread clientThread = null; 230 Thread serverThread = null; 231 232 /* 233 * Primary constructor, used to drive remainder of the test. 234 * 235 * Fork off the other side, then do your work. 236 */ 237 ShortRSAKeyWithinTLS() throws Exception { 238 try { 239 if (separateServerThread) { 240 startServer(true); 241 startClient(false); 242 } else { 243 startClient(true); 244 startServer(false); 245 } 246 } catch (Exception e) { 247 // swallow for now. Show later 248 } 249 250 /* 251 * Wait for other side to close down. 252 */ 253 if (separateServerThread) { 254 serverThread.join(); 255 } else { 256 clientThread.join(); 257 } 258 259 /* 260 * When we get here, the test is pretty much over. 261 * Which side threw the error? 262 */ 263 Exception local; 264 Exception remote; 265 String whichRemote; 266 267 if (separateServerThread) { 268 remote = serverException; 269 local = clientException; 270 whichRemote = "server"; 271 } else { 272 remote = clientException; 273 local = serverException; 274 whichRemote = "client"; 275 } 276 277 /* 278 * If both failed, return the curthread's exception, but also 279 * print the remote side Exception 280 */ 281 if ((local != null) && (remote != null)) { 282 System.out.println(whichRemote + " also threw:"); 283 remote.printStackTrace(); 284 System.out.println(); 285 throw local; 286 } 287 288 if (remote != null) { 289 throw remote; 290 } 291 292 if (local != null) { 293 throw local; 294 } 295 } 296 297 void startServer(boolean newThread) throws Exception { 298 if (newThread) { 299 serverThread = new Thread() { 300 public void run() { 301 try { 302 doServerSide(); 303 } catch (Exception e) { 304 /* 305 * Our server thread just died. 306 * 307 * Release the client, if not active already... 308 */ 309 System.err.println("Server died..."); 310 serverReady = true; 311 serverException = e; 312 } 313 } 314 }; 315 serverThread.start(); 316 } else { 317 try { 318 doServerSide(); 319 } catch (Exception e) { 320 serverException = e; 321 } finally { 322 serverReady = true; 323 } 324 } 325 } 326 327 void startClient(boolean newThread) throws Exception { 328 if (newThread) { 329 clientThread = new Thread() { 330 public void run() { 331 try { 332 doClientSide(); 333 } catch (Exception e) { 334 /* 335 * Our client thread just died. 336 */ 337 System.err.println("Client died..."); 338 clientException = e; 339 } 340 } 341 }; 342 clientThread.start(); 343 } else { 344 try { 345 doClientSide(); 346 } catch (Exception e) { 347 clientException = e; 348 } 349 } 350 } 351 } 352