1 /*
   2  * Copyright (c) 2017, Oracle and/or its affiliates. All rights reserved.
   3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
   4  *
   5  * This code is free software; you can redistribute it and/or modify it
   6  * under the terms of the GNU General Public License version 2 only, as
   7  * published by the Free Software Foundation.
   8  *
   9  * This code is distributed in the hope that it will be useful, but WITHOUT
  10  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
  11  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
  12  * version 2 for more details (a copy is included in the LICENSE file that
  13  * accompanied this code).
  14  *
  15  * You should have received a copy of the GNU General Public License version
  16  * 2 along with this work; if not, write to the Free Software Foundation,
  17  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
  18  *
  19  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
  20  * or visit www.oracle.com if you need additional information or have any
  21  * questions.
  22  */
  23 
  24 /*
  25  * @test
  26  * @bug 8006259
  27  * @summary Test several modes of operation using vectors from SP 800-38A
  28  * @modules java.xml.bind
  29  * @run main CheckExampleVectors
  30  */
  31 
  32 import java.io.*;
  33 import java.security.*;
  34 import java.util.*;
  35 import java.util.function.*;
  36 import javax.xml.bind.DatatypeConverter;
  37 import javax.crypto.*;
  38 import javax.crypto.spec.*;
  39 
  40 public class CheckExampleVectors {
  41 
  42     private enum Mode {
  43         ECB,
  44         CBC,
  45         CFB1,
  46         CFB8,
  47         CFB128,
  48         OFB,
  49         CTR
  50     }
  51 
  52     private enum Operation {
  53         Encrypt,
  54         Decrypt
  55     }
  56 
  57     private static class Block {
  58         private byte[] input;
  59         private byte[] output;
  60 
  61         public Block() {
  62 
  63         }
  64         public Block(String settings) {
  65             String[] settingsParts = settings.split(",");
  66             input = stringToBytes(settingsParts[0]);
  67             output = stringToBytes(settingsParts[1]);
  68         }
  69         public byte[] getInput() {
  70             return input;
  71         }
  72         public byte[] getOutput() {
  73             return output;
  74         }
  75     }
  76 
  77     private static class TestVector {
  78         private Mode mode;
  79         private Operation operation;
  80         private byte[] key;
  81         private byte[] iv;
  82         private List<Block> blocks = new ArrayList<Block>();
  83 
  84         public TestVector(String settings) {
  85             String[] settingsParts = settings.split(",");
  86             mode = Mode.valueOf(settingsParts[0]);
  87             operation = Operation.valueOf(settingsParts[1]);
  88             key = stringToBytes(settingsParts[2]);
  89             if (settingsParts.length > 3) {
  90                 iv = stringToBytes(settingsParts[3]);
  91             }
  92         }
  93 
  94         public Mode getMode() {
  95             return mode;
  96         }
  97         public Operation getOperation() {
  98             return operation;
  99         }
 100         public byte[] getKey() {
 101             return key;
 102         }
 103         public byte[] getIv() {
 104             return iv;
 105         }
 106         public void addBlock (Block b) {
 107             blocks.add(b);
 108         }
 109         public Iterable<Block> getBlocks() {
 110             return blocks;
 111         }
 112     }
 113 
 114     private static final String VECTOR_FILE_NAME = "NIST_800_38A_vectors.txt";
 115     private static final Mode[] REQUIRED_MODES = {Mode.ECB, Mode.CBC, Mode.CTR};
 116     private static Set<Mode> supportedModes = new HashSet<Mode>();
 117 
 118     public static void main(String[] args) throws Exception {
 119         checkAllProviders();
 120         checkSupportedModes();
 121     }
 122 
 123     private static byte[] stringToBytes(String v) {
 124         if (v.equals("")) {
 125             return null;
 126         }
 127         return DatatypeConverter.parseBase64Binary(v);
 128     }
 129 
 130     private static String toModeString(Mode mode) {
 131         return mode.toString();
 132     }
 133 
 134     private static int toCipherOperation(Operation op) {
 135         switch (op) {
 136             case Encrypt:
 137                 return Cipher.ENCRYPT_MODE;
 138             case Decrypt:
 139                 return Cipher.DECRYPT_MODE;
 140         }
 141 
 142         throw new RuntimeException("Unknown operation: " + op);
 143     }
 144 
 145     private static void log(String str) {
 146         System.out.println(str);
 147     }
 148 
 149     private static void checkVector(String providerName, TestVector test) {
 150 
 151         String modeString = toModeString(test.getMode());
 152         String cipherString = "AES" + "/" + modeString + "/" + "NoPadding";
 153         log("checking: " + cipherString + " on " + providerName);
 154         try {
 155             Cipher cipher = Cipher.getInstance(cipherString, providerName);
 156             SecretKeySpec key = new SecretKeySpec(test.getKey(), "AES");
 157             if (test.getIv() != null) {
 158                 IvParameterSpec iv = new IvParameterSpec(test.getIv());
 159                 cipher.init(toCipherOperation(test.getOperation()), key, iv);
 160             }
 161             else {
 162                 cipher.init(toCipherOperation(test.getOperation()), key);
 163             }
 164             int blockIndex = 0;
 165             for (Block curBlock : test.getBlocks()) {
 166                 byte[] blockOutput = cipher.update(curBlock.getInput());
 167                 byte[] expectedBlockOutput = curBlock.getOutput();
 168                 if (!Arrays.equals(blockOutput, expectedBlockOutput)) {
 169                     throw new RuntimeException("Blocks do not match at index "
 170                         + blockIndex);
 171                 }
 172                 blockIndex++;
 173             }
 174             log("success");
 175             supportedModes.add(test.getMode());
 176         } catch (NoSuchAlgorithmException ex) {
 177             log("algorithm not supported");
 178         } catch (NoSuchProviderException | NoSuchPaddingException
 179             | InvalidKeyException | InvalidAlgorithmParameterException ex) {
 180             throw new RuntimeException(ex);
 181         }
 182     }
 183 
 184     private static boolean isComment(String line) {
 185         return (line != null) && line.startsWith("//");
 186     }
 187 
 188     private static TestVector readVector(BufferedReader in) throws IOException {
 189         String line;
 190         while (isComment(line = in.readLine())) {
 191             // skip comment lines
 192         }
 193         if (line == null || line.isEmpty()) {
 194             return null;
 195         }
 196 
 197         TestVector newVector = new TestVector(line);
 198         String numBlocksStr = in.readLine();
 199         int numBlocks = Integer.parseInt(numBlocksStr);
 200         for (int i = 0; i < numBlocks; i++) {
 201             Block newBlock = new Block(in.readLine());
 202             newVector.addBlock(newBlock);
 203         }
 204 
 205         return newVector;
 206     }
 207 
 208     private static void checkAllProviders() throws IOException {
 209         File dataFile = new File(System.getProperty("test.src", "."),
 210                                  VECTOR_FILE_NAME);
 211         BufferedReader in = new BufferedReader(new FileReader(dataFile));
 212         List<TestVector> allTests = new ArrayList<>();
 213         TestVector newTest;
 214         while ((newTest = readVector(in)) != null) {
 215             allTests.add(newTest);
 216         }
 217 
 218         for (Provider provider : Security.getProviders()) {
 219             checkProvider(provider.getName(), allTests);
 220         }
 221     }
 222 
 223     private static void checkProvider(String providerName,
 224                                       List<TestVector> allVectors)
 225         throws IOException {
 226 
 227         for (TestVector curVector : allVectors) {
 228             checkVector(providerName, curVector);
 229         }
 230     }
 231 
 232     /*
 233      *  This method helps ensure that the test is working properly by
 234      *  verifying that the test was able to check the test vectors for
 235      *  some of the modes of operation.
 236      */
 237     private static void checkSupportedModes() {
 238         for (Mode curMode : REQUIRED_MODES) {
 239             if (!supportedModes.contains(curMode)) {
 240                 throw new RuntimeException(
 241                     "Mode not supported by any provider: " + curMode);
 242             }
 243         }
 244 
 245     }
 246 
 247 }
 248