1 /*
  2  * Copyright (c) 2018, 2020, Oracle and/or its affiliates. All rights reserved.
  3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
  4  *
  5  * This code is free software; you can redistribute it and/or modify it
  6  * under the terms of the GNU General Public License version 2 only, as
  7  * published by the Free Software Foundation.
  8  *
  9  * This code is distributed in the hope that it will be useful, but WITHOUT
 10  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
 11  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
 12  * version 2 for more details (a copy is included in the LICENSE file that
 13  * accompanied this code).
 14  *
 15  * You should have received a copy of the GNU General Public License version
 16  * 2 along with this work; if not, write to the Free Software Foundation,
 17  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
 18  *
 19  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
 20  * or visit www.oracle.com if you need additional information or have any
 21  * questions.
 22  */
 23 import java.security.*;
 24 import java.security.interfaces.RSAPrivateKey;
 25 import java.security.interfaces.RSAPublicKey;
 26 import java.security.spec.*;
 27 import java.util.Arrays;
 28 import java.util.stream.IntStream;
 29 import static javax.crypto.Cipher.PRIVATE_KEY;
 30 import static javax.crypto.Cipher.PUBLIC_KEY;
 31 
 32 /**
 33  * @test
 34  * @bug 8146293 8242556 8172366
 35  * @summary Test RSASSA-PSS AlgorithmParameters impl of SunRsaSign provider.
 36  * @run main PSSParametersTest
 37  */
 38 public class PSSParametersTest {
 39     /**
 40      * JDK default RSA Provider.
 41      */
 42     private static final String PROVIDER = "SunRsaSign";
 43 
 44     private static final String PSS_ALGO = "RSASSA-PSS";
 45     private static final String PSS_OID = "1.2.840.113549.1.1.10";
 46 
 47     public static void main(String[] args) throws Exception {
 48         System.out.println("Testing against DEFAULT parameters");
 49         test(PSSParameterSpec.DEFAULT);
 50         System.out.println("Testing against custom parameters");
 51         test(new PSSParameterSpec("SHA-512/224", "MGF1",
 52                 MGF1ParameterSpec.SHA384, 100, 1));
 53         test(new PSSParameterSpec("SHA3-256", "MGF1",
 54             new MGF1ParameterSpec("SHA3-256"), 256>>3, 1));
 55         System.out.println("Test Passed");
 56     }
 57 
 58     // test the given spec by first initializing w/ it, generate the DER
 59     // bytes, then initialize w/ the DER bytes, retrieve the spec.
 60     // compare both spec for equality and throw exception if the check failed.
 61     private static void test(PSSParameterSpec spec) throws Exception {
 62         System.out.println("Testing PSS spec: " + spec);
 63         String ALGORITHMS[] = { PSS_ALGO, PSS_OID };
 64         for (String alg : ALGORITHMS) {
 65             AlgorithmParameters params = AlgorithmParameters.getInstance
 66                     (alg, PROVIDER);
 67             params.init(spec);
 68             byte[] encoded = params.getEncoded();
 69             AlgorithmParameters params2 = AlgorithmParameters.getInstance
 70                     (alg, PROVIDER);
 71             params2.init(encoded);
 72             PSSParameterSpec spec2 = params2.getParameterSpec
 73                     (PSSParameterSpec.class);
 74             if (!isEqual(spec, spec2)) {
 75                 throw new RuntimeException("Spec check Failed for " + alg);
 76             }
 77         }
 78     }
 79 
 80     private static boolean isEqual(PSSParameterSpec spec,
 81             PSSParameterSpec spec2) throws Exception {
 82         if (spec == spec2) return true;
 83         if (spec == null || spec2 == null) return false;
 84 
 85         if (!spec.getDigestAlgorithm().equals(spec2.getDigestAlgorithm())) {
 86             System.out.println("Different digest algorithms: " +
 87                 spec.getDigestAlgorithm() + " vs " + spec2.getDigestAlgorithm());
 88             return false;
 89         }
 90         if (!spec.getMGFAlgorithm().equals(spec2.getMGFAlgorithm())) {
 91             System.out.println("Different MGF algorithms: " +
 92                 spec.getMGFAlgorithm() + " vs " + spec2.getMGFAlgorithm());
 93             return false;
 94         }
 95         if (spec.getSaltLength() != spec2.getSaltLength()) {
 96             System.out.println("Different Salt Length: " +
 97                 spec.getSaltLength() + " vs " + spec2.getSaltLength());
 98             return false;
 99         }
100         if (spec.getTrailerField() != spec2.getTrailerField()) {
101             System.out.println("Different TrailerField: " +
102                 spec.getTrailerField() + " vs " + spec2.getTrailerField());
103             return false;
104         }
105         // continue checking MGF Parameters
106         AlgorithmParameterSpec mgfParams = spec.getMGFParameters();
107         AlgorithmParameterSpec mgfParams2 = spec2.getMGFParameters();
108         if (mgfParams == mgfParams2) return true;
109         if (mgfParams == null || mgfParams2 == null) {
110             System.out.println("Different MGF Parameters: " +
111                 mgfParams + " vs " + mgfParams2);
112             return false;
113         }
114         if (mgfParams instanceof MGF1ParameterSpec) {
115             if (mgfParams2 instanceof MGF1ParameterSpec) {
116                 boolean result =
117                     ((MGF1ParameterSpec)mgfParams).getDigestAlgorithm().equals
118                          (((MGF1ParameterSpec)mgfParams2).getDigestAlgorithm());
119                 if (!result) {
120                     System.out.println("Different MGF1 digest algorithms: " +
121                         ((MGF1ParameterSpec)mgfParams).getDigestAlgorithm() +
122                         " vs " +
123                         ((MGF1ParameterSpec)mgfParams2).getDigestAlgorithm());
124                 }
125                 return result;
126             } else {
127                 System.out.println("Different MGF Parameters types: " +
128                     mgfParams.getClass() + " vs " + mgfParams2.getClass());
129                 return false;
130             }
131         }
132         throw new RuntimeException("Unrecognized MGFParameters: " + mgfParams);
133     }
134 }