1 /* 2 * Copyright (c) 2018, Oracle and/or its affiliates. All rights reserved. 3 * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. 4 * 5 * This code is free software; you can redistribute it and/or modify it 6 * under the terms of the GNU General Public License version 2 only, as 7 * published by the Free Software Foundation. Oracle designates this 8 * particular file as subject to the "Classpath" exception as provided 9 * by Oracle in the LICENSE file that accompanied this code. 10 * 11 * This code is distributed in the hope that it will be useful, but WITHOUT 12 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or 13 * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License 14 * version 2 for more details (a copy is included in the LICENSE file that 15 * accompanied this code). 16 * 17 * You should have received a copy of the GNU General Public License version 18 * 2 along with this work; if not, write to the Free Software Foundation, 19 * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. 20 * 21 * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA 22 * or visit www.oracle.com if you need additional information or have any 23 * questions. 24 */ 25 26 package sun.security.util.math.intpoly; 27 28 import java.lang.invoke.MethodHandles; 29 import java.lang.invoke.VarHandle; 30 import java.math.BigInteger; 31 import java.nio.*; 32 33 /** 34 * An IntegerFieldModuloP designed for use with the Poly1305 authenticator. 35 * The representation uses 5 signed long values. 36 * 37 * In addition to the branch-free operations specified in the parent class, 38 * the following operations are branch-free: 39 * 40 * addModPowerTwo 41 * asByteArray 42 * 43 */ 44 45 public class IntegerPolynomial1305 extends IntegerPolynomial { 46 47 protected static final int SUBTRAHEND = 5; 48 protected static final int NUM_LIMBS = 5; 49 private static final int POWER = 130; 50 private static final int BITS_PER_LIMB = 26; 51 private static final BigInteger MODULUS 52 = TWO.pow(POWER).subtract(BigInteger.valueOf(SUBTRAHEND)); 53 54 private final long[] posModLimbs; 55 56 private long[] setPosModLimbs() { 57 long[] result = new long[NUM_LIMBS]; 58 setLimbsValuePositive(MODULUS, result); 59 return result; 60 } 61 62 public IntegerPolynomial1305() { 63 super(BITS_PER_LIMB, NUM_LIMBS, MODULUS); 64 posModLimbs = setPosModLimbs(); 65 } 66 67 protected void mult(long[] a, long[] b, long[] r) { 68 69 // Use grade-school multiplication into primitives to avoid the 70 // temporary array allocation. This is equivalent to the following 71 // code: 72 // long[] c = new long[2 * NUM_LIMBS - 1]; 73 // for(int i = 0; i < NUM_LIMBS; i++) { 74 // for(int j - 0; j < NUM_LIMBS; j++) { 75 // c[i + j] += a[i] * b[j] 76 // } 77 // } 78 79 long c0 = (a[0] * b[0]); 80 long c1 = (a[0] * b[1]) + (a[1] * b[0]); 81 long c2 = (a[0] * b[2]) + (a[1] * b[1]) + (a[2] * b[0]); 82 long c3 = (a[0] * b[3]) + (a[1] * b[2]) + (a[2] * b[1]) + (a[3] * b[0]); 83 long c4 = (a[0] * b[4]) + (a[1] * b[3]) + (a[2] * b[2]) + (a[3] * b[1]) + (a[4] * b[0]); 84 long c5 = (a[1] * b[4]) + (a[2] * b[3]) + (a[3] * b[2]) + (a[4] * b[1]); 85 long c6 = (a[2] * b[4]) + (a[3] * b[3]) + (a[4] * b[2]); 86 long c7 = (a[3] * b[4]) + (a[4] * b[3]); 87 long c8 = (a[4] * b[4]); 88 89 carryReduce(r, c0, c1, c2, c3, c4, c5, c6, c7, c8); 90 } 91 92 private void carryReduce(long[] r, long c0, long c1, long c2, long c3, 93 long c4, long c5, long c6, long c7, long c8) { 94 //reduce(2, 2) 95 r[2] = c2 + (c7 * SUBTRAHEND); 96 c3 += (c8 * SUBTRAHEND); 97 98 // carry(3, 2) 99 long carry3 = carryValue(c3); 100 r[3] = c3 - (carry3 << BITS_PER_LIMB); 101 c4 += carry3; 102 103 long carry4 = carryValue(c4); 104 r[4] = c4 - (carry4 << BITS_PER_LIMB); 105 c5 += carry4; 106 107 // reduce(0, 2) 108 r[0] = c0 + (c5 * SUBTRAHEND); 109 r[1] = c1 + (c6 * SUBTRAHEND); 110 111 // carry(0, 4) 112 carry(r); 113 } 114 115 protected void multByInt(long[] a, long b, long[] r) { 116 117 for (int i = 0; i < a.length; i++) { 118 r[i] = a[i] * b; 119 } 120 121 reduce(r); 122 } 123 124 @Override 125 protected void square(long[] a, long[] r) { 126 // Use grade-school multiplication with a simple squaring optimization. 127 // Multiply into primitives to avoid the temporary array allocation. 128 // This is equivalent to the following code: 129 // long[] c = new long[2 * NUM_LIMBS - 1]; 130 // for(int i = 0; i < NUM_LIMBS; i++) { 131 // c[2 * i] = a[i] * a[i]; 132 // for(int j = i + 1; j < NUM_LIMBS; j++) { 133 // c[i + j] += 2 * a[i] * a[j] 134 // } 135 // } 136 137 long c0 = (a[0] * a[0]); 138 long c1 = 2 * (a[0] * a[1]); 139 long c2 = 2 * (a[0] * a[2]) + (a[1] * a[1]); 140 long c3 = 2 * (a[0] * a[3] + a[1] * a[2]); 141 long c4 = 2 * (a[0] * a[4] + a[1] * a[3]) + (a[2] * a[2]); 142 long c5 = 2 * (a[1] * a[4] + a[2] * a[3]); 143 long c6 = 2 * (a[2] * a[4]) + (a[3] * a[3]); 144 long c7 = 2 * (a[3] * a[4]); 145 long c8 = (a[4] * a[4]); 146 147 carryReduce(r, c0, c1, c2, c3, c4, c5, c6, c7, c8); 148 } 149 150 @Override 151 protected void encode(ByteBuffer buf, int length, byte highByte, 152 long[] result) { 153 if (length == 16) { 154 long low = buf.getLong(); 155 long high = buf.getLong(); 156 encode(high, low, highByte, result); 157 } else { 158 super.encode(buf, length, highByte, result); 159 } 160 } 161 162 protected void encode(long high, long low, byte highByte, long[] result) { 163 result[0] = low & 0x3FFFFFFL; 164 result[1] = (low >>> 26) & 0x3FFFFFFL; 165 result[2] = (low >>> 52) + ((high & 0x3FFFL) << 12); 166 result[3] = (high >>> 14) & 0x3FFFFFFL; 167 result[4] = (high >>> 40) + (highByte << 24L); 168 } 169 170 private static final VarHandle AS_LONG_LE = MethodHandles 171 .byteArrayViewVarHandle(long[].class, ByteOrder.LITTLE_ENDIAN); 172 173 protected void encode(byte[] v, int offset, int length, byte highByte, 174 long[] result) { 175 if (length == 16) { 176 long low = (long) AS_LONG_LE.get(v, offset); 177 long high = (long) AS_LONG_LE.get(v, offset + 8); 178 encode(high, low, highByte, result); 179 } else { 180 super.encode(v, offset, length, highByte, result); 181 } 182 } 183 184 protected void modReduceIn(long[] limbs, int index, long x) { 185 // this only works when BITS_PER_LIMB * NUM_LIMBS = POWER exactly 186 long reducedValue = (x * SUBTRAHEND); 187 limbs[index - NUM_LIMBS] += reducedValue; 188 } 189 190 protected final void modReduce(long[] limbs, int start, int end) { 191 192 for (int i = start; i < end; i++) { 193 modReduceIn(limbs, i, limbs[i]); 194 limbs[i] = 0; 195 } 196 } 197 198 protected void modReduce(long[] limbs) { 199 200 modReduce(limbs, NUM_LIMBS, NUM_LIMBS - 1); 201 } 202 203 @Override 204 protected long carryValue(long x) { 205 // This representation has plenty of extra space, so we can afford to 206 // do a simplified carry operation that is more time-efficient. 207 208 return x >> BITS_PER_LIMB; 209 } 210 211 212 protected void reduce(long[] limbs) { 213 long carry3 = carryOut(limbs, 3); 214 long new4 = carry3 + limbs[4]; 215 216 long carry4 = carryValue(new4); 217 limbs[4] = new4 - (carry4 << BITS_PER_LIMB); 218 219 modReduceIn(limbs, 5, carry4); 220 carry(limbs); 221 } 222 223 // Convert reduced limbs into a number between 0 and MODULUS-1 224 private void finalReduce(long[] limbs) { 225 226 addLimbs(limbs, posModLimbs, limbs); 227 // now all values are positive, so remaining operations will be unsigned 228 229 // unsigned carry out of last position and reduce in to first position 230 long carry = limbs[NUM_LIMBS - 1] >> BITS_PER_LIMB; 231 limbs[NUM_LIMBS - 1] -= carry << BITS_PER_LIMB; 232 modReduceIn(limbs, NUM_LIMBS, carry); 233 234 // unsigned carry on all positions 235 carry = 0; 236 for (int i = 0; i < NUM_LIMBS; i++) { 237 limbs[i] += carry; 238 carry = limbs[i] >> BITS_PER_LIMB; 239 limbs[i] -= carry << BITS_PER_LIMB; 240 } 241 // reduce final carry value back in 242 modReduceIn(limbs, NUM_LIMBS, carry); 243 // we only reduce back in a nonzero value if some value was carried out 244 // of the previous loop. So at least one remaining value is small. 245 246 // One more carry is all that is necessary. Nothing will be carried out 247 // at the end 248 carry = 0; 249 for (int i = 0; i < NUM_LIMBS; i++) { 250 limbs[i] += carry; 251 carry = limbs[i] >> BITS_PER_LIMB; 252 limbs[i] -= carry << BITS_PER_LIMB; 253 } 254 255 // limbs are positive and all less than 2^BITS_PER_LIMB 256 // but the value may be greater than the MODULUS. 257 // Subtract the max limb values only if all limbs end up non-negative 258 int smallerNonNegative = 1; 259 long[] smaller = new long[NUM_LIMBS]; 260 for (int i = NUM_LIMBS - 1; i >= 0; i--) { 261 smaller[i] = limbs[i] - posModLimbs[i]; 262 // expression on right is 1 if smaller[i] is nonnegative, 263 // 0 otherwise 264 smallerNonNegative *= (int) (smaller[i] >> 63) + 1; 265 } 266 conditionalSwap(smallerNonNegative, limbs, smaller); 267 268 } 269 270 @Override 271 protected void limbsToByteArray(long[] limbs, byte[] result) { 272 273 long[] reducedLimbs = limbs.clone(); 274 finalReduce(reducedLimbs); 275 276 decode(reducedLimbs, result, 0, result.length); 277 } 278 279 @Override 280 protected void addLimbsModPowerTwo(long[] limbs, long[] other, 281 byte[] result) { 282 283 long[] reducedOther = other.clone(); 284 long[] reducedLimbs = limbs.clone(); 285 finalReduce(reducedLimbs); 286 287 addLimbs(reducedLimbs, reducedOther, reducedLimbs); 288 289 // may carry out a value which can be ignored 290 long carry = 0; 291 for (int i = 0; i < NUM_LIMBS; i++) { 292 reducedLimbs[i] += carry; 293 carry = reducedLimbs[i] >> BITS_PER_LIMB; 294 reducedLimbs[i] -= carry << BITS_PER_LIMB; 295 } 296 297 decode(reducedLimbs, result, 0, result.length); 298 } 299 300 } 301