1 /*
   2  * Copyright (c) 2018, Oracle and/or its affiliates. All rights reserved.
   3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
   4  *
   5  * This code is free software; you can redistribute it and/or modify it
   6  * under the terms of the GNU General Public License version 2 only, as
   7  * published by the Free Software Foundation.
   8  *
   9  * This code is distributed in the hope that it will be useful, but WITHOUT
  10  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
  11  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
  12  * version 2 for more details (a copy is included in the LICENSE file that
  13  * accompanied this code).
  14  *
  15  * You should have received a copy of the GNU General Public License version
  16  * 2 along with this work; if not, write to the Free Software Foundation,
  17  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
  18  *
  19  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
  20  * or visit www.oracle.com if you need additional information or have any
  21  * questions.
  22  */
  23 
  24 /*
  25  * @test
  26  * @bug 8181594
  27  * @summary Test proper operation of integer field arithmetic
  28  * @modules java.base/sun.security.util java.base/sun.security.util.math java.base/sun.security.util.math.intpoly
  29  * @build BigIntegerModuloP
  30  * @run main TestIntegerModuloP sun.security.util.math.intpoly.IntegerPolynomial25519 32 0
  31  * @run main TestIntegerModuloP sun.security.util.math.intpoly.IntegerPolynomial448 56 1
  32  * @run main TestIntegerModuloP sun.security.util.math.intpoly.IntegerPolynomial1305 16 2
  33  */
  34 
  35 import sun.security.util.math.*;
  36 import java.util.function.*;
  37 
  38 import java.util.*;
  39 import java.math.*;
  40 import java.nio.*;
  41 
  42 public class TestIntegerModuloP {
  43 
  44     static BigInteger TWO = BigInteger.valueOf(2);
  45 
  46     // The test has a list of functions, and it selects randomly from that list
  47 
  48     // The function types
  49     interface ElemFunction extends BiFunction
  50         <MutableIntegerModuloP, IntegerModuloP, IntegerModuloP> { }
  51     interface ElemArrayFunction extends BiFunction
  52         <MutableIntegerModuloP, IntegerModuloP, byte[]> { }
  53     interface TriConsumer <T, U, V> {
  54         void accept(T t, U u, V v);
  55     }
  56     interface ElemSetFunction extends TriConsumer
  57         <MutableIntegerModuloP, IntegerModuloP, byte[]> { }
  58 
  59     // The lists of functions. Multiple lists are needed because the test
  60     // respects the limitations of the arithmetic implementations.
  61     static final List<ElemFunction> ADD_FUNCTIONS = new ArrayList<>();
  62     static final List<ElemFunction> MULT_FUNCTIONS = new ArrayList<>();
  63     static final List<ElemArrayFunction> ARRAY_FUNCTIONS = new ArrayList<>();
  64     static final List<ElemSetFunction> SET_FUNCTIONS = new ArrayList<>();
  65 
  66     static void setUpFunctions(IntegerFieldModuloP field, int length) {
  67 
  68         ADD_FUNCTIONS.clear();
  69         MULT_FUNCTIONS.clear();
  70         SET_FUNCTIONS.clear();
  71         ARRAY_FUNCTIONS.clear();
  72 
  73         byte highByte = (byte)
  74             (field.getSize().bitLength() > length * 8 ? 1 : 0);
  75 
  76         // add functions are (im)mutable add/subtract
  77         ADD_FUNCTIONS.add(IntegerModuloP::add);
  78         ADD_FUNCTIONS.add(IntegerModuloP::subtract);
  79         ADD_FUNCTIONS.add(MutableIntegerModuloP::setSum);
  80         ADD_FUNCTIONS.add(MutableIntegerModuloP::setDifference);
  81         // also include functions that return the first/second argument
  82         ADD_FUNCTIONS.add((a, b) -> a);
  83         ADD_FUNCTIONS.add((a, b) -> b);
  84 
  85         // mult functions are (im)mutable multiply and square
  86         MULT_FUNCTIONS.add(IntegerModuloP::multiply);
  87         MULT_FUNCTIONS.add((a, b) -> a.square());
  88         MULT_FUNCTIONS.add((a, b) -> b.square());
  89         MULT_FUNCTIONS.add(MutableIntegerModuloP::setProduct);
  90         MULT_FUNCTIONS.add((a, b) -> a.setSquare());
  91         // also test multiplication by a small value
  92         MULT_FUNCTIONS.add((a, b) -> a.setProduct(b.getField().getSmallValue(
  93             b.asBigInteger().mod(BigInteger.valueOf(262144)).intValue())));
  94 
  95         // set functions are setValue with various argument types
  96         SET_FUNCTIONS.add((a, b, c) -> a.setValue(b));
  97         SET_FUNCTIONS.add((a, b, c) ->
  98             a.setValue(c, 0, c.length, (byte) 0));
  99         SET_FUNCTIONS.add((a, b, c) ->
 100             a.setValue(ByteBuffer.wrap(c, 0, c.length).order(ByteOrder.LITTLE_ENDIAN),
 101             c.length, highByte));
 102 
 103         // array functions return the (possibly modified) value as byte array
 104         ARRAY_FUNCTIONS.add((a, b ) -> a.asByteArray(length));
 105         ARRAY_FUNCTIONS.add((a, b) -> a.addModPowerTwo(b, length));
 106     }
 107 
 108     public static void main(String[] args) {
 109 
 110         String className = args[0];
 111         final int length = Integer.parseInt(args[1]);
 112         int seed = Integer.parseInt(args[2]);
 113 
 114         Class<IntegerFieldModuloP> fieldBaseClass = IntegerFieldModuloP.class;
 115         try {
 116             Class<? extends IntegerFieldModuloP> clazz =
 117                 Class.forName(className).asSubclass(fieldBaseClass);
 118             IntegerFieldModuloP field =
 119                 clazz.getDeclaredConstructor().newInstance();
 120 
 121             setUpFunctions(field, length);
 122 
 123             runFieldTest(field, length, seed);
 124         } catch (Exception ex) {
 125             throw new RuntimeException(ex);
 126         }
 127 
 128         System.out.println("All tests passed");
 129     }
 130 
 131 
 132     static void assertEqual(IntegerModuloP e1, IntegerModuloP e2) {
 133 
 134         if (!e1.asBigInteger().equals(e2.asBigInteger())) {
 135             throw new RuntimeException("values not equal: "
 136                 + e1.asBigInteger() + " != " + e2.asBigInteger());
 137         }
 138     }
 139 
 140     // A class that holds pairs of actual/expected values, and allows
 141     // computation on these pairs.
 142     static class TestPair<T extends IntegerModuloP> {
 143         private final T test;
 144         private final T baseline;
 145 
 146         public TestPair(T test, T baseline) {
 147             this.test = test;
 148             this.baseline = baseline;
 149         }
 150 
 151         public T getTest() {
 152             return test;
 153         }
 154         public T getBaseline() {
 155             return baseline;
 156         }
 157 
 158         private void assertEqual() {
 159             TestIntegerModuloP.assertEqual(test, baseline);
 160         }
 161 
 162         public TestPair<MutableIntegerModuloP> mutable() {
 163             return new TestPair<>(test.mutable(), baseline.mutable());
 164         }
 165 
 166         public
 167         <R extends IntegerModuloP, X extends IntegerModuloP>
 168         TestPair<X> apply(BiFunction<T, R, X> func, TestPair<R> right) {
 169             X testResult = func.apply(test, right.test);
 170             X baselineResult = func.apply(baseline, right.baseline);
 171             return new TestPair(testResult, baselineResult);
 172         }
 173 
 174         public
 175         <U extends IntegerModuloP, V>
 176         void apply(TriConsumer<T, U, V> func, TestPair<U> right, V argV) {
 177             func.accept(test, right.test, argV);
 178             func.accept(baseline, right.baseline, argV);
 179         }
 180 
 181         public
 182         <R extends IntegerModuloP>
 183         void applyAndCheckArray(BiFunction<T, R, byte[]> func,
 184                                 TestPair<R> right) {
 185             byte[] testResult = func.apply(test, right.test);
 186             byte[] baselineResult = func.apply(baseline, right.baseline);
 187             if (!Arrays.equals(testResult, baselineResult)) {
 188                 throw new RuntimeException("Array values do not match: "
 189                     + byteArrayToHexString(testResult) + " != "
 190                     + byteArrayToHexString(baselineResult));
 191             }
 192         }
 193 
 194     }
 195 
 196     static String byteArrayToHexString(byte[] arr) {
 197         StringBuilder result = new StringBuilder();
 198         for (int i = 0; i < arr.length; ++i) {
 199             byte curVal = arr[i];
 200             result.append(Character.forDigit(curVal >> 4 & 0xF, 16));
 201             result.append(Character.forDigit(curVal & 0xF, 16));
 202         }
 203         return result.toString();
 204     }
 205 
 206     static TestPair<IntegerModuloP>
 207     applyAndCheck(ElemFunction func, TestPair<MutableIntegerModuloP> left,
 208                   TestPair<IntegerModuloP> right) {
 209 
 210         TestPair<IntegerModuloP> result = left.apply(func, right);
 211         result.assertEqual();
 212         left.assertEqual();
 213         right.assertEqual();
 214 
 215         return result;
 216     }
 217 
 218     static void
 219     setAndCheck(ElemSetFunction func, TestPair<MutableIntegerModuloP> left,
 220                 TestPair<IntegerModuloP> right, byte[] argV) {
 221 
 222         left.apply(func, right, argV);
 223         left.assertEqual();
 224         right.assertEqual();
 225     }
 226 
 227     static TestPair<MutableIntegerModuloP>
 228     applyAndCheckMutable(ElemFunction func,
 229                          TestPair<MutableIntegerModuloP> left,
 230                          TestPair<IntegerModuloP> right) {
 231 
 232         TestPair<IntegerModuloP> result = applyAndCheck(func, left, right);
 233 
 234         TestPair<MutableIntegerModuloP> mutableResult = result.mutable();
 235         mutableResult.assertEqual();
 236         result.assertEqual();
 237         left.assertEqual();
 238         right.assertEqual();
 239 
 240         return mutableResult;
 241     }
 242 
 243     static void
 244     cswapAndCheck(int swap, TestPair<MutableIntegerModuloP> left,
 245                   TestPair<MutableIntegerModuloP> right) {
 246 
 247         left.getTest().conditionalSwapWith(right.getTest(), swap);
 248         left.getBaseline().conditionalSwapWith(right.getBaseline(), swap);
 249 
 250         left.assertEqual();
 251         right.assertEqual();
 252 
 253     }
 254 
 255     // Request arithmetic that should overflow, and ensure that overflow is
 256     // detected.
 257     static void runOverflowTest(TestPair<IntegerModuloP> elem) {
 258 
 259         TestPair<MutableIntegerModuloP> mutableElem = elem.mutable();
 260 
 261         try {
 262             for (int i = 0; i < 1000; i++) {
 263                 applyAndCheck(MutableIntegerModuloP::setSum, mutableElem, elem);
 264             }
 265             applyAndCheck(MutableIntegerModuloP::setProduct, mutableElem, elem);
 266         } catch (ArithmeticException ex) {
 267             // this is expected
 268         }
 269 
 270         mutableElem = elem.mutable();
 271         try {
 272             for (int i = 0; i < 1000; i++) {
 273                 elem = applyAndCheck(IntegerModuloP::add,
 274                     mutableElem, elem);
 275             }
 276             applyAndCheck(IntegerModuloP::multiply, mutableElem, elem);
 277         } catch (ArithmeticException ex) {
 278             // this is expected
 279         }
 280     }
 281 
 282     // Run a large number of random operations and ensure that
 283     // results are correct
 284     static void runOperationsTest(Random random, int length,
 285                                   TestPair<IntegerModuloP> elem,
 286                                   TestPair<IntegerModuloP> right) {
 287 
 288         TestPair<MutableIntegerModuloP> left = elem.mutable();
 289 
 290         for (int i = 0; i < 10000; i++) {
 291 
 292             ElemFunction addFunc1 =
 293                 ADD_FUNCTIONS.get(random.nextInt(ADD_FUNCTIONS.size()));
 294             TestPair<MutableIntegerModuloP> result1 =
 295                 applyAndCheckMutable(addFunc1, left, right);
 296 
 297             // left could have been modified, so turn it back into a summand
 298             applyAndCheckMutable((a, b) -> a.setSquare(), left, right);
 299 
 300             ElemFunction addFunc2 =
 301                 ADD_FUNCTIONS.get(random.nextInt(ADD_FUNCTIONS.size()));
 302             TestPair<IntegerModuloP> result2 =
 303                 applyAndCheck(addFunc2, left, right);
 304 
 305             ElemFunction multFunc2 =
 306                 MULT_FUNCTIONS.get(random.nextInt(MULT_FUNCTIONS.size()));
 307             TestPair<MutableIntegerModuloP> multResult =
 308                 applyAndCheckMutable(multFunc2, result1, result2);
 309 
 310             int swap = random.nextInt(2);
 311             cswapAndCheck(swap, left, multResult);
 312 
 313             ElemSetFunction setFunc =
 314                 SET_FUNCTIONS.get(random.nextInt(SET_FUNCTIONS.size()));
 315             byte[] valueArr = new byte[length];
 316             random.nextBytes(valueArr);
 317             setAndCheck(setFunc, result1, result2, valueArr);
 318 
 319             // left could have been modified, so to turn it back into a summand
 320             applyAndCheckMutable((a, b) -> a.setSquare(), left, right);
 321 
 322             ElemArrayFunction arrayFunc =
 323                 ARRAY_FUNCTIONS.get(random.nextInt(ARRAY_FUNCTIONS.size()));
 324             left.applyAndCheckArray(arrayFunc, right);
 325         }
 326     }
 327 
 328     // Run all the tests for a given field
 329     static void runFieldTest(IntegerFieldModuloP testField,
 330                              int length, int seed) {
 331         System.out.println("Testing: " + testField.getClass().getSimpleName());
 332 
 333         Random random = new Random(seed);
 334 
 335         IntegerFieldModuloP baselineField =
 336             new BigIntegerModuloP(testField.getSize());
 337 
 338         int numBits = testField.getSize().bitLength();
 339         BigInteger r =
 340             new BigInteger(numBits, random).mod(testField.getSize());
 341         TestPair<IntegerModuloP> rand =
 342             new TestPair(testField.getElement(r), baselineField.getElement(r));
 343 
 344         runOverflowTest(rand);
 345 
 346         // check combinations of operations for different kinds of elements
 347         List<TestPair<IntegerModuloP>> testElements = new ArrayList<>();
 348         testElements.add(rand);
 349         testElements.add(new TestPair(testField.get0(), baselineField.get0()));
 350         testElements.add(new TestPair(testField.get1(), baselineField.get1()));
 351         byte[] testArr = {121, 37, -100, -5, 76, 33};
 352         testElements.add(new TestPair(testField.getElement(testArr),
 353             baselineField.getElement(testArr)));
 354 
 355         testArr = new byte[length];
 356         random.nextBytes(testArr);
 357         testElements.add(new TestPair(testField.getElement(testArr),
 358             baselineField.getElement(testArr)));
 359 
 360         random.nextBytes(testArr);
 361         byte highByte = (byte) (numBits > length * 8 ? 1 : 0);
 362         testElements.add(
 363             new TestPair(
 364                 testField.getElement(testArr, 0, testArr.length, highByte),
 365                 baselineField.getElement(testArr, 0, testArr.length, highByte)
 366             )
 367         );
 368 
 369         for (int i = 0; i < testElements.size(); i++) {
 370             for (int j = 0; j < testElements.size(); j++) {
 371                 runOperationsTest(random, length, testElements.get(i),
 372                     testElements.get(j));
 373             }
 374         }
 375     }
 376 }
 377