inline GenClientContext, macro redefition, new functions JDK-8199569-branch
authorweijun
Tue, 15 May 2018 12:19:38 +0800 (2018-05-15)
branchJDK-8199569-branch
changeset 56554 9b381f73498a
parent 56553 3e490160d5ec
child 56555 0cd4e27a12cf
inline GenClientContext, macro redefition, new functions
src/java.security.jgss/windows/native/libsspi_bridge/sspi.cpp
--- a/src/java.security.jgss/windows/native/libsspi_bridge/sspi.cpp	Mon May 14 21:06:55 2018 +0800
+++ b/src/java.security.jgss/windows/native/libsspi_bridge/sspi.cpp	Tue May 15 12:19:38 2018 +0800
@@ -40,32 +40,40 @@
 
 #pragma comment(lib, "secur32.lib")
 
-//#define DEBUG
+#define DEBUG
 
 #ifdef DEBUG
 TCHAR _bb[256];
-#define SEC_SUCCESS(Status) ((Status) >= 0 ? TRUE: (FormatMessage(FORMAT_MESSAGE_FROM_SYSTEM|FORMAT_MESSAGE_IGNORE_INSERTS,0,ss,MAKELANGID(LANG_NEUTRAL, SUBLANG_DEFAULT),_bb,256,0),printf("SECURITY_STATUS: (%lx) %ls\n",ss,_bb),FALSE))
-#define P fprintf(stdout, "SSPI (%ld): \n", __LINE__); fflush(stdout);
-#define PP(s) fprintf(stdout, "SSPI (%ld): ", __LINE__); fprintf(stdout, "%s\n", s); fflush(stdout)
-#define PP1(s,n) fprintf(stdout, "SSPI (%ld): ", __LINE__); fprintf(stdout, s, n); fflush(stdout)
-#define PP2(s,n1,n2) fprintf(stdout, "SSPI (%ld): ", __LINE__); fprintf(stdout, s, n1, n2); fflush(stdout)
-#define PP3(s,n1,n2,n3) fprintf(stdout, "SSPI (%ld): ", __LINE__); fprintf(stdout, s, n1, n2, n3); fflush(stdout)
-BOOL debug = TRUE;
+#define SEC_SUCCESS(Status) \
+        ((Status) >= 0 ? TRUE: \
+        (FormatMessage(FORMAT_MESSAGE_FROM_SYSTEM|FORMAT_MESSAGE_IGNORE_INSERTS, \
+            0, ss, MAKELANGID(LANG_NEUTRAL, SUBLANG_DEFAULT), _bb, 256, 0), \
+        printf("SECURITY_STATUS: (%lx) %ls\n", ss, _bb), \
+        FALSE))
+#define PP(fmt, ...) \
+        fprintf(stdout, "SSPI (%ld): ", __LINE__); \
+        fprintf(stdout, fmt, ##__VA_ARGS__); \
+        fflush(stdout)
 #else
 #define SEC_SUCCESS(Status) ((Status) >= 0)
-#define P
-#define PP(s)
-#define PP1(s,n)
-#define PP2(s,n1,n2)
-#define PP3(s,n1,n2,n3)
-BOOL debug = FALSE;
+#define PP(dmt, ...)
 #endif
 
-char KRB5_OID[9] = {(char)0x2a, (char)0x86, (char)0x48, (char)0x86, (char)0xf7, (char)0x12, (char)0x01, (char)0x02, (char)0x02};
-char KRB5_U2U_OID[10] = {(char)0x2a, (char)0x86, (char)0x48, (char)0x86, (char)0xf7, (char)0x12, (char)0x01, (char)0x02, (char)0x02, (char)0x03};
-char SPNEGO_OID[6] = {(char)0x2b, (char)0x06, (char)0x01, (char)0x05, (char)0x05, (char)0x02};
-char USER_NAME_OID[10] = {(char)0x2a, (char)0x86, (char)0x48, (char)0x86, (char)0xf7, (char)0x12, (char)0x01, (char)0x02, (char)0x01, (char)0x01};
-char HOST_SERVICE_NAME_OID[10] = {(char)0x2a, (char)0x86, (char)0x48, (char)0x86, (char)0xf7, (char)0x12, (char)0x01, (char)0x02, (char)0x01, (char)0x04};
+char KRB5_OID[9] = {
+        (char)0x2a, (char)0x86, (char)0x48, (char)0x86, (char)0xf7, (char)0x12,
+        (char)0x01, (char)0x02, (char)0x02};
+char SPNEGO_OID[6] = {
+        (char)0x2b, (char)0x06, (char)0x01, (char)0x05, (char)0x05, (char)0x02};
+char USER_NAME_OID[10] = {
+        (char)0x2a, (char)0x86, (char)0x48, (char)0x86, (char)0xf7, (char)0x12,
+        (char)0x01, (char)0x02, (char)0x01, (char)0x01};
+char HOST_SERVICE_NAME_OID[10] = {
+        (char)0x2a, (char)0x86, (char)0x48, (char)0x86, (char)0xf7, (char)0x12,
+        (char)0x01, (char)0x02, (char)0x01, (char)0x04};
+
+// gss_name_t is SecPkgCredentials_Names*
+// gss_cred_id_t is CredHandle*
+// gss_ctx_id_t is Context*
 
 typedef struct {
     TCHAR PackageName[20];
@@ -82,7 +90,15 @@
 __declspec(dllexport) OM_uint32 gss_release_name
                                 (OM_uint32 *minor_status,
                                 gss_name_t *name) {
-    return GSS_S_FAILURE;
+    if (name != NULL) {
+        SecPkgCredentials_Names* names = (SecPkgCredentials_Names*)name;
+        if (names->sUserName != NULL) {
+            delete[] names->sUserName;
+        }
+        delete names;
+        *name = GSS_C_NO_NAME;
+    }
+    return GSS_S_COMPLETE;
 }
 
 __declspec(dllexport) OM_uint32 gss_import_name
@@ -90,20 +106,49 @@
                                 gss_buffer_t input_name_buffer,
                                 gss_OID input_name_type,
                                 gss_name_t *output_name) {
+PP("");
+    if (input_name_buffer == NULL || input_name_buffer->value == NULL
+            || input_name_buffer->length == 0) {
+        return GSS_S_BAD_NAME;
+    }
+PP("");
     SecPkgCredentials_Names* names = new SecPkgCredentials_Names();
+    if (names == NULL) {
+        goto err;
+    }
     int len = (int)input_name_buffer->length;
+PP("%d", len);
     names->sUserName = new SEC_WCHAR[len + 1];
-    MultiByteToWideChar(CP_ACP, 0, (LPSTR)input_name_buffer->value, len, names->sUserName, len);
+PP("");
+    if (names->sUserName == NULL) {
+        goto err;
+    }
+PP("");
+    if (MultiByteToWideChar(CP_ACP, 0, (LPSTR)input_name_buffer->value, len,
+            names->sUserName, len) == 0) {
+        goto err;
+    }
+PP("");
     names->sUserName[len] = 0;
-    if (input_name_type->length == 10 && !memcmp(input_name_type->elements, HOST_SERVICE_NAME_OID, 10)) {
+    if (input_name_type != NULL && input_name_type->length == 10
+            && !memcmp(input_name_type->elements, HOST_SERVICE_NAME_OID, 10)) {
         for (int i = 0; i < len; i++) {
             if (names->sUserName[i] == '@') {
                 names->sUserName[i] = '/';
+                break;
             }
         }
     }
     *output_name = (gss_name_t) names;
     return GSS_S_COMPLETE;
+err:
+    if (names != NULL && names->sUserName != NULL) {
+        delete[] names->sUserName;
+    }
+    if (names != NULL) {
+        delete names;
+    }
+    return GSS_S_FAILURE;
 }
 
 __declspec(dllexport) OM_uint32 gss_compare_name
@@ -111,7 +156,19 @@
                                 gss_name_t name1,
                                 gss_name_t name2,
                                 int *name_equal) {
-    return GSS_S_FAILURE;
+    if (name1 == NULL || name2 == NULL) {
+        *name_equal = 0;
+        return GSS_S_BAD_NAME;
+    }
+
+    SecPkgCredentials_Names* names1 = (SecPkgCredentials_Names*)name1;
+    SecPkgCredentials_Names* names2 = (SecPkgCredentials_Names*)name2;
+    if (lstrcmp(names1->sUserName, names2->sUserName)) {
+        *name_equal = 0;
+    } else {
+        *name_equal = 1;
+    }
+    return GSS_S_COMPLETE;
 }
 
 __declspec(dllexport) OM_uint32 gss_canonicalize_name
@@ -119,13 +176,25 @@
                                 gss_name_t input_name,
                                 gss_OID mech_type,
                                 gss_name_t *output_name) {
-    return GSS_S_FAILURE;
+    SecPkgCredentials_Names* names1 = (SecPkgCredentials_Names*)input_name;
+    SecPkgCredentials_Names* names2 = new SecPkgCredentials_Names();
+    names2->sUserName = new SEC_WCHAR[lstrlen(names1->sUserName) + 1];
+    lstrcpy(names2->sUserName, names1->sUserName);
+    *output_name = (gss_name_t)names2;
+    return GSS_S_COMPLETE;
 }
 
 __declspec(dllexport) OM_uint32 gss_export_name
                                 (OM_uint32 *minor_status,
                                 gss_name_t input_name,
                                 gss_buffer_t exported_name) {
+    SecPkgCredentials_Names* names = (SecPkgCredentials_Names*)input_name;
+    int len = (int)wcslen(names->sUserName);
+    char* buffer = new char[len+1];
+    WideCharToMultiByte(CP_ACP, 0, names->sUserName, len, buffer, len, NULL, NULL);
+    buffer[len] = 0;
+    exported_name->length = len+1;
+    exported_name->value = buffer;
     return GSS_S_FAILURE;
 }
 
