test/java/math/BigInteger/BigIntegerTest.java

Print this page
rev 7462 : 4837946: Faster multiplication and exponentiation of large integers
4646474: BigInteger.pow() algorithm slow in 1.4.0
Summary: Implement Karatsuba and 3-way Toom-Cook multiplication as well as exponentiation using Karatsuba and Toom-Cook squaring.
Reviewed-by: alanb, bpb, martin
Contributed-by: Alan Eliasen <eliasen@mindspring.com>

@@ -1,7 +1,7 @@
 /*
- * Copyright (c) 1998, 2011, Oracle and/or its affiliates. All rights reserved.
+ * Copyright (c) 1998, 2013, Oracle and/or its affiliates. All rights reserved.
  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
  *
  * This code is free software; you can redistribute it and/or modify it
  * under the terms of the GNU General Public License version 2 only, as
  * published by the Free Software Foundation.

@@ -21,19 +21,23 @@
  * questions.
  */
 
 /*
  * @test
- * @bug 4181191 4161971 4227146 4194389 4823171 4624738 4812225
+ * @bug 4181191 4161971 4227146 4194389 4823171 4624738 4812225 4837946
  * @summary tests methods in BigInteger
  * @run main/timeout=400 BigIntegerTest
  * @author madbot
  */
 
-import java.util.Random;
+import java.io.File;
+import java.io.FileInputStream;
+import java.io.FileOutputStream;
+import java.io.ObjectInputStream;
+import java.io.ObjectOutputStream;
 import java.math.BigInteger;
