1 /*
   2  * Copyright (c) 2018, 2019, Oracle and/or its affiliates. All rights reserved.
   3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
   4  *
   5  * This code is free software; you can redistribute it and/or modify it
   6  * under the terms of the GNU General Public License version 2 only, as
   7  * published by the Free Software Foundation.
   8  *
   9  * This code is distributed in the hope that it will be useful, but WITHOUT
  10  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
  11  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
  12  * version 2 for more details (a copy is included in the LICENSE file that
  13  * accompanied this code).
  14  *
  15  * You should have received a copy of the GNU General Public License version
  16  * 2 along with this work; if not, write to the Free Software Foundation,
  17  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
  18  *
  19  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
  20  * or visit www.oracle.com if you need additional information or have any
  21  * questions.
  22  */
  23 
  24 /*
  25  * @test
  26  * @bug 8206929 8212885
  27  * @summary ensure that client only resumes a session if certain properties
  28  *    of the session are compatible with the new connection
  29  * @run main/othervm -Djdk.tls.client.protocols=TLSv1.2 -Djdk.tls.server.enableSessionTicketExtension=false -Djdk.tls.client.enableSessionTicketExtension=false ResumeChecksClient BASIC
  30  * @run main/othervm -Djdk.tls.client.protocols=TLSv1.2 -Djdk.tls.server.enableSessionTicketExtension=true -Djdk.tls.client.enableSessionTicketExtension=false ResumeChecksClient BASIC
  31  * @run main/othervm -Djdk.tls.client.protocols=TLSv1.2 -Djdk.tls.server.enableSessionTicketExtension=true -Djdk.tls.client.enableSessionTicketExtension=true ResumeChecksClient BASIC
  32  * @run main/othervm -Djdk.tls.client.protocols=TLSv1.3 -Djdk.tls.server.enableSessionTicketExtension=true -Djdk.tls.client.enableSessionTicketExtension=true ResumeChecksClient BASIC
  33  * @run main/othervm -Djdk.tls.client.protocols=TLSv1.2 -Djdk.tls.server.enableSessionTicketExtension=false -Djdk.tls.client.enableSessionTicketExtension=true ResumeChecksClient BASIC
  34  * @run main/othervm -Djdk.tls.client.protocols=TLSv1.3 -Djdk.tls.server.enableSessionTicketExtension=false -Djdk.tls.client.enableSessionTicketExtension=true ResumeChecksClient BASIC
  35  * @run main/othervm -Djdk.tls.server.enableSessionTicketExtension=false -Djdk.tls.client.enableSessionTicketExtension=true ResumeChecksClient BASIC
  36  * @run main/othervm -Djdk.tls.server.enableSessionTicketExtension=true -Djdk.tls.client.enableSessionTicketExtension=true ResumeChecksClient VERSION_2_TO_3
  37  * @run main/othervm -Djdk.tls.server.enableSessionTicketExtension=true -Djdk.tls.client.enableSessionTicketExtension=true ResumeChecksClient VERSION_3_TO_2
  38  * @run main/othervm -Djdk.tls.client.protocols=TLSv1.3 -Djdk.tls.server.enableSessionTicketExtension=true -Djdk.tls.client.enableSessionTicketExtension=true ResumeChecksClient CIPHER_SUITE
  39  * @run main/othervm -Djdk.tls.client.protocols=TLSv1.3 -Djdk.tls.server.enableSessionTicketExtension=true -Djdk.tls.client.enableSessionTicketExtension=true ResumeChecksClient SIGNATURE_SCHEME
  40  *
  41  */
  42 
  43 import javax.net.*;
  44 import javax.net.ssl.*;
  45 import java.io.*;
  46 import java.security.*;
  47 import java.net.*;
  48 import java.util.*;
  49 
  50 public class ResumeChecksClient {
  51 
  52     static String pathToStores = "../../../../javax/net/ssl/etc";
  53     static String keyStoreFile = "keystore";
  54     static String trustStoreFile = "truststore";
  55     static String passwd = "passphrase";
  56 
  57     enum TestMode {
  58         BASIC,
  59         VERSION_2_TO_3,
  60         VERSION_3_TO_2,
  61         CIPHER_SUITE,
  62         SIGNATURE_SCHEME
  63     }
  64 
  65     public static void main(String[] args) throws Exception {
  66 
  67         TestMode mode = TestMode.valueOf(args[0]);
  68 
  69         String keyFilename =
  70             System.getProperty("test.src", "./") + "/" + pathToStores +
  71                 "/" + keyStoreFile;
  72         String trustFilename =
  73             System.getProperty("test.src", "./") + "/" + pathToStores +
  74                 "/" + trustStoreFile;
  75 
  76         System.setProperty("javax.net.ssl.keyStore", keyFilename);
  77         System.setProperty("javax.net.ssl.keyStorePassword", passwd);
  78         System.setProperty("javax.net.ssl.trustStore", trustFilename);
  79         System.setProperty("javax.net.ssl.trustStorePassword", passwd);
  80 
  81         Server server = startServer();
  82         server.signal();
  83         SSLContext sslContext = SSLContext.getDefault();
  84         while (!server.started) {
  85             Thread.yield();
  86         }
  87         SSLSession firstSession = connect(sslContext, server.port, mode, false);
  88 
  89         server.signal();
  90         long secondStartTime = System.currentTimeMillis();
  91         Thread.sleep(10);
  92         SSLSession secondSession = connect(sslContext, server.port, mode, true);
  93 
  94         server.go = false;
  95         server.signal();
  96 
  97         switch (mode) {
  98         case BASIC:
  99             // fail if session is not resumed
 100             checkResumedSession(firstSession, secondSession);
 101             break;
 102         case VERSION_2_TO_3:
 103         case VERSION_3_TO_2:
 104         case CIPHER_SUITE:
 105         case SIGNATURE_SCHEME:
 106             // fail if a new session is not created
 107             if (secondSession.getCreationTime() <= secondStartTime) {
 108                 throw new RuntimeException("Existing session was used");
 109             }
 110             break;
 111         default:
 112             throw new RuntimeException("unknown mode: " + mode);
 113         }
 114     }
 115 
 116     private static class NoSig implements AlgorithmConstraints {
 117 
 118         private final String alg;
 119 
 120         NoSig(String alg) {
 121             this.alg = alg;
 122         }
 123 
 124 
 125         private boolean test(String a) {
 126             return !a.toLowerCase().contains(alg.toLowerCase());
 127         }
 128 
 129         @Override
 130         public boolean permits(Set<CryptoPrimitive> primitives, Key key) {
 131             return true;
 132         }
 133         @Override
 134         public boolean permits(Set<CryptoPrimitive> primitives,
 135             String algorithm, AlgorithmParameters parameters) {
 136 
 137             return test(algorithm);
 138         }
 139         @Override
 140         public boolean permits(Set<CryptoPrimitive> primitives,
 141             String algorithm, Key key, AlgorithmParameters parameters) {
 142 
 143             return test(algorithm);
 144         }
 145     }
 146 
 147     private static SSLSession connect(SSLContext sslContext, int port,
 148         TestMode mode, boolean second) {
 149 
 150         try {
 151             SSLSocket sock = (SSLSocket)
 152                 sslContext.getSocketFactory().createSocket();
 153             SSLParameters params = sock.getSSLParameters();
 154 
 155             switch (mode) {
 156             case BASIC:
 157                 // do nothing to ensure resumption works
 158                 break;
 159             case VERSION_2_TO_3:
 160                 if (second) {
 161                     params.setProtocols(new String[] {"TLSv1.3"});
 162                 } else {
 163                     params.setProtocols(new String[] {"TLSv1.2"});
 164                 }
 165                 break;
 166             case VERSION_3_TO_2:
 167                 if (second) {
 168                     params.setProtocols(new String[] {"TLSv1.2"});
 169                 } else {
 170                     params.setProtocols(new String[] {"TLSv1.3"});
 171                 }
 172                 break;
 173             case CIPHER_SUITE:
 174                 if (second) {
 175                     params.setCipherSuites(
 176                         new String[] {"TLS_AES_256_GCM_SHA384"});
 177                 } else {
 178                     params.setCipherSuites(
 179                         new String[] {"TLS_AES_128_GCM_SHA256"});
 180                 }
 181                 break;
 182             case SIGNATURE_SCHEME:
 183                 AlgorithmConstraints constraints =
 184                     params.getAlgorithmConstraints();
 185                 if (second) {
 186                     params.setAlgorithmConstraints(new NoSig("ecdsa"));
 187                 } else {
 188                     params.setAlgorithmConstraints(new NoSig("rsa"));
 189                 }
 190                 break;
 191             default:
 192                 throw new RuntimeException("unknown mode: " + mode);
 193             }
 194             sock.setSSLParameters(params);
 195             sock.connect(new InetSocketAddress("localhost", port));
 196             PrintWriter out = new PrintWriter(
 197                 new OutputStreamWriter(sock.getOutputStream()));
 198             out.println("message");
 199             out.flush();
 200             BufferedReader reader = new BufferedReader(
 201                 new InputStreamReader(sock.getInputStream()));
 202             String inMsg = reader.readLine();
 203             System.out.println("Client received: " + inMsg);
 204             SSLSession result = sock.getSession();
 205             sock.close();
 206             return result;
 207         } catch (Exception ex) {
 208             // unexpected exception
 209             throw new RuntimeException(ex);
 210         }
 211     }
 212 
 213     private static void checkResumedSession(SSLSession initSession,
 214             SSLSession resSession) throws Exception {
 215         StringBuilder diffLog = new StringBuilder();
 216 
 217         // Initial and resumed SSLSessions should have the same creation
 218         // times so they get invalidated together.
 219         long initCt = initSession.getCreationTime();
 220         long resumeCt = resSession.getCreationTime();
 221         if (initCt != resumeCt) {
 222             diffLog.append("Session creation time is different. Initial: ").
 223                     append(initCt).append(", Resumed: ").append(resumeCt).
 224                     append("\n");
 225         }
 226 
 227         // Ensure that peer and local certificate lists are preserved
 228         if (!Arrays.equals(initSession.getLocalCertificates(),
 229                 resSession.getLocalCertificates())) {
 230             diffLog.append("Local certificate mismatch between initial " +
 231                     "and resumed sessions\n");
 232         }
 233 
 234         if (!Arrays.equals(initSession.getPeerCertificates(),
 235                 resSession.getPeerCertificates())) {
 236             diffLog.append("Peer certificate mismatch between initial " +
 237                     "and resumed sessions\n");
 238         }
 239 
 240         // Buffer sizes should also be the same
 241         if (initSession.getApplicationBufferSize() !=
 242                 resSession.getApplicationBufferSize()) {
 243             diffLog.append(String.format(
 244                     "App Buffer sizes differ: Init: %d, Res: %d\n",
 245                     initSession.getApplicationBufferSize(),
 246                     resSession.getApplicationBufferSize()));
 247         }
 248 
 249         if (initSession.getPacketBufferSize() !=
 250                 resSession.getPacketBufferSize()) {
 251             diffLog.append(String.format(
 252                     "Packet Buffer sizes differ: Init: %d, Res: %d\n",
 253                     initSession.getPacketBufferSize(),
 254                     resSession.getPacketBufferSize()));
 255         }
 256 
 257         // Cipher suite should match
 258         if (!initSession.getCipherSuite().equals(
 259                 resSession.getCipherSuite())) {
 260             diffLog.append(String.format(
 261                     "CipherSuite does not match - Init: %s, Res: %s\n",
 262                     initSession.getCipherSuite(), resSession.getCipherSuite()));
 263         }
 264 
 265         // Peer host/port should match
 266         if (!initSession.getPeerHost().equals(resSession.getPeerHost()) ||
 267                 initSession.getPeerPort() != resSession.getPeerPort()) {
 268             diffLog.append(String.format(
 269                     "Host/Port mismatch - Init: %s/%d, Res: %s/%d\n",
 270                     initSession.getPeerHost(), initSession.getPeerPort(),
 271                     resSession.getPeerHost(), resSession.getPeerPort()));
 272         }
 273 
 274         // Check protocol
 275         if (!initSession.getProtocol().equals(resSession.getProtocol())) {
 276             diffLog.append(String.format(
 277                     "Protocol mismatch - Init: %s, Res: %s\n",
 278                     initSession.getProtocol(), resSession.getProtocol()));
 279         }
 280 
 281         // If the StringBuilder has any data in it then one of the checks
 282         // above failed and we should throw an exception.
 283         if (diffLog.length() > 0) {
 284             throw new RuntimeException(diffLog.toString());
 285         }
 286     }
 287 
 288     private static Server startServer() {
 289         Server server = new Server();
 290         new Thread(server).start();
 291         return server;
 292     }
 293 
 294     private static class Server implements Runnable {
 295 
 296         public volatile boolean go = true;
 297         private boolean signal = false;
 298         public volatile int port = 0;
 299         public volatile boolean started = false;
 300 
 301         private synchronized void waitForSignal() {
 302             while (!signal) {
 303                 try {
 304                     wait();
 305                 } catch (InterruptedException ex) {
 306                     // do nothing
 307                 }
 308             }
 309             signal = false;
 310         }
 311         public synchronized void signal() {
 312             signal = true;
 313             notify();
 314         }
 315 
 316         @Override
 317         public void run() {
 318             try {
 319 
 320                 SSLContext sc = SSLContext.getDefault();
 321                 ServerSocketFactory fac = sc.getServerSocketFactory();
 322                 SSLServerSocket ssock = (SSLServerSocket)
 323                     fac.createServerSocket(0);
 324                 this.port = ssock.getLocalPort();
 325 
 326                 waitForSignal();
 327                 started = true;
 328                 while (go) {
 329                     try {
 330                         System.out.println("Waiting for connection");
 331                         Socket sock = ssock.accept();
 332                         BufferedReader reader = new BufferedReader(
 333                             new InputStreamReader(sock.getInputStream()));
 334                         String line = reader.readLine();
 335                         System.out.println("server read: " + line);
 336                         PrintWriter out = new PrintWriter(
 337                             new OutputStreamWriter(sock.getOutputStream()));
 338                         out.println(line);
 339                         out.flush();
 340                         waitForSignal();
 341                     } catch (Exception ex) {
 342                         ex.printStackTrace();
 343                     }
 344                 }
 345             } catch (Exception ex) {
 346                 throw new RuntimeException(ex);
 347             }
 348         }
 349     }
 350 }