@@ -141,8 +210,8 @@
     buffer[len] = 0;
     output_name_buffer->length = len+1;
     output_name_buffer->value = buffer;
-    PP1("Name found: %ls\n", names->sUserName);
-    PP2("%d [%s]", len, buffer);
+    PP("Name found: %ls\n", names->sUserName);
+    PP("%d [%s]", len, buffer);
     if (output_name_type != NULL) {
         gss_OID_desc* oid = new gss_OID_desc();
         oid->length = (OM_uint32)strlen(USER_NAME_OID);
@@ -158,7 +227,7 @@
     GetSystemTimeAsFileTime(&fnow);
     a = (ULARGE_INTEGER*)time;
     b = (ULARGE_INTEGER*)&fnow;
-    PP1("Difference %ld\n", (long)((a->QuadPart - b->QuadPart) / 10000000));
+    PP("Difference %ld\n", (long)((a->QuadPart - b->QuadPart) / 10000000));
     return (long)((a->QuadPart - b->QuadPart) / 10000000);
 }
 
@@ -178,7 +247,7 @@
     CredHandle* cred = new CredHandle();
     TimeStamp ts;
 	cred_usage = 0;
-    PP1("AcquireCredentialsHandle with %d\n", cred_usage);
+    PP("AcquireCredentialsHandle with %d\n", cred_usage);
     ss = AcquireCredentialsHandle(
             NULL,
             L"Kerberos",
@@ -204,7 +273,11 @@
 __declspec(dllexport) OM_uint32 gss_release_cred
                                 (OM_uint32 *minor_status,
                                 gss_cred_id_t *cred_handle) {
-    return GSS_S_FAILURE;
+    if (cred_handle && *cred_handle) {
+        FreeCredentialsHandle((CredHandle*)*cred_handle);
+        *cred_handle = GSS_C_NO_CREDENTIAL;
+    }
+    return GSS_S_COMPLETE;
 }
 
 __declspec(dllexport) OM_uint32 gss_inquire_cred
@@ -237,114 +310,6 @@
                 &pc->SecPkgContextSizes);
 }
 
