1 /*
   2  * Copyright (c) 2017, 2018, Oracle and/or its affiliates. All rights reserved.
   3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
   4  *
   5  * This code is free software; you can redistribute it and/or modify it
   6  * under the terms of the GNU General Public License version 2 only, as
   7  * published by the Free Software Foundation.
   8  *
   9  * This code is distributed in the hope that it will be useful, but WITHOUT
  10  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
  11  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
  12  * version 2 for more details (a copy is included in the LICENSE file that
  13  * accompanied this code).
  14  *
  15  * You should have received a copy of the GNU General Public License version
  16  * 2 along with this work; if not, write to the Free Software Foundation,
  17  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
  18  *
  19  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
  20  * or visit www.oracle.com if you need additional information or have any
  21  * questions.
  22  */
  23 
  24 /*
  25  * @test
  26  * @bug 8006259
  27  * @summary Test several modes of operation using vectors from SP 800-38A
  28  * @run main CheckExampleVectors
  29  */
  30 
  31 import java.io.*;
  32 import java.security.*;
  33 import java.util.*;
  34 import java.util.function.*;
  35 import javax.crypto.*;
  36 import javax.crypto.spec.*;
  37 
  38 public class CheckExampleVectors {
  39 
  40     private enum Mode {
  41         ECB,
  42         CBC,
  43         CFB1,
  44         CFB8,
  45         CFB128,
  46         OFB,
  47         CTR
  48     }
  49 
  50     private enum Operation {
  51         Encrypt,
  52         Decrypt
  53     }
  54 
  55     private static class Block {
  56         private byte[] input;
  57         private byte[] output;
  58 
  59         public Block() {
  60 
  61         }
  62         public Block(String settings) {
  63             String[] settingsParts = settings.split(",");
  64             input = stringToBytes(settingsParts[0]);
  65             output = stringToBytes(settingsParts[1]);
  66         }
  67         public byte[] getInput() {
  68             return input;
  69         }
  70         public byte[] getOutput() {
  71             return output;
  72         }
  73     }
  74 
  75     private static class TestVector {
  76         private Mode mode;
  77         private Operation operation;
  78         private byte[] key;
  79         private byte[] iv;
  80         private List<Block> blocks = new ArrayList<Block>();
  81 
  82         public TestVector(String settings) {
  83             String[] settingsParts = settings.split(",");
  84             mode = Mode.valueOf(settingsParts[0]);
  85             operation = Operation.valueOf(settingsParts[1]);
  86             key = stringToBytes(settingsParts[2]);
  87             if (settingsParts.length > 3) {
  88                 iv = stringToBytes(settingsParts[3]);
  89             }
  90         }
  91 
  92         public Mode getMode() {
  93             return mode;
  94         }
  95         public Operation getOperation() {
  96             return operation;
  97         }
  98         public byte[] getKey() {
  99             return key;
 100         }
 101         public byte[] getIv() {
 102             return iv;
 103         }
 104         public void addBlock (Block b) {
 105             blocks.add(b);
 106         }
 107         public Iterable<Block> getBlocks() {
 108             return blocks;
 109         }
 110     }
 111 
 112     private static final String VECTOR_FILE_NAME = "NIST_800_38A_vectors.txt";
 113     private static final Mode[] REQUIRED_MODES = {Mode.ECB, Mode.CBC, Mode.CTR};
 114     private static Set<Mode> supportedModes = new HashSet<Mode>();
 115 
 116     public static void main(String[] args) throws Exception {
 117         checkAllProviders();
 118         checkSupportedModes();
 119     }
 120 
 121     private static byte[] stringToBytes(String v) {
 122         if (v.equals("")) {
 123             return null;
 124         }
 125         return Base64.getDecoder().decode(v);
 126     }
 127 
 128     private static String toModeString(Mode mode) {
 129         return mode.toString();
 130     }
 131 
 132     private static int toCipherOperation(Operation op) {
 133         switch (op) {
 134             case Encrypt:
 135                 return Cipher.ENCRYPT_MODE;
 136             case Decrypt:
 137                 return Cipher.DECRYPT_MODE;
 138         }
 139 
 140         throw new RuntimeException("Unknown operation: " + op);
 141     }
 142 
 143     private static void log(String str) {
 144         System.out.println(str);
 145     }
 146 
 147     private static void checkVector(String providerName, TestVector test) {
 148 
 149         String modeString = toModeString(test.getMode());
 150         String cipherString = "AES" + "/" + modeString + "/" + "NoPadding";
 151         log("checking: " + cipherString + " on " + providerName);
 152         try {
 153             Cipher cipher = Cipher.getInstance(cipherString, providerName);
 154             SecretKeySpec key = new SecretKeySpec(test.getKey(), "AES");
 155             if (test.getIv() != null) {
 156                 IvParameterSpec iv = new IvParameterSpec(test.getIv());
 157                 cipher.init(toCipherOperation(test.getOperation()), key, iv);
 158             }
 159             else {
 160                 cipher.init(toCipherOperation(test.getOperation()), key);
 161             }
 162             int blockIndex = 0;
 163             for (Block curBlock : test.getBlocks()) {
 164                 byte[] blockOutput = cipher.update(curBlock.getInput());
 165                 byte[] expectedBlockOutput = curBlock.getOutput();
 166                 if (!Arrays.equals(blockOutput, expectedBlockOutput)) {
 167                     throw new RuntimeException("Blocks do not match at index "
 168                         + blockIndex);
 169                 }
 170                 blockIndex++;
 171             }
 172             log("success");
 173             supportedModes.add(test.getMode());
 174         } catch (NoSuchAlgorithmException ex) {
 175             log("algorithm not supported");
 176         } catch (NoSuchProviderException | NoSuchPaddingException
 177             | InvalidKeyException | InvalidAlgorithmParameterException ex) {
 178             throw new RuntimeException(ex);
 179         }
 180     }
 181 
 182     private static boolean isComment(String line) {
 183         return (line != null) && line.startsWith("//");
 184     }
 185 
 186     private static TestVector readVector(BufferedReader in) throws IOException {
 187         String line;
 188         while (isComment(line = in.readLine())) {
 189             // skip comment lines
 190         }
 191         if (line == null || line.isEmpty()) {
 192             return null;
 193         }
 194 
 195         TestVector newVector = new TestVector(line);
 196         String numBlocksStr = in.readLine();
 197         int numBlocks = Integer.parseInt(numBlocksStr);
 198         for (int i = 0; i < numBlocks; i++) {
 199             Block newBlock = new Block(in.readLine());
 200             newVector.addBlock(newBlock);
 201         }
 202 
 203         return newVector;
 204     }
 205 
 206     private static void checkAllProviders() throws IOException {
 207         File dataFile = new File(System.getProperty("test.src", "."),
 208                                  VECTOR_FILE_NAME);
 209         BufferedReader in = new BufferedReader(new FileReader(dataFile));
 210         List<TestVector> allTests = new ArrayList<>();
 211         TestVector newTest;
 212         while ((newTest = readVector(in)) != null) {
 213             allTests.add(newTest);
 214         }
 215 
 216         for (Provider provider : Security.getProviders()) {
 217             checkProvider(provider.getName(), allTests);
 218         }
 219     }
 220 
 221     private static void checkProvider(String providerName,
 222                                       List<TestVector> allVectors)
 223         throws IOException {
 224 
 225         for (TestVector curVector : allVectors) {
 226             checkVector(providerName, curVector);
 227         }
 228     }
 229 
 230     /*
 231      *  This method helps ensure that the test is working properly by
 232      *  verifying that the test was able to check the test vectors for
 233      *  some of the modes of operation.
 234      */
 235     private static void checkSupportedModes() {
 236         for (Mode curMode : REQUIRED_MODES) {
 237             if (!supportedModes.contains(curMode)) {
 238                 throw new RuntimeException(
 239                     "Mode not supported by any provider: " + curMode);
 240             }
 241         }
 242 
 243     }
 244 
 245 }
 246