1 /* 2 * Copyright (c) 2018, 2020, Oracle and/or its affiliates. All rights reserved. 3 * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. 4 * 5 * This code is free software; you can redistribute it and/or modify it 6 * under the terms of the GNU General Public License version 2 only, as 7 * published by the Free Software Foundation. 8 * 9 * This code is distributed in the hope that it will be useful, but WITHOUT 10 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or 11 * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License 12 * version 2 for more details (a copy is included in the LICENSE file that 13 * accompanied this code). 14 * 15 * You should have received a copy of the GNU General Public License version 16 * 2 along with this work; if not, write to the Free Software Foundation, 17 * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. 18 * 19 * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA 20 * or visit www.oracle.com if you need additional information or have any 21 * questions. 22 */ 23 24 /* 25 * @test 26 * @bug 8181594 8208648 27 * @summary Test proper operation of integer field arithmetic 28 * @modules java.base/sun.security.util java.base/sun.security.util.math java.base/sun.security.util.math.intpoly 29 * @build BigIntegerModuloP 30 * @run main TestIntegerModuloP sun.security.util.math.intpoly.IntegerPolynomial25519 32 0 31 * @run main TestIntegerModuloP sun.security.util.math.intpoly.IntegerPolynomial448 56 1 32 * @run main TestIntegerModuloP sun.security.util.math.intpoly.IntegerPolynomial1305 16 2 33 * @run main TestIntegerModuloP sun.security.util.math.intpoly.IntegerPolynomialP256 32 5 34 * @run main TestIntegerModuloP sun.security.util.math.intpoly.IntegerPolynomialP384 48 6 35 * @run main TestIntegerModuloP sun.security.util.math.intpoly.IntegerPolynomialP521 66 7 36 * @run main TestIntegerModuloP sun.security.util.math.intpoly.P256OrderField 32 8 37 * @run main TestIntegerModuloP sun.security.util.math.intpoly.P384OrderField 48 9 38 * @run main TestIntegerModuloP sun.security.util.math.intpoly.P521OrderField 66 10 39 * @run main TestIntegerModuloP sun.security.util.math.intpoly.Curve25519OrderField 32 11 40 * @run main TestIntegerModuloP sun.security.util.math.intpoly.Curve448OrderField 56 12 41 */ 42 43 import sun.security.util.math.*; 44 import sun.security.util.math.intpoly.*; 45 import java.util.function.*; 46 47 import java.util.*; 48 import java.math.*; 49 import java.nio.*; 50 51 public class TestIntegerModuloP { 52 53 static BigInteger TWO = BigInteger.valueOf(2); 54 55 // The test has a list of functions, and it selects randomly from that list 56 57 // The function types 58 interface ElemFunction extends BiFunction 59 <MutableIntegerModuloP, IntegerModuloP, IntegerModuloP> { } 60 interface ElemArrayFunction extends BiFunction 61 <MutableIntegerModuloP, IntegerModuloP, byte[]> { } 62 interface TriConsumer <T, U, V> { 63 void accept(T t, U u, V v); 64 } 65 interface ElemSetFunction extends TriConsumer 66 <MutableIntegerModuloP, IntegerModuloP, byte[]> { } 67 68 // The lists of functions. Multiple lists are needed because the test 69 // respects the limitations of the arithmetic implementations. 70 static final List<ElemFunction> ADD_FUNCTIONS = new ArrayList<>(); 71 static final List<ElemFunction> MULT_FUNCTIONS = new ArrayList<>(); 72 static final List<ElemArrayFunction> ARRAY_FUNCTIONS = new ArrayList<>(); 73 static final List<ElemSetFunction> SET_FUNCTIONS = new ArrayList<>(); 74 75 static void setUpFunctions(IntegerFieldModuloP field, int length) { 76 77 ADD_FUNCTIONS.clear(); 78 MULT_FUNCTIONS.clear(); 79 SET_FUNCTIONS.clear(); 80 ARRAY_FUNCTIONS.clear(); 81 82 byte highByte = (byte) 83 (field.getSize().bitLength() > length * 8 ? 1 : 0); 84 85 // add functions are (im)mutable add/subtract 86 ADD_FUNCTIONS.add(IntegerModuloP::add); 87 ADD_FUNCTIONS.add(IntegerModuloP::subtract); 88 ADD_FUNCTIONS.add(MutableIntegerModuloP::setSum); 89 ADD_FUNCTIONS.add(MutableIntegerModuloP::setDifference); 90 // also include functions that return the first/second argument 91 ADD_FUNCTIONS.add((a, b) -> a); 92 ADD_FUNCTIONS.add((a, b) -> b); 93 94 // mult functions are (im)mutable multiply and square 95 MULT_FUNCTIONS.add(IntegerModuloP::multiply); 96 MULT_FUNCTIONS.add((a, b) -> a.square()); 97 MULT_FUNCTIONS.add((a, b) -> b.square()); 98 MULT_FUNCTIONS.add(MutableIntegerModuloP::setProduct); 99 MULT_FUNCTIONS.add((a, b) -> a.setSquare()); 100 // also test multiplication by a small value 101 MULT_FUNCTIONS.add((a, b) -> a.setProduct(b.getField().getSmallValue( 102 b.asBigInteger().mod(BigInteger.valueOf(262144)).intValue()))); 103 104 // set functions are setValue with various argument types 105 SET_FUNCTIONS.add((a, b, c) -> a.setValue(b)); 106 SET_FUNCTIONS.add((a, b, c) -> 107 a.setValue(c, 0, c.length, (byte) 0)); 108 SET_FUNCTIONS.add((a, b, c) -> 109 a.setValue(c, 0, c.length / 2, (byte) 0)); 110 SET_FUNCTIONS.add((a, b, c) -> 111 a.setValue(ByteBuffer.wrap(c, 0, c.length / 2).order(ByteOrder.LITTLE_ENDIAN), 112 c.length / 2, highByte)); 113 114 // array functions return the (possibly modified) value as byte array 115 ARRAY_FUNCTIONS.add((a, b ) -> a.asByteArray(length)); 116 ARRAY_FUNCTIONS.add((a, b) -> a.addModPowerTwo(b, length)); 117 } 118 119 public static void main(String[] args) { 120 121 String className = args[0]; 122 final int length = Integer.parseInt(args[1]); 123 int seed = Integer.parseInt(args[2]); 124 125 Class<IntegerFieldModuloP> fieldBaseClass = IntegerFieldModuloP.class; 126 try { 127 Class<? extends IntegerFieldModuloP> clazz = 128 Class.forName(className).asSubclass(fieldBaseClass); 129 IntegerFieldModuloP field = 130 clazz.getDeclaredConstructor().newInstance(); 131 132 setUpFunctions(field, length); 133 134 runFieldTest(field, length, seed); 135 } catch (Exception ex) { 136 throw new RuntimeException(ex); 137 } 138 System.out.println("All tests passed"); 139 } 140 141 static void assertEqual(IntegerModuloP e1, IntegerModuloP e2) { 142 143 if (!e1.asBigInteger().equals(e2.asBigInteger())) { 144 throw new RuntimeException("values not equal: " 145 + e1.asBigInteger() + " != " + e2.asBigInteger()); 146 } 147 } 148 149 // A class that holds pairs of actual/expected values, and allows 150 // computation on these pairs. 151 static class TestPair<T extends IntegerModuloP> { 152 private final T test; 153 private final T baseline; 154 155 public TestPair(T test, T baseline) { 156 this.test = test; 157 this.baseline = baseline; 158 } 159 160 public T getTest() { 161 return test; 162 } 163 public T getBaseline() { 164 return baseline; 165 } 166 167 private void assertEqual() { 168 TestIntegerModuloP.assertEqual(test, baseline); 169 } 170 171 public TestPair<MutableIntegerModuloP> mutable() { 172 return new TestPair<>(test.mutable(), baseline.mutable()); 173 } 174 175 public 176 <R extends IntegerModuloP, X extends IntegerModuloP> 177 TestPair<X> apply(BiFunction<T, R, X> func, TestPair<R> right) { 178 X testResult = func.apply(test, right.test); 179 X baselineResult = func.apply(baseline, right.baseline); 180 return new TestPair(testResult, baselineResult); 181 } 182 183 public 184 <U extends IntegerModuloP, V> 185 void apply(TriConsumer<T, U, V> func, TestPair<U> right, V argV) { 186 func.accept(test, right.test, argV); 187 func.accept(baseline, right.baseline, argV); 188 } 189 190 public 191 <R extends IntegerModuloP> 192 void applyAndCheckArray(BiFunction<T, R, byte[]> func, 193 TestPair<R> right) { 194 byte[] testResult = func.apply(test, right.test); 195 byte[] baselineResult = func.apply(baseline, right.baseline); 196 if (!Arrays.equals(testResult, baselineResult)) { 197 throw new RuntimeException("Array values do not match: " 198 + byteArrayToHexString(testResult) + " != " 199 + byteArrayToHexString(baselineResult)); 200 } 201 } 202 203 } 204 205 static String byteArrayToHexString(byte[] arr) { 206 StringBuilder result = new StringBuilder(); 207 for (int i = 0; i < arr.length; ++i) { 208 byte curVal = arr[i]; 209 result.append(Character.forDigit(curVal >> 4 & 0xF, 16)); 210 result.append(Character.forDigit(curVal & 0xF, 16)); 211 } 212 return result.toString(); 213 } 214 215 static TestPair<IntegerModuloP> 216 applyAndCheck(ElemFunction func, TestPair<MutableIntegerModuloP> left, 217 TestPair<IntegerModuloP> right) { 218 219 TestPair<IntegerModuloP> result = left.apply(func, right); 220 result.assertEqual(); 221 left.assertEqual(); 222 right.assertEqual(); 223 224 return result; 225 } 226 227 static void 228 setAndCheck(ElemSetFunction func, TestPair<MutableIntegerModuloP> left, 229 TestPair<IntegerModuloP> right, byte[] argV) { 230 231 left.apply(func, right, argV); 232 left.assertEqual(); 233 right.assertEqual(); 234 } 235 236 static TestPair<MutableIntegerModuloP> 237 applyAndCheckMutable(ElemFunction func, 238 TestPair<MutableIntegerModuloP> left, 239 TestPair<IntegerModuloP> right) { 240 241 TestPair<IntegerModuloP> result = applyAndCheck(func, left, right); 242 243 TestPair<MutableIntegerModuloP> mutableResult = result.mutable(); 244 mutableResult.assertEqual(); 245 result.assertEqual(); 246 left.assertEqual(); 247 right.assertEqual(); 248 249 return mutableResult; 250 } 251 252 static void 253 cswapAndCheck(int swap, TestPair<MutableIntegerModuloP> left, 254 TestPair<MutableIntegerModuloP> right) { 255 256 left.getTest().conditionalSwapWith(right.getTest(), swap); 257 left.getBaseline().conditionalSwapWith(right.getBaseline(), swap); 258 259 left.assertEqual(); 260 right.assertEqual(); 261 262 } 263 264 // Request arithmetic that should overflow, and ensure that overflow is 265 // detected. 266 static void runOverflowTest(TestPair<IntegerModuloP> elem) { 267 268 TestPair<MutableIntegerModuloP> mutableElem = elem.mutable(); 269 270 try { 271 for (int i = 0; i < 1000; i++) { 272 applyAndCheck(MutableIntegerModuloP::setSum, mutableElem, elem); 273 } 274 applyAndCheck(MutableIntegerModuloP::setProduct, mutableElem, elem); 275 } catch (ArithmeticException ex) { 276 // this is expected 277 } 278 279 mutableElem = elem.mutable(); 280 try { 281 for (int i = 0; i < 1000; i++) { 282 elem = applyAndCheck(IntegerModuloP::add, 283 mutableElem, elem); 284 } 285 applyAndCheck(IntegerModuloP::multiply, mutableElem, elem); 286 } catch (ArithmeticException ex) { 287 // this is expected 288 } 289 } 290 291 // Run a large number of random operations and ensure that 292 // results are correct 293 static void runOperationsTest(Random random, int length, 294 TestPair<IntegerModuloP> elem, 295 TestPair<IntegerModuloP> right) { 296 297 TestPair<MutableIntegerModuloP> left = elem.mutable(); 298 299 for (int i = 0; i < 10000; i++) { 300 301 ElemFunction addFunc1 = 302 ADD_FUNCTIONS.get(random.nextInt(ADD_FUNCTIONS.size())); 303 TestPair<MutableIntegerModuloP> result1 = 304 applyAndCheckMutable(addFunc1, left, right); 305 306 // left could have been modified, so turn it back into a summand 307 applyAndCheckMutable((a, b) -> a.setSquare(), left, right); 308 309 ElemFunction addFunc2 = 310 ADD_FUNCTIONS.get(random.nextInt(ADD_FUNCTIONS.size())); 311 TestPair<IntegerModuloP> result2 = 312 applyAndCheck(addFunc2, left, right); 313 314 if (elem.test.getField() instanceof IntegerPolynomial) { 315 IntegerPolynomial field = 316 (IntegerPolynomial) elem.test.getField(); 317 int numAdds = field.getMaxAdds(); 318 for (int j = 1; j < numAdds; j++) { 319 ElemFunction addFunc3 = ADD_FUNCTIONS. 320 get(random.nextInt(ADD_FUNCTIONS.size())); 321 result2 = applyAndCheck(addFunc3, left, right); 322 } 323 } 324 325 ElemFunction multFunc2 = 326 MULT_FUNCTIONS.get(random.nextInt(MULT_FUNCTIONS.size())); 327 TestPair<MutableIntegerModuloP> multResult = 328 applyAndCheckMutable(multFunc2, result1, result2); 329 330 int swap = random.nextInt(2); 331 cswapAndCheck(swap, left, multResult); 332 333 ElemSetFunction setFunc = 334 SET_FUNCTIONS.get(random.nextInt(SET_FUNCTIONS.size())); 335 byte[] valueArr = new byte[2 * length]; 336 random.nextBytes(valueArr); 337 setAndCheck(setFunc, result1, result2, valueArr); 338 339 // left could have been modified, so to turn it back into a summand 340 applyAndCheckMutable((a, b) -> a.setSquare(), left, right); 341 342 ElemArrayFunction arrayFunc = 343 ARRAY_FUNCTIONS.get(random.nextInt(ARRAY_FUNCTIONS.size())); 344 left.applyAndCheckArray(arrayFunc, right); 345 } 346 } 347 348 // Run all the tests for a given field 349 static void runFieldTest(IntegerFieldModuloP testField, 350 int length, int seed) { 351 System.out.println("Testing: " + testField.getClass().getSimpleName()); 352 353 Random random = new Random(seed); 354 355 IntegerFieldModuloP baselineField = 356 new BigIntegerModuloP(testField.getSize()); 357 358 int numBits = testField.getSize().bitLength(); 359 BigInteger r = 360 new BigInteger(numBits, random).mod(testField.getSize()); 361 TestPair<IntegerModuloP> rand = 362 new TestPair(testField.getElement(r), baselineField.getElement(r)); 363 364 runOverflowTest(rand); 365 366 // check combinations of operations for different kinds of elements 367 List<TestPair<IntegerModuloP>> testElements = new ArrayList<>(); 368 testElements.add(rand); 369 testElements.add(new TestPair(testField.get0(), baselineField.get0())); 370 testElements.add(new TestPair(testField.get1(), baselineField.get1())); 371 byte[] testArr = {121, 37, -100, -5, 76, 33}; 372 testElements.add(new TestPair(testField.getElement(testArr), 373 baselineField.getElement(testArr))); 374 375 testArr = new byte[length]; 376 random.nextBytes(testArr); 377 testElements.add(new TestPair(testField.getElement(testArr), 378 baselineField.getElement(testArr))); 379 380 random.nextBytes(testArr); 381 byte highByte = (byte) (numBits > length * 8 ? 1 : 0); 382 testElements.add( 383 new TestPair( 384 testField.getElement(testArr, 0, testArr.length, highByte), 385 baselineField.getElement(testArr, 0, testArr.length, highByte) 386 ) 387 ); 388 389 for (int i = 0; i < testElements.size(); i++) { 390 for (int j = 0; j < testElements.size(); j++) { 391 runOperationsTest(random, length, testElements.get(i), 392 testElements.get(j)); 393 } 394 } 395 } 396 } 397