8205445: Add RSASSA-PSS Signature support to SunMSCAPI
authorweijun
Fri, 22 Jun 2018 21:42:00 +0800
changeset 50715 46492a773912
parent 50714 2230bb152a9f
child 50716 77fdd64c6334
child 50719 106dc156ce6b
child 56802 a48cca98dea6
8205445: Add RSASSA-PSS Signature support to SunMSCAPI Reviewed-by: xuelei
make/lib/Lib-jdk.crypto.mscapi.gmk
src/java.base/share/classes/sun/security/rsa/RSAPSSSignature.java
src/jdk.crypto.mscapi/windows/classes/sun/security/mscapi/RSASignature.java
src/jdk.crypto.mscapi/windows/classes/sun/security/mscapi/SunMSCAPI.java
src/jdk.crypto.mscapi/windows/native/libsunmscapi/security.cpp
test/jdk/sun/security/mscapi/InteropWithSunRsaSign.java
test/jdk/sun/security/rsa/pss/InitAgain.java
--- a/make/lib/Lib-jdk.crypto.mscapi.gmk	Fri Jun 22 13:20:55 2018 +0200
+++ b/make/lib/Lib-jdk.crypto.mscapi.gmk	Fri Jun 22 21:42:00 2018 +0800
@@ -35,7 +35,7 @@
       CFLAGS := $(CFLAGS_JDKLIB), \
       LDFLAGS := $(LDFLAGS_JDKLIB) $(LDFLAGS_CXX_JDK) \
           $(call SET_SHARED_LIBRARY_ORIGIN), \
-      LIBS := crypt32.lib advapi32.lib, \
+      LIBS := crypt32.lib advapi32.lib ncrypt.lib, \
   ))
 
   TARGETS += $(BUILD_LIBSUNMSCAPI)
--- a/src/java.base/share/classes/sun/security/rsa/RSAPSSSignature.java	Fri Jun 22 13:20:55 2018 +0200
+++ b/src/java.base/share/classes/sun/security/rsa/RSAPSSSignature.java	Fri Jun 22 21:42:00 2018 +0800
@@ -132,7 +132,7 @@
         }
         this.pubKey = (RSAPublicKey) isValid((RSAKey)publicKey);
         this.privKey = null;
-
+        resetDigest();
     }
 
     // initialize for signing. See JCA doc
@@ -153,6 +153,7 @@
         this.pubKey = null;
         this.random =
             (random == null? JCAUtil.getSecureRandom() : random);
