--- /dev/null 2018-05-11 10:42:23.849000000 -0700 +++ new/src/java.base/share/classes/sun/security/rsa/RSAPSSSignature.java 2018-05-11 15:08:35.193153700 -0700 @@ -0,0 +1,618 @@ +/* + * Copyright (c) 2018, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. Oracle designates this + * particular file as subject to the "Classpath" exception as provided + * by Oracle in the LICENSE file that accompanied this code. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ + +package sun.security.rsa; + +import java.io.IOException; +import java.nio.ByteBuffer; + +import java.security.*; +import java.security.spec.AlgorithmParameterSpec; +import java.security.spec.PSSParameterSpec; +import java.security.spec.MGF1ParameterSpec; +import java.security.interfaces.*; + +import java.util.Arrays; +import java.util.Hashtable; + +import sun.security.util.*; +import sun.security.jca.JCAUtil; + + +/** + * PKCS#1 v2.2 RSASSA-PSS signatures with various message digest algorithms. + * RSASSA-PSS implementation takes the message digest algorithm, MGF algorithm, + * and salt length values through the required signature PSS parameters. + * We support SHA-1, SHA-224, SHA-256, SHA-384, SHA-512, SHA-512/224, and + * SHA-512/256 message digest algorithms and MGF1 mask generation function. + * + * @since 11 + */ +public class RSAPSSSignature extends SignatureSpi { + + private static final boolean DEBUG = false; + + // utility method for comparing digest algorithms + // NOTE that first argument is assumed to be standard digest name + private boolean isDigestEqual(String stdAlg, String givenAlg) { + if (stdAlg == null || givenAlg == null) return false; + + if (givenAlg.indexOf("-") != -1) { + return stdAlg.equalsIgnoreCase(givenAlg); + } else { + if (stdAlg.equals("SHA-1")) { + return (givenAlg.equalsIgnoreCase("SHA") + || givenAlg.equalsIgnoreCase("SHA1")); + } else { + StringBuilder sb = new StringBuilder(givenAlg); + // case-insensitive check + if (givenAlg.regionMatches(true, 0, "SHA", 0, 3)) { + givenAlg = sb.insert(3, "-").toString(); + return stdAlg.equalsIgnoreCase(givenAlg); + } else { + throw new ProviderException("Unsupported digest algorithm " + + givenAlg); + } + } + } + } + + private static final byte[] EIGHT_BYTES_OF_ZEROS = new byte[8]; + + private static final Hashtable DIGEST_LENGTHS = + new Hashtable(); + static { + DIGEST_LENGTHS.put("SHA-1", 20); + DIGEST_LENGTHS.put("SHA", 20); + DIGEST_LENGTHS.put("SHA1", 20); + DIGEST_LENGTHS.put("SHA-224", 28); + DIGEST_LENGTHS.put("SHA224", 28); + DIGEST_LENGTHS.put("SHA-256", 32); + DIGEST_LENGTHS.put("SHA256", 32); + DIGEST_LENGTHS.put("SHA-384", 48); + DIGEST_LENGTHS.put("SHA384", 48); + DIGEST_LENGTHS.put("SHA-512", 64); + DIGEST_LENGTHS.put("SHA512", 64); + DIGEST_LENGTHS.put("SHA-512/224", 28); + DIGEST_LENGTHS.put("SHA512/224", 28); + DIGEST_LENGTHS.put("SHA-512/256", 32); + DIGEST_LENGTHS.put("SHA512/256", 32); + } + + // message digest implementation we use for hashing the data + private MessageDigest md; + // flag indicating whether the digest is reset + private boolean digestReset = true; + + // private key, if initialized for signing + private RSAPrivateKey privKey = null; + // public key, if initialized for verifying + private RSAPublicKey pubKey = null; + // PSS parameters from signatures and keys respectively + private PSSParameterSpec sigParams = null; // required for PSS signatures + + // PRNG used to generate salt bytes if none given + private SecureRandom random; + + /** + * Construct a new RSAPSSSignatur with arbitrary digest algorithm + */ + public RSAPSSSignature() { + this.md = null; + } + + // initialize for verification. See JCA doc + @Override + protected void engineInitVerify(PublicKey publicKey) + throws InvalidKeyException { + if (!(publicKey instanceof RSAPublicKey)) { + throw new InvalidKeyException("key must be RSAPublicKey"); + } + this.pubKey = (RSAPublicKey) isValid((RSAKey)publicKey); + this.privKey = null; + + } + + // initialize for signing. See JCA doc + @Override + protected void engineInitSign(PrivateKey privateKey) + throws InvalidKeyException { + engineInitSign(privateKey, null); + } + + // initialize for signing. See JCA doc + @Override + protected void engineInitSign(PrivateKey privateKey, SecureRandom random) + throws InvalidKeyException { + if (!(privateKey instanceof RSAPrivateKey)) { + throw new InvalidKeyException("key must be RSAPrivateKey"); + } + this.privKey = (RSAPrivateKey) isValid((RSAKey)privateKey); + this.pubKey = null; + this.random = + (random == null? JCAUtil.getSecureRandom() : random); + } + + /** + * Utility method for checking the key PSS parameters against signature + * PSS parameters. + * Returns false if any of the digest/MGF algorithms and trailerField + * values does not match or if the salt length in key parameters is + * larger than the value in signature parameters. + */ + private static boolean isCompatible(AlgorithmParameterSpec keyParams, + PSSParameterSpec sigParams) { + if (keyParams == null) { + // key with null PSS parameters means no restriction + return true; + } + if (!(keyParams instanceof PSSParameterSpec)) { + return false; + } + // nothing to compare yet, defer the check to when sigParams is set + if (sigParams == null) { + return true; + } + PSSParameterSpec pssKeyParams = (PSSParameterSpec) keyParams; + // first check the salt length requirement + if (pssKeyParams.getSaltLength() > sigParams.getSaltLength()) { + return false; + } + + // compare equality of the rest of fields based on DER encoding + PSSParameterSpec keyParams2 = + new PSSParameterSpec(pssKeyParams.getDigestAlgorithm(), + pssKeyParams.getMGFAlgorithm(), + pssKeyParams.getMGFParameters(), + sigParams.getSaltLength(), + pssKeyParams.getTrailerField()); + PSSParameters ap = new PSSParameters(); + try { + ap.engineInit(keyParams2); + byte[] encoded = ap.engineGetEncoded(); + ap.engineInit(sigParams); + byte[] encoded2 = ap.engineGetEncoded(); + return Arrays.equals(encoded, encoded2); + } catch (Exception e) { + if (DEBUG) { + e.printStackTrace(); + } + return false; + } + } + + /** + * Validate the specified RSAKey and its associated parameters against + * internal signature parameters. + */ + private RSAKey isValid(RSAKey rsaKey) throws InvalidKeyException { + try { + AlgorithmParameterSpec keyParams = rsaKey.getParams(); + // validate key parameters + if (!isCompatible(rsaKey.getParams(), this.sigParams)) { + throw new InvalidKeyException + ("Key contains incompatible PSS parameter values"); + } + // validate key length + if (this.sigParams != null) { + Integer hLen = + DIGEST_LENGTHS.get(this.sigParams.getDigestAlgorithm()); + if (hLen == null) { + throw new ProviderException("Unsupported digest algo: " + + this.sigParams.getDigestAlgorithm()); + } + checkKeyLength(rsaKey, hLen, this.sigParams.getSaltLength()); + } + return rsaKey; + } catch (SignatureException e) { + throw new InvalidKeyException(e); + } + } + + /** + * Validate the specified Signature PSS parameters. + */ + private PSSParameterSpec validateSigParams(AlgorithmParameterSpec p) + throws InvalidAlgorithmParameterException { + if (p == null) { + throw new InvalidAlgorithmParameterException + ("Parameters cannot be null"); + } + if (!(p instanceof PSSParameterSpec)) { + throw new InvalidAlgorithmParameterException + ("parameters must be type PSSParameterSpec"); + } + // no need to validate again if same as current signature parameters + PSSParameterSpec params = (PSSParameterSpec) p; + if (params == this.sigParams) return params; + + RSAKey key = (this.privKey == null? this.pubKey : this.privKey); + // check against keyParams if set + if (key != null) { + if (!isCompatible(key.getParams(), params)) { + throw new InvalidAlgorithmParameterException + ("Signature parameters does not match key parameters"); + } + } + // now sanity check the parameter values + if (!(params.getMGFAlgorithm().equalsIgnoreCase("MGF1"))) { + throw new InvalidAlgorithmParameterException("Only supports MGF1"); + + } + if (params.getTrailerField() != 1) { + throw new InvalidAlgorithmParameterException + ("Only supports TrailerFieldBC(1)"); + + } + String digestAlgo = params.getDigestAlgorithm(); + // check key length again + if (key != null) { + try { + int hLen = DIGEST_LENGTHS.get(digestAlgo); + checkKeyLength(key, hLen, params.getSaltLength()); + } catch (SignatureException e) { + throw new InvalidAlgorithmParameterException(e); + } + } + return params; + } + + /** + * Ensure the object is initialized with key and parameters and + * reset digest + */ + private void ensureInit() throws SignatureException { + RSAKey key = (this.privKey == null? this.pubKey : this.privKey); + if (key == null) { + throw new SignatureException("Missing key"); + } + if (this.sigParams == null) { + // Parameters are required for signature verification + throw new SignatureException + ("Parameters required for RSASSA-PSS signatures"); + } + } + + /** + * Utility method for checking key length against digest length and + * salt length + */ + private static void checkKeyLength(RSAKey key, int digestLen, + int saltLen) throws SignatureException { + if (key != null) { + int keyLength = getKeyLengthInBits(key) >> 3; + int minLength = Math.addExact(Math.addExact(digestLen, saltLen), 2); + if (keyLength < minLength) { + throw new SignatureException + ("Key is too short, need min " + minLength); + } + } + } + + /** + * Reset the message digest if it is not already reset. + */ + private void resetDigest() { + if (digestReset == false) { + this.md.reset(); + digestReset = true; + } + } + + /** + * Return the message digest value. + */ + private byte[] getDigestValue() { + digestReset = true; + return this.md.digest(); + } + + // update the signature with the plaintext data. See JCA doc + @Override + protected void engineUpdate(byte b) throws SignatureException { + ensureInit(); + this.md.update(b); + digestReset = false; + } + + // update the signature with the plaintext data. See JCA doc + @Override + protected void engineUpdate(byte[] b, int off, int len) + throws SignatureException { + ensureInit(); + this.md.update(b, off, len); + digestReset = false; + } + + // update the signature with the plaintext data. See JCA doc + @Override + protected void engineUpdate(ByteBuffer b) { + try { + ensureInit(); + } catch (SignatureException se) { + // hack for working around API bug + throw new RuntimeException(se.getMessage()); + } + this.md.update(b); + digestReset = false; + } + + // sign the data and return the signature. See JCA doc + @Override + protected byte[] engineSign() throws SignatureException { + ensureInit(); + byte[] mHash = getDigestValue(); + try { + byte[] encoded = encodeSignature(mHash); + byte[] encrypted = RSACore.rsa(encoded, privKey, true); + return encrypted; + } catch (GeneralSecurityException e) { + throw new SignatureException("Could not sign data", e); + } catch (IOException e) { + throw new SignatureException("Could not encode data", e); + } + } + + // verify the data and return the result. See JCA doc + // should be reset to the state after engineInitVerify call. + @Override + protected boolean engineVerify(byte[] sigBytes) throws SignatureException { + ensureInit(); + try { + if (sigBytes.length != RSACore.getByteLength(this.pubKey)) { + throw new SignatureException + ("Signature length not correct: got " + + sigBytes.length + " but was expecting " + + RSACore.getByteLength(this.pubKey)); + } + byte[] mHash = getDigestValue(); + byte[] decrypted = RSACore.rsa(sigBytes, this.pubKey); + return decodeSignature(mHash, decrypted); + } catch (javax.crypto.BadPaddingException e) { + // occurs if the app has used the wrong RSA public key + // or if sigBytes is invalid + // return false rather than propagating the exception for + // compatibility/ease of use + return false; + } catch (IOException e) { + throw new SignatureException("Signature encoding error", e); + } finally { + resetDigest(); + } + } + + // return the modulus length in bits + private static int getKeyLengthInBits(RSAKey k) { + if (k != null) { + return k.getModulus().bitLength(); + } + return -1; + } + + /** + * Encode the digest 'mHash', return the to-be-signed data. + * Also used by the PKCS#11 provider. + */ + private byte[] encodeSignature(byte[] mHash) + throws IOException, DigestException { + AlgorithmParameterSpec mgfParams = this.sigParams.getMGFParameters(); + String mgfDigestAlgo; + if (mgfParams != null) { + mgfDigestAlgo = + ((MGF1ParameterSpec) mgfParams).getDigestAlgorithm(); + } else { + mgfDigestAlgo = this.md.getAlgorithm(); + } + try { + int emBits = getKeyLengthInBits(this.privKey) - 1; + int emLen =(emBits + 7) >> 3; + int hLen = this.md.getDigestLength(); + int dbLen = emLen - hLen - 1; + int sLen = this.sigParams.getSaltLength(); + + // maps DB into the corresponding region of EM and + // stores its bytes directly into EM + byte[] em = new byte[emLen]; + + // step7 and some of step8 + em[dbLen - sLen - 1] = (byte) 1; // set DB's padding2 into EM + em[em.length - 1] = (byte) 0xBC; // set trailer field of EM + + if (!digestReset) { + throw new ProviderException("Digest should be reset"); + } + // step5: generates M' using padding1, mHash, and salt + this.md.update(EIGHT_BYTES_OF_ZEROS); + digestReset = false; // mark digest as it now has data + this.md.update(mHash); + if (sLen != 0) { + // step4: generate random salt + byte[] salt = new byte[sLen]; + this.random.nextBytes(salt); + this.md.update(salt); + + // step8: set DB's salt into EM + System.arraycopy(salt, 0, em, dbLen - sLen, sLen); + } + // step6: generate H using M' + this.md.digest(em, dbLen, hLen); // set H field of EM + digestReset = true; + + // step7 and 8 are already covered by the code which setting up + // EM as above + + // step9 and 10: feed H into MGF and xor with DB in EM + MGF1 mgf1 = new MGF1(mgfDigestAlgo); + mgf1.generateAndXor(em, dbLen, hLen, dbLen, em, 0); + + // step11: set the leftmost (8emLen - emBits) bits of the leftmost + // octet to 0 + int numZeroBits = (emLen << 3) - emBits; + if (numZeroBits != 0) { + byte MASK = (byte) (0xff >>> numZeroBits); + em[0] = (byte) (em[0] & MASK); + } + + // step12: em should now holds maskedDB || hash h || 0xBC + return em; + } catch (NoSuchAlgorithmException e) { + throw new IOException(e.toString()); + } + } + + /** + * Decode the signature data. Verify that the object identifier matches + * and return the message digest. + */ + private boolean decodeSignature(byte[] mHash, byte[] em) + throws IOException { + int hLen = mHash.length; + int sLen = this.sigParams.getSaltLength(); + int emLen = em.length; + int emBits = getKeyLengthInBits(this.pubKey) - 1; + + // step3 + if (emLen < (hLen + sLen + 2)) { + return false; + } + + // step4 + if (em[emLen - 1] != (byte) 0xBC) { + return false; + } + + // step6: check if the leftmost (8emLen - emBits) bits of the leftmost + // octet are 0 + int numZeroBits = (emLen << 3) - emBits; + if (numZeroBits != 0) { + byte MASK = (byte) (0xff << (8 - numZeroBits)); + if ((em[0] & MASK) != 0) { + return false; + } + } + String mgfDigestAlgo; + AlgorithmParameterSpec mgfParams = this.sigParams.getMGFParameters(); + if (mgfParams != null) { + mgfDigestAlgo = + ((MGF1ParameterSpec) mgfParams).getDigestAlgorithm(); + } else { + mgfDigestAlgo = this.md.getAlgorithm(); + } + // step 7 and 8 + int dbLen = emLen - hLen - 1; + try { + MGF1 mgf1 = new MGF1(mgfDigestAlgo); + mgf1.generateAndXor(em, dbLen, hLen, dbLen, em, 0); + } catch (NoSuchAlgorithmException nsae) { + throw new IOException(nsae.toString()); + } + + // step9: set the leftmost (8emLen - emBits) bits of the leftmost + // octet to 0 + if (numZeroBits != 0) { + byte MASK = (byte) (0xff >>> numZeroBits); + em[0] = (byte) (em[0] & MASK); + } + + // step10 + int i = 0; + for (; i < dbLen - sLen - 1; i++) { + if (em[i] != 0) { + return false; + } + } + if (em[i] != 0x01) { + return false; + } + // step12 and 13 + this.md.update(EIGHT_BYTES_OF_ZEROS); + digestReset = false; + this.md.update(mHash); + if (sLen > 0) { + this.md.update(em, (dbLen - sLen), sLen); + } + byte[] digest2 = this.md.digest(); + digestReset = true; + + // step14 + byte[] digestInEM = Arrays.copyOfRange(em, dbLen, emLen - 1); + return MessageDigest.isEqual(digest2, digestInEM); + } + + // set parameter, not supported. See JCA doc + @Deprecated + @Override + protected void engineSetParameter(String param, Object value) + throws InvalidParameterException { + throw new UnsupportedOperationException("setParameter() not supported"); + } + + @Override + protected void engineSetParameter(AlgorithmParameterSpec params) + throws InvalidAlgorithmParameterException { + this.sigParams = validateSigParams(params); + // disallow changing parameters when digest has been used + if (!digestReset) { + throw new ProviderException + ("Cannot set parameters during operations"); + } + String newHashAlg = this.sigParams.getDigestAlgorithm(); + // re-allocate md if not yet assigned or algorithm changed + if ((this.md == null) || + !(this.md.getAlgorithm().equalsIgnoreCase(newHashAlg))) { + try { + this.md = MessageDigest.getInstance(newHashAlg); + } catch (NoSuchAlgorithmException nsae) { + // should not happen as we pick default digest algorithm + throw new InvalidAlgorithmParameterException + ("Unsupported digest algorithm " + + newHashAlg, nsae); + } + } + } + + // get parameter, not supported. See JCA doc + @Deprecated + @Override + protected Object engineGetParameter(String param) + throws InvalidParameterException { + throw new UnsupportedOperationException("getParameter() not supported"); + } + + @Override + protected AlgorithmParameters engineGetParameters() { + if (this.sigParams == null) { + throw new ProviderException("Missing required PSS parameters"); + } + try { + AlgorithmParameters ap = + AlgorithmParameters.getInstance("RSASSA-PSS"); + ap.init(this.sigParams); + return ap; + } catch (GeneralSecurityException gse) { + throw new ProviderException(gse.getMessage()); + } + } +}