-import java.io.*;
+import java.util.Random;
 
 /**
  * This is a simple test class created to ensure that the results
  * generated by BigInteger adhere to certain identities. Passing
  * this test is a strong assurance that the BigInteger operations

@@ -46,78 +50,229 @@
  * generated by a Random class as well as special cases which
  * throw in boundary numbers such as 0, 1, maximum sized, etc.
  *
  */
 public class BigIntegerTest {
+    //
+    // Bit large number thresholds based on the int thresholds
+    // defined in BigInteger itself:
+    //
+    // KARATSUBA_THRESHOLD        = 50  ints = 1600 bits
+    // TOOM_COOK_THRESHOLD        = 75  ints = 2400 bits
+    // KARATSUBA_SQUARE_THRESHOLD = 90  ints = 2880 bits
+    // TOOM_COOK_SQUARE_THRESHOLD = 140 ints = 4480 bits
+    //
+    static final int BITS_KARATSUBA = 1600;
+    static final int BITS_TOOM_COOK = 2400;
+    static final int BITS_KARATSUBA_SQUARE = 2880;
+    static final int BITS_TOOM_COOK_SQUARE = 4480;
+
+    static final int ORDER_SMALL = 60;
+    static final int ORDER_MEDIUM = 100;
+    // #bits for testing Karatsuba and Burnikel-Ziegler
+    static final int ORDER_KARATSUBA = 1800;
+    // #bits for testing Toom-Cook
+    static final int ORDER_TOOM_COOK = 3000;
+    // #bits for testing Karatsuba squaring
+    static final int ORDER_KARATSUBA_SQUARE = 3200;
+    // #bits for testing Toom-Cook squaring
+    static final int ORDER_TOOM_COOK_SQUARE = 4600;
+
     static Random rnd = new Random();
     static int size = 1000; // numbers per batch
     static boolean failure = false;
 
-    // Some variables for sizing test numbers in bits
-    private static int order1 = 100;
-    private static int order2 = 60;
-    private static int order3 = 30;
-
-    public static void pow() {
+    public static void pow(int order) {
         int failCount1 = 0;
 
         for (int i=0; i<size; i++) {
-            int power = rnd.nextInt(6) +2;
-            BigInteger x = fetchNumber(order1);
+            // Test identity x^power == x*x*x ... *x
+            int power = rnd.nextInt(6) + 2;
+            BigInteger x = fetchNumber(order);
             BigInteger y = x.pow(power);
             BigInteger z = x;
 
             for (int j=1; j<power; j++)
                 z = z.multiply(x);
 
             if (!y.equals(z))
                 failCount1++;
         }
-        report("pow", failCount1);
+        report("pow for " + order + " bits", failCount1);
     }
 
-    public static void arithmetic() {
+    public static void square(int order) {
+        int failCount1 = 0;
+
+        for (int i=0; i<size; i++) {
+            // Test identity x^2 == x*x
+            BigInteger x  = fetchNumber(order);
+            BigInteger xx = x.multiply(x);
+            BigInteger x2 = x.pow(2);
+
+            if (!x2.equals(xx))
+                failCount1++;
+        }
+        report("square for " + order + " bits", failCount1);
+    }
+
+    public static void arithmetic(int order) {
         int failCount = 0;
 
         for (int i=0; i<size; i++) {
-            BigInteger x = fetchNumber(order1);
+            BigInteger x = fetchNumber(order);
             while(x.compareTo(BigInteger.ZERO) != 1)
-                x = fetchNumber(order1);
-            BigInteger y = fetchNumber(order1/2);
+                x = fetchNumber(order);
+            BigInteger y = fetchNumber(order/2);
             while(x.compareTo(y) == -1)
-                y = fetchNumber(order1/2);
+                y = fetchNumber(order/2);
             if (y.equals(BigInteger.ZERO))
                 y = y.add(BigInteger.ONE);
 
+            // Test identity ((x/y))*y + x%y - x == 0
+            // using separate divide() and remainder()
             BigInteger baz = x.divide(y);
             baz = baz.multiply(y);
             baz = baz.add(x.remainder(y));
             baz = baz.subtract(x);
             if (!baz.equals(BigInteger.ZERO))
                 failCount++;
         }
-        report("Arithmetic I", failCount);
+        report("Arithmetic I for " + order + " bits", failCount);
 
         failCount = 0;
         for (int i=0; i<100; i++) {
-            BigInteger x = fetchNumber(order1);
+            BigInteger x = fetchNumber(order);
             while(x.compareTo(BigInteger.ZERO) != 1)
-                x = fetchNumber(order1);
-            BigInteger y = fetchNumber(order1/2);
+                x = fetchNumber(order);
+            BigInteger y = fetchNumber(order/2);
             while(x.compareTo(y) == -1)
-                y = fetchNumber(order1/2);
+                y = fetchNumber(order/2);
             if (y.equals(BigInteger.ZERO))
                 y = y.add(BigInteger.ONE);
 
+            // Test identity ((x/y))*y + x%y - x == 0
+            // using divideAndRemainder()
             BigInteger baz[] = x.divideAndRemainder(y);
             baz[0] = baz[0].multiply(y);
             baz[0] = baz[0].add(baz[1]);
             baz[0] = baz[0].subtract(x);
             if (!baz[0].equals(BigInteger.ZERO))
                 failCount++;
         }
-        report("Arithmetic II", failCount);
+        report("Arithmetic II for " + order + " bits", failCount);
+    }
+
+    /**
+     * Sanity test for Karatsuba and 3-way Toom-Cook multiplication.
+     * For each of the Karatsuba and 3-way Toom-Cook multiplication thresholds,
+     * construct two factors each with a mag array one element shorter than the
+     * threshold, and with the most significant bit set and the rest of the bits
+     * random. Each of these numbers will therefore be below the threshold but
+     * if shifted left be above the threshold. Call the numbers 'u' and 'v' and
+     * define random shifts 'a' and 'b' in the range [1,32]. Then we have the
+     * identity
+     * <pre>
+     * (u << a)*(v << b) = (u*v) << (a + b)
+     * </pre>
+     * For Karatsuba multiplication, the right hand expression will be evaluated
+     * using the standard naive algorithm, and the left hand expression using
+     * the Karatsuba algorithm. For 3-way Toom-Cook multiplication, the right
+     * hand expression will be evaluated using Karatsuba multiplication, and the
+     * left hand expression using 3-way Toom-Cook multiplication.
+     */
+    public static void multiplyLarge() {
+        int failCount = 0;
+
+        BigInteger base = BigInteger.ONE.shiftLeft(BITS_KARATSUBA - 32 - 1);
+        for (int i=0; i<size; i++) {
+            BigInteger x = fetchNumber(BITS_KARATSUBA - 32 - 1);
+            BigInteger u = base.add(x);
+            int a = 1 + rnd.nextInt(31);
+            BigInteger w = u.shiftLeft(a);
+
+            BigInteger y = fetchNumber(BITS_KARATSUBA - 32 - 1);
+            BigInteger v = base.add(y);
+            int b = 1 + rnd.nextInt(32);
+            BigInteger z = v.shiftLeft(b);
+
+            BigInteger multiplyResult = u.multiply(v).shiftLeft(a + b);
+            BigInteger karatsubaMultiplyResult = w.multiply(z);
+
+            if (!multiplyResult.equals(karatsubaMultiplyResult)) {
+                failCount++;
+            }
+        }
+
+        report("multiplyLarge Karatsuba", failCount);
+
+        failCount = 0;
+        base = base.shiftLeft(BITS_TOOM_COOK - BITS_KARATSUBA);
+        for (int i=0; i<size; i++) {
+            BigInteger x = fetchNumber(BITS_TOOM_COOK - 32 - 1);
+            BigInteger u = base.add(x);
+            BigInteger u2 = u.shiftLeft(1);
+            BigInteger y = fetchNumber(BITS_TOOM_COOK - 32 - 1);
+            BigInteger v = base.add(y);
+            BigInteger v2 = v.shiftLeft(1);
+
+            BigInteger multiplyResult = u.multiply(v).shiftLeft(2);
+            BigInteger toomCookMultiplyResult = u2.multiply(v2);
+
+            if (!multiplyResult.equals(toomCookMultiplyResult)) {
+                failCount++;
+            }
+        }
+
+        report("multiplyLarge Toom-Cook", failCount);
+    }
+
+    /**
+     * Sanity test for Karatsuba and 3-way Toom-Cook squaring.
+     * This test is analogous to {@link AbstractMethodError#multiplyLarge}
+     * with both factors being equal. The squaring methods will not be tested
+     * unless the <code>bigInteger.multiply(bigInteger)</code> tests whether
+     * the parameter is the same instance on which the method is being invoked
+     * and calls <code>square()</code> accordingly.
+     */
+    public static void squareLarge() {
+        int failCount = 0;
+
+        BigInteger base = BigInteger.ONE.shiftLeft(BITS_KARATSUBA_SQUARE - 32 - 1);
+        for (int i=0; i<size; i++) {
+            BigInteger x = fetchNumber(BITS_KARATSUBA_SQUARE - 32 - 1);
+            BigInteger u = base.add(x);
+            int a = 1 + rnd.nextInt(31);
+            BigInteger w = u.shiftLeft(a);
+
+            BigInteger squareResult = u.multiply(u).shiftLeft(2*a);
+            BigInteger karatsubaSquareResult = w.multiply(w);
+
+            if (!squareResult.equals(karatsubaSquareResult)) {
+                failCount++;
+            }
+        }
+
+        report("squareLarge Karatsuba", failCount);
+
+        failCount = 0;
+        base = base.shiftLeft(BITS_TOOM_COOK_SQUARE - BITS_KARATSUBA_SQUARE);
+        for (int i=0; i<size; i++) {
+            BigInteger x = fetchNumber(BITS_TOOM_COOK_SQUARE - 32 - 1);
+            BigInteger u = base.add(x);
+            int a = 1 + rnd.nextInt(31);
+            BigInteger w = u.shiftLeft(a);
+
+            BigInteger squareResult = u.multiply(u).shiftLeft(2*a);
+            BigInteger toomCookSquareResult = w.multiply(w);
+
+            if (!squareResult.equals(toomCookSquareResult)) {
+                failCount++;
+            }
+        }
+
+        report("squareLarge Toom-Cook", failCount);
     }
 
     public static void bitCount() {
         int failCount = 0;
 

@@ -158,18 +313,18 @@
         }
 
         report("BitLength", failCount);
     }
 
-    public static void bitOps() {
+    public static void bitOps(int order) {
         int failCount1 = 0, failCount2 = 0, failCount3 = 0;
 
         for (int i=0; i<size*5; i++) {
-            BigInteger x = fetchNumber(order1);
+            BigInteger x = fetchNumber(order);
             BigInteger y;
 
-            /* Test setBit and clearBit (and testBit) */
+            // Test setBit and clearBit (and testBit)
             if (x.signum() < 0) {
                 y = BigInteger.valueOf(-1);
                 for (int j=0; j<x.bitLength(); j++)
                     if (!x.testBit(j))
                         y = y.clearBit(j);

@@ -180,25 +335,25 @@
                         y = y.setBit(j);
             }
             if (!x.equals(y))
                 failCount1++;
 
-            /* Test flipBit (and testBit) */
+            // Test flipBit (and testBit)
             y = BigInteger.valueOf(x.signum()<0 ? -1 : 0);
             for (int j=0; j<x.bitLength(); j++)
                 if (x.signum()<0  ^  x.testBit(j))
                     y = y.flipBit(j);
             if (!x.equals(y))
                 failCount2++;
         }
-        report("clearBit/testBit", failCount1);
-        report("flipBit/testBit", failCount2);
+        report("clearBit/testBit for " + order + " bits", failCount1);
+        report("flipBit/testBit for " + order + " bits", failCount2);
 
         for (int i=0; i<size*5; i++) {
-            BigInteger x = fetchNumber(order1);
+            BigInteger x = fetchNumber(order);
 
-            /* Test getLowestSetBit() */
+            // Test getLowestSetBit()
             int k = x.getLowestSetBit();
             if (x.signum() == 0) {
                 if (k != -1)
                     failCount3++;
             } else {

@@ -208,47 +363,47 @@
                     ;
                 if (k != j)
                     failCount3++;
             }
         }
-        report("getLowestSetBit", failCount3);
+        report("getLowestSetBit for " + order + " bits", failCount3);
     }
 
-    public static void bitwise() {
+    public static void bitwise(int order) {
 
-        /* Test identity x^y == x|y &~ x&y */
+        // Test identity x^y == x|y &~ x&y
         int failCount = 0;
         for (int i=0; i<size; i++) {
-            BigInteger x = fetchNumber(order1);
-            BigInteger y = fetchNumber(order1);
+            BigInteger x = fetchNumber(order);
+            BigInteger y = fetchNumber(order);
             BigInteger z = x.xor(y);
             BigInteger w = x.or(y).andNot(x.and(y));
             if (!z.equals(w))
                 failCount++;
         }
-        report("Logic (^ | & ~)", failCount);
+        report("Logic (^ | & ~) for " + order + " bits", failCount);
 
-        /* Test identity x &~ y == ~(~x | y) */
+        // Test identity x &~ y == ~(~x | y)
         failCount = 0;
         for (int i=0; i<size; i++) {
-            BigInteger x = fetchNumber(order1);
-            BigInteger y = fetchNumber(order1);
+            BigInteger x = fetchNumber(order);
+            BigInteger y = fetchNumber(order);
             BigInteger z = x.andNot(y);
             BigInteger w = x.not().or(y).not();
             if (!z.equals(w))
                 failCount++;
         }
-        report("Logic (&~ | ~)", failCount);
+        report("Logic (&~ | ~) for " + order + " bits", failCount);
     }
 
-    public static void shift() {
+    public static void shift(int order) {
         int failCount1 = 0;
         int failCount2 = 0;
         int failCount3 = 0;
 
         for (int i=0; i<100; i++) {
-            BigInteger x = fetchNumber(order1);
+            BigInteger x = fetchNumber(order);
             int n = Math.abs(rnd.nextInt()%200);
 
             if (!x.shiftLeft(n).equals
                 (x.multiply(BigInteger.valueOf(2L).pow(n))))
                 failCount1++;

@@ -272,22 +427,22 @@
             }
 
             if (!x.shiftLeft(n).shiftRight(n).equals(x))
                 failCount3++;
         }
-        report("baz shiftLeft", failCount1);
-        report("baz shiftRight", failCount2);
-        report("baz shiftLeft/Right", failCount3);
+        report("baz shiftLeft for " + order + " bits", failCount1);
+        report("baz shiftRight for " + order + " bits", failCount2);
+        report("baz shiftLeft/Right for " + order + " bits", failCount3);
     }
 
-    public static void divideAndRemainder() {
+    public static void divideAndRemainder(int order) {
         int failCount1 = 0;
 
         for (int i=0; i<size; i++) {
-            BigInteger x = fetchNumber(order1).abs();
+            BigInteger x = fetchNumber(order).abs();
             while(x.compareTo(BigInteger.valueOf(3L)) != 1)
-                x = fetchNumber(order1).abs();
+                x = fetchNumber(order).abs();
             BigInteger z = x.divide(BigInteger.valueOf(2L));
             BigInteger y[] = x.divideAndRemainder(x);
             if (!y[0].equals(BigInteger.ONE)) {
                 failCount1++;
                 System.err.println("fail1 x :"+x);

@@ -304,11 +459,11 @@
                 failCount1++;
                 System.err.println("fail3 x :"+x);
                 System.err.println("      y :"+y);
             }
         }
-        report("divideAndRemainder I", failCount1);
+        report("divideAndRemainder for " + order + " bits", failCount1);
     }
 
     public static void stringConv() {
         int failCount = 0;
 

@@ -329,37 +484,37 @@
             }
         }
         report("String Conversion", failCount);
     }
 
-    public static void byteArrayConv() {
+    public static void byteArrayConv(int order) {
         int failCount = 0;
 
         for (int i=0; i<size; i++) {
-            BigInteger x = fetchNumber(order1);
+            BigInteger x = fetchNumber(order);
             while (x.equals(BigInteger.ZERO))
-                x = fetchNumber(order1);
+                x = fetchNumber(order);
             BigInteger y = new BigInteger(x.toByteArray());
             if (!x.equals(y)) {
                 failCount++;
                 System.err.println("orig is "+x);
                 System.err.println("new is "+y);
             }
         }
-        report("Array Conversion", failCount);
+        report("Array Conversion for " + order + " bits", failCount);
     }
 
-    public static void modInv() {
+    public static void modInv(int order) {
         int failCount = 0, successCount = 0, nonInvCount = 0;
 
         for (int i=0; i<size; i++) {
-            BigInteger x = fetchNumber(order1);
+            BigInteger x = fetchNumber(order);
             while(x.equals(BigInteger.ZERO))
-                x = fetchNumber(order1);
-            BigInteger m = fetchNumber(order1).abs();
+                x = fetchNumber(order);
+            BigInteger m = fetchNumber(order).abs();
             while(m.compareTo(BigInteger.ONE) != 1)
-                m = fetchNumber(order1).abs();
+                m = fetchNumber(order).abs();
 
             try {
                 BigInteger inv = x.modInverse(m);
                 BigInteger prod = inv.multiply(x).remainder(m);
 

@@ -372,14 +527,14 @@
                     failCount++;
             } catch(ArithmeticException e) {
                 nonInvCount++;
             }
         }
-        report("Modular Inverse", failCount);
+        report("Modular Inverse for " + order + " bits", failCount);
     }
 
-    public static void modExp() {
+    public static void modExp(int order1, int order2) {
         int failCount = 0;
 
         for (int i=0; i<size/10; i++) {
             BigInteger m = fetchNumber(order1).abs();
             while(m.compareTo(BigInteger.ONE) != 1)

@@ -396,39 +551,40 @@
                 System.err.println("base is "+base);
                 System.err.println("exp is "+exp);
                 failCount++;
             }
         }
-        report("Exponentiation I", failCount);
+        report("Exponentiation I for " + order1 + " and " +
+               order2 + " bits", failCount);
     }
 
     // This test is based on Fermat's theorem
     // which is not ideal because base must not be multiple of modulus
     // and modulus must be a prime or pseudoprime (Carmichael number)
-    public static void modExp2() {
+    public static void modExp2(int order) {
         int failCount = 0;
 
         for (int i=0; i<10; i++) {
             BigInteger m = new BigInteger(100, 5, rnd);
             while(m.compareTo(BigInteger.ONE) != 1)
                 m = new BigInteger(100, 5, rnd);
             BigInteger exp = m.subtract(BigInteger.ONE);
-            BigInteger base = fetchNumber(order1).abs();
+            BigInteger base = fetchNumber(order).abs();
             while(base.compareTo(m) != -1)
-                base = fetchNumber(order1).abs();
+                base = fetchNumber(order).abs();
             while(base.equals(BigInteger.ZERO))
-                base = fetchNumber(order1).abs();
+                base = fetchNumber(order).abs();
 
             BigInteger one = base.modPow(exp, m);
             if (!one.equals(BigInteger.ONE)) {
                 System.err.println("m is "+m);
                 System.err.println("base is "+base);
                 System.err.println("exp is "+exp);
                 failCount++;
             }
         }
-        report("Exponentiation II", failCount);
+        report("Exponentiation II for " + order + " bits", failCount);
     }
 
     private static final int[] mersenne_powers = {
         521, 607, 1279, 2203, 2281, 3217, 4253, 4423, 9689, 9941, 11213, 19937,
         21701, 23209, 44497, 86243, 110503, 132049, 216091, 756839, 859433,

@@ -702,40 +858,66 @@
      * the maximum number of decimal digits that the parameters will have.
      *
      */
     public static void main(String[] args) throws Exception {
 
+        // Some variables for sizing test numbers in bits
+        int order1 = ORDER_MEDIUM;
+        int order2 = ORDER_SMALL;
+        int order3 = ORDER_KARATSUBA;
+        int order4 = ORDER_TOOM_COOK;
+
         if (args.length >0)
             order1 = (int)((Integer.parseInt(args[0]))* 3.333);
         if (args.length >1)
             order2 = (int)((Integer.parseInt(args[1]))* 3.333);
         if (args.length >2)
             order3 = (int)((Integer.parseInt(args[2]))* 3.333);
+        if (args.length >3)
+            order4 = (int)((Integer.parseInt(args[3]))* 3.333);
 
         prime();
         nextProbablePrime();
 
-        arithmetic();
-        divideAndRemainder();
-        pow();
+        arithmetic(order1);   // small numbers
+        arithmetic(order3);   // Karatsuba / Burnikel-Ziegler range
+        arithmetic(order4);   // Toom-Cook range
+
+        divideAndRemainder(order1);   // small numbers
+        divideAndRemainder(order3);   // Karatsuba / Burnikel-Ziegler range
+        divideAndRemainder(order4);   // Toom-Cook range
+
+        pow(order1);
+        pow(order3);
+        pow(order4);
+
+        square(ORDER_MEDIUM);
+        square(ORDER_KARATSUBA_SQUARE);
+        square(ORDER_TOOM_COOK_SQUARE);
 
         bitCount();
         bitLength();
-        bitOps();
-        bitwise();
+        bitOps(order1);
+        bitwise(order1);
 
-        shift();
+        shift(order1);
 
-        byteArrayConv();
+        byteArrayConv(order1);
 
-        modInv();
-        modExp();
-        modExp2();
+        modInv(order1);   // small numbers
+        modInv(order3);   // Karatsuba / Burnikel-Ziegler range
+        modInv(order4);   // Toom-Cook range
+
+        modExp(order1, order2);
+        modExp2(order1);
 
         stringConv();
         serialize();
 
+        multiplyLarge();
+        squareLarge();
+
         if (failure)
             throw new RuntimeException("Failure in BigIntegerTest.");
     }
 
     /*

@@ -745,11 +927,11 @@
      *
      * If order is less than 2, order is changed to 2.
      */
     private static BigInteger fetchNumber(int order) {
         boolean negative = rnd.nextBoolean();
-        int numType = rnd.nextInt(6);
+        int numType = rnd.nextInt(7);
         BigInteger result = null;
         if (order < 2) order = 2;
 
         switch (numType) {
             case 0: // Empty

@@ -781,10 +963,23 @@
                     BigInteger temp = BigInteger.ONE.shiftLeft(
                                                 rnd.nextInt(order));
                     result = result.or(temp);
                 }
                 break;
+            case 5: // Runs of consecutive ones and zeros
+                result = ZERO;
+                int remaining = order;
+                int bit = rnd.nextInt(2);
+                while (remaining > 0) {
+                    int runLength = Math.min(remaining, rnd.nextInt(order));
+                    result = result.shiftLeft(runLength);
+                    if (bit > 0)
+                        result = result.add(ONE.shiftLeft(runLength).subtract(ONE));
+                    remaining -= runLength;
+                    bit = 1 - bit;
+                }
+                break;
 
             default: // random bits
                 result = new BigInteger(order, rnd);
         }