+        resetDigest();
     }
 
     /**
--- a/src/jdk.crypto.mscapi/windows/classes/sun/security/mscapi/RSASignature.java	Fri Jun 22 13:20:55 2018 +0200
+++ b/src/jdk.crypto.mscapi/windows/classes/sun/security/mscapi/RSASignature.java	Fri Jun 22 21:42:00 2018 +0800
@@ -29,6 +29,9 @@
 import java.security.*;
 import java.security.spec.AlgorithmParameterSpec;
 import java.math.BigInteger;
+import java.security.spec.MGF1ParameterSpec;
+import java.security.spec.PSSParameterSpec;
+import java.util.Locale;
 
 import sun.security.rsa.RSAKeyFactory;
 
@@ -45,6 +48,7 @@
  *  . "SHA512withRSA"
  *  . "MD5withRSA"
  *  . "MD2withRSA"
+ *  . "RSASSA-PSS"
  *
  * NOTE: RSA keys must be at least 512 bits long.
  *
@@ -59,19 +63,19 @@
 abstract class RSASignature extends java.security.SignatureSpi
 {
     // message digest implementation we use
-    private final MessageDigest messageDigest;
+    protected MessageDigest messageDigest;
 
     // message digest name
     private String messageDigestAlgorithm;
 
     // flag indicating whether the digest has been reset
-    private boolean needsReset;
+    protected boolean needsReset;
 
     // the signing key
-    private Key privateKey = null;
+    protected Key privateKey = null;
 
     // the verification key
-    private Key publicKey = null;
+    protected Key publicKey = null;
 
     /**
      * Constructs a new RSASignature. Used by Raw subclass.
@@ -222,6 +226,254 @@
         }
     }
 
+    public static final class PSS extends RSASignature {
+
+        private PSSParameterSpec pssParams = null;
+
+        // Workaround: Cannot import raw public key to CNG. This signature
+        // will be used for verification if key is not from MSCAPI.
+        private Signature fallbackSignature;
+
+        @Override
+        protected void engineInitSign(PrivateKey key) throws InvalidKeyException {
+            super.engineInitSign(key);
+            fallbackSignature = null;
+        }
+
+        @Override
+        protected void engineInitVerify(PublicKey key) throws InvalidKeyException {
+            // This signature accepts only RSAPublicKey
+            if ((key instanceof java.security.interfaces.RSAPublicKey) == false) {
+                throw new InvalidKeyException("Key type not supported");
+            }
+
+            this.privateKey = null;
+
+            if (key instanceof sun.security.mscapi.RSAPublicKey) {
+                fallbackSignature = null;
+                publicKey = (sun.security.mscapi.RSAPublicKey) key;
+            } else {
+                if (fallbackSignature == null) {
+                    try {
+                        fallbackSignature = Signature.getInstance(
+                                "RSASSA-PSS", "SunRsaSign");
+                    } catch (NoSuchAlgorithmException | NoSuchProviderException e) {
+                        throw new InvalidKeyException("Invalid key", e);
+                    }
+                }
+                fallbackSignature.initVerify(key);
+                if (pssParams != null) {
+                    try {
+                        fallbackSignature.setParameter(pssParams);
+                    } catch (InvalidAlgorithmParameterException e) {
+                        throw new InvalidKeyException("Invalid params", e);
+                    }
+                }
+                publicKey = null;
+            }
+            resetDigest();
+        }
+
+        @Override
+        protected void engineUpdate(byte b) throws SignatureException {
+            ensureInit();
+            if (fallbackSignature != null) {
+                fallbackSignature.update(b);
+            } else {
+                messageDigest.update(b);
+            }
+            needsReset = true;
+        }
+
+        @Override
+        protected void engineUpdate(byte[] b, int off, int len) throws SignatureException {
+            ensureInit();
+            if (fallbackSignature != null) {
+                fallbackSignature.update(b, off, len);
+            } else {
+                messageDigest.update(b, off, len);
+            }
+            needsReset = true;
+        }
+
+        @Override
+        protected void engineUpdate(ByteBuffer input) {
+            try {
+                ensureInit();
+            } catch (SignatureException se) {
+                // hack for working around API bug
+                throw new RuntimeException(se.getMessage());
+            }
+            if (fallbackSignature != null) {
+                try {
+                    fallbackSignature.update(input);
+                } catch (SignatureException se) {
+                    // hack for working around API bug
+                    throw new RuntimeException(se.getMessage());
+                }
+            } else {
+                messageDigest.update(input);
+            }
+            needsReset = true;
+        }
+
+        @Override
+        protected byte[] engineSign() throws SignatureException {
+            ensureInit();
+            byte[] hash = getDigestValue();
+            return signPssHash(hash, hash.length,
+                    pssParams.getSaltLength(),
+                    ((MGF1ParameterSpec)
+                            pssParams.getMGFParameters()).getDigestAlgorithm(),
+                    privateKey.getHCryptProvider(), privateKey.getHCryptKey());
+        }
+
+        @Override
+        protected boolean engineVerify(byte[] sigBytes) throws SignatureException {
+            ensureInit();
+            if (fallbackSignature != null) {
+                needsReset = false;
+                return fallbackSignature.verify(sigBytes);
+            } else {
+                byte[] hash = getDigestValue();
+                return verifyPssSignedHash(
+                        hash, hash.length,
+                        sigBytes, sigBytes.length,
+                        pssParams.getSaltLength(),
+                        ((MGF1ParameterSpec)
+                                pssParams.getMGFParameters()).getDigestAlgorithm(),
+                        publicKey.getHCryptProvider(),
+                        publicKey.getHCryptKey()
+                );
+            }
+        }
+
+        @Override
+        protected void engineSetParameter(AlgorithmParameterSpec params)
+                throws InvalidAlgorithmParameterException {
+            if (needsReset) {
+                throw new ProviderException
+                        ("Cannot set parameters during operations");
+            }
+            this.pssParams = validateSigParams(params);
+            if (fallbackSignature != null) {
+                fallbackSignature.setParameter(params);
+            }
+        }
+
+        @Override
+        protected AlgorithmParameters engineGetParameters() {
+            if (this.pssParams == null) {
+                throw new ProviderException("Missing required PSS parameters");
+            }
+            try {
+                AlgorithmParameters ap =
+                        AlgorithmParameters.getInstance("RSASSA-PSS");
+                ap.init(this.pssParams);
+                return ap;
+            } catch (GeneralSecurityException gse) {
+                throw new ProviderException(gse.getMessage());
+            }
+        }
+
+        private void ensureInit() throws SignatureException {
+            if (this.privateKey == null && this.publicKey == null
+                    && fallbackSignature == null) {
+                throw new SignatureException("Missing key");
+            }
+            if (this.pssParams == null) {
+                // Parameters are required for signature verification
+                throw new SignatureException
+                        ("Parameters required for RSASSA-PSS signatures");
+            }
+            if (fallbackSignature == null && messageDigest == null) {
+                // This could happen if initVerify(softKey), setParameter(),
+                // and initSign() were called. No messageDigest. Create it.
+                try {
+                    messageDigest = MessageDigest
+                            .getInstance(pssParams.getDigestAlgorithm());
+                } catch (NoSuchAlgorithmException e) {
+                    throw new SignatureException(e);
+                }
+            }
+        }
+
+        /**
+         * Validate the specified Signature PSS parameters.
+         */
+        private PSSParameterSpec validateSigParams(AlgorithmParameterSpec p)
+                throws InvalidAlgorithmParameterException {
+
+            if (p == null) {
+                throw new InvalidAlgorithmParameterException
+                        ("Parameters cannot be null");
+            }
+
+            if (!(p instanceof PSSParameterSpec)) {
+                throw new InvalidAlgorithmParameterException
+                        ("parameters must be type PSSParameterSpec");
+            }
+
+            // no need to validate again if same as current signature parameters
+            PSSParameterSpec params = (PSSParameterSpec) p;
+            if (params == this.pssParams) return params;
+
+            // now sanity check the parameter values
+            if (!(params.getMGFAlgorithm().equalsIgnoreCase("MGF1"))) {
+                throw new InvalidAlgorithmParameterException("Only supports MGF1");
+
+            }
+
+            if (params.getTrailerField() != PSSParameterSpec.TRAILER_FIELD_BC) {
+                throw new InvalidAlgorithmParameterException
+                        ("Only supports TrailerFieldBC(1)");
+            }
+
+            AlgorithmParameterSpec algSpec = params.getMGFParameters();
+            if (!(algSpec instanceof MGF1ParameterSpec)) {
+                throw new InvalidAlgorithmParameterException
+                        ("Only support MGF1ParameterSpec");
+            }
+
+            MGF1ParameterSpec mgfSpec = (MGF1ParameterSpec)algSpec;
+
+            String msgHashAlg = params.getDigestAlgorithm()
+                    .toLowerCase(Locale.ROOT).replaceAll("-", "");
+            if (msgHashAlg.equals("sha")) {
+                msgHashAlg = "sha1";
+            }
+            String mgf1HashAlg = mgfSpec.getDigestAlgorithm()
+                    .toLowerCase(Locale.ROOT).replaceAll("-", "");
+            if (mgf1HashAlg.equals("sha")) {
+                mgf1HashAlg = "sha1";
+            }
+
+            if (!mgf1HashAlg.equals(msgHashAlg)) {
+                throw new InvalidAlgorithmParameterException
+                        ("MGF1 hash must be the same as message hash");
+            }
+
+            return params;
+        }
+
+        /**
+         * Sign hash using CNG API with HCRYPTKEY. Used by RSASSA-PSS.
+         */
+        private native static byte[] signPssHash(byte[] hash,
+                int hashSize, int saltLength, String hashAlgorithm,
+                long hCryptProv, long nCryptKey)
+                throws SignatureException;
+
+        /**
+         * Verify a signed hash using CNG API with HCRYPTKEY. Used by RSASSA-PSS.
+         * This method is not used now. See {@link #fallbackSignature}.
+         */
+        private native static boolean verifyPssSignedHash(byte[] hash, int hashSize,
+                byte[] signature, int signatureSize,
+                int saltLength, String hashAlgorithm,
+                long hCryptProv, long hKey) throws SignatureException;
+    }
+
     // initialize for signing. See JCA doc
     @Override
     protected void engineInitVerify(PublicKey key)
