# HG changeset patch # User aph # Date 1434472313 -3600 # Tue Jun 16 17:31:53 2015 +0100 # Node ID f3adbca056e3cb13ebf1812449bf0b45ac4d6846 # Parent 5a9d5d58e667b7a49b5d07ccfdb02c00252a9800 8130150: Implement BigInteger.montgomeryMultiply intrinsic Summary: Add montgomeryMultiply intrinsics Reviewed-by: kvn diff --git a/src/cpu/x86/vm/sharedRuntime_x86_64.cpp b/src/cpu/x86/vm/sharedRuntime_x86_64.cpp --- a/src/cpu/x86/vm/sharedRuntime_x86_64.cpp +++ b/src/cpu/x86/vm/sharedRuntime_x86_64.cpp @@ -23,6 +23,7 @@ */ #include "precompiled.hpp" +#include "alloca.h" #include "asm/macroAssembler.hpp" #include "asm/macroAssembler.inline.hpp" #include "code/debugInfoRec.hpp" @@ -3511,6 +3512,250 @@ } +//------------------------------Montgomery multiplication------------------------ +// + +#ifndef _WINDOWS + +#define ASM_SUBTRACT + +#ifdef ASM_SUBTRACT +// Subtract 0:b from carry:a. Return carry. +static unsigned long +sub(unsigned long a[], unsigned long b[], unsigned long carry, long len) { + long i = 0, cnt = len; + unsigned long tmp; + asm volatile("clc; " + "0: ; " + "mov (%[b], %[i], 8), %[tmp]; " + "sbb %[tmp], (%[a], %[i], 8); " + "inc %[i]; dec %[cnt]; " + "jne 0b; " + "mov %[carry], %[tmp]; sbb $0, %[tmp]; " + : [i]"+r"(i), [cnt]"+r"(cnt), [tmp]"=&r"(tmp) + : [a]"r"(a), [b]"r"(b), [carry]"r"(carry) + : "memory"); + return tmp; +} +#else // ASM_SUBTRACT +typedef int __attribute__((mode(TI))) int128; + +// Subtract 0:b from carry:a. Return carry. +static unsigned long +sub(unsigned long a[], unsigned long b[], unsigned long carry, int len) { + int128 tmp = 0; + int i; + for (i = 0; i < len; i++) { + tmp += a[i]; + tmp -= b[i]; + a[i] = tmp; + tmp >>= 64; + assert(-1 <= tmp && tmp <= 0, "invariant"); + } + return tmp + carry; +} +#endif // ! ASM_SUBTRACT + +// Multiply (unsigned) Long A by Long B, accumulating the double- +// length result into the accumulator formed of T0, T1, and T2. +#define MACC(A, B, T0, T1, T2) \ +do { \ + unsigned long hi, lo; \ + __asm__ ("mul %5; add %%rax, %2; adc %%rdx, %3; adc $0, %4" \ + : "=&d"(hi), "=a"(lo), "+r"(T0), "+r"(T1), "+g"(T2) \ + : "r"(A), "a"(B) : "cc"); \ + } while(0) + +// As above, but add twice the double-length result into the +// accumulator. +#define MACC2(A, B, T0, T1, T2) \ +do { \ + unsigned long hi, lo; \ + __asm__ ("mul %5; add %%rax, %2; adc %%rdx, %3; adc $0, %4; " \ + "add %%rax, %2; adc %%rdx, %3; adc $0, %4" \ + : "=&d"(hi), "=a"(lo), "+r"(T0), "+r"(T1), "+g"(T2) \ + : "r"(A), "a"(B) : "cc"); \ + } while(0) + +// Fast Montgomery multiplication. The derivation of the algorithm is +// in A Cryptographic Library for the Motorola DSP56000, +// Dusse and Kaliski, Proc. EUROCRYPT 90, pp. 230-237. + +static void __attribute__((noinline)) +montgomery_multiply(unsigned long a[], unsigned long b[], unsigned long n[], + unsigned long m[], unsigned long inv, int len) { + unsigned long t0 = 0, t1 = 0, t2 = 0; // Triple-precision accumulator + int i; + + assert(inv * n[0] == -1UL, "broken inverse in Montgomery multiply"); + + for (i = 0; i < len; i++) { + int j; + for (j = 0; j < i; j++) { + MACC(a[j], b[i-j], t0, t1, t2); + MACC(m[j], n[i-j], t0, t1, t2); + } + MACC(a[i], b[0], t0, t1, t2); + m[i] = t0 * inv; + MACC(m[i], n[0], t0, t1, t2); + + assert(t0 == 0, "broken Montgomery multiply"); + + t0 = t1; t1 = t2; t2 = 0; + } + + for (i = len; i < 2*len; i++) { + int j; + for (j = i-len+1; j < len; j++) { + MACC(a[j], b[i-j], t0, t1, t2); + MACC(m[j], n[i-j], t0, t1, t2); + } + m[i-len] = t0; + t0 = t1; t1 = t2; t2 = 0; + } + + while (t0) + t0 = sub(m, n, t0, len); +} + +// Fast Montgomery squaring. This uses asymptotically 25% fewer +// multiplies so it should be up to 25% faster than Montgomery +// multiplication. However, its loop control is more complex and it +// may actually run slower on some machines. + +static void __attribute__((noinline)) +montgomery_square(unsigned long a[], unsigned long n[], + unsigned long m[], unsigned long inv, int len) { + unsigned long t0 = 0, t1 = 0, t2 = 0; // Triple-precision accumulator + int i; + + assert(inv * n[0] == -1UL, "broken inverse in Montgomery multiply"); + + for (i = 0; i < len; i++) { + int j; + int end = (i+1)/2; + for (j = 0; j < end; j++) { + MACC2(a[j], a[i-j], t0, t1, t2); + MACC(m[j], n[i-j], t0, t1, t2); + } + if ((i & 1) == 0) { + MACC(a[j], a[j], t0, t1, t2); + } + for (; j < i; j++) { + MACC(m[j], n[i-j], t0, t1, t2); + } + m[i] = t0 * inv; + MACC(m[i], n[0], t0, t1, t2); + + assert(t0 == 0, "broken Montgomery square"); + + t0 = t1; t1 = t2; t2 = 0; + } + + for (i = len; i < 2*len; i++) { + int start = i-len+1; + int end = start + (len - start)/2; + int j; + for (j = start; j < end; j++) { + MACC2(a[j], a[i-j], t0, t1, t2); + MACC(m[j], n[i-j], t0, t1, t2); + } + if ((i & 1) == 0) { + MACC(a[j], a[j], t0, t1, t2); + } + for (; j < len; j++) { + MACC(m[j], n[i-j], t0, t1, t2); + } + m[i-len] = t0; + t0 = t1; t1 = t2; t2 = 0; + } + + while (t0) + t0 = sub(m, n, t0, len); +} + +// Swap words in a longword. +static unsigned long swap(unsigned long x) { + return (x << 32) | (x >> 32); +} + +// Copy len longwords from s to d, word-swapping as we go. The +// destination array is reversed. +static void reverse_words(unsigned long *s, unsigned long *d, int len) { + d += len; + while(len-- > 0) { + d--; + *d = swap(*s); + s++; + } +} + +// The threshold at which squaring is advantageous was determined +// experimentally on an i7-3930K (Ivy Bridge) CPU @ 3.5GHz. +#define MONTGOMERY_SQUARING_THRESHOLD 64 + +void SharedRuntime::montgomery_multiply(jint *a_ints, jint *b_ints, jint *n_ints, + jint len, jlong inv, + jint *m_ints) { + assert(len % 2 == 0, "array length in montgomery_multiply must be even"); + int longwords = len/2; + + // Make very sure we don't use so much space that the stack might + // overflow. 512 jints corresponds to an 16384-bit integer and + // will use here a total of 8k bytes of stack space. + int total_allocation = longwords * sizeof (unsigned long) * 4; + guarantee(total_allocation <= 8192, "must be"); + unsigned long *scratch = (unsigned long *)alloca(total_allocation); + + // Local scratch arrays + unsigned long + *a = scratch + 0 * longwords, + *b = scratch + 1 * longwords, + *n = scratch + 2 * longwords, + *m = scratch + 3 * longwords; + + reverse_words((unsigned long *)a_ints, a, longwords); + reverse_words((unsigned long *)b_ints, b, longwords); + reverse_words((unsigned long *)n_ints, n, longwords); + + ::montgomery_multiply(a, b, n, m, (unsigned long)inv, longwords); + + reverse_words(m, (unsigned long *)m_ints, longwords); +} + +void SharedRuntime::montgomery_square(jint *a_ints, jint *n_ints, + jint len, jlong inv, + jint *m_ints) { + assert(len % 2 == 0, "array length in montgomery_square must be even"); + int longwords = len/2; + + // Make very sure we don't use so much space that the stack might + // overflow. 512 jints corresponds to an 16384-bit integer and + // will use here a total of 6k bytes of stack space. + int total_allocation = longwords * sizeof (unsigned long) * 3; + guarantee(total_allocation <= 8192, "must be"); + unsigned long *scratch = (unsigned long *)alloca(total_allocation); + + // Local scratch arrays + unsigned long + *a = scratch + 0 * longwords, + *n = scratch + 1 * longwords, + *m = scratch + 2 * longwords; + + reverse_words((unsigned long *)a_ints, a, longwords); + reverse_words((unsigned long *)n_ints, n, longwords); + + if (len >= MONTGOMERY_SQUARING_THRESHOLD) { + ::montgomery_square(a, n, m, (unsigned long)inv, longwords); + } else { + ::montgomery_multiply(a, a, n, m, (unsigned long)inv, longwords); + } + + reverse_words(m, (unsigned long *)m_ints, longwords); +} + +#endif // WINDOWS + #ifdef COMPILER2 // This is here instead of runtime_x86_64.cpp because it uses SimpleRuntimeFrame // diff --git a/src/cpu/x86/vm/stubGenerator_x86_64.cpp b/src/cpu/x86/vm/stubGenerator_x86_64.cpp --- a/src/cpu/x86/vm/stubGenerator_x86_64.cpp +++ b/src/cpu/x86/vm/stubGenerator_x86_64.cpp @@ -4137,7 +4137,18 @@ if (UseMulAddIntrinsic) { StubRoutines::_mulAdd = generate_mulAdd(); } -#endif + +#ifdef _WINDOWS + if (UseMontgomeryMultiplyIntrinsic) { + StubRoutines::_montgomeryMultiply + = CAST_FROM_FN_PTR(address, SharedRuntime::montgomery_multiply); + } + if (UseMontgomerySquareIntrinsic) { + StubRoutines::_montgomerySquare + = CAST_FROM_FN_PTR(address, SharedRuntime::montgomery_square); + } +#endif // WINDOWS +#endif // COMPILER2 } public: diff --git a/src/cpu/x86/vm/vm_version_x86.cpp b/src/cpu/x86/vm/vm_version_x86.cpp --- a/src/cpu/x86/vm/vm_version_x86.cpp +++ b/src/cpu/x86/vm/vm_version_x86.cpp @@ -796,6 +796,12 @@ if (FLAG_IS_DEFAULT(UseMulAddIntrinsic)) { UseMulAddIntrinsic = true; } + if (FLAG_IS_DEFAULT(UseMontgomeryMultiplyIntrinsic)) { + UseMontgomeryMultiplyIntrinsic = true; + } + if (FLAG_IS_DEFAULT(UseMontgomerySquareIntrinsic)) { + UseMontgomerySquareIntrinsic = true; + } #else if (UseMultiplyToLenIntrinsic) { if (!FLAG_IS_DEFAULT(UseMultiplyToLenIntrinsic)) { @@ -803,6 +809,18 @@ } FLAG_SET_DEFAULT(UseMultiplyToLenIntrinsic, false); } + if (UseMontgomeryMultiplyIntrinsic) { + if (!FLAG_IS_DEFAULT(UseMontgomeryMultiplyIntrinsic)) { + warning("montgomeryMultiply intrinsic is not available in 32-bit VM"); + } + FLAG_SET_DEFAULT(UseMontgomeryMultiplyIntrinsic, false); + } + if (UseMontgomerySquareIntrinsic) { + if (!FLAG_IS_DEFAULT(UseMontgomerySquareIntrinsic)) { + warning("montgomerySquare intrinsic is not available in 32-bit VM"); + } + FLAG_SET_DEFAULT(UseMontgomerySquareIntrinsic, false); + } if (UseSquareToLenIntrinsic) { if (!FLAG_IS_DEFAULT(UseSquareToLenIntrinsic)) { warning("squareToLen intrinsic is not available in 32-bit VM"); diff --git a/src/share/vm/classfile/vmSymbols.hpp b/src/share/vm/classfile/vmSymbols.hpp --- a/src/share/vm/classfile/vmSymbols.hpp +++ b/src/share/vm/classfile/vmSymbols.hpp @@ -796,7 +796,7 @@ do_signature(encodeISOArray_signature, "([CI[BII)I") \ \ do_class(java_math_BigInteger, "java/math/BigInteger") \ - do_intrinsic(_multiplyToLen, java_math_BigInteger, multiplyToLen_name, multiplyToLen_signature, F_R) \ + do_intrinsic(_multiplyToLen, java_math_BigInteger, multiplyToLen_name, multiplyToLen_signature, F_S) \ do_name( multiplyToLen_name, "multiplyToLen") \ do_signature(multiplyToLen_signature, "([II[II[I)[I") \ \ @@ -808,6 +808,14 @@ do_name( mulAdd_name, "implMulAdd") \ do_signature(mulAdd_signature, "([I[IIII)I") \ \ + do_intrinsic(_montgomeryMultiply, java_math_BigInteger, montgomeryMultiply_name, montgomeryMultiply_signature, F_S) \ + do_name( montgomeryMultiply_name, "implMontgomeryMultiply") \ + do_signature(montgomeryMultiply_signature, "([I[I[IIJ[I)[I") \ + \ + do_intrinsic(_montgomerySquare, java_math_BigInteger, montgomerySquare_name, montgomerySquare_signature, F_S) \ + do_name( montgomerySquare_name, "implMontgomerySquare") \ + do_signature(montgomerySquare_signature, "([I[IIJ[I)[I") \ + \ /* java/lang/ref/Reference */ \ do_intrinsic(_Reference_get, java_lang_ref_Reference, get_name, void_object_signature, F_R) \ \ diff --git a/src/share/vm/opto/c2_globals.hpp b/src/share/vm/opto/c2_globals.hpp --- a/src/share/vm/opto/c2_globals.hpp +++ b/src/share/vm/opto/c2_globals.hpp @@ -671,6 +671,12 @@ product(bool, UseMulAddIntrinsic, false, \ "Enables intrinsification of BigInteger.mulAdd()") \ \ + product(bool, UseMontgomeryMultiplyIntrinsic, false, \ + "Enables intrinsification of BigInteger.montgomeryMultiply()") \ + \ + product(bool, UseMontgomerySquareIntrinsic, false, \ + "Enables intrinsification of BigInteger.montgomerySquare()") \ + \ product(bool, UseTypeSpeculation, true, \ "Speculatively propagate types from profiles") \ \ diff --git a/src/share/vm/opto/escape.cpp b/src/share/vm/opto/escape.cpp --- a/src/share/vm/opto/escape.cpp +++ b/src/share/vm/opto/escape.cpp @@ -974,8 +974,10 @@ strcmp(call->as_CallLeaf()->_name, "sha512_implCompressMB") == 0 || strcmp(call->as_CallLeaf()->_name, "multiplyToLen") == 0 || strcmp(call->as_CallLeaf()->_name, "squareToLen") == 0 || - strcmp(call->as_CallLeaf()->_name, "mulAdd") == 0) - ))) { + strcmp(call->as_CallLeaf()->_name, "mulAdd") == 0 || + strcmp(call->as_CallLeaf()->_name, "montgomery_multiply") == 0 || + strcmp(call->as_CallLeaf()->_name, "montgomery_square") == 0) + ))) { call->dump(); fatal(err_msg_res("EA unexpected CallLeaf %s", call->as_CallLeaf()->_name)); } diff --git a/src/share/vm/opto/library_call.cpp b/src/share/vm/opto/library_call.cpp --- a/src/share/vm/opto/library_call.cpp +++ b/src/share/vm/opto/library_call.cpp @@ -293,6 +293,8 @@ bool inline_multiplyToLen(); bool inline_squareToLen(); bool inline_mulAdd(); + bool inline_montgomeryMultiply(); + bool inline_montgomerySquare(); bool inline_profileBoolean(); bool inline_isCompileConstant(); @@ -504,6 +506,13 @@ if (!UseMulAddIntrinsic) return NULL; break; + case vmIntrinsics::_montgomeryMultiply: + if (!UseMontgomeryMultiplyIntrinsic) return NULL; + break; + case vmIntrinsics::_montgomerySquare: + if (!UseMontgomerySquareIntrinsic) return NULL; + break; + case vmIntrinsics::_cipherBlockChaining_encryptAESCrypt: case vmIntrinsics::_cipherBlockChaining_decryptAESCrypt: if (!UseAESIntrinsics) return NULL; @@ -929,6 +938,11 @@ case vmIntrinsics::_mulAdd: return inline_mulAdd(); + case vmIntrinsics::_montgomeryMultiply: + return inline_montgomeryMultiply(); + case vmIntrinsics::_montgomerySquare: + return inline_montgomerySquare(); + case vmIntrinsics::_encodeISOArray: return inline_encodeISOArray(); @@ -5233,11 +5247,12 @@ assert(callee()->signature()->size() == 5, "multiplyToLen has 5 parameters"); - Node* x = argument(1); - Node* xlen = argument(2); - Node* y = argument(3); - Node* ylen = argument(4); - Node* z = argument(5); + // no receiver because it is a static method + Node* x = argument(0); + Node* xlen = argument(1); + Node* y = argument(2); + Node* ylen = argument(3); + Node* z = argument(4); const Type* x_type = x->Value(&_gvn); const Type* y_type = y->Value(&_gvn); @@ -5416,6 +5431,121 @@ return true; } +//-------------inline_montgomeryMultiply----------------------------------- +bool LibraryCallKit::inline_montgomeryMultiply() { + address stubAddr = StubRoutines::montgomeryMultiply(); + if (stubAddr == NULL) { + return false; // Intrinsic's stub is not implemented on this platform + } + + assert(UseMontgomeryMultiplyIntrinsic, "not implemented on this platform"); + const char* stubName = "montgomery_square"; + + assert(callee()->signature()->size() == 7, "montgomeryMultiply has 7 parameters"); + + Node* a = argument(0); + Node* b = argument(1); + Node* n = argument(2); + Node* len = argument(3); + Node* inv = argument(4); + Node* m = argument(6); + + const Type* a_type = a->Value(&_gvn); + const TypeAryPtr* top_a = a_type->isa_aryptr(); + const Type* b_type = b->Value(&_gvn); + const TypeAryPtr* top_b = b_type->isa_aryptr(); + const Type* n_type = a->Value(&_gvn); + const TypeAryPtr* top_n = n_type->isa_aryptr(); + const Type* m_type = a->Value(&_gvn); + const TypeAryPtr* top_m = m_type->isa_aryptr(); + if (top_a == NULL || top_a->klass() == NULL || + top_b == NULL || top_b->klass() == NULL || + top_n == NULL || top_n->klass() == NULL || + top_m == NULL || top_m->klass() == NULL) { + // failed array check + return false; + } + + BasicType a_elem = a_type->isa_aryptr()->klass()->as_array_klass()->element_type()->basic_type(); + BasicType b_elem = b_type->isa_aryptr()->klass()->as_array_klass()->element_type()->basic_type(); + BasicType n_elem = n_type->isa_aryptr()->klass()->as_array_klass()->element_type()->basic_type(); + BasicType m_elem = m_type->isa_aryptr()->klass()->as_array_klass()->element_type()->basic_type(); + if (a_elem != T_INT || b_elem != T_INT || n_elem != T_INT || m_elem != T_INT) { + return false; + } + + // Make the call + { + Node* a_start = array_element_address(a, intcon(0), a_elem); + Node* b_start = array_element_address(b, intcon(0), b_elem); + Node* n_start = array_element_address(n, intcon(0), n_elem); + Node* m_start = array_element_address(m, intcon(0), m_elem); + + Node* call = make_runtime_call(RC_LEAF, + OptoRuntime::montgomeryMultiply_Type(), + stubAddr, stubName, TypePtr::BOTTOM, + a_start, b_start, n_start, len, inv, top(), + m_start); + set_result(m); + } + + return true; +} + +bool LibraryCallKit::inline_montgomerySquare() { + address stubAddr = StubRoutines::montgomerySquare(); + if (stubAddr == NULL) { + return false; // Intrinsic's stub is not implemented on this platform + } + + assert(UseMontgomerySquareIntrinsic, "not implemented on this platform"); + const char* stubName = "montgomery_square"; + + assert(callee()->signature()->size() == 6, "montgomerySquare has 6 parameters"); + + Node* a = argument(0); + Node* n = argument(1); + Node* len = argument(2); + Node* inv = argument(3); + Node* m = argument(5); + + const Type* a_type = a->Value(&_gvn); + const TypeAryPtr* top_a = a_type->isa_aryptr(); + const Type* n_type = a->Value(&_gvn); + const TypeAryPtr* top_n = n_type->isa_aryptr(); + const Type* m_type = a->Value(&_gvn); + const TypeAryPtr* top_m = m_type->isa_aryptr(); + if (top_a == NULL || top_a->klass() == NULL || + top_n == NULL || top_n->klass() == NULL || + top_m == NULL || top_m->klass() == NULL) { + // failed array check + return false; + } + + BasicType a_elem = a_type->isa_aryptr()->klass()->as_array_klass()->element_type()->basic_type(); + BasicType n_elem = n_type->isa_aryptr()->klass()->as_array_klass()->element_type()->basic_type(); + BasicType m_elem = m_type->isa_aryptr()->klass()->as_array_klass()->element_type()->basic_type(); + if (a_elem != T_INT || n_elem != T_INT || m_elem != T_INT) { + return false; + } + + // Make the call + { + Node* a_start = array_element_address(a, intcon(0), a_elem); + Node* n_start = array_element_address(n, intcon(0), n_elem); + Node* m_start = array_element_address(m, intcon(0), m_elem); + + Node* call = make_runtime_call(RC_LEAF, + OptoRuntime::montgomerySquare_Type(), + stubAddr, stubName, TypePtr::BOTTOM, + a_start, n_start, len, inv, top(), + m_start); + set_result(m); + } + + return true; +} + /** * Calculate CRC32 for byte. diff --git a/src/share/vm/opto/runtime.cpp b/src/share/vm/opto/runtime.cpp --- a/src/share/vm/opto/runtime.cpp +++ b/src/share/vm/opto/runtime.cpp @@ -987,6 +987,52 @@ return TypeFunc::make(domain, range); } +const TypeFunc* OptoRuntime::montgomeryMultiply_Type() { + // create input type (domain) + int num_args = 7; + int argcnt = num_args; + const Type** fields = TypeTuple::fields(argcnt); + int argp = TypeFunc::Parms; + fields[argp++] = TypePtr::NOTNULL; // a + fields[argp++] = TypePtr::NOTNULL; // b + fields[argp++] = TypePtr::NOTNULL; // n + fields[argp++] = TypeInt::INT; // len + fields[argp++] = TypeLong::LONG; // inv + fields[argp++] = Type::HALF; + fields[argp++] = TypePtr::NOTNULL; // result + assert(argp == TypeFunc::Parms+argcnt, "correct decoding"); + const TypeTuple* domain = TypeTuple::make(TypeFunc::Parms+argcnt, fields); + + // result type needed + fields = TypeTuple::fields(1); + fields[TypeFunc::Parms+0] = TypePtr::NOTNULL; + + const TypeTuple* range = TypeTuple::make(TypeFunc::Parms, fields); + return TypeFunc::make(domain, range); +} + +const TypeFunc* OptoRuntime::montgomerySquare_Type() { + // create input type (domain) + int num_args = 6; + int argcnt = num_args; + const Type** fields = TypeTuple::fields(argcnt); + int argp = TypeFunc::Parms; + fields[argp++] = TypePtr::NOTNULL; // a + fields[argp++] = TypePtr::NOTNULL; // n + fields[argp++] = TypeInt::INT; // len + fields[argp++] = TypeLong::LONG; // inv + fields[argp++] = Type::HALF; + fields[argp++] = TypePtr::NOTNULL; // result + assert(argp == TypeFunc::Parms+argcnt, "correct decoding"); + const TypeTuple* domain = TypeTuple::make(TypeFunc::Parms+argcnt, fields); + + // result type needed + fields = TypeTuple::fields(1); + fields[TypeFunc::Parms+0] = TypePtr::NOTNULL; + + const TypeTuple* range = TypeTuple::make(TypeFunc::Parms, fields); + return TypeFunc::make(domain, range); +} //------------- Interpreter state access for on stack replacement diff --git a/src/share/vm/opto/runtime.hpp b/src/share/vm/opto/runtime.hpp --- a/src/share/vm/opto/runtime.hpp +++ b/src/share/vm/opto/runtime.hpp @@ -311,6 +311,8 @@ static const TypeFunc* digestBase_implCompressMB_Type(); static const TypeFunc* multiplyToLen_Type(); + static const TypeFunc* montgomeryMultiply_Type(); + static const TypeFunc* montgomerySquare_Type(); static const TypeFunc* squareToLen_Type(); diff --git a/src/share/vm/runtime/sharedRuntime.hpp b/src/share/vm/runtime/sharedRuntime.hpp --- a/src/share/vm/runtime/sharedRuntime.hpp +++ b/src/share/vm/runtime/sharedRuntime.hpp @@ -145,6 +145,12 @@ static double dsqrt(double f); #endif + // Montgomery multiplication + static void montgomery_multiply(jint *a_ints, jint *b_ints, jint *n_ints, + jint len, jlong inv, jint *m_ints); + static void montgomery_square(jint *a_ints, jint *n_ints, + jint len, jlong inv, jint *m_ints); + #ifdef __SOFTFP__ // C++ compiler generates soft float instructions as well as passing // float and double in registers. diff --git a/src/share/vm/runtime/stubRoutines.cpp b/src/share/vm/runtime/stubRoutines.cpp --- a/src/share/vm/runtime/stubRoutines.cpp +++ b/src/share/vm/runtime/stubRoutines.cpp @@ -139,6 +139,8 @@ address StubRoutines::_multiplyToLen = NULL; address StubRoutines::_squareToLen = NULL; address StubRoutines::_mulAdd = NULL; +address StubRoutines::_montgomeryMultiply = NULL; +address StubRoutines::_montgomerySquare = NULL; double (* StubRoutines::_intrinsic_log )(double) = NULL; double (* StubRoutines::_intrinsic_log10 )(double) = NULL; diff --git a/src/share/vm/runtime/stubRoutines.hpp b/src/share/vm/runtime/stubRoutines.hpp --- a/src/share/vm/runtime/stubRoutines.hpp +++ b/src/share/vm/runtime/stubRoutines.hpp @@ -199,6 +199,8 @@ static address _multiplyToLen; static address _squareToLen; static address _mulAdd; + static address _montgomeryMultiply; + static address _montgomerySquare; // These are versions of the java.lang.Math methods which perform // the same operations as the intrinsic version. They are used for @@ -360,6 +362,8 @@ static address multiplyToLen() {return _multiplyToLen; } static address squareToLen() {return _squareToLen; } static address mulAdd() {return _mulAdd; } + static address montgomeryMultiply() { return _montgomeryMultiply; } + static address montgomerySquare() { return _montgomerySquare; } static address select_fill_function(BasicType t, bool aligned, const char* &name); diff --git a/test/compiler/intrinsics/montgomerymultiply/MontgomeryMultiplyTest.java b/test/compiler/intrinsics/montgomerymultiply/MontgomeryMultiplyTest.java new file mode 100644 --- /dev/null +++ b/test/compiler/intrinsics/montgomerymultiply/MontgomeryMultiplyTest.java @@ -0,0 +1,253 @@ +// +// Copyright (c) 2000, 2015, Oracle and/or its affiliates. All rights reserved. +// Copyright (c) 2015, Red Hat Inc. All rights reserved. +// DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. +// +// This code is free software; you can redistribute it and/or modify it +// under the terms of the GNU General Public License version 2 only, as +// published by the Free Software Foundation. +// +// This code is distributed in the hope that it will be useful, but WITHOUT +// ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or +// FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License +// version 2 for more details (a copy is included in the LICENSE file that +// accompanied this code). +// +// You should have received a copy of the GNU General Public License version +// 2 along with this work; if not, write to the Free Software Foundation, +// Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. +// +// Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA +// or visit www.oracle.com if you need additional information or have any +// questions. +// +// + +import java.lang.invoke.MethodHandle; +import java.lang.invoke.MethodHandles; +import java.lang.invoke.MethodType; +import java.lang.reflect.Constructor; +import java.lang.reflect.Field; +import java.lang.reflect.Method; +import java.math.BigInteger; +import java.util.Arrays; +import java.util.Random; + +/** + * @test + * @bug 8130150 + * @library /testlibrary + * @summary Verify that the Montgomery multiply intrinsic works and correctly checks its arguments. + */ + +public class MontgomeryMultiplyTest { + + static final MethodHandles.Lookup lookup = MethodHandles.lookup(); + + static final MethodHandle montgomeryMultiplyHandle, montgomerySquareHandle; + static final MethodHandle bigIntegerConstructorHandle; + static final Field bigIntegerMagField; + + static { + // Use reflection to gain access to the methods we want to test. + try { + Method m = BigInteger.class.getDeclaredMethod("montgomeryMultiply", + /*a*/int[].class, /*b*/int[].class, /*n*/int[].class, /*len*/int.class, + /*inv*/long.class, /*product*/int[].class); + m.setAccessible(true); + montgomeryMultiplyHandle = lookup.unreflect(m); + + m = BigInteger.class.getDeclaredMethod("montgomerySquare", + /*a*/int[].class, /*n*/int[].class, /*len*/int.class, + /*inv*/long.class, /*product*/int[].class); + m.setAccessible(true); + montgomerySquareHandle = lookup.unreflect(m); + + Constructor c + = BigInteger.class.getDeclaredConstructor(int.class, int[].class); + c.setAccessible(true); + bigIntegerConstructorHandle = lookup.unreflectConstructor(c); + + bigIntegerMagField = BigInteger.class.getDeclaredField("mag"); + bigIntegerMagField.setAccessible(true); + + } catch (Throwable ex) { + throw new RuntimeException(ex); + } + } + + // Invoke either the BigInteger.montgomeryMultiply or + // BigInteger.montgomerySquare intrinsics. + int[] montgomeryMultiply(int[] a, int[] b, int[] n, int len, long inv, + int[] product) throws Throwable { + int[] result = + (a == b) ? (int[]) montgomerySquareHandle.invokeExact(a, n, len, inv, product) + : (int[]) montgomeryMultiplyHandle.invokeExact(a, b, n, len, inv, product); + return Arrays.copyOf(result, len); + } + + // Invoke the private constructor BigInteger(int[]). + BigInteger newBigInteger(int[] val) throws Throwable { + return (BigInteger) bigIntegerConstructorHandle.invokeExact(1, val); + } + + // Get the private field BigInteger.mag + int[] mag(BigInteger n) { + try { + return (int[]) bigIntegerMagField.get(n); + } catch (Exception ex) { + throw new RuntimeException(ex); + } + } + + // Montgomery multiplication + // Calculate a * b * r^-1 mod n) + // + // R is a power of the word size + // N' = R^-1 mod N + // + // T := ab + // m := (T mod R)N' mod R [so 0 <= m < R] + // t := (T + mN)/R + // if t >= N then return t - N else return t + // + BigInteger montgomeryMultiply(BigInteger a, BigInteger b, BigInteger N, + int len, BigInteger n_prime) + throws Throwable { + BigInteger T = a.multiply(b); + BigInteger R = BigInteger.ONE.shiftLeft(len*32); + BigInteger mask = R.subtract(BigInteger.ONE); + BigInteger m = (T.and(mask)).multiply(n_prime); + m = m.and(mask); // i.e. m.mod(R) + T = T.add(m.multiply(N)); + T = T.shiftRight(len*32); // i.e. T.divide(R) + if (T.compareTo(N) > 0) { + T = T.subtract(N); + } + return T; + } + + BigInteger montgomeryMultiply(int[] a_words, int[] b_words, int[] n_words, + int len, BigInteger inv) + throws Throwable { + BigInteger t = montgomeryMultiply( + newBigInteger(a_words), + newBigInteger(b_words), + newBigInteger(n_words), + len, inv); + return t; + } + + // Check that the Montgomery multiply intrinsic returns the same + // result as the longhand calculation. + void check(int[] a_words, int[] b_words, int[] n_words, int len, BigInteger inv) + throws Throwable { + BigInteger n = newBigInteger(n_words); + BigInteger slow = montgomeryMultiply(a_words, b_words, n_words, len, inv); + BigInteger fast + = newBigInteger(montgomeryMultiply + (a_words, b_words, n_words, len, inv.longValue(), null)); + // The intrinsic may not return the same value as the longhand + // calculation but they must have the same residue mod N. + if (!slow.mod(n).equals(fast.mod(n))) { + throw new RuntimeException(); + } + } + + Random rnd = new Random(0); + + int[] random_val(int len) { + int[] val = new int[len]; + for (int i = 0; i < len; i++) + val[i] = rnd.nextInt(); + return val; + } + + // Test the Montgomery multiply intrinsic with a bunch of random + // values of varying lengths. Do this for long enough that the + // caller of the intrinsic is C2-compiled. + void testResultValues() throws Throwable { + for (int j = 10; j > 0; j--) { + // Construct a random prime whose length in words is even + int lenInBits = (rnd.nextInt(32)+1)*64 + rnd.nextInt(32) + 32; + int lenInInts = (lenInBits + 31)/32; + BigInteger mod = new BigInteger(lenInBits , 2, rnd); + BigInteger r = BigInteger.ONE.shiftLeft(lenInInts*32); + BigInteger n_prime = mod.modInverse(r).negate(); + + for (int i = 0; i < 10000; i++) { + // multiply + check(random_val(lenInInts), random_val(lenInInts), mag(mod), lenInInts, n_prime); + // square + int[] tmp = random_val(lenInInts); + check(tmp, tmp, mag(mod), lenInInts, n_prime); + } + } + } + + // Range checks + void testOneMontgomeryMultiplyCheck(int[] a, int[] b, int[] n, int len, long inv, + int[] product, Class klass) { + try { + montgomeryMultiply(a, b, n, len, inv, product); + } catch (Throwable ex) { + if (klass.isAssignableFrom(ex.getClass())) + return; + throw new RuntimeException(klass + " expected, " + ex + " was thrown"); + } + throw new RuntimeException(klass + " expected, was not thrown"); + } + + void testOneMontgomeryMultiplyCheck(int[] a, int[] b, BigInteger n, int len, BigInteger inv, + Class klass) { + testOneMontgomeryMultiplyCheck(a, b, mag(n), len, inv.longValue(), null, klass); + } + + void testOneMontgomeryMultiplyCheck(int[] a, int[] b, BigInteger n, int len, BigInteger inv, + int[] product, Class klass) { + testOneMontgomeryMultiplyCheck(a, b, mag(n), len, inv.longValue(), product, klass); + } + + void testMontgomeryMultiplyChecks() { + int[] blah = random_val(40); + int[] small = random_val(39); + BigInteger mod = new BigInteger(40*32 , 2, rnd); + BigInteger r = BigInteger.ONE.shiftLeft(40*32); + BigInteger n_prime = mod.modInverse(r).negate(); + + // Length out of range: square + testOneMontgomeryMultiplyCheck(blah, blah, mod, 41, n_prime, IllegalArgumentException.class); + testOneMontgomeryMultiplyCheck(blah, blah, mod, 0, n_prime, IllegalArgumentException.class); + testOneMontgomeryMultiplyCheck(blah, blah, mod, -1, n_prime, IllegalArgumentException.class); + // As above, but for multiply + testOneMontgomeryMultiplyCheck(blah, blah.clone(), mod, 41, n_prime, IllegalArgumentException.class); + testOneMontgomeryMultiplyCheck(blah, blah.clone(), mod, 0, n_prime, IllegalArgumentException.class); + testOneMontgomeryMultiplyCheck(blah, blah.clone(), mod, 0, n_prime, IllegalArgumentException.class); + + // Length odd + testOneMontgomeryMultiplyCheck(small, small, mod, 39, n_prime, IllegalArgumentException.class); + testOneMontgomeryMultiplyCheck(small, small, mod, 0, n_prime, IllegalArgumentException.class); + testOneMontgomeryMultiplyCheck(small, small, mod, -1, n_prime, IllegalArgumentException.class); + // As above, but for multiply + testOneMontgomeryMultiplyCheck(small, small.clone(), mod, 39, n_prime, IllegalArgumentException.class); + testOneMontgomeryMultiplyCheck(small, small.clone(), mod, 0, n_prime, IllegalArgumentException.class); + testOneMontgomeryMultiplyCheck(small, small.clone(), mod, -1, n_prime, IllegalArgumentException.class); + + // array too small + testOneMontgomeryMultiplyCheck(blah, blah, mod, 40, n_prime, small, IllegalArgumentException.class); + testOneMontgomeryMultiplyCheck(blah, blah.clone(), mod, 40, n_prime, small, IllegalArgumentException.class); + testOneMontgomeryMultiplyCheck(small, blah, mod, 40, n_prime, blah, IllegalArgumentException.class); + testOneMontgomeryMultiplyCheck(blah, small, mod, 40, n_prime, blah, IllegalArgumentException.class); + testOneMontgomeryMultiplyCheck(blah, blah, mod, 40, n_prime, small, IllegalArgumentException.class); + testOneMontgomeryMultiplyCheck(small, small, mod, 40, n_prime, blah, IllegalArgumentException.class); + } + + public static void main(String args[]) { + try { + new MontgomeryMultiplyTest().testMontgomeryMultiplyChecks(); + new MontgomeryMultiplyTest().testResultValues(); + } catch (Throwable ex) { + throw new RuntimeException(ex); + } + } +}