1 /*
   2  * Copyright (c) 2018, Oracle and/or its affiliates. All rights reserved.
   3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
   4  *
   5  * This code is free software; you can redistribute it and/or modify it
   6  * under the terms of the GNU General Public License version 2 only, as
   7  * published by the Free Software Foundation.
   8  *
   9  * This code is distributed in the hope that it will be useful, but WITHOUT
  10  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
  11  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
  12  * version 2 for more details (a copy is included in the LICENSE file that
  13  * accompanied this code).
  14  *
  15  * You should have received a copy of the GNU General Public License version
  16  * 2 along with this work; if not, write to the Free Software Foundation,
  17  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
  18  *
  19  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
  20  * or visit www.oracle.com if you need additional information or have any
  21  * questions.
  22  */
  23 
  24 /*
  25  * @test
  26  * @bug 8164879
  27  * @library /lib/testlibrary ../../
  28  * @library /test/lib
  29  * @summary Verify AES/GCM's limits set in the jdk.tls.keyLimits property
  30  * start a new handshake sequence to renegotiate the symmetric key with an
  31  * SSLSocket connection.  This test verifies the handshake method was called
  32  * via debugging info.  It does not verify the renegotiation was successful
  33  * as that is very hard.
  34  *
  35  * @run main SSLEngineKeyLimit 0 server AES/GCM/NoPadding keyupdate 1050000
  36  * @run main SSLEngineKeyLimit 1 client AES/GCM/NoPadding keyupdate 2^22
  37  */
  38 
  39 /*
  40  * This test runs in another process so we can monitor the debug
  41  * results.  The OutputAnalyzer must see correct debug output to return a
  42  * success.
  43  */
  44 
  45 import javax.net.ssl.KeyManagerFactory;
  46 import javax.net.ssl.SSLContext;
  47 import javax.net.ssl.SSLEngine;
  48 import javax.net.ssl.SSLEngineResult;
  49 import javax.net.ssl.TrustManagerFactory;
  50 import java.io.File;
  51 import java.io.PrintWriter;
  52 import java.nio.ByteBuffer;
  53 import java.security.KeyStore;
  54 import java.security.SecureRandom;
  55 import java.util.Arrays;
  56 
  57 import jdk.test.lib.process.ProcessTools;
  58 import jdk.test.lib.process.OutputAnalyzer;
  59 import jdk.testlibrary.Utils;
  60 
  61 public class SSLEngineKeyLimit {
  62 
  63     SSLEngine eng;
  64     static ByteBuffer cTos;
  65     static ByteBuffer sToc;
  66     static ByteBuffer outdata;
  67     ByteBuffer buf;
  68     static boolean ready = false;
  69 
  70     static String pathToStores = "../../../../javax/net/ssl/etc/";
  71     static String keyStoreFile = "keystore";
  72     static String passwd = "passphrase";
  73     static String keyFilename;
  74     static int dataLen = 10240;
  75     static boolean serverwrite = true;
  76     int totalDataLen = 0;
  77     static boolean sc = true;
  78     int delay = 1;
  79     static boolean readdone = false;
  80 
  81     // Turn on debugging
  82     static boolean debug = false;
  83 
  84     SSLEngineKeyLimit() {
  85         buf = ByteBuffer.allocate(dataLen*4);
  86     }
  87 
  88     /**
  89      * args should have two values:  server|client, <limit size>
  90      * Prepending 'p' is for internal use only.
  91      */
  92     public static void main(String args[]) throws Exception {
  93 
  94         for (int i = 0; i < args.length; i++) {
  95             System.out.print(" " + args[i]);
  96         }
  97         System.out.println();
  98         if (args[0].compareTo("p") != 0) {
  99             boolean expectedFail = (Integer.parseInt(args[0]) == 1);
 100             if (expectedFail) {
 101                 System.out.println("Test expected to not find updated msg");
 102             }
 103 
 104             // Write security property file to overwrite default
 105             File f = new File("keyusage."+ System.nanoTime());
 106             PrintWriter p = new PrintWriter(f);
 107             p.write("jdk.tls.keyLimits=");
 108             for (int i = 2; i < args.length; i++) {
 109                 p.write(" "+ args[i]);
 110             }
 111             p.close();
 112 
 113             System.setProperty("test.java.opts",
 114                     "-Dtest.src=" + System.getProperty("test.src") +
 115                             " -Dtest.jdk=" + System.getProperty("test.jdk") +
 116                             " -Djavax.net.debug=ssl,handshake" +
 117                             " -Djava.security.properties=" + f.getName());
 118 
 119             System.out.println("test.java.opts: " +
 120                     System.getProperty("test.java.opts"));
 121 
 122             ProcessBuilder pb = ProcessTools.createJavaProcessBuilder(true,
 123                     Utils.addTestJavaOpts("SSLEngineKeyLimit", "p", args[1]));
 124 
 125             OutputAnalyzer output = ProcessTools.executeProcess(pb);
 126             try {
 127                 if (expectedFail) {
 128                     output.shouldNotContain("KeyUpdate: write key updated");
 129                     output.shouldNotContain("KeyUpdate: read key updated");
 130                 } else {
 131                     output.shouldContain("trigger key update");
 132                     output.shouldContain("KeyUpdate: write key updated");
 133                     output.shouldContain("KeyUpdate: read key updated");
 134                 }
 135             } catch (Exception e) {
 136                 throw e;
 137             } finally {
 138                 System.out.println("-- BEGIN Stdout:");
 139                 System.out.println(output.getStdout());
 140                 System.out.println("-- END Stdout");
 141                 System.out.println("-- BEGIN Stderr:");
 142                 System.out.println(output.getStderr());
 143                 System.out.println("-- END Stderr");
 144             }
 145             return;
 146         }
 147 
 148         if (args[0].compareTo("p") != 0) {
 149             throw new Exception ("Tried to run outside of a spawned process");
 150         }
 151 
 152         if (args[1].compareTo("client") == 0) {
 153             serverwrite = false;
 154         }
 155 
 156         cTos = ByteBuffer.allocateDirect(dataLen*4);
 157         keyFilename =
 158             System.getProperty("test.src", "./") + "/" + pathToStores +
 159                 "/" + keyStoreFile;
 160 
 161         System.setProperty("javax.net.ssl.keyStore", keyFilename);
 162         System.setProperty("javax.net.ssl.keyStorePassword", passwd);
 163 
 164         sToc = ByteBuffer.allocateDirect(dataLen*4);
 165         outdata = ByteBuffer.allocateDirect(dataLen);
 166 
 167         byte[] data  = new byte[dataLen];
 168         Arrays.fill(data, (byte)0x0A);
 169         outdata.put(data);
 170         outdata.flip();
 171         cTos.clear();
 172         sToc.clear();
 173 
 174         Thread ts = new Thread(serverwrite ? new Client() : new Server());
 175         ts.start();
 176         (serverwrite ? new Server() : new Client()).run();
 177         ts.interrupt();
 178         ts.join();
 179     }
 180 
 181     private static void doTask(SSLEngineResult result,
 182             SSLEngine engine) throws Exception {
 183 
 184         if (result.getHandshakeStatus() ==
 185                 SSLEngineResult.HandshakeStatus.NEED_TASK) {
 186             Runnable runnable;
 187             while ((runnable = engine.getDelegatedTask()) != null) {
 188                 print("\trunning delegated task...");
 189                 runnable.run();
 190             }
 191             SSLEngineResult.HandshakeStatus hsStatus =
 192                     engine.getHandshakeStatus();
 193             if (hsStatus == SSLEngineResult.HandshakeStatus.NEED_TASK) {
 194                 throw new Exception(
 195                     "handshake shouldn't need additional tasks");
 196             }
 197             print("\tnew HandshakeStatus: " + hsStatus);
 198         }
 199     }
 200 
 201     static void print(String s) {
 202         if (debug) {
 203             System.out.println(s);
 204         }
 205     }
 206 
 207     static void log(String s, SSLEngineResult r) {
 208         if (!debug) {
 209             return;
 210         }
 211         System.out.println(s + ": " +
 212                 r.getStatus() + "/" + r.getHandshakeStatus()+ " " +
 213                 r.bytesConsumed() + "/" + r.bytesProduced() + " ");
 214 
 215     }
 216 
 217     void write() throws Exception {
 218         int i = 0;
 219         SSLEngineResult r;
 220         boolean again = true;
 221 
 222         while (!ready) {
 223             Thread.sleep(delay);
 224         }
 225         print("Write-side. ");
 226 
 227         while (i++ < 150) {
 228             while (sc) {
 229                 if (readdone) {
 230                     return;
 231                 }
 232                 Thread.sleep(delay);
 233             }
 234 
 235             outdata.rewind();
 236             print("write wrap");
 237 
 238             while (true) {
 239                 r = eng.wrap(outdata, getWriteBuf());
 240                 log("write wrap", r);
 241                 if (debug && r.getStatus() != SSLEngineResult.Status.OK) {
 242                     print("outdata pos: " + outdata.position() +
 243                             " rem: " + outdata.remaining() +
 244                             " lim: " + outdata.limit() +
 245                             " cap: " + outdata.capacity());
 246                     print("writebuf pos: " + getWriteBuf().position() +
 247                             " rem: " + getWriteBuf().remaining() +
 248                             " lim: " + getWriteBuf().limit() +
 249                             " cap: " + getWriteBuf().capacity());
 250                 }
 251                 if (again && r.getStatus() == SSLEngineResult.Status.OK &&
 252                         r.getHandshakeStatus() ==
 253                                 SSLEngineResult.HandshakeStatus.NEED_WRAP) {
 254                     print("again");
 255                     again = false;
 256                     continue;
 257                 }
 258                 break;
 259             }
 260             doTask(r, eng);
 261             getWriteBuf().flip();
 262             sc = true;
 263             while (sc) {
 264                 if (readdone) {
 265                     return;
 266                 }
 267                 Thread.sleep(delay);
 268             }
 269 
 270             while (true) {
 271                 buf.clear();
 272                 r = eng.unwrap(getReadBuf(), buf);
 273                 log("write unwrap", r);
 274                 if (debug && r.getStatus() != SSLEngineResult.Status.OK) {
 275                     print("buf pos: " + buf.position() +
 276                             " rem: " + buf.remaining() +
 277                             " lim: " + buf.limit() +
 278                             " cap: " + buf.capacity());
 279                     print("readbuf pos: " + getReadBuf().position() +
 280                             " rem: " + getReadBuf().remaining() +
 281                             " lim: " + getReadBuf().limit() +
 282                             " cap:"  + getReadBuf().capacity());
 283                 }
 284                 break;
 285             }
 286             doTask(r, eng);
 287             getReadBuf().compact();
 288             print("compacted readbuf pos: " + getReadBuf().position() +
 289                     " rem: " + getReadBuf().remaining() +
 290                     " lim: " + getReadBuf().limit() +
 291                     " cap: " + getReadBuf().capacity());
 292             sc = true;
 293         }
 294     }
 295 
 296     void read() throws Exception {
 297         byte b = 0x0B;
 298         ByteBuffer buf2 = ByteBuffer.allocateDirect(dataLen);
 299         SSLEngineResult r = null;
 300         boolean exit, again = true;
 301 
 302         while (eng == null) {
 303             Thread.sleep(delay);
 304         }
 305 
 306         try {
 307             System.out.println("connected");
 308             print("entering read loop");
 309             ready = true;
 310             while (true) {
 311 
 312                 while (!sc) {
 313                     Thread.sleep(delay);
 314                 }
 315 
 316                 print("read wrap");
 317                 exit = false;
 318                 while (!exit) {
 319                     buf2.put(b);
 320                     buf2.flip();
 321                     r = eng.wrap(buf2, getWriteBuf());
 322                     log("read wrap", r);
 323                     if (debug) {
 324                              // && r.getStatus() != SSLEngineResult.Status.OK) {
 325                         print("buf2 pos: " + buf2.position() +
 326                                 " rem: " + buf2.remaining() +
 327                                 " cap: " + buf2.capacity());
 328                         print("writebuf pos: " + getWriteBuf().position() +
 329                                 " rem: " + getWriteBuf().remaining() +
 330                                 " cap: " + getWriteBuf().capacity());
 331                     }
 332                     if (again && r.getStatus() == SSLEngineResult.Status.OK &&
 333                             r.getHandshakeStatus() ==
 334                                 SSLEngineResult.HandshakeStatus.NEED_WRAP) {
 335                         buf2.compact();
 336                         again = false;
 337                         continue;
 338                     }
 339                     exit = true;
 340                 }
 341                 doTask(r, eng);
 342                 buf2.clear();
 343                 getWriteBuf().flip();
 344 
 345                 sc = false;
 346 
 347                 while (!sc) {
 348                     Thread.sleep(delay);
 349                 }
 350 
 351                 while (true) {
 352                         buf.clear();
 353                         r = eng.unwrap(getReadBuf(), buf);
 354                         log("read unwrap", r);
 355                         if (debug &&
 356                                 r.getStatus() != SSLEngineResult.Status.OK) {
 357                             print("buf pos " + buf.position() +
 358                                     " rem: " + buf.remaining() +
 359                                     " lim: " + buf.limit() +
 360                                     " cap: " + buf.capacity());
 361                             print("readbuf pos: " + getReadBuf().position() +
 362                                     " rem: " + getReadBuf().remaining() +
 363                                     " lim: " + getReadBuf().limit() +
 364                                     " cap: " + getReadBuf().capacity());
 365                             doTask(r, eng);
 366                         }
 367 
 368                     if (again && r.getStatus() == SSLEngineResult.Status.OK &&
 369                             r.getHandshakeStatus() ==
 370                                 SSLEngineResult.HandshakeStatus.NEED_UNWRAP) {
 371                         buf.clear();
 372                         print("again");
 373                         again = false;
 374                         continue;
 375 
 376                     }
 377                     break;
 378                 }
 379                 buf.clear();
 380                 getReadBuf().compact();
 381 
 382                 totalDataLen += r.bytesProduced();
 383                 sc = false;
 384             }
 385         } catch (Exception e) {
 386             sc = false;
 387             readdone = true;
 388             System.out.println(e.getMessage());
 389             e.printStackTrace();
 390             System.out.println("Total data read = " + totalDataLen);
 391         }
 392     }
 393 
 394     ByteBuffer getReadBuf() {
 395         return null;
 396     }
 397 
 398     ByteBuffer getWriteBuf() {
 399         return null;
 400     }
 401 
 402 
 403     SSLContext initContext() throws Exception {
 404         SSLContext sc = SSLContext.getInstance("TLSv1.3");
 405         KeyStore ks = KeyStore.getInstance(
 406                 new File(System.getProperty("javax.net.ssl.keyStore")),
 407                 passwd.toCharArray());
 408         KeyManagerFactory kmf = KeyManagerFactory.getInstance(
 409                 KeyManagerFactory.getDefaultAlgorithm());
 410         kmf.init(ks, passwd.toCharArray());
 411         TrustManagerFactory tmf = TrustManagerFactory.getInstance(
 412                 TrustManagerFactory.getDefaultAlgorithm());
 413         tmf.init(ks);
 414         sc.init(kmf.getKeyManagers(),
 415                 tmf.getTrustManagers(), new SecureRandom());
 416         return sc;
 417     }
 418 
 419     static class Server extends SSLEngineKeyLimit implements Runnable {
 420         Server() throws Exception {
 421             super();
 422             eng = initContext().createSSLEngine();
 423             eng.setUseClientMode(false);
 424             eng.setNeedClientAuth(true);
 425         }
 426 
 427         public void run() {
 428             try {
 429                 if (serverwrite) {
 430                     write();
 431                 } else {
 432                     read();
 433                 }
 434 
 435             } catch (Exception e) {
 436                 System.out.println("server: " + e.getMessage());
 437                 e.printStackTrace();
 438             }
 439             System.out.println("Server closed");
 440         }
 441 
 442         @Override
 443         ByteBuffer getWriteBuf() {
 444             return sToc;
 445         }
 446         @Override
 447         ByteBuffer getReadBuf() {
 448             return cTos;
 449         }
 450     }
 451 
 452 
 453     static class Client extends SSLEngineKeyLimit implements Runnable {
 454         Client() throws Exception {
 455             super();
 456             eng = initContext().createSSLEngine();
 457             eng.setUseClientMode(true);
 458         }
 459 
 460         public void run() {
 461             try {
 462                 if (!serverwrite) {
 463                     write();
 464                 } else {
 465                     read();
 466                 }
 467             } catch (Exception e) {
 468                 System.out.println("client: " + e.getMessage());
 469                 e.printStackTrace();
 470             }
 471             System.out.println("Client closed");
 472         }
 473         @Override
 474         ByteBuffer getWriteBuf() {
 475             return cTos;
 476         }
 477         @Override
 478         ByteBuffer getReadBuf() {
 479             return sToc;
 480         }
 481     }
 482 }