1 /*
  2  * Copyright (c) 2018, 2020, Oracle and/or its affiliates. All rights reserved.
  3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
  4  *
  5  * This code is free software; you can redistribute it and/or modify it
  6  * under the terms of the GNU General Public License version 2 only, as
  7  * published by the Free Software Foundation.  Oracle designates this
  8  * particular file as subject to the "Classpath" exception as provided
  9  * by Oracle in the LICENSE file that accompanied this code.
 10  *
 11  * This code is distributed in the hope that it will be useful, but WITHOUT
 12  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
 13  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
 14  * version 2 for more details (a copy is included in the LICENSE file that
 15  * accompanied this code).
 16  *
 17  * You should have received a copy of the GNU General Public License version
 18  * 2 along with this work; if not, write to the Free Software Foundation,
 19  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
 20  *
 21  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
 22  * or visit www.oracle.com if you need additional information or have any
 23  * questions.
 24  */
 25 
 26 package sun.security.rsa;
 27 
 28 import java.io.IOException;
 29 import java.nio.ByteBuffer;
 30 
 31 import java.security.*;
 32 import java.security.spec.AlgorithmParameterSpec;
 33 import java.security.spec.PSSParameterSpec;
 34 import java.security.spec.MGF1ParameterSpec;
 35 import java.security.interfaces.*;
 36 
 37 import java.util.Arrays;
 38 import java.util.Hashtable;
 39 
 40 import sun.security.util.*;
 41 import sun.security.jca.JCAUtil;
 42 
 43 
 44 /**
 45  * PKCS#1 v2.2 RSASSA-PSS signatures with various message digest algorithms.
 46  * RSASSA-PSS implementation takes the message digest algorithm, MGF algorithm,
 47  * and salt length values through the required signature PSS parameters.
 48  * We support SHA-1, SHA-224, SHA-256, SHA-384, SHA-512, SHA-512/224, and
 49  * SHA-512/256 message digest algorithms and MGF1 mask generation function.
 50  *
 51  * @since   11
 52  */
 53 public class RSAPSSSignature extends SignatureSpi {
 54 
 55     private static final boolean DEBUG = false;
 56 
 57     // utility method for comparing digest algorithms
 58     // NOTE that first argument is assumed to be standard digest name
 59     private boolean isDigestEqual(String stdAlg, String givenAlg) {
 60         if (stdAlg == null || givenAlg == null) return false;
 61 
 62         if (givenAlg.indexOf("-") != -1) {
 63             return stdAlg.equalsIgnoreCase(givenAlg);
 64         } else {
 65             if (stdAlg.equals("SHA-1")) {
 66                 return (givenAlg.equalsIgnoreCase("SHA")
 67                         || givenAlg.equalsIgnoreCase("SHA1"));
 68             } else {
 69                 StringBuilder sb = new StringBuilder(givenAlg);
 70                 // case-insensitive check
 71                 if (givenAlg.regionMatches(true, 0, "SHA", 0, 3)) {
 72                     givenAlg = sb.insert(3, "-").toString();
 73                     return stdAlg.equalsIgnoreCase(givenAlg);
 74                 } else {
 75                     throw new ProviderException("Unsupported digest algorithm "
 76                             + givenAlg);
 77                 }
 78             }
 79         }
 80     }
 81 
 82     private static final byte[] EIGHT_BYTES_OF_ZEROS = new byte[8];
 83 
 84     private static final Hashtable<String, Integer> DIGEST_LENGTHS =
 85         new Hashtable<String, Integer>();
 86     static {
 87         DIGEST_LENGTHS.put("SHA-1", 20);
 88         DIGEST_LENGTHS.put("SHA", 20);
 89         DIGEST_LENGTHS.put("SHA1", 20);
 90         DIGEST_LENGTHS.put("SHA-224", 28);
 91         DIGEST_LENGTHS.put("SHA224", 28);
 92         DIGEST_LENGTHS.put("SHA-256", 32);
 93         DIGEST_LENGTHS.put("SHA256", 32);
 94         DIGEST_LENGTHS.put("SHA-384", 48);
 95         DIGEST_LENGTHS.put("SHA384", 48);
 96         DIGEST_LENGTHS.put("SHA-512", 64);
 97         DIGEST_LENGTHS.put("SHA512", 64);
 98         DIGEST_LENGTHS.put("SHA-512/224", 28);
 99         DIGEST_LENGTHS.put("SHA512/224", 28);
100         DIGEST_LENGTHS.put("SHA-512/256", 32);
101         DIGEST_LENGTHS.put("SHA512/256", 32);
102     }
103 
104     // message digest implementation we use for hashing the data
105     private MessageDigest md;
106     // flag indicating whether the digest is reset
107     private boolean digestReset = true;
108 
109     // private key, if initialized for signing
110     private RSAPrivateKey privKey = null;
111     // public key, if initialized for verifying
112     private RSAPublicKey pubKey = null;
113     // PSS parameters from signatures and keys respectively
114     private PSSParameterSpec sigParams = null; // required for PSS signatures
115 
116     // PRNG used to generate salt bytes if none given
117     private SecureRandom random;
118 
119     /**
120      * Construct a new RSAPSSSignatur with arbitrary digest algorithm
121      */
122     public RSAPSSSignature() {
123         this.md = null;
124     }
125 
126     // initialize for verification. See JCA doc
127     @Override
128     protected void engineInitVerify(PublicKey publicKey)
129             throws InvalidKeyException {
130         if (!(publicKey instanceof RSAPublicKey)) {
131             throw new InvalidKeyException("key must be RSAPublicKey");
132         }
133         this.pubKey = (RSAPublicKey) isValid((RSAKey)publicKey);
134         this.privKey = null;
135         resetDigest();
136     }
137 
138     // initialize for signing. See JCA doc
139     @Override
140     protected void engineInitSign(PrivateKey privateKey)
141             throws InvalidKeyException {
142         engineInitSign(privateKey, null);
143     }
144 
145     // initialize for signing. See JCA doc
146     @Override
147     protected void engineInitSign(PrivateKey privateKey, SecureRandom random)
148             throws InvalidKeyException {
149         if (!(privateKey instanceof RSAPrivateKey)) {
150             throw new InvalidKeyException("key must be RSAPrivateKey");
151         }
152         this.privKey = (RSAPrivateKey) isValid((RSAKey)privateKey);
153         this.pubKey = null;
154         this.random =
155             (random == null? JCAUtil.getSecureRandom() : random);
156         resetDigest();
157     }
158 
159     /**
160      * Utility method for checking the key PSS parameters against signature
161      * PSS parameters.
162      * Returns false if any of the digest/MGF algorithms and trailerField
163      * values does not match or if the salt length in key parameters is
164      * larger than the value in signature parameters.
165      */
166     private static boolean isCompatible(AlgorithmParameterSpec keyParams,
167             PSSParameterSpec sigParams) {
168         if (keyParams == null) {
169             // key with null PSS parameters means no restriction
170             return true;
171         }
172         if (!(keyParams instanceof PSSParameterSpec)) {
173             return false;
174         }
175         // nothing to compare yet, defer the check to when sigParams is set
176         if (sigParams == null) {
177             return true;
178         }
179         PSSParameterSpec pssKeyParams = (PSSParameterSpec) keyParams;
180         // first check the salt length requirement
181         if (pssKeyParams.getSaltLength() > sigParams.getSaltLength()) {
182             return false;
183         }
184 
185         // compare equality of the rest of fields based on DER encoding
186         PSSParameterSpec keyParams2 =
187             new PSSParameterSpec(pssKeyParams.getDigestAlgorithm(),
188                     pssKeyParams.getMGFAlgorithm(),
189                     pssKeyParams.getMGFParameters(),
190                     sigParams.getSaltLength(),
191                     pssKeyParams.getTrailerField());
192         PSSParameters ap = new PSSParameters();
193         // skip the JCA overhead
194         try {
195             ap.engineInit(keyParams2);
196             byte[] encoded = ap.engineGetEncoded();
197             ap.engineInit(sigParams);
198             byte[] encoded2 = ap.engineGetEncoded();
199             return Arrays.equals(encoded, encoded2);
200         } catch (Exception e) {
201             if (DEBUG) {
202                 e.printStackTrace();
203             }
204             return false;
205         }
206     }
207 
208     /**
209      * Validate the specified RSAKey and its associated parameters against
210      * internal signature parameters.
211      */
212     private RSAKey isValid(RSAKey rsaKey) throws InvalidKeyException {
213         try {
214             AlgorithmParameterSpec keyParams = rsaKey.getParams();
215             // validate key parameters
216             if (!isCompatible(rsaKey.getParams(), this.sigParams)) {
217                 throw new InvalidKeyException
218                     ("Key contains incompatible PSS parameter values");
219             }
220             // validate key length
221             if (this.sigParams != null) {
222                 Integer hLen =
223                     DIGEST_LENGTHS.get(this.sigParams.getDigestAlgorithm());
224                 if (hLen == null) {
225                     throw new ProviderException("Unsupported digest algo: " +
226                         this.sigParams.getDigestAlgorithm());
227                 }
228                 checkKeyLength(rsaKey, hLen, this.sigParams.getSaltLength());
229             }
230             return rsaKey;
231         } catch (SignatureException e) {
232             throw new InvalidKeyException(e);
233         }
234     }
235 
236     /**
237      * Validate the specified Signature PSS parameters.
238      */
239     private PSSParameterSpec validateSigParams(AlgorithmParameterSpec p)
240             throws InvalidAlgorithmParameterException {
241         if (p == null) {
242             throw new InvalidAlgorithmParameterException
243                 ("Parameters cannot be null");
244         }
245         if (!(p instanceof PSSParameterSpec)) {
246             throw new InvalidAlgorithmParameterException
247                 ("parameters must be type PSSParameterSpec");
248         }
249         // no need to validate again if same as current signature parameters
250         PSSParameterSpec params = (PSSParameterSpec) p;
251         if (params == this.sigParams) return params;
252 
253         RSAKey key = (this.privKey == null? this.pubKey : this.privKey);
254         // check against keyParams if set
255         if (key != null) {
256             if (!isCompatible(key.getParams(), params)) {
257                 throw new InvalidAlgorithmParameterException
258                     ("Signature parameters does not match key parameters");
259             }
260         }
261         // now sanity check the parameter values
262         if (!(params.getMGFAlgorithm().equalsIgnoreCase("MGF1"))) {
263             throw new InvalidAlgorithmParameterException("Only supports MGF1");
264 
265         }
266         if (params.getTrailerField() != PSSParameterSpec.TRAILER_FIELD_BC) {
267             throw new InvalidAlgorithmParameterException
268                 ("Only supports TrailerFieldBC(1)");
269 
270         }
271         String digestAlgo = params.getDigestAlgorithm();
272         // check key length again
273         if (key != null) {
274             try {
275                 int hLen = DIGEST_LENGTHS.get(digestAlgo);
276                 checkKeyLength(key, hLen, params.getSaltLength());
277             } catch (SignatureException e) {
278                 throw new InvalidAlgorithmParameterException(e);
279             }
280         }
281         return params;
282     }
283 
284     /**
285      * Ensure the object is initialized with key and parameters and
286      * reset digest
287      */
288     private void ensureInit() throws SignatureException {
289         RSAKey key = (this.privKey == null? this.pubKey : this.privKey);
290         if (key == null) {
291             throw new SignatureException("Missing key");
292         }
293         if (this.sigParams == null) {
294             // Parameters are required for signature verification
295             throw new SignatureException
296                 ("Parameters required for RSASSA-PSS signatures");
297         }
298     }
299 
300     /**
301      * Utility method for checking key length against digest length and
302      * salt length
303      */
304     private static void checkKeyLength(RSAKey key, int digestLen,
305             int saltLen) throws SignatureException {
306         if (key != null) {
307             int keyLength = (getKeyLengthInBits(key) + 7) >> 3;
308             int minLength = Math.addExact(Math.addExact(digestLen, saltLen), 2);
309             if (keyLength < minLength) {
310                 throw new SignatureException
311                     ("Key is too short, need min " + minLength + " bytes");
312             }
313         }
314     }
315 
316     /**
317      * Reset the message digest if it is not already reset.
318      */
319     private void resetDigest() {
320         if (digestReset == false) {
321             this.md.reset();
322             digestReset = true;
323         }
324     }
325 
326     /**
327      * Return the message digest value.
328      */
329     private byte[] getDigestValue() {
330         digestReset = true;
331         return this.md.digest();
332     }
333 
334     // update the signature with the plaintext data. See JCA doc
335     @Override
336     protected void engineUpdate(byte b) throws SignatureException {
337         ensureInit();
338         this.md.update(b);
339         digestReset = false;
340     }
341 
342     // update the signature with the plaintext data. See JCA doc
343     @Override
344     protected void engineUpdate(byte[] b, int off, int len)
345             throws SignatureException {
346         ensureInit();
347         this.md.update(b, off, len);
348         digestReset = false;
349     }
350 
351     // update the signature with the plaintext data. See JCA doc
352     @Override
353     protected void engineUpdate(ByteBuffer b) {
354         try {
355             ensureInit();
356         } catch (SignatureException se) {
357             // hack for working around API bug
358             throw new RuntimeException(se.getMessage());
359         }
360         this.md.update(b);
361         digestReset = false;
362     }
363 
364     // sign the data and return the signature. See JCA doc
365     @Override
366     protected byte[] engineSign() throws SignatureException {
367         ensureInit();
368         byte[] mHash = getDigestValue();
369         try {
370             byte[] encoded = encodeSignature(mHash);
371             byte[] encrypted = RSACore.rsa(encoded, privKey, true);
372             return encrypted;
373         } catch (GeneralSecurityException e) {
374             throw new SignatureException("Could not sign data", e);
375         } catch (IOException e) {
376             throw new SignatureException("Could not encode data", e);
377         }
378     }
379 
380     // verify the data and return the result. See JCA doc
381     // should be reset to the state after engineInitVerify call.
382     @Override
383     protected boolean engineVerify(byte[] sigBytes) throws SignatureException {
384         ensureInit();
385         try {
386             if (sigBytes.length != RSACore.getByteLength(this.pubKey)) {
387                 throw new SignatureException
388                     ("Signature length not correct: got "
389                     + sigBytes.length + " but was expecting "
390                     + RSACore.getByteLength(this.pubKey));
391             }
392             byte[] mHash = getDigestValue();
393             byte[] decrypted = RSACore.rsa(sigBytes, this.pubKey);
394             return decodeSignature(mHash, decrypted);
395         } catch (javax.crypto.BadPaddingException e) {
396             // occurs if the app has used the wrong RSA public key
397             // or if sigBytes is invalid
398             // return false rather than propagating the exception for
399             // compatibility/ease of use
400             return false;
401         } catch (IOException e) {
402             throw new SignatureException("Signature encoding error", e);
403         } finally {
404             resetDigest();
405         }
406     }
407 
408     // return the modulus length in bits
409     private static int getKeyLengthInBits(RSAKey k) {
410         if (k != null) {
411             return k.getModulus().bitLength();
412         }
413         return -1;
414     }
415 
416     /**
417      * Encode the digest 'mHash', return the to-be-signed data.
418      * Also used by the PKCS#11 provider.
419      */
420     private byte[] encodeSignature(byte[] mHash)
421         throws IOException, DigestException {
422         AlgorithmParameterSpec mgfParams = this.sigParams.getMGFParameters();
423         String mgfDigestAlgo;
424         if (mgfParams != null) {
425             mgfDigestAlgo =
426                 ((MGF1ParameterSpec) mgfParams).getDigestAlgorithm();
427         } else {
428             mgfDigestAlgo = this.md.getAlgorithm();
429         }
430         try {
431             int emBits = getKeyLengthInBits(this.privKey) - 1;
432             int emLen = (emBits + 7) >> 3;
433             int hLen = this.md.getDigestLength();
434             int dbLen = emLen - hLen - 1;
435             int sLen = this.sigParams.getSaltLength();
436 
437             // maps DB into the corresponding region of EM and
438             // stores its bytes directly into EM
439             byte[] em = new byte[emLen];
440 
441             // step7 and some of step8
442             em[dbLen - sLen - 1] = (byte) 1; // set DB's padding2 into EM
443             em[em.length - 1] = (byte) 0xBC; // set trailer field of EM
444 
445             if (!digestReset) {
446                 throw new ProviderException("Digest should be reset");
447             }
448             // step5: generates M' using padding1, mHash, and salt
449             this.md.update(EIGHT_BYTES_OF_ZEROS);
450             digestReset = false; // mark digest as it now has data
451             this.md.update(mHash);
452             if (sLen != 0) {
453                 // step4: generate random salt
454                 byte[] salt = new byte[sLen];
455                 this.random.nextBytes(salt);
456                 this.md.update(salt);
457 
458                 // step8: set DB's salt into EM
459                 System.arraycopy(salt, 0, em, dbLen - sLen, sLen);
460             }
461             // step6: generate H using M'
462             this.md.digest(em, dbLen, hLen); // set H field of EM
463             digestReset = true;
464 
465             // step7 and 8 are already covered by the code which setting up
466             // EM as above
467 
468             // step9 and 10: feed H into MGF and xor with DB in EM
469             MGF1 mgf1 = new MGF1(mgfDigestAlgo);
470             mgf1.generateAndXor(em, dbLen, hLen, dbLen, em, 0);
471 
472             // step11: set the leftmost (8emLen - emBits) bits of the leftmost
473             // octet to 0
474             int numZeroBits = (emLen << 3) - emBits;
475 
476             if (numZeroBits != 0) {
477                 byte MASK = (byte) (0xff >>> numZeroBits);
478                 em[0] = (byte) (em[0] & MASK);
479             }
480 
481             // step12: em should now holds maskedDB || hash h || 0xBC
482             return em;
483         } catch (NoSuchAlgorithmException e) {
484             throw new IOException(e.toString());
485         }
486     }
487 
488     /**
489      * Decode the signature data as under RFC8017 sec9.1.2 EMSA-PSS-VERIFY
490      */
491     private boolean decodeSignature(byte[] mHash, byte[] em)
492             throws IOException {
493         int hLen = mHash.length;
494         int sLen = this.sigParams.getSaltLength();
495         int emBits = getKeyLengthInBits(this.pubKey) - 1;
496         int emLen = (emBits + 7) >> 3;
497 
498         // When key length is 8N+1 bits (N+1 bytes), emBits = 8N,
499         // emLen = N which is one byte shorter than em.length.
500         // Otherwise, emLen should be same as em.length
501         int emOfs = em.length - emLen;
502         if ((emOfs == 1) && (em[0] != 0)) {
503             return false;
504         }
505 
506         // step3
507         if (emLen < (hLen + sLen + 2)) {
508             return false;
509         }
510 
511         // step4
512         if (em[emOfs + emLen - 1] != (byte) 0xBC) {
513             return false;
514         }
515 
516         // step6: check if the leftmost (8emLen - emBits) bits of the leftmost
517         // octet are 0
518         int numZeroBits = (emLen << 3) - emBits;
519 
520         if (numZeroBits != 0) {
521             byte MASK = (byte) (0xff << (8 - numZeroBits));
522             if ((em[emOfs] & MASK) != 0) {
523                 return false;
524             }
525         }
526         String mgfDigestAlgo;
527         AlgorithmParameterSpec mgfParams = this.sigParams.getMGFParameters();
528         if (mgfParams != null) {
529             mgfDigestAlgo =
530                 ((MGF1ParameterSpec) mgfParams).getDigestAlgorithm();
531         } else {
532             mgfDigestAlgo = this.md.getAlgorithm();
533         }
534         // step 7 and 8
535         int dbLen = emLen - hLen - 1;
536         try {
537             MGF1 mgf1 = new MGF1(mgfDigestAlgo);
538             mgf1.generateAndXor(em, emOfs + dbLen, hLen, dbLen,
539                     em, emOfs);
540         } catch (NoSuchAlgorithmException nsae) {
541             throw new IOException(nsae.toString());
542         }
543 
544         // step9: set the leftmost (8emLen - emBits) bits of the leftmost
545         //  octet to 0
546         if (numZeroBits != 0) {
547             byte MASK = (byte) (0xff >>> numZeroBits);
548             em[emOfs] = (byte) (em[emOfs] & MASK);
549         }
550 
551         // step10
552         int i = emOfs;
553         for (; i < emOfs + (dbLen - sLen - 1); i++) {
554             if (em[i] != 0) {
555                 return false;
556             }
557         }
558         if (em[i] != 0x01) {
559             return false;
560         }
561         // step12 and 13
562         this.md.update(EIGHT_BYTES_OF_ZEROS);
563         digestReset = false;
564         this.md.update(mHash);
565         if (sLen > 0) {
566             this.md.update(em, emOfs + (dbLen - sLen), sLen);
567         }
568         byte[] digest2 = this.md.digest();
569         digestReset = true;
570 
571         // step14
572         byte[] digestInEM = Arrays.copyOfRange(em, emOfs + dbLen,
573                 emOfs + emLen - 1);
574         return MessageDigest.isEqual(digest2, digestInEM);
575     }
576 
577     // set parameter, not supported. See JCA doc
578     @Deprecated
579     @Override
580     protected void engineSetParameter(String param, Object value)
581             throws InvalidParameterException {
582         throw new UnsupportedOperationException("setParameter() not supported");
583     }
584 
585     @Override
586     protected void engineSetParameter(AlgorithmParameterSpec params)
587             throws InvalidAlgorithmParameterException {
588         this.sigParams = validateSigParams(params);
589         // disallow changing parameters when digest has been used
590         if (!digestReset) {
591             throw new ProviderException
592                 ("Cannot set parameters during operations");
593         }
594         String newHashAlg = this.sigParams.getDigestAlgorithm();
595         // re-allocate md if not yet assigned or algorithm changed
596         if ((this.md == null) ||
597             !(this.md.getAlgorithm().equalsIgnoreCase(newHashAlg))) {
598             try {
599                 this.md = MessageDigest.getInstance(newHashAlg);
600             } catch (NoSuchAlgorithmException nsae) {
601                 // should not happen as we pick default digest algorithm
602                 throw new InvalidAlgorithmParameterException
603                     ("Unsupported digest algorithm " +
604                      newHashAlg, nsae);
605             }
606         }
607     }
608 
609     // get parameter, not supported. See JCA doc
610     @Deprecated
611     @Override
612     protected Object engineGetParameter(String param)
613             throws InvalidParameterException {
614         throw new UnsupportedOperationException("getParameter() not supported");
615     }
616 
617     @Override
618     protected AlgorithmParameters engineGetParameters() {
619         AlgorithmParameters ap = null;
620         if (this.sigParams != null) {
621             try {
622                 ap = AlgorithmParameters.getInstance("RSASSA-PSS");
623                 ap.init(this.sigParams);
624             } catch (GeneralSecurityException gse) {
625                 throw new ProviderException(gse.getMessage());
626             }
627         }
628         return ap;
629     }
630 }