1 /* 2 * Copyright (c) 2018, Oracle and/or its affiliates. All rights reserved. 3 * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. 4 * 5 * This code is free software; you can redistribute it and/or modify it 6 * under the terms of the GNU General Public License version 2 only, as 7 * published by the Free Software Foundation. 8 * 9 * This code is distributed in the hope that it will be useful, but WITHOUT 10 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or 11 * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License 12 * version 2 for more details (a copy is included in the LICENSE file that 13 * accompanied this code). 14 * 15 * You should have received a copy of the GNU General Public License version 16 * 2 along with this work; if not, write to the Free Software Foundation, 17 * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. 18 * 19 * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA 20 * or visit www.oracle.com if you need additional information or have any 21 * questions. 22 */ 23 24 /* 25 * @test 26 * @bug 8181594 27 * @summary Test proper operation of integer field arithmetic 28 * @modules java.base/sun.security.util java.base/sun.security.util.math java.base/sun.security.util.math.intpoly 29 * @build BigIntegerModuloP 30 * @run main TestIntegerModuloP sun.security.util.math.intpoly.IntegerPolynomial25519 32 0 31 * @run main TestIntegerModuloP sun.security.util.math.intpoly.IntegerPolynomial448 56 1 32 * @run main TestIntegerModuloP sun.security.util.math.intpoly.IntegerPolynomial1305 16 2 33 */ 34 35 import sun.security.util.math.*; 36 import java.util.function.*; 37 38 import java.util.*; 39 import java.math.*; 40 import java.nio.*; 41 42 public class TestIntegerModuloP { 43 44 static BigInteger TWO = BigInteger.valueOf(2); 45 46 // The test has a list of functions, and it selects randomly from that list 47 48 // The function types 49 interface ElemFunction extends BiFunction 50 <MutableIntegerModuloP, IntegerModuloP, IntegerModuloP> { } 51 interface ElemArrayFunction extends BiFunction 52 <MutableIntegerModuloP, IntegerModuloP, byte[]> { } 53 interface TriConsumer <T, U, V> { 54 void accept(T t, U u, V v); 55 } 56 interface ElemSetFunction extends TriConsumer 57 <MutableIntegerModuloP, IntegerModuloP, byte[]> { } 58 59 // The lists of functions. Multiple lists are needed because the test 60 // respects the limitations of the arithmetic implementations. 61 static final List<ElemFunction> ADD_FUNCTIONS = new ArrayList<>(); 62 static final List<ElemFunction> MULT_FUNCTIONS = new ArrayList<>(); 63 static final List<ElemArrayFunction> ARRAY_FUNCTIONS = new ArrayList<>(); 64 static final List<ElemSetFunction> SET_FUNCTIONS = new ArrayList<>(); 65 66 static void setUpFunctions(IntegerFieldModuloP field, int length) { 67 68 ADD_FUNCTIONS.clear(); 69 MULT_FUNCTIONS.clear(); 70 SET_FUNCTIONS.clear(); 71 ARRAY_FUNCTIONS.clear(); 72 73 byte highByte = (byte) 74 (field.getSize().bitLength() > length * 8 ? 1 : 0); 75 76 // add functions are (im)mutable add/subtract 77 ADD_FUNCTIONS.add(IntegerModuloP::add); 78 ADD_FUNCTIONS.add(IntegerModuloP::subtract); 79 ADD_FUNCTIONS.add(MutableIntegerModuloP::setSum); 80 ADD_FUNCTIONS.add(MutableIntegerModuloP::setDifference); 81 // also include functions that return the first/second argument 82 ADD_FUNCTIONS.add((a, b) -> a); 83 ADD_FUNCTIONS.add((a, b) -> b); 84 85 // mult functions are (im)mutable multiply and square 86 MULT_FUNCTIONS.add(IntegerModuloP::multiply); 87 MULT_FUNCTIONS.add((a, b) -> a.square()); 88 MULT_FUNCTIONS.add((a, b) -> b.square()); 89 MULT_FUNCTIONS.add(MutableIntegerModuloP::setProduct); 90 MULT_FUNCTIONS.add((a, b) -> a.setSquare()); 91 // also test multiplication by a small value 92 MULT_FUNCTIONS.add((a, b) -> a.setProduct(b.getField().getSmallValue( 93 b.asBigInteger().mod(BigInteger.valueOf(262144)).intValue()))); 94 95 // set functions are setValue with various argument types 96 SET_FUNCTIONS.add((a, b, c) -> a.setValue(b)); 97 SET_FUNCTIONS.add((a, b, c) -> 98 a.setValue(c, 0, c.length, (byte) 0)); 99 SET_FUNCTIONS.add((a, b, c) -> 100 a.setValue(ByteBuffer.wrap(c, 0, c.length).order(ByteOrder.LITTLE_ENDIAN), 101 c.length, highByte)); 102 103 // array functions return the (possibly modified) value as byte array 104 ARRAY_FUNCTIONS.add((a, b ) -> a.asByteArray(length)); 105 ARRAY_FUNCTIONS.add((a, b) -> a.addModPowerTwo(b, length)); 106 } 107 108 public static void main(String[] args) { 109 110 String className = args[0]; 111 final int length = Integer.parseInt(args[1]); 112 int seed = Integer.parseInt(args[2]); 113 114 Class<IntegerFieldModuloP> fieldBaseClass = IntegerFieldModuloP.class; 115 try { 116 Class<? extends IntegerFieldModuloP> clazz = 117 Class.forName(className).asSubclass(fieldBaseClass); 118 IntegerFieldModuloP field = 119 clazz.getDeclaredConstructor().newInstance(); 120 121 setUpFunctions(field, length); 122 123 runFieldTest(field, length, seed); 124 } catch (Exception ex) { 125 throw new RuntimeException(ex); 126 } 127 128 System.out.println("All tests passed"); 129 } 130 131 132 static void assertEqual(IntegerModuloP e1, IntegerModuloP e2) { 133 134 if (!e1.asBigInteger().equals(e2.asBigInteger())) { 135 throw new RuntimeException("values not equal: " 136 + e1.asBigInteger() + " != " + e2.asBigInteger()); 137 } 138 } 139 140 // A class that holds pairs of actual/expected values, and allows 141 // computation on these pairs. 142 static class TestPair<T extends IntegerModuloP> { 143 private final T test; 144 private final T baseline; 145 146 public TestPair(T test, T baseline) { 147 this.test = test; 148 this.baseline = baseline; 149 } 150 151 public T getTest() { 152 return test; 153 } 154 public T getBaseline() { 155 return baseline; 156 } 157 158 private void assertEqual() { 159 TestIntegerModuloP.assertEqual(test, baseline); 160 } 161 162 public TestPair<MutableIntegerModuloP> mutable() { 163 return new TestPair<>(test.mutable(), baseline.mutable()); 164 } 165 166 public 167 <R extends IntegerModuloP, X extends IntegerModuloP> 168 TestPair<X> apply(BiFunction<T, R, X> func, TestPair<R> right) { 169 X testResult = func.apply(test, right.test); 170 X baselineResult = func.apply(baseline, right.baseline); 171 return new TestPair(testResult, baselineResult); 172 } 173 174 public 175 <U extends IntegerModuloP, V> 176 void apply(TriConsumer<T, U, V> func, TestPair<U> right, V argV) { 177 func.accept(test, right.test, argV); 178 func.accept(baseline, right.baseline, argV); 179 } 180 181 public 182 <R extends IntegerModuloP> 183 void applyAndCheckArray(BiFunction<T, R, byte[]> func, 184 TestPair<R> right) { 185 byte[] testResult = func.apply(test, right.test); 186 byte[] baselineResult = func.apply(baseline, right.baseline); 187 if (!Arrays.equals(testResult, baselineResult)) { 188 throw new RuntimeException("Array values do not match: " 189 + byteArrayToHexString(testResult) + " != " 190 + byteArrayToHexString(baselineResult)); 191 } 192 } 193 194 } 195 196 static String byteArrayToHexString(byte[] arr) { 197 StringBuilder result = new StringBuilder(); 198 for (int i = 0; i < arr.length; ++i) { 199 byte curVal = arr[i]; 200 result.append(Character.forDigit(curVal >> 4 & 0xF, 16)); 201 result.append(Character.forDigit(curVal & 0xF, 16)); 202 } 203 return result.toString(); 204 } 205 206 static TestPair<IntegerModuloP> 207 applyAndCheck(ElemFunction func, TestPair<MutableIntegerModuloP> left, 208 TestPair<IntegerModuloP> right) { 209 210 TestPair<IntegerModuloP> result = left.apply(func, right); 211 result.assertEqual(); 212 left.assertEqual(); 213 right.assertEqual(); 214 215 return result; 216 } 217 218 static void 219 setAndCheck(ElemSetFunction func, TestPair<MutableIntegerModuloP> left, 220 TestPair<IntegerModuloP> right, byte[] argV) { 221 222 left.apply(func, right, argV); 223 left.assertEqual(); 224 right.assertEqual(); 225 } 226 227 static TestPair<MutableIntegerModuloP> 228 applyAndCheckMutable(ElemFunction func, 229 TestPair<MutableIntegerModuloP> left, 230 TestPair<IntegerModuloP> right) { 231 232 TestPair<IntegerModuloP> result = applyAndCheck(func, left, right); 233 234 TestPair<MutableIntegerModuloP> mutableResult = result.mutable(); 235 mutableResult.assertEqual(); 236 result.assertEqual(); 237 left.assertEqual(); 238 right.assertEqual(); 239 240 return mutableResult; 241 } 242 243 static void 244 cswapAndCheck(int swap, TestPair<MutableIntegerModuloP> left, 245 TestPair<MutableIntegerModuloP> right) { 246 247 left.getTest().conditionalSwapWith(right.getTest(), swap); 248 left.getBaseline().conditionalSwapWith(right.getBaseline(), swap); 249 250 left.assertEqual(); 251 right.assertEqual(); 252 253 } 254 255 // Request arithmetic that should overflow, and ensure that overflow is 256 // detected. 257 static void runOverflowTest(TestPair<IntegerModuloP> elem) { 258 259 TestPair<MutableIntegerModuloP> mutableElem = elem.mutable(); 260 261 try { 262 for (int i = 0; i < 1000; i++) { 263 applyAndCheck(MutableIntegerModuloP::setSum, mutableElem, elem); 264 } 265 applyAndCheck(MutableIntegerModuloP::setProduct, mutableElem, elem); 266 } catch (ArithmeticException ex) { 267 // this is expected 268 } 269 270 mutableElem = elem.mutable(); 271 try { 272 for (int i = 0; i < 1000; i++) { 273 elem = applyAndCheck(IntegerModuloP::add, 274 mutableElem, elem); 275 } 276 applyAndCheck(IntegerModuloP::multiply, mutableElem, elem); 277 } catch (ArithmeticException ex) { 278 // this is expected 279 } 280 } 281 282 // Run a large number of random operations and ensure that 283 // results are correct 284 static void runOperationsTest(Random random, int length, 285 TestPair<IntegerModuloP> elem, 286 TestPair<IntegerModuloP> right) { 287 288 TestPair<MutableIntegerModuloP> left = elem.mutable(); 289 290 for (int i = 0; i < 10000; i++) { 291 292 ElemFunction addFunc1 = 293 ADD_FUNCTIONS.get(random.nextInt(ADD_FUNCTIONS.size())); 294 TestPair<MutableIntegerModuloP> result1 = 295 applyAndCheckMutable(addFunc1, left, right); 296 297 // left could have been modified, so turn it back into a summand 298 applyAndCheckMutable((a, b) -> a.setSquare(), left, right); 299 300 ElemFunction addFunc2 = 301 ADD_FUNCTIONS.get(random.nextInt(ADD_FUNCTIONS.size())); 302 TestPair<IntegerModuloP> result2 = 303 applyAndCheck(addFunc2, left, right); 304 305 ElemFunction multFunc2 = 306 MULT_FUNCTIONS.get(random.nextInt(MULT_FUNCTIONS.size())); 307 TestPair<MutableIntegerModuloP> multResult = 308 applyAndCheckMutable(multFunc2, result1, result2); 309 310 int swap = random.nextInt(2); 311 cswapAndCheck(swap, left, multResult); 312 313 ElemSetFunction setFunc = 314 SET_FUNCTIONS.get(random.nextInt(SET_FUNCTIONS.size())); 315 byte[] valueArr = new byte[length]; 316 random.nextBytes(valueArr); 317 setAndCheck(setFunc, result1, result2, valueArr); 318 319 // left could have been modified, so to turn it back into a summand 320 applyAndCheckMutable((a, b) -> a.setSquare(), left, right); 321 322 ElemArrayFunction arrayFunc = 323 ARRAY_FUNCTIONS.get(random.nextInt(ARRAY_FUNCTIONS.size())); 324 left.applyAndCheckArray(arrayFunc, right); 325 } 326 } 327 328 // Run all the tests for a given field 329 static void runFieldTest(IntegerFieldModuloP testField, 330 int length, int seed) { 331 System.out.println("Testing: " + testField.getClass().getSimpleName()); 332 333 Random random = new Random(seed); 334 335 IntegerFieldModuloP baselineField = 336 new BigIntegerModuloP(testField.getSize()); 337 338 int numBits = testField.getSize().bitLength(); 339 BigInteger r = 340 new BigInteger(numBits, random).mod(testField.getSize()); 341 TestPair<IntegerModuloP> rand = 342 new TestPair(testField.getElement(r), baselineField.getElement(r)); 343 344 runOverflowTest(rand); 345 346 // check combinations of operations for different kinds of elements 347 List<TestPair<IntegerModuloP>> testElements = new ArrayList<>(); 348 testElements.add(rand); 349 testElements.add(new TestPair(testField.get0(), baselineField.get0())); 350 testElements.add(new TestPair(testField.get1(), baselineField.get1())); 351 byte[] testArr = {121, 37, -100, -5, 76, 33}; 352 testElements.add(new TestPair(testField.getElement(testArr), 353 baselineField.getElement(testArr))); 354 355 testArr = new byte[length]; 356 random.nextBytes(testArr); 357 testElements.add(new TestPair(testField.getElement(testArr), 358 baselineField.getElement(testArr))); 359 360 random.nextBytes(testArr); 361 byte highByte = (byte) (numBits > length * 8 ? 1 : 0); 362 testElements.add( 363 new TestPair( 364 testField.getElement(testArr, 0, testArr.length, highByte), 365 baselineField.getElement(testArr, 0, testArr.length, highByte) 366 ) 367 ); 368 369 for (int i = 0; i < testElements.size(); i++) { 370 for (int j = 0; j < testElements.size(); j++) { 371 runOperationsTest(random, length, testElements.get(i), 372 testElements.get(j)); 373 } 374 } 375 } 376 } 377