1 /*
   2  * Copyright (c) 2018, Oracle and/or its affiliates. All rights reserved.
   3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
   4  *
   5  * This code is free software; you can redistribute it and/or modify it
   6  * under the terms of the GNU General Public License version 2 only, as
   7  * published by the Free Software Foundation.  Oracle designates this
   8  * particular file as subject to the "Classpath" exception as provided
   9  * by Oracle in the LICENSE file that accompanied this code.
  10  *
  11  * This code is distributed in the hope that it will be useful, but WITHOUT
  12  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
  13  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
  14  * version 2 for more details (a copy is included in the LICENSE file that
  15  * accompanied this code).
  16  *
  17  * You should have received a copy of the GNU General Public License version
  18  * 2 along with this work; if not, write to the Free Software Foundation,
  19  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
  20  *
  21  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
  22  * or visit www.oracle.com if you need additional information or have any
  23  * questions.
  24  */
  25 
  26 package sun.security.util.math.intpoly;
  27 
  28 import java.lang.invoke.MethodHandles;
  29 import java.lang.invoke.VarHandle;
  30 import java.math.BigInteger;
  31 import java.nio.*;
  32 
  33 /**
  34  * An IntegerFieldModuloP designed for use with the Poly1305 authenticator.
  35  * The representation uses 5 signed long values.
  36  *
  37  * In addition to the branch-free operations specified in the parent class,
  38  * the following operations are branch-free:
  39  *
  40  * addModPowerTwo
  41  * asByteArray
  42  *
  43  */
  44 
  45 public class IntegerPolynomial1305 extends IntegerPolynomial {
  46 
  47     protected static final int SUBTRAHEND = 5;
  48     protected static final int NUM_LIMBS = 5;
  49     private static final int POWER = 130;
  50     private static final int BITS_PER_LIMB = 26;
  51     private static final BigInteger MODULUS
  52         = TWO.pow(POWER).subtract(BigInteger.valueOf(SUBTRAHEND));
  53 
  54     private final long[] posModLimbs;
  55 
  56     private long[] setPosModLimbs() {
  57         long[] result = new long[NUM_LIMBS];
  58         setLimbsValuePositive(MODULUS, result);
  59         return result;
  60     }
  61 
  62     public IntegerPolynomial1305() {
  63         super(BITS_PER_LIMB, NUM_LIMBS, MODULUS);
  64         posModLimbs = setPosModLimbs();
  65     }
  66 
  67     protected void mult(long[] a, long[] b, long[] r) {
  68 
  69         // Use grade-school multiplication into primitives to avoid the
  70         // temporary array allocation. This is equivalent to the following
  71         // code:
  72         //  long[] c = new long[2 * NUM_LIMBS - 1];
  73         //  for(int i = 0; i < NUM_LIMBS; i++) {
  74         //      for(int j - 0; j < NUM_LIMBS; j++) {
  75         //          c[i + j] += a[i] * b[j]
  76         //      }
  77         //  }
  78 
  79         long c0 = (a[0] * b[0]);
  80         long c1 = (a[0] * b[1]) + (a[1] * b[0]);
  81         long c2 = (a[0] * b[2]) + (a[1] * b[1]) + (a[2] * b[0]);
  82         long c3 = (a[0] * b[3]) + (a[1] * b[2]) + (a[2] * b[1]) + (a[3] * b[0]);
  83         long c4 = (a[0] * b[4]) + (a[1] * b[3]) + (a[2] * b[2]) + (a[3] * b[1]) + (a[4] * b[0]);
  84         long c5 = (a[1] * b[4]) + (a[2] * b[3]) + (a[3] * b[2]) + (a[4] * b[1]);
  85         long c6 = (a[2] * b[4]) + (a[3] * b[3]) + (a[4] * b[2]);
  86         long c7 = (a[3] * b[4]) + (a[4] * b[3]);
  87         long c8 = (a[4] * b[4]);
  88 
  89         carryReduce(r, c0, c1, c2, c3, c4, c5, c6, c7, c8);
  90     }
  91 
  92     private void carryReduce(long[] r, long c0, long c1, long c2, long c3,
  93                              long c4, long c5, long c6, long c7, long c8) {
  94         //reduce(2, 2)
  95         r[2] = c2 + (c7 * SUBTRAHEND);
  96         c3 += (c8 * SUBTRAHEND);
  97 
  98         // carry(3, 2)
  99         long carry3 = carryValue(c3);
 100         r[3] = c3 - (carry3 << BITS_PER_LIMB);
 101         c4 += carry3;
 102 
 103         long carry4 = carryValue(c4);
 104         r[4] = c4 - (carry4 << BITS_PER_LIMB);
 105         c5 += carry4;
 106 
 107         // reduce(0, 2)
 108         r[0] = c0 + (c5 * SUBTRAHEND);
 109         r[1] = c1 + (c6 * SUBTRAHEND);
 110 
 111         // carry(0, 4)
 112         carry(r);
 113     }
 114 
 115     protected void multByInt(long[] a, long b, long[] r) {
 116 
 117         for (int i = 0; i < a.length; i++) {
 118             r[i] = a[i] * b;
 119         }
 120 
 121         reduce(r);
 122     }
 123 
 124     @Override
 125     protected void square(long[] a, long[] r) {
 126         // Use grade-school multiplication with a simple squaring optimization.
 127         // Multiply into primitives to avoid the temporary array allocation.
 128         // This is equivalent to the following code:
 129         //  long[] c = new long[2 * NUM_LIMBS - 1];
 130         //  for(int i = 0; i < NUM_LIMBS; i++) {
 131         //      c[2 * i] = a[i] * a[i];
 132         //      for(int j = i + 1; j < NUM_LIMBS; j++) {
 133         //          c[i + j] += 2 * a[i] * a[j]
 134         //      }
 135         //  }
 136 
 137         long c0 = (a[0] * a[0]);
 138         long c1 = 2 * (a[0] * a[1]);
 139         long c2 = 2 * (a[0] * a[2]) + (a[1] * a[1]);
 140         long c3 = 2 * (a[0] * a[3] + a[1] * a[2]);
 141         long c4 = 2 * (a[0] * a[4] + a[1] * a[3]) + (a[2] * a[2]);
 142         long c5 = 2 * (a[1] * a[4] + a[2] * a[3]);
 143         long c6 = 2 * (a[2] * a[4]) + (a[3] * a[3]);
 144         long c7 = 2 * (a[3] * a[4]);
 145         long c8 = (a[4] * a[4]);
 146 
 147         carryReduce(r, c0, c1, c2, c3, c4, c5, c6, c7, c8);
 148     }
 149 
 150     @Override
 151     protected void encode(ByteBuffer buf, int length, byte highByte,
 152                           long[] result) {
 153         if (length == 16) {
 154             long low = buf.getLong();
 155             long high = buf.getLong();
 156             encode(high, low, highByte, result);
 157         } else {
 158             super.encode(buf, length, highByte, result);
 159         }
 160     }
 161 
 162     protected void encode(long high, long low, byte highByte, long[] result) {
 163         result[0] = low & 0x3FFFFFFL;
 164         result[1] = (low >>> 26) & 0x3FFFFFFL;
 165         result[2] = (low >>> 52) + ((high & 0x3FFFL) << 12);
 166         result[3] = (high >>> 14) & 0x3FFFFFFL;
 167         result[4] = (high >>> 40) + (highByte << 24L);
 168     }
 169 
 170     private static final VarHandle AS_LONG_LE = MethodHandles
 171         .byteArrayViewVarHandle(long[].class, ByteOrder.LITTLE_ENDIAN);
 172 
 173     protected void encode(byte[] v, int offset, int length, byte highByte,
 174                           long[] result) {
 175         if (length == 16) {
 176             long low = (long) AS_LONG_LE.get(v, offset);
 177             long high = (long) AS_LONG_LE.get(v, offset + 8);
 178             encode(high, low, highByte, result);
 179         } else {
 180             super.encode(v, offset, length, highByte, result);
 181         }
 182     }
 183 
 184     protected void modReduceIn(long[] limbs, int index, long x) {
 185         // this only works when BITS_PER_LIMB * NUM_LIMBS = POWER exactly
 186         long reducedValue = (x * SUBTRAHEND);
 187         limbs[index - NUM_LIMBS] += reducedValue;
 188     }
 189 
 190     protected final void modReduce(long[] limbs, int start, int end) {
 191 
 192         for (int i = start; i < end; i++) {
 193             modReduceIn(limbs, i, limbs[i]);
 194             limbs[i] = 0;
 195         }
 196     }
 197 
 198     protected void modReduce(long[] limbs) {
 199 
 200         modReduce(limbs, NUM_LIMBS, NUM_LIMBS - 1);
 201     }
 202 
 203     @Override
 204     protected long carryValue(long x) {
 205         // This representation has plenty of extra space, so we can afford to
 206         // do a simplified carry operation that is more time-efficient.
 207 
 208         return x >> BITS_PER_LIMB;
 209     }
 210 
 211 
 212     protected void reduce(long[] limbs) {
 213         long carry3 = carryOut(limbs, 3);
 214         long new4 = carry3 + limbs[4];
 215 
 216         long carry4 = carryValue(new4);
 217         limbs[4] = new4 - (carry4 << BITS_PER_LIMB);
 218 
 219         modReduceIn(limbs, 5, carry4);
 220         carry(limbs);
 221     }
 222 
 223     // Convert reduced limbs into a number between 0 and MODULUS-1
 224     private void finalReduce(long[] limbs) {
 225 
 226         addLimbs(limbs, posModLimbs, limbs);
 227         // now all values are positive, so remaining operations will be unsigned
 228 
 229         // unsigned carry out of last position and reduce in to first position
 230         long carry = limbs[NUM_LIMBS - 1] >> BITS_PER_LIMB;
 231         limbs[NUM_LIMBS - 1] -= carry << BITS_PER_LIMB;
 232         modReduceIn(limbs, NUM_LIMBS, carry);
 233 
 234         // unsigned carry on all positions
 235         carry = 0;
 236         for (int i = 0; i < NUM_LIMBS; i++) {
 237             limbs[i] += carry;
 238             carry = limbs[i] >> BITS_PER_LIMB;
 239             limbs[i] -= carry << BITS_PER_LIMB;
 240         }
 241         // reduce final carry value back in
 242         modReduceIn(limbs, NUM_LIMBS, carry);
 243         // we only reduce back in a nonzero value if some value was carried out
 244         // of the previous loop. So at least one remaining value is small.
 245 
 246         // One more carry is all that is necessary. Nothing will be carried out
 247         // at the end
 248         carry = 0;
 249         for (int i = 0; i < NUM_LIMBS; i++) {
 250             limbs[i] += carry;
 251             carry = limbs[i] >> BITS_PER_LIMB;
 252             limbs[i] -= carry << BITS_PER_LIMB;
 253         }
 254 
 255         // limbs are positive and all less than 2^BITS_PER_LIMB
 256         // but the value may be greater than the MODULUS.
 257         // Subtract the max limb values only if all limbs end up non-negative
 258         int smallerNonNegative = 1;
 259         long[] smaller = new long[NUM_LIMBS];
 260         for (int i = NUM_LIMBS - 1; i >= 0; i--) {
 261             smaller[i] = limbs[i] - posModLimbs[i];
 262             // expression on right is 1 if smaller[i] is nonnegative,
 263             // 0 otherwise
 264             smallerNonNegative *= (int) (smaller[i] >> 63) + 1;
 265         }
 266         conditionalSwap(smallerNonNegative, limbs, smaller);
 267 
 268     }
 269 
 270     @Override
 271     protected void limbsToByteArray(long[] limbs, byte[] result) {
 272 
 273         long[] reducedLimbs = limbs.clone();
 274         finalReduce(reducedLimbs);
 275 
 276         decode(reducedLimbs, result, 0, result.length);
 277     }
 278 
 279     @Override
 280     protected void addLimbsModPowerTwo(long[] limbs, long[] other,
 281                                        byte[] result) {
 282 
 283         long[] reducedOther = other.clone();
 284         long[] reducedLimbs = limbs.clone();
 285         finalReduce(reducedLimbs);
 286 
 287         addLimbs(reducedLimbs, reducedOther, reducedLimbs);
 288 
 289         // may carry out a value which can be ignored
 290         long carry = 0;
 291         for (int i = 0; i < NUM_LIMBS; i++) {
 292             reducedLimbs[i] += carry;
 293             carry  = reducedLimbs[i] >> BITS_PER_LIMB;
 294             reducedLimbs[i] -= carry << BITS_PER_LIMB;
 295         }
 296 
 297         decode(reducedLimbs, result, 0, result.length);
 298     }
 299 
 300 }
 301