< prev index next >

src/hotspot/cpu/aarch64/stubGenerator_aarch64.cpp

Print this page
rev 60737 : 8252204: AArch64: Implement SHA3 accelerator/intrinsic
Reviewed-by: duke
Contributed-by: dongbo4@huawei.com

@@ -3289,10 +3289,229 @@
     __ ret(lr);
 
     return start;
   }
 
+  // Arguments:
+  //
+  // Inputs:
+  //   c_rarg0   - byte[]  source+offset
+  //   c_rarg1   - byte[]   SHA.state
+  //   c_rarg2   - int     digest_length
+  //   c_rarg3   - int     offset
+  //   c_rarg4   - int     limit
+  //
+  address generate_sha3_implCompress(bool multi_block, const char *name) {
+    static const uint64_t round_consts[24] = {
+      0x0000000000000001L, 0x0000000000008082L, 0x800000000000808AL,
+      0x8000000080008000L, 0x000000000000808BL, 0x0000000080000001L,
+      0x8000000080008081L, 0x8000000000008009L, 0x000000000000008AL,
+      0x0000000000000088L, 0x0000000080008009L, 0x000000008000000AL,
+      0x000000008000808BL, 0x800000000000008BL, 0x8000000000008089L,
+      0x8000000000008003L, 0x8000000000008002L, 0x8000000000000080L,
+      0x000000000000800AL, 0x800000008000000AL, 0x8000000080008081L,
+      0x8000000000008080L, 0x0000000080000001L, 0x8000000080008008L
+    };
+
+    __ align(CodeEntryAlignment);
+    StubCodeMark mark(this, "StubRoutines", name);
+    address start = __ pc();
+
+    Register buf           = c_rarg0;
+    Register state         = c_rarg1;
+    Register digest_length = c_rarg2;
+    Register ofs           = c_rarg3;
+    Register limit         = c_rarg4;
+
+    Label sha3_loop, rounds24_loop;
+    Label sha3_512, sha3_384_or_224, sha3_256;
+
+    __ stpd(v8, v9, __ pre(sp, -64));
+    __ stpd(v10, v11, Address(sp, 16));
+    __ stpd(v12, v13, Address(sp, 32));
+    __ stpd(v14, v15, Address(sp, 48));
+
+    // load state
+    __ add(rscratch1, state, 32);
+    __ ld1(v0, v1, v2,  v3,  __ T1D, state);
+    __ ld1(v4, v5, v6,  v7,  __ T1D, __ post(rscratch1, 32));
+    __ ld1(v8, v9, v10, v11, __ T1D, __ post(rscratch1, 32));
+    __ ld1(v12, v13, v14, v15, __ T1D, __ post(rscratch1, 32));
+    __ ld1(v16, v17, v18, v19, __ T1D, __ post(rscratch1, 32));
+    __ ld1(v20, v21, v22, v23, __ T1D, __ post(rscratch1, 32));
+    __ ld1(v24, __ T1D, rscratch1);
+
+    __ BIND(sha3_loop);
+
+    // 24 keccak rounds
+    __ movw(rscratch2, 24);
+
+    // load round_constants base
+    __ lea(rscratch1, ExternalAddress((address) round_consts));
+
+    // load input
+    __ ld1(v25, v26, v27, v28, __ T8B, __ post(buf, 32));
+    __ ld1(v29, v30, v31, __ T8B, __ post(buf, 24));
+    __ eor(v0, __ T8B, v0, v25);
+    __ eor(v1, __ T8B, v1, v26);
+    __ eor(v2, __ T8B, v2, v27);
+    __ eor(v3, __ T8B, v3, v28);
+    __ eor(v4, __ T8B, v4, v29);
+    __ eor(v5, __ T8B, v5, v30);
+    __ eor(v6, __ T8B, v6, v31);
+
+    // digest_length == 64, SHA3-512
+    __ tbnz(digest_length, 6, sha3_512);
+
+    __ ld1(v25, v26, v27, v28, __ T8B, __ post(buf, 32));
+    __ ld1(v29, v30, __ T8B, __ post(buf, 16));
+    __ eor(v7, __ T8B, v7, v25);
+    __ eor(v8, __ T8B, v8, v26);
+    __ eor(v9, __ T8B, v9, v27);
+    __ eor(v10, __ T8B, v10, v28);
+    __ eor(v11, __ T8B, v11, v29);
+    __ eor(v12, __ T8B, v12, v30);
+
+    // digest_length == 28, SHA3-224;  digest_length == 48, SHA3-384
+    __ tbnz(digest_length, 4, sha3_384_or_224);
+
+    // SHA3-256
+    __ ld1(v25, v26, v27, v28, __ T8B, __ post(buf, 32));
+    __ eor(v13, __ T8B, v13, v25);
+    __ eor(v14, __ T8B, v14, v26);
+    __ eor(v15, __ T8B, v15, v27);
+    __ eor(v16, __ T8B, v16, v28);
+    __ b(rounds24_loop);
+
+    __ BIND(sha3_384_or_224);
+    __ tbz(digest_length, 2, rounds24_loop); // bit 2 cleared? SHA-384
+
+    // SHA3-224
+    __ ld1(v25, v26, v27, v28, __ T8B, __ post(buf, 32));
+    __ ld1(v29, __ T8B, __ post(buf, 8));
+    __ eor(v13, __ T8B, v13, v25);
+    __ eor(v14, __ T8B, v14, v26);
+    __ eor(v15, __ T8B, v15, v27);
+    __ eor(v16, __ T8B, v16, v28);
+    __ eor(v17, __ T8B, v17, v29);
+    __ b(rounds24_loop);
+
+    __ BIND(sha3_512);
+    __ ld1(v25, v26, __ T8B, __ post(buf, 16));
+    __ eor(v7, __ T8B, v7, v25);
+    __ eor(v8, __ T8B, v8, v26);
+
+    __ BIND(rounds24_loop);
+    __ subw(rscratch2, rscratch2, 1);
+
+    __ eor3(v29, __ T16B, v4, v9, v14);
+    __ eor3(v26, __ T16B, v1, v6, v11);
+    __ eor3(v28, __ T16B, v3, v8, v13);
+    __ eor3(v25, __ T16B, v0, v5, v10);
+    __ eor3(v27, __ T16B, v2, v7, v12);
+    __ eor3(v29, __ T16B, v29, v19, v24);
+    __ eor3(v26, __ T16B, v26, v16, v21);
+    __ eor3(v28, __ T16B, v28, v18, v23);
+    __ eor3(v25, __ T16B, v25, v15, v20);
+    __ eor3(v27, __ T16B, v27, v17, v22);
+
+    __ rax1(v30, __ T2D, v29, v26);
+    __ rax1(v26, __ T2D, v26, v28);
+    __ rax1(v28, __ T2D, v28, v25);
+    __ rax1(v25, __ T2D, v25, v27);
+    __ rax1(v27, __ T2D, v27, v29);
+
+    __ eor(v0, __ T16B, v0, v30);
+    __ xar(v29, __ T2D, v1,  v25, (64 - 1));
+    __ xar(v1,  __ T2D, v6,  v25, (64 - 44));
+    __ xar(v6,  __ T2D, v9,  v28, (64 - 20));
+    __ xar(v9,  __ T2D, v22, v26, (64 - 61));
+    __ xar(v22, __ T2D, v14, v28, (64 - 39));
+    __ xar(v14, __ T2D, v20, v30, (64 - 18));
+    __ xar(v31, __ T2D, v2,  v26, (64 - 62));
+    __ xar(v2,  __ T2D, v12, v26, (64 - 43));
+    __ xar(v12, __ T2D, v13, v27, (64 - 25));
+    __ xar(v13, __ T2D, v19, v28, (64 - 8));
+    __ xar(v19, __ T2D, v23, v27, (64 - 56));
+    __ xar(v23, __ T2D, v15, v30, (64 - 41));
+    __ xar(v15, __ T2D, v4,  v28, (64 - 27));
+    __ xar(v28, __ T2D, v24, v28, (64 - 14));
+    __ xar(v24, __ T2D, v21, v25, (64 - 2));
+    __ xar(v8,  __ T2D, v8,  v27, (64 - 55));
+    __ xar(v4,  __ T2D, v16, v25, (64 - 45));
+    __ xar(v16, __ T2D, v5,  v30, (64 - 36));
+    __ xar(v5,  __ T2D, v3,  v27, (64 - 28));
+    __ xar(v27, __ T2D, v18, v27, (64 - 21));
+    __ xar(v3,  __ T2D, v17, v26, (64 - 15));
+    __ xar(v25, __ T2D, v11, v25, (64 - 10));
+    __ xar(v26, __ T2D, v7,  v26, (64 - 6));
+    __ xar(v30, __ T2D, v10, v30, (64 - 3));
+
+    __ bcax(v20, __ T16B, v31, v22, v8);
+    __ bcax(v21, __ T16B, v8,  v23, v22);
+    __ bcax(v22, __ T16B, v22, v24, v23);
+    __ bcax(v23, __ T16B, v23, v31, v24);
+    __ bcax(v24, __ T16B, v24, v8,  v31);
+
+    __ ld1r(v31, __ T2D, __ post(rscratch1, 8));
+
+    __ bcax(v17, __ T16B, v25, v19, v3);
+    __ bcax(v18, __ T16B, v3,  v15, v19);
+    __ bcax(v19, __ T16B, v19, v16, v15);
+    __ bcax(v15, __ T16B, v15, v25, v16);
+    __ bcax(v16, __ T16B, v16, v3,  v25);
+
+    __ bcax(v10, __ T16B, v29, v12, v26);
+    __ bcax(v11, __ T16B, v26, v13, v12);
+    __ bcax(v12, __ T16B, v12, v14, v13);
+    __ bcax(v13, __ T16B, v13, v29, v14);
+    __ bcax(v14, __ T16B, v14, v26, v29);
+
+    __ bcax(v7, __ T16B, v30, v9,  v4);
+    __ bcax(v8, __ T16B, v4,  v5,  v9);
+    __ bcax(v9, __ T16B, v9,  v6,  v5);
+    __ bcax(v5, __ T16B, v5,  v30, v6);
+    __ bcax(v6, __ T16B, v6,  v4,  v30);
+
+    __ bcax(v3, __ T16B, v27, v0,  v28);
+    __ bcax(v4, __ T16B, v28, v1,  v0);
+    __ bcax(v0, __ T16B, v0,  v2,  v1);
+    __ bcax(v1, __ T16B, v1,  v27, v2);
+    __ bcax(v2, __ T16B, v2,  v28, v27);
+
+    __ eor(v0, __ T16B, v0, v31);
+
+    __ cbnzw(rscratch2, rounds24_loop);
+
+    if (multi_block) {
+      // block_size =  200 - 2 * digest_length, ofs += block_size
+      __ add(ofs, ofs, 200);
+      __ sub(ofs, ofs, digest_length, Assembler::LSL, 1);
+
+      __ cmp(ofs, limit);
+      __ br(Assembler::LE, sha3_loop);
+      __ mov(c_rarg0, ofs); // return ofs
+    }
+
+    __ st1(v0, v1, v2,  v3,  __ T1D, __ post(state, 32));
+    __ st1(v4, v5, v6,  v7,  __ T1D, __ post(state, 32));
+    __ st1(v8, v9, v10, v11, __ T1D, __ post(state, 32));
+    __ st1(v12, v13, v14, v15, __ T1D, __ post(state, 32));
+    __ st1(v16, v17, v18, v19, __ T1D, __ post(state, 32));
+    __ st1(v20, v21, v22, v23, __ T1D, __ post(state, 32));
+    __ st1(v24, __ T1D, state);
+
+    __ ldpd(v14, v15, Address(sp, 48));
+    __ ldpd(v12, v13, Address(sp, 32));
+    __ ldpd(v10, v11, Address(sp, 16));
+    __ ldpd(v8, v9, __ post(sp, 64));
+
+    __ ret(lr);
+
+    return start;
+  }
+
   // Safefetch stubs.
   void generate_safefetch(const char* name, int size, address* entry,
                           address* fault_pc, address* continuation_pc) {
     // safefetch signatures:
     //   int      SafeFetch32(int*      adr, int      errValue);

@@ -6020,10 +6239,14 @@
     }
     if (UseSHA512Intrinsics) {
       StubRoutines::_sha512_implCompress   = generate_sha512_implCompress(false, "sha512_implCompress");
       StubRoutines::_sha512_implCompressMB = generate_sha512_implCompress(true,  "sha512_implCompressMB");
     }
+    if (UseSHA3Intrinsics) {
+      StubRoutines::_sha3_implCompress     = generate_sha3_implCompress(false,   "sha3_implCompress");
+      StubRoutines::_sha3_implCompressMB   = generate_sha3_implCompress(true,    "sha3_implCompressMB");
+    }
 
     // generate Adler32 intrinsics code
     if (UseAdler32Intrinsics) {
       StubRoutines::_updateBytesAdler32 = generate_updateBytesAdler32();
     }
< prev index next >