1 /*
   2  * Copyright (c) 2018, Oracle and/or its affiliates. All rights reserved.
   3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
   4  *
   5  * This code is free software; you can redistribute it and/or modify it
   6  * under the terms of the GNU General Public License version 2 only, as
   7  * published by the Free Software Foundation.  Oracle designates this
   8  * particular file as subject to the "Classpath" exception as provided
   9  * by Oracle in the LICENSE file that accompanied this code.
  10  *
  11  * This code is distributed in the hope that it will be useful, but WITHOUT
  12  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
  13  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
  14  * version 2 for more details (a copy is included in the LICENSE file that
  15  * accompanied this code).
  16  *
  17  * You should have received a copy of the GNU General Public License version
  18  * 2 along with this work; if not, write to the Free Software Foundation,
  19  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
  20  *
  21  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
  22  * or visit www.oracle.com if you need additional information or have any
  23  * questions.
  24  */
  25 
  26 package sun.security.ec;
  27 
  28 import sun.security.ec.point.*;
  29 import sun.security.util.math.*;
  30 import sun.security.util.math.intpoly.*;
  31 
  32 import java.math.BigInteger;
  33 import java.security.ProviderException;
  34 import java.security.spec.ECFieldFp;
  35 import java.security.spec.ECParameterSpec;
  36 import java.security.spec.EllipticCurve;
  37 import java.util.Map;
  38 import java.util.Optional;
  39 
  40 /*
  41  * Elliptic curve point arithmetic for prime-order curves where a=-3.
  42  * Formulas are derived from "Complete addition formulas for prime order
  43  * elliptic curves" by Renes, Costello, and Batina.
  44  */
  45 
  46 public class ECOperations {
  47 
  48     /*
  49      * An exception indicating a problem with an intermediate value produced
  50      * by some part of the computation. For example, the signing operation
  51      * will throw this exception to indicate that the r or s value is 0, and
  52      * that the signing operation should be tried again with a different nonce.
  53      */
  54     static class IntermediateValueException extends Exception {
  55         private static final long serialVersionUID = 1;
  56     }
  57 
  58     static final Map<BigInteger, IntegerFieldModuloP> fields = Map.of(
  59         IntegerPolynomialP256.MODULUS, new IntegerPolynomialP256(),
  60         IntegerPolynomialP384.MODULUS, new IntegerPolynomialP384(),
  61         IntegerPolynomialP521.MODULUS, new IntegerPolynomialP521()
  62     );
  63 
  64     static final Map<BigInteger, IntegerFieldModuloP> orderFields = Map.of(
  65         P256OrderField.MODULUS, new P256OrderField(),
  66         P384OrderField.MODULUS, new P384OrderField(),
  67         P521OrderField.MODULUS, new P521OrderField()
  68     );
  69 
  70     public static Optional<ECOperations> forParameters(ECParameterSpec params) {
  71 
  72         EllipticCurve curve = params.getCurve();
  73         if (!(curve.getField() instanceof ECFieldFp)) {
  74             return Optional.empty();
  75         }
  76         ECFieldFp primeField = (ECFieldFp) curve.getField();
  77 
  78         BigInteger three = BigInteger.valueOf(3);
  79         if (!primeField.getP().subtract(curve.getA()).equals(three)) {
  80             return Optional.empty();
  81         }
  82         IntegerFieldModuloP field = fields.get(primeField.getP());
  83         if (field == null) {
  84             return Optional.empty();
  85         }
  86 
  87         IntegerFieldModuloP orderField = orderFields.get(params.getOrder());
  88         if (orderField == null) {
  89             return Optional.empty();
  90         }
  91 
  92         ImmutableIntegerModuloP b = field.getElement(curve.getB());
  93         ECOperations ecOps = new ECOperations(b, orderField);
  94         return Optional.of(ecOps);
  95     }
  96 
  97     final ImmutableIntegerModuloP b;
  98     final SmallValue one;
  99     final SmallValue two;
 100     final SmallValue three;
 101     final SmallValue four;
 102     final ProjectivePoint.Immutable neutral;
 103     private final IntegerFieldModuloP orderField;
 104 
 105     public ECOperations(IntegerModuloP b, IntegerFieldModuloP orderField) {
 106         this.b = b.fixed();
 107         this.orderField = orderField;
 108 
 109         this.one = b.getField().getSmallValue(1);
 110         this.two = b.getField().getSmallValue(2);
 111         this.three = b.getField().getSmallValue(3);
 112         this.four = b.getField().getSmallValue(4);
 113 
 114         IntegerFieldModuloP field = b.getField();
 115         this.neutral = new ProjectivePoint.Immutable(field.get0(),
 116             field.get1(), field.get0());
 117     }
 118 
 119     public IntegerFieldModuloP getField() {
 120         return b.getField();
 121     }
 122     public IntegerFieldModuloP getOrderField() {
 123         return orderField;
 124     }
 125 
 126     protected ProjectivePoint.Immutable getNeutral() {
 127         return neutral;
 128     }
 129 
 130     public boolean isNeutral(Point p) {
 131         ProjectivePoint<?> pp = (ProjectivePoint<?>) p;
 132 
 133         IntegerModuloP z = pp.getZ();
 134 
 135         IntegerFieldModuloP field = z.getField();
 136         int byteLength = (field.getSize().bitLength() + 7) / 8;
 137         byte[] zBytes = z.asByteArray(byteLength);
 138         return allZero(zBytes);
 139     }
 140 
 141     byte[] seedToScalar(byte[] seedBytes)
 142         throws IntermediateValueException {
 143 
 144         // Produce a nonce from the seed using FIPS 186-4,section B.5.1:
 145         // Per-Message Secret Number Generation Using Extra Random Bits
 146         // or
 147         // Produce a scalar from the seed using FIPS 186-4, section B.4.1:
 148         // Key Pair Generation Using Extra Random Bits
 149 
 150         // To keep the implementation simple, sample in the range [0,n)
 151         // and throw IntermediateValueException in the (unlikely) event
 152         // that the result is 0.
 153 
 154         // Get 64 extra bits and reduce in to the nonce
 155         int seedBits = orderField.getSize().bitLength() + 64;
 156         if (seedBytes.length * 8 < seedBits) {
 157             throw new ProviderException("Incorrect seed length: " +
 158             seedBytes.length * 8 + " < " + seedBits);
 159         }
 160 
 161         // input conversion only works on byte boundaries
 162         // clear high-order bits of last byte so they don't influence nonce
 163         int lastByteBits = seedBits % 8;
 164         if (lastByteBits != 0) {
 165             int lastByteIndex = seedBits / 8;
 166             byte mask = (byte) (0xFF >>> (8 - lastByteBits));
 167             seedBytes[lastByteIndex] &= mask;
 168         }
 169 
 170         int seedLength = (seedBits + 7) / 8;
 171         IntegerModuloP scalarElem =
 172             orderField.getElement(seedBytes, 0, seedLength, (byte) 0);
 173         int scalarLength = (orderField.getSize().bitLength() + 7) / 8;
 174         byte[] scalarArr = new byte[scalarLength];
 175         scalarElem.asByteArray(scalarArr);
 176         if (ECOperations.allZero(scalarArr)) {
 177             throw new IntermediateValueException();
 178         }
 179         return scalarArr;
 180     }
 181 
 182     /*
 183      * Compare all values in the array to 0 without branching on any value
 184      *
 185      */
 186     public static boolean allZero(byte[] arr) {
 187         byte acc = 0;
 188         for (int i = 0; i < arr.length; i++) {
 189             acc |= arr[i];
 190         }
 191         return acc == 0;
 192     }
 193 
 194     /*
 195      * 4-bit branchless array lookup for projective points.
 196      */
 197     private void lookup4(ProjectivePoint.Immutable[] arr, int index,
 198         ProjectivePoint.Mutable result, IntegerModuloP zero) {
 199 
 200         for (int i = 0; i < 16; i++) {
 201             int xor = index ^ i;
 202             int bit3 = (xor & 0x8) >>> 3;
 203             int bit2 = (xor & 0x4) >>> 2;
 204             int bit1 = (xor & 0x2) >>> 1;
 205             int bit0 = (xor & 0x1);
 206             int inverse = bit0 | bit1 | bit2 | bit3;
 207             int set = 1 - inverse;
 208 
 209             ProjectivePoint.Immutable pi = arr[i];
 210             result.conditionalSet(pi, set);
 211         }
 212     }
 213 
 214     private void double4(ProjectivePoint.Mutable p, MutableIntegerModuloP t0,
 215         MutableIntegerModuloP t1, MutableIntegerModuloP t2,
 216         MutableIntegerModuloP t3, MutableIntegerModuloP t4) {
 217 
 218         for (int i = 0; i < 4; i++) {
 219             setDouble(p, t0, t1, t2, t3, t4);
 220         }
 221     }
 222 
 223     /**
 224      * Multiply an affine point by a scalar and return the result as a mutable
 225      * point.
 226      *
 227      * @param affineP the point
 228      * @param s the scalar as a little-endian array
 229      * @return the product
 230      */
 231     public MutablePoint multiply(AffinePoint affineP, byte[] s) {
 232 
 233         // 4-bit windowed multiply with branchless lookup.
 234         // The mixed addition is faster, so it is used to construct the array
 235         // at the beginning of the operation.
 236 
 237         IntegerFieldModuloP field = affineP.getX().getField();
 238         ImmutableIntegerModuloP zero = field.get0();
 239         // temporaries
 240         MutableIntegerModuloP t0 = zero.mutable();
 241         MutableIntegerModuloP t1 = zero.mutable();
 242         MutableIntegerModuloP t2 = zero.mutable();
 243         MutableIntegerModuloP t3 = zero.mutable();
 244         MutableIntegerModuloP t4 = zero.mutable();
 245 
 246         ProjectivePoint.Mutable result = new ProjectivePoint.Mutable(field);
 247         result.getY().setValue(field.get1().mutable());
 248 
 249         ProjectivePoint.Immutable[] pointMultiples =
 250             new ProjectivePoint.Immutable[16];
 251         // 0P is neutral---same as initial result value
 252         pointMultiples[0] = result.fixed();
 253 
 254         ProjectivePoint.Mutable ps = new ProjectivePoint.Mutable(field);
 255         ps.setValue(affineP);
 256         // 1P = P
 257         pointMultiples[1] = ps.fixed();
 258 
 259         // the rest are calculated using mixed point addition
 260         for (int i = 2; i < 16; i++) {
 261             setSum(ps, affineP, t0, t1, t2, t3, t4);
 262             pointMultiples[i] = ps.fixed();
 263         }
 264 
 265         ProjectivePoint.Mutable lookupResult = ps.mutable();
 266 
 267         for (int i = s.length - 1; i >= 0; i--) {
 268 
 269             double4(result, t0, t1, t2, t3, t4);
 270 
 271             int high = (0xFF & s[i]) >>> 4;
 272             lookup4(pointMultiples, high, lookupResult, zero);
 273             setSum(result, lookupResult, t0, t1, t2, t3, t4);
 274 
 275             double4(result, t0, t1, t2, t3, t4);
 276 
 277             int low = 0xF & s[i];
 278             lookup4(pointMultiples, low, lookupResult, zero);
 279             setSum(result, lookupResult, t0, t1, t2, t3, t4);
 280         }
 281 
 282         return result;
 283 
 284     }
 285 
 286     /*
 287      * Point double
 288      */
 289     private void setDouble(ProjectivePoint.Mutable p, MutableIntegerModuloP t0,
 290         MutableIntegerModuloP t1, MutableIntegerModuloP t2,
 291         MutableIntegerModuloP t3, MutableIntegerModuloP t4) {
 292 
 293         t0.setValue(p.getX()).setSquare();
 294         t1.setValue(p.getY()).setSquare();
 295         t2.setValue(p.getZ()).setSquare();
 296         t3.setValue(p.getX()).setProduct(p.getY());
 297         t4.setValue(p.getY()).setProduct(p.getZ());
 298 
 299         t3.setSum(t3);
 300         p.getZ().setProduct(p.getX());
 301 
 302         p.getZ().setProduct(two);
 303 
 304         p.getY().setValue(t2).setProduct(b);
 305         p.getY().setDifference(p.getZ());
 306 
 307         p.getX().setValue(p.getY()).setProduct(two);
 308         p.getY().setSum(p.getX());
 309         p.getY().setReduced();
 310         p.getX().setValue(t1).setDifference(p.getY());
 311 
 312         p.getY().setSum(t1);
 313         p.getY().setProduct(p.getX());
 314         p.getX().setProduct(t3);
 315 
 316         t3.setValue(t2).setProduct(two);
 317         t2.setSum(t3);
 318         p.getZ().setProduct(b);
 319 
 320         t2.setReduced();
 321         p.getZ().setDifference(t2);
 322         p.getZ().setDifference(t0);
 323         t3.setValue(p.getZ()).setProduct(two);
 324         p.getZ().setReduced();
 325         p.getZ().setSum(t3);
 326         t0.setProduct(three);
 327 
 328         t0.setDifference(t2);
 329         t0.setProduct(p.getZ());
 330         p.getY().setSum(t0);
 331 
 332         t4.setSum(t4);
 333         p.getZ().setProduct(t4);
 334 
 335         p.getX().setDifference(p.getZ());
 336         p.getZ().setValue(t4).setProduct(t1);
 337 
 338         p.getZ().setProduct(four);
 339 
 340     }
 341 
 342     /*
 343      * Mixed point addition. This method constructs new temporaries each time
 344      * it is called. For better efficiency, the method that reuses temporaries
 345      * should be used if more than one sum will be computed.
 346      */
 347     public void setSum(MutablePoint p, AffinePoint p2) {
 348 
 349         IntegerModuloP zero = p.getField().get0();
 350         MutableIntegerModuloP t0 = zero.mutable();
 351         MutableIntegerModuloP t1 = zero.mutable();
 352         MutableIntegerModuloP t2 = zero.mutable();
 353         MutableIntegerModuloP t3 = zero.mutable();
 354         MutableIntegerModuloP t4 = zero.mutable();
 355         setSum((ProjectivePoint.Mutable) p, p2, t0, t1, t2, t3, t4);
 356 
 357     }
 358 
 359     /*
 360      * Mixed point addition
 361      */
 362     private void setSum(ProjectivePoint.Mutable p, AffinePoint p2,
 363         MutableIntegerModuloP t0, MutableIntegerModuloP t1,
 364         MutableIntegerModuloP t2, MutableIntegerModuloP t3,
 365         MutableIntegerModuloP t4) {
 366 
 367         t0.setValue(p.getX()).setProduct(p2.getX());
 368         t1.setValue(p.getY()).setProduct(p2.getY());
 369         t3.setValue(p2.getX()).setSum(p2.getY());
 370         t4.setValue(p.getX()).setSum(p.getY());
 371         p.getX().setReduced();
 372         t3.setProduct(t4);
 373         t4.setValue(t0).setSum(t1);
 374 
 375         t3.setDifference(t4);
 376         t4.setValue(p2.getY()).setProduct(p.getZ());
 377         t4.setSum(p.getY());
 378 
 379         p.getY().setValue(p2.getX()).setProduct(p.getZ());
 380         p.getY().setSum(p.getX());
 381         t2.setValue(p.getZ());
 382         p.getZ().setProduct(b);
 383 
 384         p.getX().setValue(p.getY()).setDifference(p.getZ());
 385         p.getX().setReduced();
 386         p.getZ().setValue(p.getX()).setProduct(two);
 387         p.getX().setSum(p.getZ());
 388 
 389         p.getZ().setValue(t1).setDifference(p.getX());
 390         p.getX().setSum(t1);
 391         p.getY().setProduct(b);
 392 
 393         t1.setValue(t2).setProduct(two);
 394         t2.setSum(t1);
 395         t2.setReduced();
 396         p.getY().setDifference(t2);
 397 
 398         p.getY().setDifference(t0);
 399         p.getY().setReduced();
 400         t1.setValue(p.getY()).setProduct(two);
 401         p.getY().setSum(t1);
 402 
 403         t1.setValue(t0).setProduct(two);
 404         t0.setSum(t1);
 405         t0.setDifference(t2);
 406 
 407         t1.setValue(t4).setProduct(p.getY());
 408         t2.setValue(t0).setProduct(p.getY());
 409         p.getY().setValue(p.getX()).setProduct(p.getZ());
 410 
 411         p.getY().setSum(t2);
 412         p.getX().setProduct(t3);
 413         p.getX().setDifference(t1);
 414 
 415         p.getZ().setProduct(t4);
 416         t1.setValue(t3).setProduct(t0);
 417         p.getZ().setSum(t1);
 418 
 419     }
 420 
 421     /*
 422      * Projective point addition
 423      */
 424     private void setSum(ProjectivePoint.Mutable p, ProjectivePoint.Mutable p2,
 425         MutableIntegerModuloP t0, MutableIntegerModuloP t1,
 426         MutableIntegerModuloP t2, MutableIntegerModuloP t3,
 427         MutableIntegerModuloP t4) {
 428 
 429         t0.setValue(p.getX()).setProduct(p2.getX());
 430         t1.setValue(p.getY()).setProduct(p2.getY());
 431         t2.setValue(p.getZ()).setProduct(p2.getZ());
 432 
 433         t3.setValue(p.getX()).setSum(p.getY());
 434         t4.setValue(p2.getX()).setSum(p2.getY());
 435         t3.setProduct(t4);
 436 
 437         t4.setValue(t0).setSum(t1);
 438         t3.setDifference(t4);
 439         t4.setValue(p.getY()).setSum(p.getZ());
 440 
 441         p.getY().setValue(p2.getY()).setSum(p2.getZ());
 442         t4.setProduct(p.getY());
 443         p.getY().setValue(t1).setSum(t2);
 444 
 445         t4.setDifference(p.getY());
 446         p.getX().setSum(p.getZ());
 447         p.getY().setValue(p2.getX()).setSum(p2.getZ());
 448 
 449         p.getX().setProduct(p.getY());
 450         p.getY().setValue(t0).setSum(t2);
 451         p.getY().setAdditiveInverse().setSum(p.getX());
 452         p.getY().setReduced();
 453 
 454         p.getZ().setValue(t2).setProduct(b);
 455         p.getX().setValue(p.getY()).setDifference(p.getZ());
 456         p.getZ().setValue(p.getX()).setProduct(two);
 457 
 458         p.getX().setSum(p.getZ());
 459         p.getX().setReduced();
 460         p.getZ().setValue(t1).setDifference(p.getX());
 461         p.getX().setSum(t1);
 462 
 463         p.getY().setProduct(b);
 464         t1.setValue(t2).setSum(t2);
 465         t2.setSum(t1);
 466         t2.setReduced();
 467 
 468         p.getY().setDifference(t2);
 469         p.getY().setDifference(t0);
 470         p.getY().setReduced();
 471         t1.setValue(p.getY()).setSum(p.getY());
 472 
 473         p.getY().setSum(t1);
 474         t1.setValue(t0).setProduct(two);
 475         t0.setSum(t1);
 476 
 477         t0.setDifference(t2);
 478         t1.setValue(t4).setProduct(p.getY());
 479         t2.setValue(t0).setProduct(p.getY());
 480 
 481         p.getY().setValue(p.getX()).setProduct(p.getZ());
 482         p.getY().setSum(t2);
 483         p.getX().setProduct(t3);
 484 
 485         p.getX().setDifference(t1);
 486         p.getZ().setProduct(t4);
 487         t1.setValue(t3).setProduct(t0);
 488 
 489         p.getZ().setSum(t1);
 490 
 491     }
 492 }
 493