1 /*
   2  * Copyright (c) 2018, Oracle and/or its affiliates. All rights reserved.
   3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
   4  *
   5  * This code is free software; you can redistribute it and/or modify it
   6  * under the terms of the GNU General Public License version 2 only, as
   7  * published by the Free Software Foundation.  Oracle designates this
   8  * particular file as subject to the "Classpath" exception as provided
   9  * by Oracle in the LICENSE file that accompanied this code.
  10  *
  11  * This code is distributed in the hope that it will be useful, but WITHOUT
  12  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
  13  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
  14  * version 2 for more details (a copy is included in the LICENSE file that
  15  * accompanied this code).
  16  *
  17  * You should have received a copy of the GNU General Public License version
  18  * 2 along with this work; if not, write to the Free Software Foundation,
  19  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
  20  *
  21  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
  22  * or visit www.oracle.com if you need additional information or have any
  23  * questions.
  24  */
  25 
  26 package sun.security.rsa;
  27 
  28 import java.io.IOException;
  29 import java.nio.ByteBuffer;
  30 
  31 import java.security.*;
  32 import java.security.spec.AlgorithmParameterSpec;
  33 import java.security.spec.PSSParameterSpec;
  34 import java.security.spec.MGF1ParameterSpec;
  35 import java.security.interfaces.*;
  36 
  37 import java.util.Arrays;
  38 import java.util.Hashtable;
  39 
  40 import sun.security.util.*;
  41 import sun.security.jca.JCAUtil;
  42 
  43 
  44 /**
  45  * PKCS#1 v2.2 RSASSA-PSS signatures with various message digest algorithms.
  46  * RSASSA-PSS implementation takes the message digest algorithm, MGF algorithm,
  47  * and salt length values through the required signature PSS parameters.
  48  * We support SHA-1, SHA-224, SHA-256, SHA-384, SHA-512, SHA-512/224, and
  49  * SHA-512/256 message digest algorithms and MGF1 mask generation function.
  50  *
  51  * @since   11
  52  */
  53 public class RSAPSSSignature extends SignatureSpi {
  54 
  55     private static final boolean DEBUG = false;
  56 
  57     // utility method for comparing digest algorithms
  58     // NOTE that first argument is assumed to be standard digest name
  59     private boolean isDigestEqual(String stdAlg, String givenAlg) {
  60         if (stdAlg == null || givenAlg == null) return false;
  61 
  62         if (givenAlg.indexOf("-") != -1) {
  63             return stdAlg.equalsIgnoreCase(givenAlg);
  64         } else {
  65             if (stdAlg.equals("SHA-1")) {
  66                 return (givenAlg.equalsIgnoreCase("SHA")
  67                         || givenAlg.equalsIgnoreCase("SHA1"));
  68             } else {
  69                 StringBuilder sb = new StringBuilder(givenAlg);
  70                 // case-insensitive check
  71                 if (givenAlg.regionMatches(true, 0, "SHA", 0, 3)) {
  72                     givenAlg = sb.insert(3, "-").toString();
  73                     return stdAlg.equalsIgnoreCase(givenAlg);
  74                 } else {
  75                     throw new ProviderException("Unsupported digest algorithm "
  76                             + givenAlg);
  77                 }
  78             }
  79         }
  80     }
  81 
  82     private static final byte[] EIGHT_BYTES_OF_ZEROS = new byte[8];
  83 
  84     private static final Hashtable<String, Integer> DIGEST_LENGTHS =
  85         new Hashtable<String, Integer>();
  86     static {
  87         DIGEST_LENGTHS.put("SHA-1", 20);
  88         DIGEST_LENGTHS.put("SHA", 20);
  89         DIGEST_LENGTHS.put("SHA1", 20);
  90         DIGEST_LENGTHS.put("SHA-224", 28);
  91         DIGEST_LENGTHS.put("SHA224", 28);
  92         DIGEST_LENGTHS.put("SHA-256", 32);
  93         DIGEST_LENGTHS.put("SHA256", 32);
  94         DIGEST_LENGTHS.put("SHA-384", 48);
  95         DIGEST_LENGTHS.put("SHA384", 48);
  96         DIGEST_LENGTHS.put("SHA-512", 64);
  97         DIGEST_LENGTHS.put("SHA512", 64);
  98         DIGEST_LENGTHS.put("SHA-512/224", 28);
  99         DIGEST_LENGTHS.put("SHA512/224", 28);
 100         DIGEST_LENGTHS.put("SHA-512/256", 32);
 101         DIGEST_LENGTHS.put("SHA512/256", 32);
 102     }
 103 
 104     // message digest implementation we use for hashing the data
 105     private MessageDigest md;
 106     // flag indicating whether the digest is reset
 107     private boolean digestReset = true;
 108 
 109     // private key, if initialized for signing
 110     private RSAPrivateKey privKey = null;
 111     // public key, if initialized for verifying
 112     private RSAPublicKey pubKey = null;
 113     // PSS parameters from signatures and keys respectively
 114     private PSSParameterSpec sigParams = null; // required for PSS signatures
 115 
 116     // PRNG used to generate salt bytes if none given
 117     private SecureRandom random;
 118 
 119     /**
 120      * Construct a new RSAPSSSignatur with arbitrary digest algorithm
 121      */
 122     public RSAPSSSignature() {
 123         this.md = null;
 124     }
 125 
 126     // initialize for verification. See JCA doc
 127     @Override
 128     protected void engineInitVerify(PublicKey publicKey)
 129             throws InvalidKeyException {
 130         if (!(publicKey instanceof RSAPublicKey)) {
 131             throw new InvalidKeyException("key must be RSAPublicKey");
 132         }
 133         this.pubKey = (RSAPublicKey) isValid((RSAKey)publicKey);
 134         this.privKey = null;
 135 
 136     }
 137 
 138     // initialize for signing. See JCA doc
 139     @Override
 140     protected void engineInitSign(PrivateKey privateKey)
 141             throws InvalidKeyException {
 142         engineInitSign(privateKey, null);
 143     }
 144 
 145     // initialize for signing. See JCA doc
 146     @Override
 147     protected void engineInitSign(PrivateKey privateKey, SecureRandom random)
 148             throws InvalidKeyException {
 149         if (!(privateKey instanceof RSAPrivateKey)) {
 150             throw new InvalidKeyException("key must be RSAPrivateKey");
 151         }
 152         this.privKey = (RSAPrivateKey) isValid((RSAKey)privateKey);
 153         this.pubKey = null;
 154         this.random =
 155             (random == null? JCAUtil.getSecureRandom() : random);
 156     }
 157 
 158     /**
 159      * Utility method for checking the key PSS parameters against signature
 160      * PSS parameters.
 161      * Returns false if any of the digest/MGF algorithms and trailerField
 162      * values does not match or if the salt length in key parameters is
 163      * larger than the value in signature parameters.
 164      */
 165     private static boolean isCompatible(AlgorithmParameterSpec keyParams,
 166             PSSParameterSpec sigParams) {
 167         if (keyParams == null) {
 168             // key with null PSS parameters means no restriction
 169             return true;
 170         }
 171         if (!(keyParams instanceof PSSParameterSpec)) {
 172             return false;
 173         }
 174         // nothing to compare yet, defer the check to when sigParams is set
 175         if (sigParams == null) {
 176             return true;
 177         }
 178         PSSParameterSpec pssKeyParams = (PSSParameterSpec) keyParams;
 179         // first check the salt length requirement
 180         if (pssKeyParams.getSaltLength() > sigParams.getSaltLength()) {
 181             return false;
 182         }
 183 
 184         // compare equality of the rest of fields based on DER encoding
 185         PSSParameterSpec keyParams2 =
 186             new PSSParameterSpec(pssKeyParams.getDigestAlgorithm(),
 187                     pssKeyParams.getMGFAlgorithm(),
 188                     pssKeyParams.getMGFParameters(),
 189                     sigParams.getSaltLength(),
 190                     pssKeyParams.getTrailerField());
 191         PSSParameters ap = new PSSParameters();
 192         try {
 193             ap.engineInit(keyParams2);
 194             byte[] encoded = ap.engineGetEncoded();
 195             ap.engineInit(sigParams);
 196             byte[] encoded2 = ap.engineGetEncoded();
 197             return Arrays.equals(encoded, encoded2);
 198         } catch (Exception e) {
 199             if (DEBUG) {
 200                 e.printStackTrace();
 201             }
 202             return false;
 203         }
 204     }
 205 
 206     /**
 207      * Validate the specified RSAKey and its associated parameters against
 208      * internal signature parameters.
 209      */
 210     private RSAKey isValid(RSAKey rsaKey) throws InvalidKeyException {
 211         try {
 212             AlgorithmParameterSpec keyParams = rsaKey.getParams();
 213             // validate key parameters
 214             if (!isCompatible(rsaKey.getParams(), this.sigParams)) {
 215                 throw new InvalidKeyException
 216                     ("Key contains incompatible PSS parameter values");
 217             }
 218             // validate key length
 219             if (this.sigParams != null) {
 220                 Integer hLen =
 221                     DIGEST_LENGTHS.get(this.sigParams.getDigestAlgorithm());
 222                 if (hLen == null) {
 223                     throw new ProviderException("Unsupported digest algo: " +
 224                         this.sigParams.getDigestAlgorithm());
 225                 }
 226                 checkKeyLength(rsaKey, hLen, this.sigParams.getSaltLength());
 227             }
 228             return rsaKey;
 229         } catch (SignatureException e) {
 230             throw new InvalidKeyException(e);
 231         }
 232     }
 233 
 234     /**
 235      * Validate the specified Signature PSS parameters.
 236      */
 237     private PSSParameterSpec validateSigParams(AlgorithmParameterSpec p)
 238             throws InvalidAlgorithmParameterException {
 239         if (p == null) {
 240             throw new InvalidAlgorithmParameterException
 241                 ("Parameters cannot be null");
 242         }
 243         if (!(p instanceof PSSParameterSpec)) {
 244             throw new InvalidAlgorithmParameterException
 245                 ("parameters must be type PSSParameterSpec");
 246         }
 247         // no need to validate again if same as current signature parameters
 248         PSSParameterSpec params = (PSSParameterSpec) p;
 249         if (params == this.sigParams) return params;
 250 
 251         RSAKey key = (this.privKey == null? this.pubKey : this.privKey);
 252         // check against keyParams if set
 253         if (key != null) {
 254             if (!isCompatible(key.getParams(), params)) {
 255                 throw new InvalidAlgorithmParameterException
 256                     ("Signature parameters does not match key parameters");
 257             }
 258         }
 259         // now sanity check the parameter values
 260         if (!(params.getMGFAlgorithm().equalsIgnoreCase("MGF1"))) {
 261             throw new InvalidAlgorithmParameterException("Only supports MGF1");
 262 
 263         }
 264         if (params.getTrailerField() != 1) {
 265             throw new InvalidAlgorithmParameterException
 266                 ("Only supports TrailerFieldBC(1)");
 267 
 268         }
 269         String digestAlgo = params.getDigestAlgorithm();
 270         // check key length again
 271         if (key != null) {
 272             try {
 273                 int hLen = DIGEST_LENGTHS.get(digestAlgo);
 274                 checkKeyLength(key, hLen, params.getSaltLength());
 275             } catch (SignatureException e) {
 276                 throw new InvalidAlgorithmParameterException(e);
 277             }
 278         }
 279         return params;
 280     }
 281 
 282     /**
 283      * Ensure the object is initialized with key and parameters and
 284      * reset digest
 285      */
 286     private void ensureInit() throws SignatureException {
 287         RSAKey key = (this.privKey == null? this.pubKey : this.privKey);
 288         if (key == null) {
 289             throw new SignatureException("Missing key");
 290         }
 291         if (this.sigParams == null) {
 292             // Parameters are required for signature verification
 293             throw new SignatureException
 294                 ("Parameters required for RSASSA-PSS signatures");
 295         }
 296     }
 297 
 298     /**
 299      * Utility method for checking key length against digest length and
 300      * salt length
 301      */
 302     private static void checkKeyLength(RSAKey key, int digestLen,
 303             int saltLen) throws SignatureException {
 304         if (key != null) {
 305             int keyLength = getKeyLengthInBits(key) >> 3;
 306             int minLength = Math.addExact(Math.addExact(digestLen, saltLen), 2);
 307             if (keyLength < minLength) {
 308                 throw new SignatureException
 309                     ("Key is too short, need min " + minLength);
 310             }
 311         }
 312     }
 313 
 314     /**
 315      * Reset the message digest if it is not already reset.
 316      */
 317     private void resetDigest() {
 318         if (digestReset == false) {
 319             this.md.reset();
 320             digestReset = true;
 321         }
 322     }
 323 
 324     /**
 325      * Return the message digest value.
 326      */
 327     private byte[] getDigestValue() {
 328         digestReset = true;
 329         return this.md.digest();
 330     }
 331 
 332     // update the signature with the plaintext data. See JCA doc
 333     @Override
 334     protected void engineUpdate(byte b) throws SignatureException {
 335         ensureInit();
 336         this.md.update(b);
 337         digestReset = false;
 338     }
 339 
 340     // update the signature with the plaintext data. See JCA doc
 341     @Override
 342     protected void engineUpdate(byte[] b, int off, int len)
 343             throws SignatureException {
 344         ensureInit();
 345         this.md.update(b, off, len);
 346         digestReset = false;
 347     }
 348 
 349     // update the signature with the plaintext data. See JCA doc
 350     @Override
 351     protected void engineUpdate(ByteBuffer b) {
 352         try {
 353             ensureInit();
 354         } catch (SignatureException se) {
 355             // hack for working around API bug
 356             throw new RuntimeException(se.getMessage());
 357         }
 358         this.md.update(b);
 359         digestReset = false;
 360     }
 361 
 362     // sign the data and return the signature. See JCA doc
 363     @Override
 364     protected byte[] engineSign() throws SignatureException {
 365         ensureInit();
 366         byte[] mHash = getDigestValue();
 367         try {
 368             byte[] encoded = encodeSignature(mHash);
 369             byte[] encrypted = RSACore.rsa(encoded, privKey, true);
 370             return encrypted;
 371         } catch (GeneralSecurityException e) {
 372             throw new SignatureException("Could not sign data", e);
 373         } catch (IOException e) {
 374             throw new SignatureException("Could not encode data", e);
 375         }
 376     }
 377 
 378     // verify the data and return the result. See JCA doc
 379     // should be reset to the state after engineInitVerify call.
 380     @Override
 381     protected boolean engineVerify(byte[] sigBytes) throws SignatureException {
 382         ensureInit();
 383         try {
 384             if (sigBytes.length != RSACore.getByteLength(this.pubKey)) {
 385                 throw new SignatureException
 386                     ("Signature length not correct: got "
 387                     + sigBytes.length + " but was expecting "
 388                     + RSACore.getByteLength(this.pubKey));
 389             }
 390             byte[] mHash = getDigestValue();
 391             byte[] decrypted = RSACore.rsa(sigBytes, this.pubKey);
 392             return decodeSignature(mHash, decrypted);
 393         } catch (javax.crypto.BadPaddingException e) {
 394             // occurs if the app has used the wrong RSA public key
 395             // or if sigBytes is invalid
 396             // return false rather than propagating the exception for
 397             // compatibility/ease of use
 398             return false;
 399         } catch (IOException e) {
 400             throw new SignatureException("Signature encoding error", e);
 401         } finally {
 402             resetDigest();
 403         }
 404     }
 405 
 406     // return the modulus length in bits
 407     private static int getKeyLengthInBits(RSAKey k) {
 408         if (k != null) {
 409             return k.getModulus().bitLength();
 410         }
 411         return -1;
 412     }
 413 
 414     /**
 415      * Encode the digest 'mHash', return the to-be-signed data.
 416      * Also used by the PKCS#11 provider.
 417      */
 418     private byte[] encodeSignature(byte[] mHash)
 419         throws IOException, DigestException {
 420         AlgorithmParameterSpec mgfParams = this.sigParams.getMGFParameters();
 421         String mgfDigestAlgo;
 422         if (mgfParams != null) {
 423             mgfDigestAlgo =
 424                 ((MGF1ParameterSpec) mgfParams).getDigestAlgorithm();
 425         } else {
 426             mgfDigestAlgo = this.md.getAlgorithm();
 427         }
 428         try {
 429             int emBits = getKeyLengthInBits(this.privKey) - 1;
 430             int emLen =(emBits + 7) >> 3;
 431             int hLen = this.md.getDigestLength();
 432             int dbLen = emLen - hLen - 1;
 433             int sLen = this.sigParams.getSaltLength();
 434 
 435             // maps DB into the corresponding region of EM and
 436             // stores its bytes directly into EM
 437             byte[] em = new byte[emLen];
 438 
 439             // step7 and some of step8
 440             em[dbLen - sLen - 1] = (byte) 1; // set DB's padding2 into EM
 441             em[em.length - 1] = (byte) 0xBC; // set trailer field of EM
 442 
 443             if (!digestReset) {
 444                 throw new ProviderException("Digest should be reset");
 445             }
 446             // step5: generates M' using padding1, mHash, and salt
 447             this.md.update(EIGHT_BYTES_OF_ZEROS);
 448             digestReset = false; // mark digest as it now has data
 449             this.md.update(mHash);
 450             if (sLen != 0) {
 451                 // step4: generate random salt
 452                 byte[] salt = new byte[sLen];
 453                 this.random.nextBytes(salt);
 454                 this.md.update(salt);
 455 
 456                 // step8: set DB's salt into EM
 457                 System.arraycopy(salt, 0, em, dbLen - sLen, sLen);
 458             }
 459             // step6: generate H using M'
 460             this.md.digest(em, dbLen, hLen); // set H field of EM
 461             digestReset = true;
 462 
 463             // step7 and 8 are already covered by the code which setting up
 464             // EM as above
 465 
 466             // step9 and 10: feed H into MGF and xor with DB in EM
 467             MGF1 mgf1 = new MGF1(mgfDigestAlgo);
 468             mgf1.generateAndXor(em, dbLen, hLen, dbLen, em, 0);
 469 
 470             // step11: set the leftmost (8emLen - emBits) bits of the leftmost
 471             // octet to 0
 472             int numZeroBits = (emLen << 3) - emBits;
 473             if (numZeroBits != 0) {
 474                 byte MASK = (byte) (0xff >>> numZeroBits);
 475                 em[0] = (byte) (em[0] & MASK);
 476             }
 477 
 478             // step12: em should now holds maskedDB || hash h || 0xBC
 479             return em;
 480         } catch (NoSuchAlgorithmException e) {
 481             throw new IOException(e.toString());
 482         }
 483     }
 484 
 485     /**
 486      * Decode the signature data. Verify that the object identifier matches
 487      * and return the message digest.
 488      */
 489     private boolean decodeSignature(byte[] mHash, byte[] em)
 490             throws IOException {
 491         int hLen = mHash.length;
 492         int sLen = this.sigParams.getSaltLength();
 493         int emLen = em.length;
 494         int emBits = getKeyLengthInBits(this.pubKey) - 1;
 495 
 496         // step3
 497         if (emLen < (hLen + sLen + 2)) {
 498             return false;
 499         }
 500 
 501         // step4
 502         if (em[emLen - 1] != (byte) 0xBC) {
 503             return false;
 504         }
 505 
 506         // step6: check if the leftmost (8emLen - emBits) bits of the leftmost
 507         // octet are 0
 508         int numZeroBits = (emLen << 3) - emBits;
 509         if (numZeroBits != 0) {
 510             byte MASK = (byte) (0xff << (8 - numZeroBits));
 511             if ((em[0] & MASK) != 0) {
 512                 return false;
 513             }
 514         }
 515         String mgfDigestAlgo;
 516         AlgorithmParameterSpec mgfParams = this.sigParams.getMGFParameters();
 517         if (mgfParams != null) {
 518             mgfDigestAlgo =
 519                 ((MGF1ParameterSpec) mgfParams).getDigestAlgorithm();
 520         } else {
 521             mgfDigestAlgo = this.md.getAlgorithm();
 522         }
 523         // step 7 and 8
 524         int dbLen = emLen - hLen - 1;
 525         try {
 526             MGF1 mgf1 = new MGF1(mgfDigestAlgo);
 527             mgf1.generateAndXor(em, dbLen, hLen, dbLen, em, 0);
 528         } catch (NoSuchAlgorithmException nsae) {
 529             throw new IOException(nsae.toString());
 530         }
 531 
 532         // step9: set the leftmost (8emLen - emBits) bits of the leftmost
 533         //  octet to 0
 534         if (numZeroBits != 0) {
 535             byte MASK = (byte) (0xff >>> numZeroBits);
 536             em[0] = (byte) (em[0] & MASK);
 537         }
 538 
 539         // step10
 540         int i = 0;
 541         for (; i < dbLen - sLen - 1; i++) {
 542             if (em[i] != 0) {
 543                 return false;
 544             }
 545         }
 546         if (em[i] != 0x01) {
 547             return false;
 548         }
 549         // step12 and 13
 550         this.md.update(EIGHT_BYTES_OF_ZEROS);
 551         digestReset = false;
 552         this.md.update(mHash);
 553         if (sLen > 0) {
 554             this.md.update(em, (dbLen - sLen), sLen);
 555         }
 556         byte[] digest2 = this.md.digest();
 557         digestReset = true;
 558 
 559         // step14
 560         byte[] digestInEM = Arrays.copyOfRange(em, dbLen, emLen - 1);
 561         return MessageDigest.isEqual(digest2, digestInEM);
 562     }
 563 
 564     // set parameter, not supported. See JCA doc
 565     @Deprecated
 566     @Override
 567     protected void engineSetParameter(String param, Object value)
 568             throws InvalidParameterException {
 569         throw new UnsupportedOperationException("setParameter() not supported");
 570     }
 571 
 572     @Override
 573     protected void engineSetParameter(AlgorithmParameterSpec params)
 574             throws InvalidAlgorithmParameterException {
 575         this.sigParams = validateSigParams(params);
 576         // disallow changing parameters when digest has been used
 577         if (!digestReset) {
 578             throw new ProviderException
 579                 ("Cannot set parameters during operations");
 580         }
 581         String newHashAlg = this.sigParams.getDigestAlgorithm();
 582         // re-allocate md if not yet assigned or algorithm changed
 583         if ((this.md == null) ||
 584             !(this.md.getAlgorithm().equalsIgnoreCase(newHashAlg))) {
 585             try {
 586                 this.md = MessageDigest.getInstance(newHashAlg);
 587             } catch (NoSuchAlgorithmException nsae) {
 588                 // should not happen as we pick default digest algorithm
 589                 throw new InvalidAlgorithmParameterException
 590                     ("Unsupported digest algorithm " +
 591                      newHashAlg, nsae);
 592             }
 593         }
 594     }
 595 
 596     // get parameter, not supported. See JCA doc
 597     @Deprecated
 598     @Override
 599     protected Object engineGetParameter(String param)
 600             throws InvalidParameterException {
 601         throw new UnsupportedOperationException("getParameter() not supported");
 602     }
 603 
 604     @Override
 605     protected AlgorithmParameters engineGetParameters() {
 606         if (this.sigParams == null) {
 607             throw new ProviderException("Missing required PSS parameters");
 608         }
 609         try {
 610             AlgorithmParameters ap =
 611                 AlgorithmParameters.getInstance("RSASSA-PSS");
 612             ap.init(this.sigParams);
 613             return ap;
 614         } catch (GeneralSecurityException gse) {
 615             throw new ProviderException(gse.getMessage());
 616         }
 617     }
 618 }