1 /*
   2  * Copyright (c) 2018, Oracle and/or its affiliates. All rights reserved.
   3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
   4  *
   5  * This code is free software; you can redistribute it and/or modify it
   6  * under the terms of the GNU General Public License version 2 only, as
   7  * published by the Free Software Foundation.  Oracle designates this
   8  * particular file as subject to the "Classpath" exception as provided
   9  * by Oracle in the LICENSE file that accompanied this code.
  10  *
  11  * This code is distributed in the hope that it will be useful, but WITHOUT
  12  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
  13  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
  14  * version 2 for more details (a copy is included in the LICENSE file that
  15  * accompanied this code).
  16  *
  17  * You should have received a copy of the GNU General Public License version
  18  * 2 along with this work; if not, write to the Free Software Foundation,
  19  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
  20  *
  21  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
  22  * or visit www.oracle.com if you need additional information or have any
  23  * questions.
  24  */
  25 
  26 package sun.security.ec;
  27 
  28 import sun.security.util.math.IntegerFieldModuloP;
  29 import sun.security.util.math.IntegerModuloP;
  30 import sun.security.util.math.IntegerModuloP_Base;
  31 import sun.security.util.math.MutableIntegerModuloP;
  32 import sun.security.util.math.SmallValue;
  33 import sun.security.util.math.intpoly.IntegerPolynomial25519;
  34 import sun.security.util.math.intpoly.IntegerPolynomial448;
  35 
  36 import java.math.BigInteger;
  37 import java.security.ProviderException;
  38 import java.security.SecureRandom;
  39 
  40 public class XECOperations {
  41 
  42     private final XECParameters params;
  43     private final IntegerFieldModuloP field;
  44     private final IntegerModuloP zero;
  45     private final IntegerModuloP one;
  46     private final SmallValue a24;
  47     private final IntegerModuloP basePoint;
  48 
  49     public XECOperations(XECParameters c) {
  50         this.params = c;
  51 
  52         BigInteger p = params.getP();
  53         this.field = getIntegerFieldModulo(p);
  54         this.zero = field.getElement(BigInteger.ZERO).fixed();
  55         this.one = field.get1().fixed();
  56         this.a24 = field.getSmallValue(params.getA24());
  57         this.basePoint = field.getElement(
  58             BigInteger.valueOf(c.getBasePoint()));
  59     }
  60 
  61     public XECParameters getParameters() {
  62         return params;
  63     }
  64 
  65     public byte[] generatePrivate(SecureRandom random) {
  66         byte[] result = new byte[this.params.getBytes()];
  67         random.nextBytes(result);
  68         return result;
  69     }
  70 
  71     /**
  72      * Compute a public key from an encoded private key. This method will
  73      * modify the supplied array in order to prune it.
  74      */
  75     public BigInteger computePublic(byte[] k) {
  76         pruneK(k);
  77         return pointMultiply(k, this.basePoint).asBigInteger();
  78     }
  79 
  80     /**
  81      *
  82      * Multiply an encoded scalar with a point as a BigInteger and return an
  83      * encoded point. The array k holding the scalar will be pruned by
  84      * modifying it in place.
  85      *
  86      * @param k an encoded scalar
  87      * @param u the u-coordinate of a point as a BigInteger
  88      * @return the encoded product
  89      */
  90     public byte[] encodedPointMultiply(byte[] k, BigInteger u) {
  91         pruneK(k);
  92         IntegerModuloP elemU = field.getElement(u);
  93         return pointMultiply(k, elemU).asByteArray(params.getBytes());
  94     }
  95 
  96     /**
  97      *
  98      * Multiply an encoded scalar with an encoded point and return an encoded
  99      * point. The array k holding the scalar will be pruned by
 100      * modifying it in place.
 101      *
 102      * @param k an encoded scalar
 103      * @param u an encoded point
 104      * @return the encoded product
 105      */
 106     public byte[] encodedPointMultiply(byte[] k, byte[] u) {
 107         pruneK(k);
 108         IntegerModuloP elemU = decodeU(u);
 109         return pointMultiply(k, elemU).asByteArray(params.getBytes());
 110     }
 111 
 112     /**
 113      * Return the field element corresponding to an encoded u-coordinate.
 114      * This method prunes u by modifying it in place.
 115      *
 116      * @param u
 117      * @param bits
 118      * @return
 119      */
 120     private IntegerModuloP decodeU(byte[] u, int bits) {
 121 
 122         maskHighOrder(u, bits);
 123 
 124         return field.getElement(u);
 125     }
 126 
 127     /**
 128      * Mask off the high order bits of an encoded integer in an array. The
 129      * array is modified in place.
 130      *
 131      * @param arr an array containing an encoded integer
 132      * @param bits the number of bits to keep
 133      * @return the number, in range [1,8], of bits kept in the highest byte
 134      */
 135     private static byte maskHighOrder(byte[] arr, int bits) {
 136 
 137         int lastByteIndex = arr.length - 1;
 138         byte bitsMod8 = (byte) (bits % 8);
 139         byte highBits = bitsMod8 == 0 ? 8 : bitsMod8;
 140         byte msbMaskOff = (byte) ((1 << highBits) - 1);
 141         arr[lastByteIndex] &= msbMaskOff;
 142 
 143         return highBits;
 144     }
 145 
 146     /**
 147      * Prune an encoded scalar value by modifying it in place. The extra
 148      * high-order bits are masked off, the highest valid bit it set, and the
 149      * number is rounded down to a multiple of the cofactor.
 150      *
 151      * @param k an encoded scalar value
 152      * @param bits the number of bits in the scalar
 153      * @param logCofactor the base-2 logarithm of the cofactor
 154      */
 155     private static void pruneK(byte[] k, int bits, int logCofactor) {
 156 
 157         int lastByteIndex = k.length - 1;
 158 
 159         // mask off unused high-order bits
 160         byte highBits = maskHighOrder(k, bits);
 161 
 162         // set the highest bit
 163         byte msbMaskOn = (byte) (1 << (highBits - 1));
 164         k[lastByteIndex] |= msbMaskOn;
 165 
 166         // round down to a multiple of the cofactor
 167         byte lsbMaskOff = (byte) (0xFF << logCofactor);
 168         k[0] &= lsbMaskOff;
 169     }
 170 
 171     private void pruneK(byte[] k) {
 172         pruneK(k, params.getBits(), params.getLogCofactor());
 173     }
 174 
 175     private IntegerModuloP decodeU(byte [] u) {
 176         return decodeU(u, params.getBits());
 177     }
 178 
 179     // Constant-time conditional swap
 180     private static void cswap(int swap, MutableIntegerModuloP x1,
 181         MutableIntegerModuloP x2) {
 182 
 183         x1.conditionalSwapWith(x2, swap);
 184     }
 185 
 186     private static IntegerFieldModuloP getIntegerFieldModulo(BigInteger p) {
 187 
 188         if (p.equals(IntegerPolynomial25519.MODULUS)) {
 189             return new IntegerPolynomial25519();
 190         }
 191         else if (p.equals(IntegerPolynomial448.MODULUS)) {
 192             return new IntegerPolynomial448();
 193         }
 194 
 195         throw new ProviderException("Unsupported prime: " + p.toString());
 196     }
 197 
 198     private int bitAt(byte[] arr, int index) {
 199         int byteIndex = index / 8;
 200         int bitIndex = index % 8;
 201         return (arr[byteIndex] & (1 << bitIndex)) >> bitIndex;
 202     }
 203 
 204     /*
 205      * Constant-time Montgomery ladder that computes k*u and returns the
 206      * result as a field element.
 207      */
 208     private IntegerModuloP_Base pointMultiply(byte[] k, IntegerModuloP u) {
 209 
 210         IntegerModuloP x_1 = u;
 211         MutableIntegerModuloP x_2 = this.one.mutable();
 212         MutableIntegerModuloP z_2 = this.zero.mutable();
 213         MutableIntegerModuloP x_3 = u.mutable();
 214         MutableIntegerModuloP z_3 = this.one.mutable();
 215         int swap = 0;
 216 
 217         // Variables below are reused to avoid unnecessary allocation
 218         // They will be assigned in the loop, so initial value doesn't matter
 219         MutableIntegerModuloP m1 = this.zero.mutable();
 220         MutableIntegerModuloP DA = this.zero.mutable();
 221         MutableIntegerModuloP E = this.zero.mutable();
 222         MutableIntegerModuloP a24_times_E = this.zero.mutable();
 223 
 224         // Comments describe the equivalent operations from RFC 7748
 225         // In comments, A(m1) means the variable m1 holds the value A
 226         for (int t = params.getBits() - 1; t >= 0; t--) {
 227             int k_t = bitAt(k, t);
 228             swap = swap ^ k_t;
 229             cswap(swap, x_2, x_3);
 230             cswap(swap, z_2, z_3);
 231             swap = k_t;
 232 
 233             // A(m1) = x_2 + z_2
 234             m1.setValue(x_2).setSum(z_2);
 235             // D = x_3 - z_3
 236             // DA = D * A(m1)
 237             DA.setValue(x_3).setDifference(z_3).setProduct(m1);
 238             // AA(m1) = A(m1)^2
 239             m1.setSquare();
 240             // B(x_2) = x_2 - z_2
 241             x_2.setDifference(z_2);
 242             // C = x_3 + z_3
 243             // CB(x_3) = C * B(x_2)
 244             x_3.setSum(z_3).setProduct(x_2);
 245             // BB(x_2) = B^2
 246             x_2.setSquare();
 247             // E = AA(m1) - BB(x_2)
 248             E.setValue(m1).setDifference(x_2);
 249             // compute a24 * E using SmallValue
 250             a24_times_E.setValue(E);
 251             a24_times_E.setProduct(this.a24);
 252 
 253             // assign results to x_3, z_3, x_2, z_2
 254             // x_2 = AA(m1) * BB
 255             x_2.setProduct(m1);
 256             // z_2 = E * (AA(m1) + a24 * E)
 257             z_2.setValue(m1).setSum(a24_times_E).setProduct(E);
 258             // z_3 = x_1*(DA - CB(x_3))^2
 259             z_3.setValue(DA).setDifference(x_3).setSquare().setProduct(x_1);
 260             // x_3 = (CB(x_3) + DA)^2
 261             x_3.setSum(DA).setSquare();
 262         }
 263 
 264         cswap(swap, x_2, x_3);
 265         cswap(swap, z_2, z_3);
 266 
 267         // return (x_2 * z_2^(p - 2))
 268         return x_2.setProduct(z_2.multiplicativeInverse());
 269     }
 270 }