-SECURITY_STATUS GenClientContext(
-        Context *pc,
-        int flag,
-        BYTE *pIn,
-        size_t cbIn,
-        BYTE *pOut,
-        size_t *pcbOut,
-        BOOL *pfDone,
-        ULONG *pOutFlag,
-        TCHAR *pszTarget) {
-    SECURITY_STATUS ss;
-    TimeStamp Lifetime;
-    SecBufferDesc OutBuffDesc;
-    SecBuffer OutSecBuff;
-    SecBufferDesc InBuffDesc;
-    SecBuffer InSecBuff;
-
-    OutBuffDesc.ulVersion = SECBUFFER_VERSION;
-    OutBuffDesc.cBuffers = 1;
-    OutBuffDesc.pBuffers = &OutSecBuff;
-
-    OutSecBuff.cbBuffer = (unsigned long)*pcbOut;
-    OutSecBuff.BufferType = SECBUFFER_TOKEN;
-    OutSecBuff.pvBuffer = pOut;
-
-    PP2("TARGET: %ls %ls\n", pszTarget, pc->PackageName);
-    PP2("flag: %x [%ls]\n", flag, pszTarget);
-    if (pIn) {
-        InBuffDesc.ulVersion = SECBUFFER_VERSION;
-        InBuffDesc.cBuffers = 1;
-        InBuffDesc.pBuffers = &InSecBuff;
-
-        InSecBuff.cbBuffer = (unsigned long)cbIn;
-        InSecBuff.BufferType = SECBUFFER_TOKEN;
-        InSecBuff.pvBuffer = pIn;
-
-        ss = InitializeSecurityContext(
-                pc->phCred,
-                &pc->hCtxt,
-                pszTarget,
-                flag,
-                0,
-                SECURITY_NATIVE_DREP,
-                &InBuffDesc,
-                0,
-                &pc->hCtxt,
-                &OutBuffDesc,
-                pOutFlag,
-                &Lifetime);
-    } else {
-        if (!pc->phCred) {
-            PP("No credentials provided, acquire automatically");
-            ss = AcquireCredentialsHandle(
-                    NULL,
-                    pc->PackageName,
-                    SECPKG_CRED_OUTBOUND,
-                    NULL,
-                    NULL,
-                    NULL,
-                    NULL,
-                    pc->phCred,
-                    &Lifetime);
-            PP("end");
-            if (!(SEC_SUCCESS(ss))) {
-                PP("Failed");
-                return ss;
-            }
-        } else {
-            PP("Credentials OK");
-        }
-        ss = InitializeSecurityContext(
-                pc->phCred,
-                NULL,
-                pszTarget,
-                flag,
-                0,
-                SECURITY_NATIVE_DREP,
-                NULL,
-                0,
-                &pc->hCtxt,
-                &OutBuffDesc,
-                pOutFlag,
-                &Lifetime);
-    }
-
-    if (!SEC_SUCCESS(ss)) {
-        PP("InitializeSecurityContext Failed");
-        return ss;
-    }
-    //-------------------------------------------------------------------
-    //  If necessary, complete the token.
-
-    if ((SEC_I_COMPLETE_NEEDED == ss)
-            || (SEC_I_COMPLETE_AND_CONTINUE == ss)) {
-        ss = CompleteAuthToken(&pc->hCtxt, &OutBuffDesc);
-        if (!SEC_SUCCESS(ss)) {
-            return ss;
-        }
-    }
-
-    *pcbOut = OutSecBuff.cbBuffer;
-
-    *pfDone = !((SEC_I_CONTINUE_NEEDED == ss) ||
-            (SEC_I_COMPLETE_AND_CONTINUE == ss));
-
-    return ss;
-}
-
 Context* NewContext(TCHAR* PackageName) {
     SECURITY_STATUS ss;
     PSecPkgInfo pkgInfo;
@@ -358,7 +323,7 @@
     }
     out->phCred = NULL;
     out->cbMaxMessage = pkgInfo->cbMaxToken;
