8210870: Libsunmscapi improved interactions
authorweijun
Mon, 08 Oct 2018 12:55:04 +0800
changeset 53330 e8bae92beee3
parent 53329 d845ee36da70
child 53331 b9149d907610
8210870: Libsunmscapi improved interactions Reviewed-by: valeriep, mschoene, rhalade
src/jdk.crypto.mscapi/windows/classes/sun/security/mscapi/KeyStore.java
src/jdk.crypto.mscapi/windows/native/libsunmscapi/security.cpp
--- a/src/jdk.crypto.mscapi/windows/classes/sun/security/mscapi/KeyStore.java	Fri Oct 05 11:36:30 2018 -0700
+++ b/src/jdk.crypto.mscapi/windows/classes/sun/security/mscapi/KeyStore.java	Mon Oct 08 12:55:04 2018 +0800
@@ -753,6 +753,7 @@
     /**
      * Generates a certificate chain from the collection of
      * certificates and stores the result into a key entry.
+     * This method is called by native code in libsunmscapi.
      */
     private void generateCertificateChain(String alias,
         Collection<? extends Certificate> certCollection)
@@ -775,13 +776,15 @@
         catch (Throwable e)
         {
             // Ignore the exception and skip this entry
-            // TODO - throw CertificateException?
+            // If e is thrown, remember to deal with it in
+            // native code.
         }
     }
 
     /**
      * Generates RSA key and certificate chain from the private key handle,
      * collection of certificates and stores the result into key entries.
+     * This method is called by native code in libsunmscapi.
      */
     private void generateRSAKeyAndCertificateChain(String alias,
         long hCryptProv, long hCryptKey, int keyLength,
@@ -807,12 +810,14 @@
         catch (Throwable e)
         {
             // Ignore the exception and skip this entry
-            // TODO - throw CertificateException?
+            // If e is thrown, remember to deal with it in
+            // native code.
         }
     }
 
     /**
      * Generates certificates from byte data and stores into cert collection.
+     * This method is called by native code in libsunmscapi.
      *
      * @param data Byte data.
      * @param certCollection Collection of certificates.
@@ -836,12 +841,14 @@
         catch (CertificateException e)
         {
             // Ignore the exception and skip this certificate
-            // TODO - throw CertificateException?
+            // If e is thrown, remember to deal with it in
+            // native code.
         }
         catch (Throwable te)
         {
             // Ignore the exception and skip this certificate
-            // TODO - throw CertificateException?
+            // If e is thrown, remember to deal with it in
+            // native code.
         }
     }
 
--- a/src/jdk.crypto.mscapi/windows/native/libsunmscapi/security.cpp	Fri Oct 05 11:36:30 2018 -0700
+++ b/src/jdk.crypto.mscapi/windows/native/libsunmscapi/security.cpp	Mon Oct 08 12:55:04 2018 +0800
@@ -511,6 +511,15 @@
                     // Create ArrayList to store certs in each chain
                     jobject jArrayList =
                         env->NewObject(clazzArrayList, mNewArrayList);
+                    if (jArrayList == NULL) {
+                        __leave;
+                    }
+
+                    // Cleanup the previous allocated name
+                    if (pszNameString) {
+                        delete [] pszNameString;
+                        pszNameString = NULL;
+                    }
 
                     for (unsigned int j=0; j < rgpChain->cElement; j++)
                     {
@@ -549,6 +558,9 @@
 
                         // Allocate and populate byte array
                         jbyteArray byteArray = env->NewByteArray(cbCertEncoded);
+                        if (byteArray == NULL) {
+                            __leave;
+                        }
                         env->SetByteArrayRegion(byteArray, 0, cbCertEncoded,
                             (jbyte*) pbCertEncoded);
 
@@ -557,30 +569,44 @@
                         env->CallVoidMethod(obj, mGenCert, byteArray, jArrayList);
                     }
 
-                    if (bHasNoPrivateKey)
-                    {
-                        // Generate certificate chain and store into cert chain
-                        // collection
-                        env->CallVoidMethod(obj, mGenCertChain,
-                            env->NewStringUTF(pszNameString),
-                            jArrayList);
-                    }
-                    else
+                    // Usually pszNameString should be non-NULL. It's either
+                    // the friendly name or an element from the subject name
+                    // or SAN.
+                    if (pszNameString)
                     {
-                        // Determine key type: RSA or DSA
-                        DWORD dwData = CALG_RSA_KEYX;
-                        DWORD dwSize = sizeof(DWORD);
-                        ::CryptGetKeyParam(hUserKey, KP_ALGID, (BYTE*)&dwData,
-                                &dwSize, NULL);
+                        if (bHasNoPrivateKey)
+                        {
+                            // Generate certificate chain and store into cert chain
+                            // collection
+                            jstring name = env->NewStringUTF(pszNameString);
+                            if (name == NULL) {
+                                __leave;
+                            }
+                            env->CallVoidMethod(obj, mGenCertChain,
+                                name,
+                                jArrayList);
+                        }
+                        else
+                        {
+                            // Determine key type: RSA or DSA
+                            DWORD dwData = CALG_RSA_KEYX;
+                            DWORD dwSize = sizeof(DWORD);
+                            ::CryptGetKeyParam(hUserKey, KP_ALGID, (BYTE*)&dwData,
+                                    &dwSize, NULL);
 
-                        if ((dwData & ALG_TYPE_RSA) == ALG_TYPE_RSA)
-                        {
-                            // Generate RSA certificate chain and store into cert
-                            // chain collection
-                            env->CallVoidMethod(obj, mGenRSAKeyAndCertChain,
-                                    env->NewStringUTF(pszNameString),
-                                    (jlong) hCryptProv, (jlong) hUserKey,
-                                    dwPublicKeyLength, jArrayList);
+                            if ((dwData & ALG_TYPE_RSA) == ALG_TYPE_RSA)
+                            {
+                                // Generate RSA certificate chain and store into cert
+                                // chain collection
+                                jstring name = env->NewStringUTF(pszNameString);
+                                if (name == NULL) {
+                                    __leave;
+                                }
+                                env->CallVoidMethod(obj, mGenRSAKeyAndCertChain,
+                                        name,
+                                        (jlong) hCryptProv, (jlong) hUserKey,
+                                        dwPublicKeyLength, jArrayList);
+                            }
                         }
                     }
                 }
@@ -726,6 +752,9 @@
 
         // Create new byte array
         jbyteArray temp = env->NewByteArray(dwBufLen);
+        if (temp == NULL) {
+            __leave;
+        }
 
         // Copy data from native buffer
         env->SetByteArrayRegion(temp, 0, dwBufLen, pSignedHashBuffer);
@@ -817,6 +846,9 @@
 
         // Create new byte array
         jbyteArray temp = env->NewByteArray(dwBufLen);
+        if (temp == NULL) {
+            __leave;
+        }
 
         // Copy data from native buffer
         env->SetByteArrayRegion(temp, 0, dwBufLen, pSignedHashBuffer);
@@ -1216,6 +1248,9 @@
         }
 
         jCertAliasChars = env->GetStringChars(jCertAliasName, NULL);
+        if (jCertAliasChars == NULL) {
+            __leave;
+        }
         memcpy(pszCertAliasName, jCertAliasChars, size * sizeof(WCHAR));
         pszCertAliasName[size] = 0; // append the string terminator
 
@@ -1646,7 +1681,9 @@
         }
 
         // Create new byte array
-        result = env->NewByteArray(dwBufLen);
+        if ((result = env->NewByteArray(dwBufLen)) == NULL) {
+            __leave;
+        }
 
         // Copy data from native buffer to Java buffer
         env->SetByteArrayRegion(result, 0, dwBufLen, (jbyte*) pData);
@@ -1697,7 +1734,9 @@
         }
 
         // Create new byte array
-        blob = env->NewByteArray(dwBlobLen);
+        if ((blob = env->NewByteArray(dwBlobLen)) == NULL) {
+            __leave;
+        }
 
         // Copy data from native buffer to Java buffer
         env->SetByteArrayRegion(blob, 0, dwBlobLen, (jbyte*) pbKeyBlob);
@@ -1726,6 +1765,13 @@
     __try {
 
         jsize length = env->GetArrayLength(jKeyBlob);
+        jsize headerLength = sizeof(PUBLICKEYSTRUC) + sizeof(RSAPUBKEY);
+
+        if (length < headerLength) {
+            ThrowExceptionWithMessage(env, KEY_EXCEPTION, "Invalid BLOB");
+            __leave;
+        }
+
         if ((keyBlob = env->GetByteArrayElements(jKeyBlob, 0)) == NULL) {
             __leave;
         }
@@ -1752,7 +1798,9 @@
             exponentBytes[i] = ((BYTE*) &pRsaPubKey->pubexp)[j];
         }
 
-        exponent = env->NewByteArray(len);
+        if ((exponent = env->NewByteArray(len)) == NULL) {
+            __leave;
+        }
         env->SetByteArrayRegion(exponent, 0, len, exponentBytes);
     }
     __finally
@@ -1782,6 +1830,13 @@
     __try {
 
         jsize length = env->GetArrayLength(jKeyBlob);
+        jsize headerLength = sizeof(PUBLICKEYSTRUC) + sizeof(RSAPUBKEY);
+
+        if (length < headerLength) {
+            ThrowExceptionWithMessage(env, KEY_EXCEPTION, "Invalid BLOB");
+            __leave;
+        }
+
         if ((keyBlob = env->GetByteArrayElements(jKeyBlob, 0)) == NULL) {
             __leave;
         }
@@ -1798,19 +1853,25 @@
             (RSAPUBKEY *) (keyBlob + sizeof(PUBLICKEYSTRUC));
 
         int len = pRsaPubKey->bitlen / 8;
+        if (len < 0 || len > length - headerLength) {
+            ThrowExceptionWithMessage(env, KEY_EXCEPTION, "Invalid key length");
+            __leave;
+        }
+
         modulusBytes = new (env) jbyte[len];
         if (modulusBytes == NULL) {
             __leave;
         }
-        BYTE * pbModulus =
-            (BYTE *) (keyBlob + sizeof(PUBLICKEYSTRUC) + sizeof(RSAPUBKEY));
+        BYTE * pbModulus = (BYTE *) (keyBlob + headerLength);
 
         // convert from little-endian while copying from blob
         for (int i = 0, j = len - 1; i < len; i++, j--) {
             modulusBytes[i] = pbModulus[j];
         }
 
-        modulus = env->NewByteArray(len);
+        if ((modulus = env->NewByteArray(len)) == NULL) {
+            __leave;
+        }
         env->SetByteArrayRegion(modulus, 0, len, modulusBytes);
     }
     __finally
@@ -2027,7 +2088,9 @@
             }
         }
 
-        jBlob = env->NewByteArray(jBlobLength);
+        if ((jBlob = env->NewByteArray(jBlobLength)) == NULL) {
+            __leave;
+        }
         env->SetByteArrayRegion(jBlob, 0, jBlobLength, jBlobBytes);
 
     }