1 /*
  2  * Copyright (c) 2018, 2020, Oracle and/or its affiliates. All rights reserved.
  3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
  4  *
  5  * This code is free software; you can redistribute it and/or modify it
  6  * under the terms of the GNU General Public License version 2 only, as
  7  * published by the Free Software Foundation.
  8  *
  9  * This code is distributed in the hope that it will be useful, but WITHOUT
 10  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
 11  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
 12  * version 2 for more details (a copy is included in the LICENSE file that
 13  * accompanied this code).
 14  *
 15  * You should have received a copy of the GNU General Public License version
 16  * 2 along with this work; if not, write to the Free Software Foundation,
 17  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
 18  *
 19  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
 20  * or visit www.oracle.com if you need additional information or have any
 21  * questions.
 22  */
 23 
 24 /*
 25  * @test
 26  * @bug 8181594 8208648
 27  * @summary Test proper operation of integer field arithmetic
 28  * @modules java.base/sun.security.util java.base/sun.security.util.math java.base/sun.security.util.math.intpoly
 29  * @build BigIntegerModuloP
 30  * @run main TestIntegerModuloP sun.security.util.math.intpoly.IntegerPolynomial25519 32 0
 31  * @run main TestIntegerModuloP sun.security.util.math.intpoly.IntegerPolynomial448 56 1
 32  * @run main TestIntegerModuloP sun.security.util.math.intpoly.IntegerPolynomial1305 16 2
 33  * @run main TestIntegerModuloP sun.security.util.math.intpoly.IntegerPolynomialP256 32 5
 34  * @run main TestIntegerModuloP sun.security.util.math.intpoly.IntegerPolynomialP384 48 6
 35  * @run main TestIntegerModuloP sun.security.util.math.intpoly.IntegerPolynomialP521 66 7
 36  * @run main TestIntegerModuloP sun.security.util.math.intpoly.P256OrderField 32 8
 37  * @run main TestIntegerModuloP sun.security.util.math.intpoly.P384OrderField 48 9
 38  * @run main TestIntegerModuloP sun.security.util.math.intpoly.P521OrderField 66 10
 39  * @run main TestIntegerModuloP sun.security.util.math.intpoly.Curve25519OrderField 32 11
 40  * @run main TestIntegerModuloP sun.security.util.math.intpoly.Curve448OrderField 56 12
 41  */
 42 
 43 import sun.security.util.math.*;
 44 import sun.security.util.math.intpoly.*;
 45 import java.util.function.*;
 46 
 47 import java.util.*;
 48 import java.math.*;
 49 import java.nio.*;
 50 
 51 public class TestIntegerModuloP {
 52 
 53     static BigInteger TWO = BigInteger.valueOf(2);
 54 
 55     // The test has a list of functions, and it selects randomly from that list
 56 
 57     // The function types
 58     interface ElemFunction extends BiFunction
 59         <MutableIntegerModuloP, IntegerModuloP, IntegerModuloP> { }
 60     interface ElemArrayFunction extends BiFunction
 61         <MutableIntegerModuloP, IntegerModuloP, byte[]> { }
 62     interface TriConsumer <T, U, V> {
 63         void accept(T t, U u, V v);
 64     }
 65     interface ElemSetFunction extends TriConsumer
 66         <MutableIntegerModuloP, IntegerModuloP, byte[]> { }
 67 
 68     // The lists of functions. Multiple lists are needed because the test
 69     // respects the limitations of the arithmetic implementations.
 70     static final List<ElemFunction> ADD_FUNCTIONS = new ArrayList<>();
 71     static final List<ElemFunction> MULT_FUNCTIONS = new ArrayList<>();
 72     static final List<ElemArrayFunction> ARRAY_FUNCTIONS = new ArrayList<>();
 73     static final List<ElemSetFunction> SET_FUNCTIONS = new ArrayList<>();
 74 
 75     static void setUpFunctions(IntegerFieldModuloP field, int length) {
 76 
 77         ADD_FUNCTIONS.clear();
 78         MULT_FUNCTIONS.clear();
 79         SET_FUNCTIONS.clear();
 80         ARRAY_FUNCTIONS.clear();
 81 
 82         byte highByte = (byte)
 83             (field.getSize().bitLength() > length * 8 ? 1 : 0);
 84 
 85         // add functions are (im)mutable add/subtract
 86         ADD_FUNCTIONS.add(IntegerModuloP::add);
 87         ADD_FUNCTIONS.add(IntegerModuloP::subtract);
 88         ADD_FUNCTIONS.add(MutableIntegerModuloP::setSum);
 89         ADD_FUNCTIONS.add(MutableIntegerModuloP::setDifference);
 90         // also include functions that return the first/second argument
 91         ADD_FUNCTIONS.add((a, b) -> a);
 92         ADD_FUNCTIONS.add((a, b) -> b);
 93 
 94         // mult functions are (im)mutable multiply and square
 95         MULT_FUNCTIONS.add(IntegerModuloP::multiply);
 96         MULT_FUNCTIONS.add((a, b) -> a.square());
 97         MULT_FUNCTIONS.add((a, b) -> b.square());
 98         MULT_FUNCTIONS.add(MutableIntegerModuloP::setProduct);
 99         MULT_FUNCTIONS.add((a, b) -> a.setSquare());
100         // also test multiplication by a small value
101         MULT_FUNCTIONS.add((a, b) -> a.setProduct(b.getField().getSmallValue(
102             b.asBigInteger().mod(BigInteger.valueOf(262144)).intValue())));
103 
104         // set functions are setValue with various argument types
105         SET_FUNCTIONS.add((a, b, c) -> a.setValue(b));
106         SET_FUNCTIONS.add((a, b, c) ->
107             a.setValue(c, 0, c.length, (byte) 0));
108         SET_FUNCTIONS.add((a, b, c) ->
109             a.setValue(c, 0, c.length / 2, (byte) 0));
110         SET_FUNCTIONS.add((a, b, c) ->
111             a.setValue(ByteBuffer.wrap(c, 0, c.length / 2).order(ByteOrder.LITTLE_ENDIAN),
112             c.length / 2, highByte));
113 
114         // array functions return the (possibly modified) value as byte array
115         ARRAY_FUNCTIONS.add((a, b ) -> a.asByteArray(length));
116         ARRAY_FUNCTIONS.add((a, b) -> a.addModPowerTwo(b, length));
117     }
118 
119     public static void main(String[] args) {
120 
121         String className = args[0];
122         final int length = Integer.parseInt(args[1]);
123         int seed = Integer.parseInt(args[2]);
124 
125         Class<IntegerFieldModuloP> fieldBaseClass = IntegerFieldModuloP.class;
126         try {
127             Class<? extends IntegerFieldModuloP> clazz =
128                 Class.forName(className).asSubclass(fieldBaseClass);
129             IntegerFieldModuloP field =
130                 clazz.getDeclaredConstructor().newInstance();
131 
132             setUpFunctions(field, length);
133 
134             runFieldTest(field, length, seed);
135         } catch (Exception ex) {
136             throw new RuntimeException(ex);
137         }
138         System.out.println("All tests passed");
139     }
140 
141     static void assertEqual(IntegerModuloP e1, IntegerModuloP e2) {
142 
143         if (!e1.asBigInteger().equals(e2.asBigInteger())) {
144             throw new RuntimeException("values not equal: "
145                 + e1.asBigInteger() + " != " + e2.asBigInteger());
146         }
147     }
148 
149     // A class that holds pairs of actual/expected values, and allows
150     // computation on these pairs.
151     static class TestPair<T extends IntegerModuloP> {
152         private final T test;
153         private final T baseline;
154 
155         public TestPair(T test, T baseline) {
156             this.test = test;
157             this.baseline = baseline;
158         }
159 
160         public T getTest() {
161             return test;
162         }
163         public T getBaseline() {
164             return baseline;
165         }
166 
167         private void assertEqual() {
168             TestIntegerModuloP.assertEqual(test, baseline);
169         }
170 
171         public TestPair<MutableIntegerModuloP> mutable() {
172             return new TestPair<>(test.mutable(), baseline.mutable());
173         }
174 
175         public
176         <R extends IntegerModuloP, X extends IntegerModuloP>
177         TestPair<X> apply(BiFunction<T, R, X> func, TestPair<R> right) {
178             X testResult = func.apply(test, right.test);
179             X baselineResult = func.apply(baseline, right.baseline);
180             return new TestPair(testResult, baselineResult);
181         }
182 
183         public
184         <U extends IntegerModuloP, V>
185         void apply(TriConsumer<T, U, V> func, TestPair<U> right, V argV) {
186             func.accept(test, right.test, argV);
187             func.accept(baseline, right.baseline, argV);
188         }
189 
190         public
191         <R extends IntegerModuloP>
192         void applyAndCheckArray(BiFunction<T, R, byte[]> func,
193                                 TestPair<R> right) {
194             byte[] testResult = func.apply(test, right.test);
195             byte[] baselineResult = func.apply(baseline, right.baseline);
196             if (!Arrays.equals(testResult, baselineResult)) {
197                 throw new RuntimeException("Array values do not match: "
198                     + byteArrayToHexString(testResult) + " != "
199                     + byteArrayToHexString(baselineResult));
200             }
201         }
202 
203     }
204 
205     static String byteArrayToHexString(byte[] arr) {
206         StringBuilder result = new StringBuilder();
207         for (int i = 0; i < arr.length; ++i) {
208             byte curVal = arr[i];
209             result.append(Character.forDigit(curVal >> 4 & 0xF, 16));
210             result.append(Character.forDigit(curVal & 0xF, 16));
211         }
212         return result.toString();
213     }
214 
215     static TestPair<IntegerModuloP>
216     applyAndCheck(ElemFunction func, TestPair<MutableIntegerModuloP> left,
217                   TestPair<IntegerModuloP> right) {
218 
219         TestPair<IntegerModuloP> result = left.apply(func, right);
220         result.assertEqual();
221         left.assertEqual();
222         right.assertEqual();
223 
224         return result;
225     }
226 
227     static void
228     setAndCheck(ElemSetFunction func, TestPair<MutableIntegerModuloP> left,
229                 TestPair<IntegerModuloP> right, byte[] argV) {
230 
231         left.apply(func, right, argV);
232         left.assertEqual();
233         right.assertEqual();
234     }
235 
236     static TestPair<MutableIntegerModuloP>
237     applyAndCheckMutable(ElemFunction func,
238                          TestPair<MutableIntegerModuloP> left,
239                          TestPair<IntegerModuloP> right) {
240 
241         TestPair<IntegerModuloP> result = applyAndCheck(func, left, right);
242 
243         TestPair<MutableIntegerModuloP> mutableResult = result.mutable();
244         mutableResult.assertEqual();
245         result.assertEqual();
246         left.assertEqual();
247         right.assertEqual();
248 
249         return mutableResult;
250     }
251 
252     static void
253     cswapAndCheck(int swap, TestPair<MutableIntegerModuloP> left,
254                   TestPair<MutableIntegerModuloP> right) {
255 
256         left.getTest().conditionalSwapWith(right.getTest(), swap);
257         left.getBaseline().conditionalSwapWith(right.getBaseline(), swap);
258 
259         left.assertEqual();
260         right.assertEqual();
261 
262     }
263 
264     // Request arithmetic that should overflow, and ensure that overflow is
265     // detected.
266     static void runOverflowTest(TestPair<IntegerModuloP> elem) {
267 
268         TestPair<MutableIntegerModuloP> mutableElem = elem.mutable();
269 
270         try {
271             for (int i = 0; i < 1000; i++) {
272                 applyAndCheck(MutableIntegerModuloP::setSum, mutableElem, elem);
273             }
274             applyAndCheck(MutableIntegerModuloP::setProduct, mutableElem, elem);
275         } catch (ArithmeticException ex) {
276             // this is expected
277         }
278 
279         mutableElem = elem.mutable();
280         try {
281             for (int i = 0; i < 1000; i++) {
282                 elem = applyAndCheck(IntegerModuloP::add,
283                     mutableElem, elem);
284             }
285             applyAndCheck(IntegerModuloP::multiply, mutableElem, elem);
286         } catch (ArithmeticException ex) {
287             // this is expected
288         }
289     }
290 
291     // Run a large number of random operations and ensure that
292     // results are correct
293     static void runOperationsTest(Random random, int length,
294                                   TestPair<IntegerModuloP> elem,
295                                   TestPair<IntegerModuloP> right) {
296 
297         TestPair<MutableIntegerModuloP> left = elem.mutable();
298 
299         for (int i = 0; i < 10000; i++) {
300 
301             ElemFunction addFunc1 =
302                 ADD_FUNCTIONS.get(random.nextInt(ADD_FUNCTIONS.size()));
303             TestPair<MutableIntegerModuloP> result1 =
304                 applyAndCheckMutable(addFunc1, left, right);
305 
306             // left could have been modified, so turn it back into a summand
307             applyAndCheckMutable((a, b) -> a.setSquare(), left, right);
308 
309             ElemFunction addFunc2 =
310                 ADD_FUNCTIONS.get(random.nextInt(ADD_FUNCTIONS.size()));
311             TestPair<IntegerModuloP> result2 =
312                 applyAndCheck(addFunc2, left, right);
313 
314             if (elem.test.getField() instanceof IntegerPolynomial) {
315                 IntegerPolynomial field =
316                     (IntegerPolynomial) elem.test.getField();
317                 int numAdds = field.getMaxAdds();
318                 for (int j = 1; j < numAdds; j++) {
319                     ElemFunction addFunc3 = ADD_FUNCTIONS.
320                         get(random.nextInt(ADD_FUNCTIONS.size()));
321                     result2 = applyAndCheck(addFunc3, left, right);
322                 }
323             }
324 
325             ElemFunction multFunc2 =
326                 MULT_FUNCTIONS.get(random.nextInt(MULT_FUNCTIONS.size()));
327             TestPair<MutableIntegerModuloP> multResult =
328                 applyAndCheckMutable(multFunc2, result1, result2);
329 
330             int swap = random.nextInt(2);
331             cswapAndCheck(swap, left, multResult);
332 
333             ElemSetFunction setFunc =
334                 SET_FUNCTIONS.get(random.nextInt(SET_FUNCTIONS.size()));
335             byte[] valueArr = new byte[2 * length];
336             random.nextBytes(valueArr);
337             setAndCheck(setFunc, result1, result2, valueArr);
338 
339             // left could have been modified, so to turn it back into a summand
340             applyAndCheckMutable((a, b) -> a.setSquare(), left, right);
341 
342             ElemArrayFunction arrayFunc =
343                 ARRAY_FUNCTIONS.get(random.nextInt(ARRAY_FUNCTIONS.size()));
344             left.applyAndCheckArray(arrayFunc, right);
345         }
346     }
347 
348     // Run all the tests for a given field
349     static void runFieldTest(IntegerFieldModuloP testField,
350                              int length, int seed) {
351         System.out.println("Testing: " + testField.getClass().getSimpleName());
352 
353         Random random = new Random(seed);
354 
355         IntegerFieldModuloP baselineField =
356             new BigIntegerModuloP(testField.getSize());
357 
358         int numBits = testField.getSize().bitLength();
359         BigInteger r =
360             new BigInteger(numBits, random).mod(testField.getSize());
361         TestPair<IntegerModuloP> rand =
362             new TestPair(testField.getElement(r), baselineField.getElement(r));
363 
364         runOverflowTest(rand);
365 
366         // check combinations of operations for different kinds of elements
367         List<TestPair<IntegerModuloP>> testElements = new ArrayList<>();
368         testElements.add(rand);
369         testElements.add(new TestPair(testField.get0(), baselineField.get0()));
370         testElements.add(new TestPair(testField.get1(), baselineField.get1()));
371         byte[] testArr = {121, 37, -100, -5, 76, 33};
372         testElements.add(new TestPair(testField.getElement(testArr),
373             baselineField.getElement(testArr)));
374 
375         testArr = new byte[length];
376         random.nextBytes(testArr);
377         testElements.add(new TestPair(testField.getElement(testArr),
378             baselineField.getElement(testArr)));
379 
380         random.nextBytes(testArr);
381         byte highByte = (byte) (numBits > length * 8 ? 1 : 0);
382         testElements.add(
383             new TestPair(
384                 testField.getElement(testArr, 0, testArr.length, highByte),
385                 baselineField.getElement(testArr, 0, testArr.length, highByte)
386             )
387         );
388 
389         for (int i = 0; i < testElements.size(); i++) {
390             for (int j = 0; j < testElements.size(); j++) {
391                 runOperationsTest(random, length, testElements.get(i),
392                     testElements.get(j));
393             }
394         }
395     }
396 }
397