@@ -298,7 +550,9 @@
      */
     protected void resetDigest() {
         if (needsReset) {
-            messageDigest.reset();
+            if (messageDigest != null) {
+                messageDigest.reset();
+            }
             needsReset = false;
         }
     }
--- a/src/jdk.crypto.mscapi/windows/classes/sun/security/mscapi/SunMSCAPI.java	Fri Jun 22 13:20:55 2018 +0200
+++ b/src/jdk.crypto.mscapi/windows/classes/sun/security/mscapi/SunMSCAPI.java	Fri Jun 22 21:42:00 2018 +0800
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2005, 2016, Oracle and/or its affiliates. All rights reserved.
+ * Copyright (c) 2005, 2018, Oracle and/or its affiliates. All rights reserved.
  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
  *
  * This code is free software; you can redistribute it and/or modify it
@@ -105,6 +105,8 @@
                         return new RSASignature.MD5();
                     } else if (algo.equals("MD2withRSA")) {
                         return new RSASignature.MD2();
+                    } else if (algo.equals("RSASSA-PSS")) {
+                        return new RSASignature.PSS();
                     }
                 } else if (type.equals("KeyPairGenerator")) {
                     if (algo.equals("RSA")) {
@@ -178,6 +180,10 @@
                            new String[] { "1.2.840.113549.1.1.13", "OID.1.2.840.113549.1.1.13" },
                            attrs));
                 putService(new ProviderService(p, "Signature",
+                        "RSASSA-PSS", "sun.security.mscapi.RSASignature$PSS",
+                        new String[] { "1.2.840.113549.1.1.10", "OID.1.2.840.113549.1.1.10" },
+                        attrs));
+                putService(new ProviderService(p, "Signature",
                            "MD5withRSA", "sun.security.mscapi.RSASignature$MD5",
                            null, attrs));
                 putService(new ProviderService(p, "Signature",
--- a/src/jdk.crypto.mscapi/windows/native/libsunmscapi/security.cpp	Fri Jun 22 13:20:55 2018 +0200
+++ b/src/jdk.crypto.mscapi/windows/native/libsunmscapi/security.cpp	Fri Jun 22 21:42:00 2018 +0800
@@ -57,6 +57,18 @@
 #define SIGNATURE_EXCEPTION "java/security/SignatureException"
 #define OUT_OF_MEMORY_ERROR "java/lang/OutOfMemoryError"
 
+#define SS_CHECK(Status) \
+        if (Status != ERROR_SUCCESS) { \
+            ThrowException(env, SIGNATURE_EXCEPTION, Status); \
+            __leave; \
+        }
+
+//#define PP(fmt, ...) \
+//        fprintf(stdout, "SSPI (%ld): ", __LINE__); \
+//        fprintf(stdout, fmt, ##__VA_ARGS__); \
+//        fprintf(stdout, "\n"); \
+//        fflush(stdout)
+
 extern "C" {
 
 /*
@@ -64,6 +76,18 @@
  */
 DEF_STATIC_JNI_OnLoad
 
+//void dump(LPSTR title, PBYTE data, DWORD len)
+//{
+//    printf("==== %s ====\n", title);
+//    for (DWORD i = 0; i < len; i++) {
+//        if (i != 0 && i % 16 == 0) {
+//            printf("\n");
+//        }
+//        printf("%02X ", *(data + i) & 0xff);
+//    }
+//    printf("\n");
+//}
+
 /*
  * Throws an arbitrary Java exception with the given message.
  */
@@ -146,6 +170,37 @@
    return algId;
 }
 
+/*
+ * Maps the name of a hash algorithm to a CNG Algorithm Identifier.
+ */
+LPCWSTR MapHashIdentifier(JNIEnv *env, jstring jHashAlgorithm) {
+
+    const char* pszHashAlgorithm = NULL;
+    LPCWSTR id = NULL;
+
+    if ((pszHashAlgorithm = env->GetStringUTFChars(jHashAlgorithm, NULL))
+            == NULL) {
+        return id;
+    }
+
+    if ((strcmp("SHA", pszHashAlgorithm) == 0) ||
+        (strcmp("SHA1", pszHashAlgorithm) == 0) ||
+        (strcmp("SHA-1", pszHashAlgorithm) == 0)) {
+
+        id = BCRYPT_SHA1_ALGORITHM;
+    } else if (strcmp("SHA-256", pszHashAlgorithm) == 0) {
+        id = BCRYPT_SHA256_ALGORITHM;
+    } else if (strcmp("SHA-384", pszHashAlgorithm) == 0) {
+        id = BCRYPT_SHA384_ALGORITHM;
+    } else if (strcmp("SHA-512", pszHashAlgorithm) == 0) {
+        id = BCRYPT_SHA512_ALGORITHM;
+    }
+
+    if (pszHashAlgorithm)
+        env->ReleaseStringUTFChars(jHashAlgorithm, pszHashAlgorithm);
+
+    return id;
+}
 
 /*
  * Returns a certificate chain context given a certificate context and key
@@ -561,7 +616,6 @@
         ::CryptReleaseContext((HCRYPTPROV) hCryptProv, NULL);
 }
 
-
 /*
  * Class:     sun_security_mscapi_RSASignature
  * Method:    signHash
@@ -693,6 +747,94 @@
 }
 
 /*
+ * Class:     sun_security_mscapi_RSASignature_PSS
+ * Method:    signPssHash
+ * Signature: ([BIILjava/lang/String;JJ)[B
+ */
+JNIEXPORT jbyteArray JNICALL Java_sun_security_mscapi_RSASignature_00024PSS_signPssHash
+  (JNIEnv *env, jclass clazz, jbyteArray jHash,
+        jint jHashSize, jint saltLen, jstring jHashAlgorithm, jlong hCryptProv,
+        jlong hCryptKey)
+{
+    jbyteArray jSignedHash = NULL;
+
+    jbyte* pHashBuffer = NULL;
+    jbyte* pSignedHashBuffer = NULL;
+    NCRYPT_KEY_HANDLE hk = NULL;
+
+    __try
+    {
+        SS_CHECK(::NCryptTranslateHandle(
+                NULL,
+                &hk,
+                hCryptProv,
+                hCryptKey,
+                NULL,
+                0));
+
+        // Copy hash from Java to native buffer
+        pHashBuffer = new (env) jbyte[jHashSize];
+        if (pHashBuffer == NULL) {
+            __leave;
+        }
+        env->GetByteArrayRegion(jHash, 0, jHashSize, pHashBuffer);
+
+        BCRYPT_PSS_PADDING_INFO pssInfo;
+        pssInfo.pszAlgId = MapHashIdentifier(env, jHashAlgorithm);
+        pssInfo.cbSalt = saltLen;
+
+        if (pssInfo.pszAlgId == NULL) {
+            ThrowExceptionWithMessage(env, SIGNATURE_EXCEPTION,
+                    "Unrecognised hash algorithm");
+            __leave;
+        }
+
+        DWORD dwBufLen = 0;
+        SS_CHECK(::NCryptSignHash(
+                hk,
+                &pssInfo,
+                (BYTE*)pHashBuffer, jHashSize,
+                NULL, 0, &dwBufLen,
+                BCRYPT_PAD_PSS
+                ));
+
+        pSignedHashBuffer = new (env) jbyte[dwBufLen];
+        if (pSignedHashBuffer == NULL) {
+            __leave;
+        }
+
+        SS_CHECK(::NCryptSignHash(
+                hk,
+                &pssInfo,
+                (BYTE*)pHashBuffer, jHashSize,
+                (BYTE*)pSignedHashBuffer, dwBufLen, &dwBufLen,
+                BCRYPT_PAD_PSS
+                ));
+
+        // Create new byte array
+        jbyteArray temp = env->NewByteArray(dwBufLen);
+
+        // Copy data from native buffer
+        env->SetByteArrayRegion(temp, 0, dwBufLen, pSignedHashBuffer);
+
+        jSignedHash = temp;
+    }
+    __finally
+    {
+        if (pSignedHashBuffer)
+            delete [] pSignedHashBuffer;
+
+        if (pHashBuffer)
+            delete [] pHashBuffer;
+
+        if (hk != NULL)
+            ::NCryptFreeObject(hk);
+    }
+
+    return jSignedHash;
+}
+
+/*
  * Class:     sun_security_mscapi_RSASignature
  * Method:    verifySignedHash
  * Signature: ([BIL/java/lang/String;[BIJJ)Z
@@ -798,6 +940,85 @@
 }
 
 /*
+ * Class:     sun_security_mscapi_RSASignature_PSS
+ * Method:    verifyPssSignedHash
+ * Signature: ([BI[BIILjava/lang/String;JJ)Z
+ */
+JNIEXPORT jboolean JNICALL Java_sun_security_mscapi_RSASignature_00024PSS_verifyPssSignedHash
+  (JNIEnv *env, jclass clazz,
+        jbyteArray jHash, jint jHashSize,
+        jbyteArray jSignedHash, jint jSignedHashSize,
+        jint saltLen, jstring jHashAlgorithm,
+        jlong hCryptProv, jlong hKey)
+{
+    jbyte* pHashBuffer = NULL;
+    jbyte* pSignedHashBuffer = NULL;
+    jboolean result = JNI_FALSE;
+    NCRYPT_KEY_HANDLE hk = NULL;
+
+    __try
+    {
+        SS_CHECK(::NCryptTranslateHandle(
+                NULL,
+                &hk,
+                hCryptProv,
+                hKey,
+                NULL,
+                0));
+
+        // Copy hash and signedHash from Java to native buffer
+        pHashBuffer = new (env) jbyte[jHashSize];
+        if (pHashBuffer == NULL) {
+            __leave;
+        }
+        env->GetByteArrayRegion(jHash, 0, jHashSize, pHashBuffer);
+
+        pSignedHashBuffer = new (env) jbyte[jSignedHashSize];
+        if (pSignedHashBuffer == NULL) {
+            __leave;
+        }
+        env->GetByteArrayRegion(jSignedHash, 0, jSignedHashSize,
+            pSignedHashBuffer);
+
+        BCRYPT_PSS_PADDING_INFO pssInfo;
+        pssInfo.pszAlgId = MapHashIdentifier(env, jHashAlgorithm);
+        pssInfo.cbSalt = saltLen;
+
+        if (pssInfo.pszAlgId == NULL) {
+            ThrowExceptionWithMessage(env, SIGNATURE_EXCEPTION,
+                    "Unrecognised hash algorithm");
+            __leave;
+        }
+
+        // For RSA, the hash encryption algorithm is normally the same as the
+        // public key algorithm, so AT_SIGNATURE is used.
+
+        // Verify the signature
+        if (::NCryptVerifySignature(hk, &pssInfo,
+                (BYTE *) pHashBuffer, jHashSize,
+                (BYTE *) pSignedHashBuffer, jSignedHashSize,
+                NCRYPT_PAD_PSS_FLAG) == ERROR_SUCCESS)
+        {
+            result = JNI_TRUE;
+        }
+    }
+
+    __finally
+    {
+        if (pSignedHashBuffer)
+            delete [] pSignedHashBuffer;
+
+        if (pHashBuffer)
+            delete [] pHashBuffer;
+
+        if (hk != NULL)
+            ::NCryptFreeObject(hk);
+    }
+
+    return result;
+}
+
+/*
  * Class:     sun_security_mscapi_RSAKeyPairGenerator
  * Method:    generateRSAKeyPair
  * Signature: (ILjava/lang/String;)Lsun/security/mscapi/RSAKeyPair;
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/test/jdk/sun/security/mscapi/InteropWithSunRsaSign.java	Fri Jun 22 21:42:00 2018 +0800
@@ -0,0 +1,171 @@
+/*
+ * Copyright (c) 2018, Oracle and/or its affiliates. All rights reserved.
+ * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
+ *
+ * This code is free software; you can redistribute it and/or modify it
+ * under the terms of the GNU General Public License version 2 only, as
+ * published by the Free Software Foundation.
+ *
+ * This code is distributed in the hope that it will be useful, but WITHOUT
+ * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
+ * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
+ * version 2 for more details (a copy is included in the LICENSE file that
+ * accompanied this code).
+ *
+ * You should have received a copy of the GNU General Public License version
+ * 2 along with this work; if not, write to the Free Software Foundation,
+ * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
+ *
+ * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
+ * or visit www.oracle.com if you need additional information or have any
+ * questions.
+ */
+
+
+/**
+ * @test
+ * @bug 8205445
+ * @summary Interop test between SunMSCAPI and SunRsaSign on RSASSA-PSS
+ * @requires os.family == "windows"
+ */
+
+import java.security.KeyPair;
+import java.security.KeyPairGenerator;
+import java.security.PrivateKey;
+import java.security.PublicKey;
+import java.security.SecureRandom;
+import java.security.Signature;
+import java.security.spec.MGF1ParameterSpec;
+import java.security.spec.PSSParameterSpec;
+import java.util.Random;
+
+public class InteropWithSunRsaSign {
+
+    private static final SecureRandom NOT_SECURE_RANDOM = new SecureRandom() {
+        Random r = new Random();
+        @Override
+        public void nextBytes(byte[] bytes) {
+            r.nextBytes(bytes);
+        }
+    };
+
+    private static boolean allResult = true;
+    private static byte[] msg = "hello".getBytes();
+
+    public static void main(String[] args) throws Exception {
+
+        matrix(new PSSParameterSpec(
+                "SHA-1",
+                "MGF1",
+                MGF1ParameterSpec.SHA1,
+                20,
+                PSSParameterSpec.TRAILER_FIELD_BC));
+
+        matrix(new PSSParameterSpec(
+                "SHA-256",
+                "MGF1",
+                MGF1ParameterSpec.SHA256,
+                32,
+                PSSParameterSpec.TRAILER_FIELD_BC));
+
+        matrix(new PSSParameterSpec(
+                "SHA-384",
+                "MGF1",
+                MGF1ParameterSpec.SHA384,
+                48,
+                PSSParameterSpec.TRAILER_FIELD_BC));
+
+        matrix(new PSSParameterSpec(
+                "SHA-512",
+                "MGF1",
+                MGF1ParameterSpec.SHA512,
+                64,
+                PSSParameterSpec.TRAILER_FIELD_BC));
+
+        // non-typical salt length
+        matrix(new PSSParameterSpec(
+                "SHA-1",
+                "MGF1",
+                MGF1ParameterSpec.SHA1,
+                17,
+                PSSParameterSpec.TRAILER_FIELD_BC));
+
+        if (!allResult) {
+            throw new Exception("Failed");
+        }
+    }
+
+    static void matrix(PSSParameterSpec pss) throws Exception {
+
+        System.out.printf("\n%10s%20s%20s%20s  %s\n", pss.getDigestAlgorithm(),
+                "KeyPairGenerator", "signer", "verifier", "result");
+        System.out.printf("%10s%20s%20s%20s  %s\n",
+                "-------", "----------------", "------", "--------", "------");
+
+        // KeyPairGenerator chooses SPI when getInstance() is called.
+        String[] provsForKPG = {"SunRsaSign", "SunMSCAPI"};
+
+        // "-" means no preferred provider. In this case, SPI is chosen
+        // when initSign/initVerify is called. Worth testing.
+        String[] provsForSignature = {"SunRsaSign", "SunMSCAPI", "-"};
+
+        int pos = 0;
+        for (String pg : provsForKPG) {
+            for (String ps : provsForSignature) {
+                for (String pv : provsForSignature) {
+                    System.out.printf("%10d%20s%20s%20s  ", ++pos, pg, ps, pv);
+                    try {
+                        boolean result = test(pg, ps, pv, pss);
+                        System.out.println(result);
+                        if (!result) {
+                            allResult = false;
+                        }
+                    } catch (Exception e) {
+                        if (pg.equals("-") || pg.equals(ps)) {
+                            // When Signature provider is automatically
+                            // chosen or the same with KeyPairGenerator,
+                            // this is an error.
+                            allResult = false;
+                            System.out.println("X " + e.getMessage());
+                        } else {
+                            // Known restriction: SunRsaSign and SunMSCAPI can't
+                            // use each other's private key for signing.
+                            System.out.println(e.getMessage());
+                        }
+                    }
+                }
+            }
+        }
+    }
+
+    static boolean test(String pg, String ps, String pv, PSSParameterSpec pss)
+            throws Exception {
+
+        KeyPairGenerator kpg = pg.length() == 1
+                ? KeyPairGenerator.getInstance("RSA")
+                :KeyPairGenerator.getInstance("RSA", pg);
+        kpg.initialize(
+                pss.getDigestAlgorithm().equals("SHA-512") ? 2048: 1024,
+                NOT_SECURE_RANDOM);
+        KeyPair kp = kpg.generateKeyPair();
+        PrivateKey pr = kp.getPrivate();
+        PublicKey pu = kp.getPublic();
+
+        Signature s = ps.length() == 1
+                ? Signature.getInstance("RSASSA-PSS")
+                : Signature.getInstance("RSASSA-PSS", ps);
+        s.initSign(pr);
+        s.setParameter(pss);
+        s.update(msg);
+        byte[] sig = s.sign();
+
+        Signature s2 = pv.length() == 1
+                ? Signature.getInstance("RSASSA-PSS")
+                : Signature.getInstance("RSASSA-PSS", pv);
+        s2.initVerify(pu);
+        s2.setParameter(pss);
+        s2.update(msg);
+
+        return s2.verify(sig);
+    }
+}
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/test/jdk/sun/security/rsa/pss/InitAgain.java	Fri Jun 22 21:42:00 2018 +0800
@@ -0,0 +1,69 @@
+/*
+ * Copyright (c) 2018, Oracle and/or its affiliates. All rights reserved.
+ * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
+ *
+ * This code is free software; you can redistribute it and/or modify it
+ * under the terms of the GNU General Public License version 2 only, as
+ * published by the Free Software Foundation.
+ *
+ * This code is distributed in the hope that it will be useful, but WITHOUT
+ * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
+ * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
+ * version 2 for more details (a copy is included in the LICENSE file that
+ * accompanied this code).
+ *
+ * You should have received a copy of the GNU General Public License version
+ * 2 along with this work; if not, write to the Free Software Foundation,
+ * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
+ *
+ * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
+ * or visit www.oracle.com if you need additional information or have any
+ * questions.
+ */
+import java.security.*;
+import java.security.spec.*;
+
+/**
+ * @test
+ * @bug 8205445
+ * @summary Make sure old state is cleared when init is called again
+ */
+public class InitAgain {
+
+    public static void main(String[] args) throws Exception {
+
+        byte[] msg = "hello".getBytes();
+
+        Signature s1 = Signature.getInstance("RSASSA-PSS");
+        Signature s2 = Signature.getInstance("RSASSA-PSS");
+
+        s1.setParameter(PSSParameterSpec.DEFAULT);
+        s2.setParameter(PSSParameterSpec.DEFAULT);
+
+        KeyPairGenerator kpg = KeyPairGenerator.getInstance("RSA");
+        kpg.initialize(1024);
+        KeyPair kp = kpg.generateKeyPair();
+
+        s1.initSign(kp.getPrivate());
+        s1.update(msg);
+        s1.initSign(kp.getPrivate());
+        s1.update(msg);
+        // Data digested in s1:
+        // Before this fix, msg | msg
+        // After this fix, msg
+
+        s2.initVerify(kp.getPublic());
+        s2.update(msg);
+        s2.initVerify(kp.getPublic());
+        s2.update(msg);
+        s2.initVerify(kp.getPublic());
+        s2.update(msg);
+        // Data digested in s2:
+        // Before this fix, msg | msg | msg
+        // After this fix, msg
+
+        if (!s2.verify(s1.sign())) {
+            throw new Exception();
+        }
+    }
+}