1 /*
   2  * Copyright (c) 2019, Oracle and/or its affiliates. All rights reserved.
   3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
   4  *
   5  * This code is free software; you can redistribute it and/or modify it
   6  * under the terms of the GNU General Public License version 2 only, as
   7  * published by the Free Software Foundation.
   8  *
   9  * This code is distributed in the hope that it will be useful, but WITHOUT
  10  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
  11  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
  12  * version 2 for more details (a copy is included in the LICENSE file that
  13  * accompanied this code).
  14  *
  15  * You should have received a copy of the GNU General Public License version
  16  * 2 along with this work; if not, write to the Free Software Foundation,
  17  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
  18  *
  19  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
  20  * or visit www.oracle.com if you need additional information or have
  21  * questions.
  22  */
  23   
  24 package benchmark.crypto;
  25 
  26 import org.openjdk.jmh.annotations.*;
  27 import jdk.incubator.vector.*;
  28 import java.util.Arrays;
  29 
  30 @State(Scope.Thread)
  31 @BenchmarkMode(Mode.Throughput)
  32 @Fork(value = 1, jvmArgsPrepend = {"--add-modules=jdk.incubator.vector"})
  33 @Warmup(iterations = 3, time = 3)
  34 @Measurement(iterations = 8, time = 2)
  35 public class ChaChaBench {
  36 
  37     @Param({"16384", "65536"})
  38     private int dataSize;
  39     
  40     private ChaChaVector cc20_S128 = makeCC20(Vector.Shape.S_128_BIT);
  41     private ChaChaVector cc20_S256 = makeCC20(Vector.Shape.S_256_BIT);
  42     private ChaChaVector cc20_S512 = makeCC20(Vector.Shape.S_512_BIT);
  43  
  44     private byte[] in;
  45     private byte[] out;
  46     
  47     private byte[] key = new byte[32];
  48     private byte[] nonce = new byte[12];
  49     private long counter = 0;
  50 
  51     private static ChaChaVector makeCC20(Vector.Shape shape) {
  52         ChaChaVector cc20 = new ChaChaVector(shape);
  53         runKAT(cc20);
  54         return cc20;
  55     }
  56 
  57     @Setup
  58     public void setup() {
  59         
  60         in = new byte[dataSize];
  61         out = new byte[dataSize];
  62     }
  63 
  64     @Benchmark
  65     public void encrypt128() {
  66         cc20_S128.chacha20(key, nonce, counter, in, out);
  67     }
  68 
  69     @Benchmark
  70     public void encrypt256() {
  71         cc20_S256.chacha20(key, nonce, counter, in, out);
  72     }
  73 
  74     @Benchmark
  75     public void encrypt512() {
  76         cc20_S512.chacha20(key, nonce, counter, in, out);
  77     }
  78 
  79     private static class ChaChaVector {
  80 
  81         private static final int[] STATE_CONSTANTS =
  82             new int[]{0x61707865, 0x3320646e, 0x79622d32, 0x6b206574};
  83 
  84         private final Vector.Species<Integer> intSpecies;
  85         private final int numBlocks;
  86 
  87         private final Vector.Shuffle<Integer> rot1;
  88         private final Vector.Shuffle<Integer> rot2;
  89         private final Vector.Shuffle<Integer> rot3;
  90 
  91         private final IntVector counterAdd;
  92 
  93         private final Vector.Shuffle<Integer> shuf0;
  94         private final Vector.Shuffle<Integer> shuf1;
  95         private final Vector.Shuffle<Integer> shuf2;
  96         private final Vector.Shuffle<Integer> shuf3;
  97 
  98         private final Vector.Mask<Integer> mask0;
  99         private final Vector.Mask<Integer> mask1;
 100         private final Vector.Mask<Integer> mask2;
 101         private final Vector.Mask<Integer> mask3;
 102 
 103         private final int[] state;
 104 
 105         public ChaChaVector(Vector.Shape shape) {
 106             this.intSpecies = Vector.Species.of(Integer.class, shape);
 107             this.numBlocks = intSpecies.length() / 4;
 108 
 109             this.rot1 = makeRotate(1);
 110             this.rot2 = makeRotate(2);
 111             this.rot3 = makeRotate(3);
 112 
 113             this.counterAdd = makeCounterAdd();
 114 
 115             this.shuf0 = makeRearrangeShuffle(0);
 116             this.shuf1 = makeRearrangeShuffle(1);
 117             this.shuf2 = makeRearrangeShuffle(2);
 118             this.shuf3 = makeRearrangeShuffle(3);
 119 
 120             this.mask0 = makeRearrangeMask(0);
 121             this.mask1 = makeRearrangeMask(1);
 122             this.mask2 = makeRearrangeMask(2);
 123             this.mask3 = makeRearrangeMask(3);
 124 
 125             this.state = new int[numBlocks * 16];
 126         }
 127 
 128         private Vector.Shuffle<Integer>  makeRotate(int amount) {
 129             int[] shuffleArr = new int[intSpecies.length()];
 130 
 131             for (int i = 0; i < intSpecies.length(); i ++) {
 132                 int offset = (i / 4) * 4;
 133                 shuffleArr[i] = offset + ((i + amount) % 4);
 134             }
 135 
 136             return IntVector.shuffleFromValues(intSpecies, shuffleArr);
 137         }
 138 
 139         private IntVector makeCounterAdd() {
 140             int[] addArr = new int[intSpecies.length()];
 141             for(int i = 0; i < numBlocks; i++) {
 142                 addArr[4 * i] = numBlocks;
 143             }
 144             return IntVector.fromArray(intSpecies, addArr, 0);
 145         }
 146 
 147         private Vector.Shuffle<Integer>  makeRearrangeShuffle(int order) {
 148             int[] shuffleArr = new int[intSpecies.length()];
 149             int start = order * 4;
 150             for (int i = 0; i < shuffleArr.length; i++) {
 151                 shuffleArr[i] = (i % 4) + start;
 152             }
 153             return IntVector.shuffleFromArray(intSpecies, shuffleArr, 0);
 154         }
 155 
 156         private Vector.Mask<Integer> makeRearrangeMask(int order) {
 157             boolean[] maskArr = new boolean[intSpecies.length()];
 158             int start = order * 4;
 159             if (start < maskArr.length) {
 160                 for (int i = 0; i < 4; i++) {
 161                     maskArr[i + start] = true;
 162                 }
 163             }
 164 
 165             return IntVector.maskFromValues(intSpecies, maskArr);
 166         }
 167 
 168         public void makeState(byte[] key, byte[] nonce, long counter,
 169             int[] out) {
 170 
 171             // first field is constants
 172             for (int i = 0; i < 4; i++) {
 173                 for (int j = 0; j < numBlocks; j++) {
 174                     out[4*j + i] = STATE_CONSTANTS[i];
 175                 }
 176             }
 177 
 178             // second field is first part of key
 179             int fieldStart = 4 * numBlocks;
 180             for (int i = 0; i < 4; i++) {
 181                 int keyInt = 0;
 182                 for (int j = 0; j < 4; j++) {
 183                     keyInt += (0xFF & key[4 * i + j]) << 8 * j;
 184                 }
 185                 for (int j = 0; j < numBlocks; j++) {
 186                     out[fieldStart + j*4 + i] = keyInt;
 187                 }
 188             }
 189 
 190             // third field is second part of key
 191             fieldStart = 8 * numBlocks;
 192             for (int i = 0; i < 4; i++) {
 193                 int keyInt = 0;
 194                 for (int j = 0; j < 4; j++) {
 195                     keyInt += (0xFF & key[4 * (i + 4) + j]) << 8 * j;
 196                 }
 197 
 198                 for (int j = 0; j < numBlocks; j++) {
 199                     out[fieldStart + j*4 + i] = keyInt;
 200                 }
 201             }
 202 
 203             // fourth field is counter and nonce
 204             fieldStart = 12 * numBlocks;
 205             for (int j = 0; j < numBlocks; j++) {
 206                 out[fieldStart + j*4] = (int) (counter + j);
 207             }
 208 
 209             for (int i = 0; i < 3; i++) {
 210                 int nonceInt = 0;
 211                 for (int j = 0; j < 4; j++) {
 212                     nonceInt += (0xFF & nonce[4 * i + j]) << 8 * j;
 213                 }
 214 
 215                 for (int j = 0; j < numBlocks; j++) {
 216                     out[fieldStart + j*4 + 1 + i] = nonceInt;
 217                 }
 218             }
 219         }
 220 
 221         public void chacha20(byte[] key, byte[] nonce, long counter,
 222             byte[] in, byte[] out) {
 223 
 224             makeState(key, nonce, counter, state);
 225 
 226             int len = intSpecies.length();
 227 
 228             IntVector sa = IntVector.fromArray(intSpecies, state, 0);
 229             IntVector sb = IntVector.fromArray(intSpecies, state, len);
 230             IntVector sc = IntVector.fromArray(intSpecies, state, 2 * len);
 231             IntVector sd = IntVector.fromArray(intSpecies, state, 3 * len);
 232 
 233             int stateLenBytes = state.length * 4;
 234             int numStates = (in.length + stateLenBytes - 1) / stateLenBytes;
 235             for (int j = 0; j < numStates; j++){
 236 
 237                 IntVector a = sa;
 238                 IntVector b = sb;
 239                 IntVector c = sc;
 240                 IntVector d = sd;
 241 
 242                 for (int i = 0; i < 10; i++) {
 243                     // first round
 244                     a = a.add(b);
 245                     d = d.xor(a);
 246                     d = d.rotateL(16);
 247 
 248                     c = c.add(d);
 249                     b = b.xor(c);
 250                     b = b.rotateL(12);
 251 
 252                     a = a.add(b);
 253                     d = d.xor(a);
 254                     d = d.rotateL(8);
 255 
 256                     c = c.add(d);
 257                     b = b.xor(c);
 258                     b = b.rotateL(7);
 259 
 260                     // makeRotate
 261                     b = b.rearrange(rot1);
 262                     c = c.rearrange(rot2);
 263                     d = d.rearrange(rot3);
 264 
 265                     // second round
 266                     a = a.add(b);
 267                     d = d.xor(a);
 268                     d = d.rotateL(16);
 269 
 270                     c = c.add(d);
 271                     b = b.xor(c);
 272                     b = b.rotateL(12);
 273 
 274                     a = a.add(b);
 275                     d = d.xor(a);
 276                     d = d.rotateL(8);
 277 
 278                     c = c.add(d);
 279                     b = b.xor(c);
 280                     b = b.rotateL(7);
 281 
 282                     // makeRotate
 283                     b = b.rearrange(rot3);
 284                     c = c.rearrange(rot2);
 285                     d = d.rearrange(rot1);
 286                 }
 287 
 288                 a = a.add(sa);
 289                 b = b.add(sb);
 290                 c = c.add(sc);
 291                 d = d.add(sd);
 292 
 293                 // rearrange the vectors
 294                 if (intSpecies.length() == 4) {
 295                     // no rearrange needed
 296                 } else if (intSpecies.length() == 8) {
 297                     IntVector a_r = a.rearrange(b, shuf0, mask1);
 298                     IntVector b_r = c.rearrange(d, shuf0, mask1);
 299                     IntVector c_r = a.rearrange(b, shuf1, mask1);
 300                     IntVector d_r = c.rearrange(d, shuf1, mask1);
 301 
 302                     a = a_r;
 303                     b = b_r;
 304                     c = c_r;
 305                     d = d_r;
 306                 } else if (intSpecies.length() == 16) {
 307                     IntVector a_r = a;
 308                     a_r = a_r.blend(b.rearrange(shuf0), mask1);
 309                     a_r = a_r.blend(c.rearrange(shuf0), mask2);
 310                     a_r = a_r.blend(d.rearrange(shuf0), mask3);
 311 
 312                     IntVector b_r = b;
 313                     b_r = b_r.blend(a.rearrange(shuf1), mask0);
 314                     b_r = b_r.blend(c.rearrange(shuf1), mask2);
 315                     b_r = b_r.blend(d.rearrange(shuf1), mask3);
 316 
 317                     IntVector c_r = c;
 318                     c_r = c_r.blend(a.rearrange(shuf2), mask0);
 319                     c_r = c_r.blend(b.rearrange(shuf2), mask1);
 320                     c_r = c_r.blend(d.rearrange(shuf2), mask3);
 321 
 322                     IntVector d_r = d;
 323                     d_r = d_r.blend(a.rearrange(shuf3), mask0);
 324                     d_r = d_r.blend(b.rearrange(shuf3), mask1);
 325                     d_r = d_r.blend(c.rearrange(shuf3), mask2);
 326 
 327                     a = a_r;
 328                     b = b_r;
 329                     c = c_r;
 330                     d = d_r;
 331                 } else {
 332                     throw new RuntimeException("not supported");
 333                 }
 334 
 335                 // xor keystream with input
 336                 int inOff = stateLenBytes * j;
 337                 IntVector ina = IntVector.fromByteArray(intSpecies, in, inOff);
 338                 IntVector inb = IntVector.fromByteArray(intSpecies, in, inOff + 4 * len);
 339                 IntVector inc = IntVector.fromByteArray(intSpecies, in, inOff + 8 * len);
 340                 IntVector ind = IntVector.fromByteArray(intSpecies, in, inOff + 12 * len);
 341 
 342                 ina.xor(a).intoByteArray(out, inOff);
 343                 inb.xor(b).intoByteArray(out, inOff + 4 * len);
 344                 inc.xor(c).intoByteArray(out, inOff + 8 * len);
 345                 ind.xor(d).intoByteArray(out, inOff + 12 * len);
 346 
 347                 // increment counter
 348                 sd = sd.add(counterAdd);
 349             }
 350         }
 351 
 352         public int implBlockSize() {
 353             return numBlocks * 64;
 354         }
 355     }
 356 
 357     private static byte[] hexStringToByteArray(String str) {
 358         byte[] result = new byte[str.length() / 2];
 359         for (int i = 0; i < result.length; i++) {
 360             result[i] = (byte) Character.digit(str.charAt(2 * i), 16);
 361             result[i] <<= 4;
 362             result[i] += Character.digit(str.charAt(2 * i + 1), 16);
 363         }
 364         return result;
 365     }
 366 
 367     private static void runKAT(ChaChaVector cc20, String keyStr,
 368         String nonceStr, long counter, String inStr, String outStr) {
 369 
 370         byte[] key = hexStringToByteArray(keyStr);
 371         byte[] nonce = hexStringToByteArray(nonceStr);
 372         byte[] in = hexStringToByteArray(inStr);
 373         byte[] expOut = hexStringToByteArray(outStr);
 374 
 375         // implementation only works at multiples of some size
 376         int blockSize = cc20.implBlockSize();
 377 
 378         int length = blockSize * ((in.length + blockSize - 1) / blockSize);
 379         in = Arrays.copyOf(in, length);
 380         byte[] out = new byte[length];
 381 
 382         cc20.chacha20(key, nonce, counter, in, out);
 383 
 384         byte[] actOut = new byte[expOut.length];
 385         System.arraycopy(out, 0, actOut, 0, expOut.length);
 386 
 387         if (!Arrays.equals(out, 0, expOut.length, expOut, 0, expOut.length)) {
 388             throw new RuntimeException("Incorrect result");
 389         }
 390     }
 391 
 392     /*
 393      * ChaCha20 Known Answer Tests to ensure that the implementation is correct.
 394      */
 395     private static void runKAT(ChaChaVector cc20) {
 396         runKAT(cc20,
 397         "0000000000000000000000000000000000000000000000000000000000000001",
 398         "000000000000000000000002",
 399         1,
 400         "416e79207375626d697373696f6e20746f20746865204945544620696e74656e" +
 401         "6465642062792074686520436f6e7472696275746f7220666f72207075626c69" +
 402         "636174696f6e20617320616c6c206f722070617274206f6620616e2049455446" +
 403         "20496e7465726e65742d4472616674206f722052464320616e6420616e792073" +
 404         "746174656d656e74206d6164652077697468696e2074686520636f6e74657874" +
 405         "206f6620616e204945544620616374697669747920697320636f6e7369646572" +
 406         "656420616e20224945544620436f6e747269627574696f6e222e205375636820" +
 407         "73746174656d656e747320696e636c756465206f72616c2073746174656d656e" +
 408         "747320696e20494554462073657373696f6e732c2061732077656c6c20617320" +
 409         "7772697474656e20616e6420656c656374726f6e696320636f6d6d756e696361" +
 410         "74696f6e73206d61646520617420616e792074696d65206f7220706c6163652c" +
 411         "207768696368206172652061646472657373656420746f",
 412         "a3fbf07df3fa2fde4f376ca23e82737041605d9f4f4f57bd8cff2c1d4b7955ec" +
 413         "2a97948bd3722915c8f3d337f7d370050e9e96d647b7c39f56e031ca5eb6250d" +
 414         "4042e02785ececfa4b4bb5e8ead0440e20b6e8db09d881a7c6132f420e527950" +
 415         "42bdfa7773d8a9051447b3291ce1411c680465552aa6c405b7764d5e87bea85a" +
 416         "d00f8449ed8f72d0d662ab052691ca66424bc86d2df80ea41f43abf937d3259d" +
 417         "c4b2d0dfb48a6c9139ddd7f76966e928e635553ba76c5c879d7b35d49eb2e62b" +
 418         "0871cdac638939e25e8a1e0ef9d5280fa8ca328b351c3c765989cbcf3daa8b6c" +
 419         "cc3aaf9f3979c92b3720fc88dc95ed84a1be059c6499b9fda236e7e818b04b0b" +
 420         "c39c1e876b193bfe5569753f88128cc08aaa9b63d1a16f80ef2554d7189c411f" +
 421         "5869ca52c5b83fa36ff216b9c1d30062bebcfd2dc5bce0911934fda79a86f6e6" +
 422         "98ced759c3ff9b6477338f3da4f9cd8514ea9982ccafb341b2384dd902f3d1ab" +
 423         "7ac61dd29c6f21ba5b862f3730e37cfdc4fd806c22f221"
 424         );
 425     }
 426 }