-    PP2("   QuerySecurityPackageInfo %ls goes %ld\n", PackageName, out->cbMaxMessage);
+    PP("   QuerySecurityPackageInfo %ls goes %ld\n", PackageName, out->cbMaxMessage);
     wcscpy(out->PackageName, PackageName);
     FreeContextBuffer(pkgInfo);
     return out;
@@ -401,6 +366,11 @@
                                 OM_uint32 *ret_flags,
                                 OM_uint32 *time_rec) {
     SECURITY_STATUS ss;
+    TimeStamp Lifetime;
+    SecBufferDesc InBuffDesc;
+    SecBuffer InSecBuff;
+    SecBufferDesc OutBuffDesc;
+    SecBuffer OutSecBuff;
 
     Context* pc;
     if (input_token->length == 0) {
@@ -420,23 +390,86 @@
     OM_uint32 minor;
     gss_buffer_desc tn;
     gss_display_name(&minor, target_name, &tn, NULL);
-    MultiByteToWideChar(CP_ACP, 0, (LPCCH)tn.value, (int)tn.length, outName, (int)tn.length);
+    MultiByteToWideChar(CP_ACP, 0, (LPCCH)tn.value, (int)tn.length,
+            outName, (int)tn.length);
     outName[tn.length] = 0;
 
     BOOL pfDone;
-    ss = GenClientContext(
-            pc, flagGss2Sspi(req_flags),
-            (BYTE*)input_token->value, input_token->length,
-            (BYTE*)output_token->value, &(output_token->length),
-            &pfDone, &outFlag,
-            (TCHAR*)outName);
-    if (ss == SEC_E_OK) FillContextAfterEstablished(pc);
-	outFlag = flagSspi2Gss(outFlag);
+    int flag = flagGss2Sspi(req_flags);
+
+    OutBuffDesc.ulVersion = SECBUFFER_VERSION;
+    OutBuffDesc.cBuffers = 1;
+    OutBuffDesc.pBuffers = &OutSecBuff;
+
+    OutSecBuff.cbBuffer = (ULONG)output_token->length;
+    OutSecBuff.BufferType = SECBUFFER_TOKEN;
+    OutSecBuff.pvBuffer = output_token->value;
+
+    if (input_token->value) {
+        InBuffDesc.ulVersion = SECBUFFER_VERSION;
+        InBuffDesc.cBuffers = 1;
+        InBuffDesc.pBuffers = &InSecBuff;
+
+        InSecBuff.BufferType = SECBUFFER_TOKEN;
+        InSecBuff.cbBuffer = (ULONG)input_token->length;
+        InSecBuff.pvBuffer = input_token->value;
+    } else {
+        if (!pc->phCred) {
+            PP("No credentials provided, acquire automatically");
+            ss = AcquireCredentialsHandle(
+                    NULL,
+                    pc->PackageName,
+                    SECPKG_CRED_OUTBOUND,
+                    NULL,
+                    NULL,
+                    NULL,
+                    NULL,
+                    pc->phCred,
+                    &Lifetime);
+            PP("end");
+            if (!(SEC_SUCCESS(ss))) {
+                PP("Failed");
+                return GSS_S_FAILURE;
+            }
+        } else {
+            PP("Credentials OK");
+        }
+    }
+    ss = InitializeSecurityContext(
+            pc->phCred,
+            input_token->value ? &pc->hCtxt : NULL,
+            outName,
+            flag,
+            0,
+            SECURITY_NATIVE_DREP,
+            input_token->value ? &InBuffDesc : NULL,
+            0,
+            &pc->hCtxt,
+            &OutBuffDesc,
+            &outFlag,
+            &Lifetime);
 
 	if (!SEC_SUCCESS(ss)) {
 		return GSS_S_FAILURE;
 	}
 
+    if ((SEC_I_COMPLETE_NEEDED == ss)
+            || (SEC_I_COMPLETE_AND_CONTINUE == ss)) {
+        ss = CompleteAuthToken(&pc->hCtxt, &OutBuffDesc);
+        if (!SEC_SUCCESS(ss)) {
+            return GSS_S_FAILURE;
+        }
+    }
+
+    output_token->length =  OutSecBuff.cbBuffer;
+
+    pfDone = !((SEC_I_CONTINUE_NEEDED == ss) ||
+                (SEC_I_COMPLETE_AND_CONTINUE == ss));
+
+    if (ss == SEC_E_OK) FillContextAfterEstablished(pc);
+
+	outFlag = flagSspi2Gss(outFlag);
+
     *ret_flags = (OM_uint32)outFlag;
     if (ss == SEC_I_CONTINUE_NEEDED) {
         return GSS_S_CONTINUE_NEEDED;
@@ -527,7 +560,7 @@
     BuffDesc.ulVersion = SECBUFFER_VERSION;
 
     SecBuff[0].BufferType = SECBUFFER_DATA;
-    SecBuff[0].cbBuffer = (unsigned long)message_buffer->length;
+    SecBuff[0].cbBuffer = (ULONG)message_buffer->length;
     SecBuff[0].pvBuffer = message_buffer->value;
 
     SecBuff[1].BufferType = SECBUFFER_TOKEN;
@@ -563,11 +596,11 @@
     BuffDesc.pBuffers = SecBuff;
 
     SecBuff[0].BufferType = SECBUFFER_TOKEN;
-    SecBuff[0].cbBuffer = (unsigned long)token_buffer->length;
+    SecBuff[0].cbBuffer = (ULONG)token_buffer->length;
     SecBuff[0].pvBuffer = token_buffer->value;
 
     SecBuff[1].BufferType = SECBUFFER_DATA;
-    SecBuff[1].cbBuffer = (unsigned long)message_buffer->length;
+    SecBuff[1].cbBuffer = (ULONG)message_buffer->length;
     SecBuff[1].pvBuffer = message_buffer->value;
 
     ss = VerifySignature(&pc->hCtxt, &BuffDesc, 0, &qop);
@@ -603,19 +636,23 @@
 
     SecBuff[0].BufferType = SECBUFFER_TOKEN;
     SecBuff[0].cbBuffer = pc->SecPkgContextSizes.cbSecurityTrailer;
-    output_message_buffer->value = SecBuff[0].pvBuffer = malloc(pc->SecPkgContextSizes.cbSecurityTrailer
-            + input_message_buffer->length + pc->SecPkgContextSizes.cbBlockSize);;
+    output_message_buffer->value = SecBuff[0].pvBuffer = malloc(
+            pc->SecPkgContextSizes.cbSecurityTrailer
+                    + input_message_buffer->length
+                    + pc->SecPkgContextSizes.cbBlockSize);;
 
     SecBuff[1].BufferType = SECBUFFER_DATA;
-    SecBuff[1].cbBuffer = (unsigned long)input_message_buffer->length;
+    SecBuff[1].cbBuffer = (ULONG)input_message_buffer->length;
     SecBuff[1].pvBuffer = malloc(SecBuff[1].cbBuffer);
-    memcpy(SecBuff[1].pvBuffer, input_message_buffer->value, input_message_buffer->length);
+    memcpy(SecBuff[1].pvBuffer, input_message_buffer->value,
+            input_message_buffer->length);
 
     SecBuff[2].BufferType = SECBUFFER_PADDING;
     SecBuff[2].cbBuffer = pc->SecPkgContextSizes.cbBlockSize;
     SecBuff[2].pvBuffer = malloc(SecBuff[2].cbBuffer);
 
-    ss = EncryptMessage(&pc->hCtxt, conf_req_flag ? 0 : SECQOP_WRAP_NO_ENCRYPT, &BuffDesc, 0);
+    ss = EncryptMessage(&pc->hCtxt, conf_req_flag ? 0 : SECQOP_WRAP_NO_ENCRYPT,
+            &BuffDesc, 0);
     *conf_state = conf_req_flag;
 
     if (!SEC_SUCCESS(ss)) {
@@ -657,9 +694,11 @@
     BuffDesc.ulVersion = SECBUFFER_VERSION;
 
     SecBuff[0].BufferType = SECBUFFER_STREAM;
-    SecBuff[0].cbBuffer = (unsigned long)input_message_buffer->length;
-    output_message_buffer->value = SecBuff[0].pvBuffer = malloc(input_message_buffer->length);
-    memcpy(SecBuff[0].pvBuffer, input_message_buffer->value, input_message_buffer->length);
+    SecBuff[0].cbBuffer = (ULONG)input_message_buffer->length;
+    output_message_buffer->value = SecBuff[0].pvBuffer
+            = malloc(input_message_buffer->length);
+    memcpy(SecBuff[0].pvBuffer, input_message_buffer->value,
+            input_message_buffer->length);
 
     SecBuff[1].BufferType = SECBUFFER_DATA;
     SecBuff[1].cbBuffer = 0;
@@ -689,7 +728,7 @@
     ULONG ccPackages;
     PSecPkgInfo packages;
     EnumerateSecurityPackages(&ccPackages, &packages);
-    PP1("EnumerateSecurityPackages returns %ld\n", ccPackages);
+    PP("EnumerateSecurityPackages returns %ld\n", ccPackages);
     // TODO: only return Kerberos, so no need to check input later
     PSecPkgInfo pkgInfo;
     SECURITY_STATUS ss = QuerySecurityPackageInfo(L"Negotiate", &pkgInfo);