1 /*
  2  * Copyright (c) 2018, 2019, Oracle and/or its affiliates. All rights reserved.
  3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
  4  *
  5  * This code is free software; you can redistribute it and/or modify it
  6  * under the terms of the GNU General Public License version 2 only, as
  7  * published by the Free Software Foundation.  Oracle designates this
  8  * particular file as subject to the "Classpath" exception as provided
  9  * by Oracle in the LICENSE file that accompanied this code.
 10  *
 11  * This code is distributed in the hope that it will be useful, but WITHOUT
 12  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
 13  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
 14  * version 2 for more details (a copy is included in the LICENSE file that
 15  * accompanied this code).
 16  *
 17  * You should have received a copy of the GNU General Public License version
 18  * 2 along with this work; if not, write to the Free Software Foundation,
 19  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
 20  *
 21  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
 22  * or visit www.oracle.com if you need additional information or have any
 23  * questions.
 24  */
 25 
 26 package sun.security.rsa;
 27 
 28 import java.io.*;
 29 import sun.security.util.*;
 30 import sun.security.x509.*;
 31 import java.security.AlgorithmParametersSpi;
 32 import java.security.NoSuchAlgorithmException;
 33 import java.security.spec.AlgorithmParameterSpec;
 34 import java.security.spec.InvalidParameterSpecException;
 35 import java.security.spec.MGF1ParameterSpec;
 36 import java.security.spec.PSSParameterSpec;
 37 import static java.security.spec.PSSParameterSpec.DEFAULT;
 38 
 39 /**
 40  * This class implements the PSS parameters used with the RSA
 41  * signatures in PSS padding. Here is its ASN.1 definition:
 42  * RSASSA-PSS-params ::= SEQUENCE {
 43  *   hashAlgorithm      [0] HashAlgorithm     DEFAULT sha1,
 44  *   maskGenAlgorithm   [1] MaskGenAlgorithm  DEFAULT mgf1SHA1,
 45  *   saltLength         [2] INTEGER           DEFAULT 20
 46  *   trailerField       [3] TrailerField      DEFAULT trailerFieldBC
 47  * }
 48  *
 49  * @author Valerie Peng
 50  *
 51  */
 52 
 53 public final class PSSParameters extends AlgorithmParametersSpi {
 54 
 55     private PSSParameterSpec spec;
 56 
 57     public PSSParameters() {
 58     }
 59 
 60     @Override
 61     protected void engineInit(AlgorithmParameterSpec paramSpec)
 62             throws InvalidParameterSpecException {
 63         if (!(paramSpec instanceof PSSParameterSpec)) {
 64             throw new InvalidParameterSpecException
 65                 ("Inappropriate parameter specification");
 66         }
 67         PSSParameterSpec spec = (PSSParameterSpec) paramSpec;
 68 
 69         String mgfName = spec.getMGFAlgorithm();
 70         if (!spec.getMGFAlgorithm().equalsIgnoreCase("MGF1")) {
 71             throw new InvalidParameterSpecException("Unsupported mgf " +
 72                 mgfName + "; MGF1 only");
 73         }
 74         AlgorithmParameterSpec mgfSpec = spec.getMGFParameters();
 75         if (!(mgfSpec instanceof MGF1ParameterSpec)) {
 76             throw new InvalidParameterSpecException("Inappropriate mgf " +
 77                 "parameters; non-null MGF1ParameterSpec only");
 78         }
 79         this.spec = spec;
 80     }
 81 
 82     @Override
 83     protected void engineInit(byte[] encoded) throws IOException {
 84         // first initialize with the DEFAULT values before
 85         // retrieving from the encoding bytes
 86         String mdName = DEFAULT.getDigestAlgorithm();
 87         MGF1ParameterSpec mgfSpec = (MGF1ParameterSpec) DEFAULT.getMGFParameters();
 88         int saltLength = DEFAULT.getSaltLength();
 89         int trailerField = DEFAULT.getTrailerField();
 90 
 91         DerInputStream der = new DerInputStream(encoded);
 92         DerValue[] datum = der.getSequence(4);
 93 
 94         for (DerValue d : datum) {
 95             if (d.isContextSpecific((byte) 0x00)) {
 96                 // hash algid
 97                 mdName = AlgorithmId.parse
 98                     (d.data.getDerValue()).getName();
 99             } else if (d.isContextSpecific((byte) 0x01)) {
100                 // mgf algid
101                 AlgorithmId val = AlgorithmId.parse(d.data.getDerValue());
102                 if (!val.getOID().equals(AlgorithmId.MGF1_oid)) {
103                     throw new IOException("Only MGF1 mgf is supported");
104                 }
105                 AlgorithmId params = AlgorithmId.parse(
106                     new DerValue(val.getEncodedParams()));
107                 String mgfDigestName = params.getName();
108                 switch (mgfDigestName) {
109                 case "SHA-1":
110                     mgfSpec = MGF1ParameterSpec.SHA1;
111                     break;
112                 case "SHA-224":
113                     mgfSpec = MGF1ParameterSpec.SHA224;
114                     break;
115                 case "SHA-256":
116                     mgfSpec = MGF1ParameterSpec.SHA256;
117                     break;
118                 case "SHA-384":
119                     mgfSpec = MGF1ParameterSpec.SHA384;
120                     break;
121                 case "SHA-512":
122                     mgfSpec = MGF1ParameterSpec.SHA512;
123                     break;
124                 case "SHA-512/224":
125                     mgfSpec = MGF1ParameterSpec.SHA512_224;
126                     break;
127                 case "SHA-512/256":
128                     mgfSpec = MGF1ParameterSpec.SHA512_256;
129                     break;
130                 default:
131                     throw new IOException
132                         ("Unrecognized message digest algorithm " +
133                         mgfDigestName);
134                 }
135             } else if (d.isContextSpecific((byte) 0x02)) {
136                 // salt length
137                 saltLength = d.data.getDerValue().getInteger();
138                 if (saltLength < 0) {
139                     throw new IOException("Negative value for saltLength");
140                 }
141             } else if (d.isContextSpecific((byte) 0x03)) {
142                 // trailer field
143                 trailerField = d.data.getDerValue().getInteger();
144                 if (trailerField != 1) {
145                     throw new IOException("Unsupported trailerField value " +
146                     trailerField);
147                 }
148             } else {
149                 throw new IOException("Invalid encoded PSSParameters");
150             }
151         }
152 
153         this.spec = new PSSParameterSpec(mdName, "MGF1", mgfSpec,
154                 saltLength, trailerField);
155     }
156 
157     @Override
158     protected void engineInit(byte[] encoded, String decodingMethod)
159             throws IOException {
160         if ((decodingMethod != null) &&
161             (!decodingMethod.equalsIgnoreCase("ASN.1"))) {
162             throw new IllegalArgumentException("Only support ASN.1 format");
163         }
164         engineInit(encoded);
165     }
166 
167     @Override
168     protected <T extends AlgorithmParameterSpec>
169             T engineGetParameterSpec(Class<T> paramSpec)
170             throws InvalidParameterSpecException {
171         if (PSSParameterSpec.class.isAssignableFrom(paramSpec)) {
172             return paramSpec.cast(spec);
173         } else {
174             throw new InvalidParameterSpecException
175                 ("Inappropriate parameter specification");
176         }
177     }
178 
179     @Override
180     protected byte[] engineGetEncoded() throws IOException {
181         return getEncoded(spec);
182     }
183 
184     @Override
185     protected byte[] engineGetEncoded(String encMethod) throws IOException {
186         if ((encMethod != null) &&
187             (!encMethod.equalsIgnoreCase("ASN.1"))) {
188             throw new IllegalArgumentException("Only support ASN.1 format");
189         }
190         return engineGetEncoded();
191     }
192 
193     @Override
194     protected String engineToString() {
195         return spec.toString();
196     }
197 
198     /**
199      * Returns the encoding of a {@link PSSParameterSpec} object. This method
200      * is used in this class and {@link AlgorithmId}.
201      *
202      * @param spec a {@code PSSParameterSpec} object
203      * @return its DER encoding
204      * @throws IOException if the name of a MessageDigest or MaskGenAlgorithm
205      *          is unsupported
206      */
207     public static byte[] getEncoded(PSSParameterSpec spec) throws IOException {
208 
209         AlgorithmParameterSpec mgfSpec = spec.getMGFParameters();
210         if (!(mgfSpec instanceof MGF1ParameterSpec)) {
211             throw new IOException("Cannot encode " + mgfSpec);
212         }
213 
214         MGF1ParameterSpec mgf1Spec = (MGF1ParameterSpec)mgfSpec;
215 
216         DerOutputStream tmp = new DerOutputStream();
217         DerOutputStream tmp2, tmp3;
218 
219         // MD
220         AlgorithmId mdAlgId;
221         try {
222             mdAlgId = AlgorithmId.get(spec.getDigestAlgorithm());
223         } catch (NoSuchAlgorithmException nsae) {
224             throw new IOException("AlgorithmId " + spec.getDigestAlgorithm() +
225                     " impl not found");
226         }
227         if (!mdAlgId.getOID().equals(AlgorithmId.SHA_oid)) {
228             tmp2 = new DerOutputStream();
229             mdAlgId.derEncode(tmp2);
230             tmp.write(DerValue.createTag(DerValue.TAG_CONTEXT, true, (byte) 0),
231                     tmp2);
232         }
233 
234         // MGF
235         AlgorithmId mgfDigestId;
236         try {
237             mgfDigestId = AlgorithmId.get(mgf1Spec.getDigestAlgorithm());
238         } catch (NoSuchAlgorithmException nase) {
239             throw new IOException("AlgorithmId " +
240                     mgf1Spec.getDigestAlgorithm() + " impl not found");
241         }
242 
243         if (!mgfDigestId.getOID().equals(AlgorithmId.SHA_oid)) {
244             tmp2 = new DerOutputStream();
245             tmp2.putOID(AlgorithmId.MGF1_oid);
246             mgfDigestId.encode(tmp2);
247             tmp3 = new DerOutputStream();
248             tmp3.write(DerValue.tag_Sequence, tmp2);
249             tmp.write(DerValue.createTag(DerValue.TAG_CONTEXT, true, (byte) 1),
250                     tmp3);
251         }
252 
253         // SaltLength
254         if (spec.getSaltLength() != 20) {
255             tmp2 = new DerOutputStream();
256             tmp2.putInteger(spec.getSaltLength());
257             tmp.write(DerValue.createTag(DerValue.TAG_CONTEXT, true, (byte) 2),
258                     tmp2);
259         }
260 
261         // TrailerField
262         if (spec.getTrailerField() != PSSParameterSpec.TRAILER_FIELD_BC) {
263             tmp2 = new DerOutputStream();
264             tmp2.putInteger(spec.getTrailerField());
265             tmp.write(DerValue.createTag(DerValue.TAG_CONTEXT, true, (byte) 3),
266                     tmp2);
267         }
268 
269         // Put all together under a SEQUENCE tag
270         DerOutputStream out = new DerOutputStream();
271         out.write(DerValue.tag_Sequence, tmp);
272         return out.toByteArray();
273     }
274 }