1 /* 2 * Copyright (c) 2019, Oracle and/or its affiliates. All rights reserved. 3 * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. 4 * 5 * This code is free software; you can redistribute it and/or modify it 6 * under the terms of the GNU General Public License version 2 only, as 7 * published by the Free Software Foundation. 8 * 9 * This code is distributed in the hope that it will be useful, but WITHOUT 10 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or 11 * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License 12 * version 2 for more details (a copy is included in the LICENSE file that 13 * accompanied this code). 14 * 15 * You should have received a copy of the GNU General Public License version 16 * 2 along with this work; if not, write to the Free Software Foundation, 17 * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. 18 * 19 * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA 20 * or visit www.oracle.com if you need additional information or have 21 * questions. 22 */ 23 24 package benchmark.crypto; 25 26 import org.openjdk.jmh.annotations.*; 27 import jdk.incubator.vector.*; 28 import java.util.Arrays; 29 30 @State(Scope.Thread) 31 @BenchmarkMode(Mode.Throughput) 32 @Fork(value = 1, jvmArgsPrepend = {"--add-modules=jdk.incubator.vector"}) 33 @Warmup(iterations = 3, time = 3) 34 @Measurement(iterations = 8, time = 2) 35 public class ChaChaBench { 36 37 @Param({"16384", "65536"}) 38 private int dataSize; 39 40 private ChaChaVector cc20_S128 = makeCC20(Vector.Shape.S_128_BIT); 41 private ChaChaVector cc20_S256 = makeCC20(Vector.Shape.S_256_BIT); 42 private ChaChaVector cc20_S512 = makeCC20(Vector.Shape.S_512_BIT); 43 44 private byte[] in; 45 private byte[] out; 46 47 private byte[] key = new byte[32]; 48 private byte[] nonce = new byte[12]; 49 private long counter = 0; 50 51 private static ChaChaVector makeCC20(Vector.Shape shape) { 52 ChaChaVector cc20 = new ChaChaVector(shape); 53 runKAT(cc20); 54 return cc20; 55 } 56 57 @Setup 58 public void setup() { 59 60 in = new byte[dataSize]; 61 out = new byte[dataSize]; 62 } 63 64 @Benchmark 65 public void encrypt128() { 66 cc20_S128.chacha20(key, nonce, counter, in, out); 67 } 68 69 @Benchmark 70 public void encrypt256() { 71 cc20_S256.chacha20(key, nonce, counter, in, out); 72 } 73 74 @Benchmark 75 public void encrypt512() { 76 cc20_S512.chacha20(key, nonce, counter, in, out); 77 } 78 79 private static class ChaChaVector { 80 81 private static final int[] STATE_CONSTANTS = 82 new int[]{0x61707865, 0x3320646e, 0x79622d32, 0x6b206574}; 83 84 private final Vector.Species<Integer> intSpecies; 85 private final int numBlocks; 86 87 private final Vector.Shuffle<Integer> rot1; 88 private final Vector.Shuffle<Integer> rot2; 89 private final Vector.Shuffle<Integer> rot3; 90 91 private final IntVector counterAdd; 92 93 private final Vector.Shuffle<Integer> shuf0; 94 private final Vector.Shuffle<Integer> shuf1; 95 private final Vector.Shuffle<Integer> shuf2; 96 private final Vector.Shuffle<Integer> shuf3; 97 98 private final Vector.Mask<Integer> mask0; 99 private final Vector.Mask<Integer> mask1; 100 private final Vector.Mask<Integer> mask2; 101 private final Vector.Mask<Integer> mask3; 102 103 private final int[] state; 104 105 public ChaChaVector(Vector.Shape shape) { 106 this.intSpecies = Vector.Species.of(Integer.class, shape); 107 this.numBlocks = intSpecies.length() / 4; 108 109 this.rot1 = makeRotate(1); 110 this.rot2 = makeRotate(2); 111 this.rot3 = makeRotate(3); 112 113 this.counterAdd = makeCounterAdd(); 114 115 this.shuf0 = makeRearrangeShuffle(0); 116 this.shuf1 = makeRearrangeShuffle(1); 117 this.shuf2 = makeRearrangeShuffle(2); 118 this.shuf3 = makeRearrangeShuffle(3); 119 120 this.mask0 = makeRearrangeMask(0); 121 this.mask1 = makeRearrangeMask(1); 122 this.mask2 = makeRearrangeMask(2); 123 this.mask3 = makeRearrangeMask(3); 124 125 this.state = new int[numBlocks * 16]; 126 } 127 128 private Vector.Shuffle<Integer> makeRotate(int amount) { 129 int[] shuffleArr = new int[intSpecies.length()]; 130 131 for (int i = 0; i < intSpecies.length(); i ++) { 132 int offset = (i / 4) * 4; 133 shuffleArr[i] = offset + ((i + amount) % 4); 134 } 135 136 return IntVector.shuffleFromValues(intSpecies, shuffleArr); 137 } 138 139 private IntVector makeCounterAdd() { 140 int[] addArr = new int[intSpecies.length()]; 141 for(int i = 0; i < numBlocks; i++) { 142 addArr[4 * i] = numBlocks; 143 } 144 return IntVector.fromArray(intSpecies, addArr, 0); 145 } 146 147 private Vector.Shuffle<Integer> makeRearrangeShuffle(int order) { 148 int[] shuffleArr = new int[intSpecies.length()]; 149 int start = order * 4; 150 for (int i = 0; i < shuffleArr.length; i++) { 151 shuffleArr[i] = (i % 4) + start; 152 } 153 return IntVector.shuffleFromArray(intSpecies, shuffleArr, 0); 154 } 155 156 private Vector.Mask<Integer> makeRearrangeMask(int order) { 157 boolean[] maskArr = new boolean[intSpecies.length()]; 158 int start = order * 4; 159 if (start < maskArr.length) { 160 for (int i = 0; i < 4; i++) { 161 maskArr[i + start] = true; 162 } 163 } 164 165 return IntVector.maskFromValues(intSpecies, maskArr); 166 } 167 168 public void makeState(byte[] key, byte[] nonce, long counter, 169 int[] out) { 170 171 // first field is constants 172 for (int i = 0; i < 4; i++) { 173 for (int j = 0; j < numBlocks; j++) { 174 out[4*j + i] = STATE_CONSTANTS[i]; 175 } 176 } 177 178 // second field is first part of key 179 int fieldStart = 4 * numBlocks; 180 for (int i = 0; i < 4; i++) { 181 int keyInt = 0; 182 for (int j = 0; j < 4; j++) { 183 keyInt += (0xFF & key[4 * i + j]) << 8 * j; 184 } 185 for (int j = 0; j < numBlocks; j++) { 186 out[fieldStart + j*4 + i] = keyInt; 187 } 188 } 189 190 // third field is second part of key 191 fieldStart = 8 * numBlocks; 192 for (int i = 0; i < 4; i++) { 193 int keyInt = 0; 194 for (int j = 0; j < 4; j++) { 195 keyInt += (0xFF & key[4 * (i + 4) + j]) << 8 * j; 196 } 197 198 for (int j = 0; j < numBlocks; j++) { 199 out[fieldStart + j*4 + i] = keyInt; 200 } 201 } 202 203 // fourth field is counter and nonce 204 fieldStart = 12 * numBlocks; 205 for (int j = 0; j < numBlocks; j++) { 206 out[fieldStart + j*4] = (int) (counter + j); 207 } 208 209 for (int i = 0; i < 3; i++) { 210 int nonceInt = 0; 211 for (int j = 0; j < 4; j++) { 212 nonceInt += (0xFF & nonce[4 * i + j]) << 8 * j; 213 } 214 215 for (int j = 0; j < numBlocks; j++) { 216 out[fieldStart + j*4 + 1 + i] = nonceInt; 217 } 218 } 219 } 220 221 public void chacha20(byte[] key, byte[] nonce, long counter, 222 byte[] in, byte[] out) { 223 224 makeState(key, nonce, counter, state); 225 226 int len = intSpecies.length(); 227 228 IntVector sa = IntVector.fromArray(intSpecies, state, 0); 229 IntVector sb = IntVector.fromArray(intSpecies, state, len); 230 IntVector sc = IntVector.fromArray(intSpecies, state, 2 * len); 231 IntVector sd = IntVector.fromArray(intSpecies, state, 3 * len); 232 233 int stateLenBytes = state.length * 4; 234 int numStates = (in.length + stateLenBytes - 1) / stateLenBytes; 235 for (int j = 0; j < numStates; j++){ 236 237 IntVector a = sa; 238 IntVector b = sb; 239 IntVector c = sc; 240 IntVector d = sd; 241 242 for (int i = 0; i < 10; i++) { 243 // first round 244 a = a.add(b); 245 d = d.xor(a); 246 d = d.rotateL(16); 247 248 c = c.add(d); 249 b = b.xor(c); 250 b = b.rotateL(12); 251 252 a = a.add(b); 253 d = d.xor(a); 254 d = d.rotateL(8); 255 256 c = c.add(d); 257 b = b.xor(c); 258 b = b.rotateL(7); 259 260 // makeRotate 261 b = b.rearrange(rot1); 262 c = c.rearrange(rot2); 263 d = d.rearrange(rot3); 264 265 // second round 266 a = a.add(b); 267 d = d.xor(a); 268 d = d.rotateL(16); 269 270 c = c.add(d); 271 b = b.xor(c); 272 b = b.rotateL(12); 273 274 a = a.add(b); 275 d = d.xor(a); 276 d = d.rotateL(8); 277 278 c = c.add(d); 279 b = b.xor(c); 280 b = b.rotateL(7); 281 282 // makeRotate 283 b = b.rearrange(rot3); 284 c = c.rearrange(rot2); 285 d = d.rearrange(rot1); 286 } 287 288 a = a.add(sa); 289 b = b.add(sb); 290 c = c.add(sc); 291 d = d.add(sd); 292 293 // rearrange the vectors 294 if (intSpecies.length() == 4) { 295 // no rearrange needed 296 } else if (intSpecies.length() == 8) { 297 IntVector a_r = a.rearrange(b, shuf0, mask1); 298 IntVector b_r = c.rearrange(d, shuf0, mask1); 299 IntVector c_r = a.rearrange(b, shuf1, mask1); 300 IntVector d_r = c.rearrange(d, shuf1, mask1); 301 302 a = a_r; 303 b = b_r; 304 c = c_r; 305 d = d_r; 306 } else if (intSpecies.length() == 16) { 307 IntVector a_r = a; 308 a_r = a_r.blend(b.rearrange(shuf0), mask1); 309 a_r = a_r.blend(c.rearrange(shuf0), mask2); 310 a_r = a_r.blend(d.rearrange(shuf0), mask3); 311 312 IntVector b_r = b; 313 b_r = b_r.blend(a.rearrange(shuf1), mask0); 314 b_r = b_r.blend(c.rearrange(shuf1), mask2); 315 b_r = b_r.blend(d.rearrange(shuf1), mask3); 316 317 IntVector c_r = c; 318 c_r = c_r.blend(a.rearrange(shuf2), mask0); 319 c_r = c_r.blend(b.rearrange(shuf2), mask1); 320 c_r = c_r.blend(d.rearrange(shuf2), mask3); 321 322 IntVector d_r = d; 323 d_r = d_r.blend(a.rearrange(shuf3), mask0); 324 d_r = d_r.blend(b.rearrange(shuf3), mask1); 325 d_r = d_r.blend(c.rearrange(shuf3), mask2); 326 327 a = a_r; 328 b = b_r; 329 c = c_r; 330 d = d_r; 331 } else { 332 throw new RuntimeException("not supported"); 333 } 334 335 // xor keystream with input 336 int inOff = stateLenBytes * j; 337 IntVector ina = IntVector.fromByteArray(intSpecies, in, inOff); 338 IntVector inb = IntVector.fromByteArray(intSpecies, in, inOff + 4 * len); 339 IntVector inc = IntVector.fromByteArray(intSpecies, in, inOff + 8 * len); 340 IntVector ind = IntVector.fromByteArray(intSpecies, in, inOff + 12 * len); 341 342 ina.xor(a).intoByteArray(out, inOff); 343 inb.xor(b).intoByteArray(out, inOff + 4 * len); 344 inc.xor(c).intoByteArray(out, inOff + 8 * len); 345 ind.xor(d).intoByteArray(out, inOff + 12 * len); 346 347 // increment counter 348 sd = sd.add(counterAdd); 349 } 350 } 351 352 public int implBlockSize() { 353 return numBlocks * 64; 354 } 355 } 356 357 private static byte[] hexStringToByteArray(String str) { 358 byte[] result = new byte[str.length() / 2]; 359 for (int i = 0; i < result.length; i++) { 360 result[i] = (byte) Character.digit(str.charAt(2 * i), 16); 361 result[i] <<= 4; 362 result[i] += Character.digit(str.charAt(2 * i + 1), 16); 363 } 364 return result; 365 } 366 367 private static void runKAT(ChaChaVector cc20, String keyStr, 368 String nonceStr, long counter, String inStr, String outStr) { 369 370 byte[] key = hexStringToByteArray(keyStr); 371 byte[] nonce = hexStringToByteArray(nonceStr); 372 byte[] in = hexStringToByteArray(inStr); 373 byte[] expOut = hexStringToByteArray(outStr); 374 375 // implementation only works at multiples of some size 376 int blockSize = cc20.implBlockSize(); 377 378 int length = blockSize * ((in.length + blockSize - 1) / blockSize); 379 in = Arrays.copyOf(in, length); 380 byte[] out = new byte[length]; 381 382 cc20.chacha20(key, nonce, counter, in, out); 383 384 byte[] actOut = new byte[expOut.length]; 385 System.arraycopy(out, 0, actOut, 0, expOut.length); 386 387 if (!Arrays.equals(out, 0, expOut.length, expOut, 0, expOut.length)) { 388 throw new RuntimeException("Incorrect result"); 389 } 390 } 391 392 /* 393 * ChaCha20 Known Answer Tests to ensure that the implementation is correct. 394 */ 395 private static void runKAT(ChaChaVector cc20) { 396 runKAT(cc20, 397 "0000000000000000000000000000000000000000000000000000000000000001", 398 "000000000000000000000002", 399 1, 400 "416e79207375626d697373696f6e20746f20746865204945544620696e74656e" + 401 "6465642062792074686520436f6e7472696275746f7220666f72207075626c69" + 402 "636174696f6e20617320616c6c206f722070617274206f6620616e2049455446" + 403 "20496e7465726e65742d4472616674206f722052464320616e6420616e792073" + 404 "746174656d656e74206d6164652077697468696e2074686520636f6e74657874" + 405 "206f6620616e204945544620616374697669747920697320636f6e7369646572" + 406 "656420616e20224945544620436f6e747269627574696f6e222e205375636820" + 407 "73746174656d656e747320696e636c756465206f72616c2073746174656d656e" + 408 "747320696e20494554462073657373696f6e732c2061732077656c6c20617320" + 409 "7772697474656e20616e6420656c656374726f6e696320636f6d6d756e696361" + 410 "74696f6e73206d61646520617420616e792074696d65206f7220706c6163652c" + 411 "207768696368206172652061646472657373656420746f", 412 "a3fbf07df3fa2fde4f376ca23e82737041605d9f4f4f57bd8cff2c1d4b7955ec" + 413 "2a97948bd3722915c8f3d337f7d370050e9e96d647b7c39f56e031ca5eb6250d" + 414 "4042e02785ececfa4b4bb5e8ead0440e20b6e8db09d881a7c6132f420e527950" + 415 "42bdfa7773d8a9051447b3291ce1411c680465552aa6c405b7764d5e87bea85a" + 416 "d00f8449ed8f72d0d662ab052691ca66424bc86d2df80ea41f43abf937d3259d" + 417 "c4b2d0dfb48a6c9139ddd7f76966e928e635553ba76c5c879d7b35d49eb2e62b" + 418 "0871cdac638939e25e8a1e0ef9d5280fa8ca328b351c3c765989cbcf3daa8b6c" + 419 "cc3aaf9f3979c92b3720fc88dc95ed84a1be059c6499b9fda236e7e818b04b0b" + 420 "c39c1e876b193bfe5569753f88128cc08aaa9b63d1a16f80ef2554d7189c411f" + 421 "5869ca52c5b83fa36ff216b9c1d30062bebcfd2dc5bce0911934fda79a86f6e6" + 422 "98ced759c3ff9b6477338f3da4f9cd8514ea9982ccafb341b2384dd902f3d1ab" + 423 "7ac61dd29c6f21ba5b862f3730e37cfdc4fd806c22f221" 424 ); 425 } 426 }