rpcrt4: Try a lot harder to resuse existing connections by comparing inside the RpcQualityOfService and RpcAuthInfo objects.
Store a copy of the SEC_WINNT_AUTH_IDENTITY structure passed in to
RpcBindingSetAuthInfo(Ex) to enable us to do this for RpcAuthInfo objects.
diff --git a/dlls/rpcrt4/rpc_binding.c b/dlls/rpcrt4/rpc_binding.c
index 56dcc9e..611041c 100644
--- a/dlls/rpcrt4/rpc_binding.c
+++ b/dlls/rpcrt4/rpc_binding.c
@@ -78,6 +78,17 @@
return s;
}
+static LPWSTR RPCRT4_strndupAtoW(LPCSTR src, INT slen)
+{
+ DWORD len;
+ LPWSTR s;
+ if (!src) return NULL;
+ len = MultiByteToWideChar(CP_ACP, 0, src, slen, NULL, 0);
+ s = HeapAlloc(GetProcessHeap(), 0, len*sizeof(WCHAR));
+ MultiByteToWideChar(CP_ACP, 0, src, slen, s, len);
+ return s;
+}
+
LPWSTR RPCRT4_strndupW(LPCWSTR src, INT slen)
{
DWORD len;
@@ -967,9 +978,24 @@
return RPC_S_OK;
}
+static inline BOOL has_nt_auth_identity(ULONG AuthnLevel)
+{
+ switch (AuthnLevel)
+ {
+ case RPC_C_AUTHN_GSS_NEGOTIATE:
+ case RPC_C_AUTHN_WINNT:
+ case RPC_C_AUTHN_GSS_KERBEROS:
+ return TRUE;
+ default:
+ return FALSE;
+ }
+}
+
static RPC_STATUS RpcAuthInfo_Create(ULONG AuthnLevel, ULONG AuthnSvc,
CredHandle cred, TimeStamp exp,
- ULONG cbMaxToken, RpcAuthInfo **ret)
+ ULONG cbMaxToken,
+ RPC_AUTH_IDENTITY_HANDLE identity,
+ RpcAuthInfo **ret)
{
RpcAuthInfo *AuthInfo = HeapAlloc(GetProcessHeap(), 0, sizeof(*AuthInfo));
if (!AuthInfo)
@@ -981,6 +1007,51 @@
AuthInfo->cred = cred;
AuthInfo->exp = exp;
AuthInfo->cbMaxToken = cbMaxToken;
+ AuthInfo->identity = identity;
+
+ /* duplicate the SEC_WINNT_AUTH_IDENTITY structure, if applicable, to
+ * enable better matching in RpcAuthInfo_IsEqual */
+ if (identity && has_nt_auth_identity(AuthnSvc))
+ {
+ const SEC_WINNT_AUTH_IDENTITY_W *nt_identity = identity;
+ AuthInfo->nt_identity = HeapAlloc(GetProcessHeap(), 0, sizeof(*AuthInfo->nt_identity));
+ if (!AuthInfo->nt_identity)
+ {
+ HeapFree(GetProcessHeap(), 0, AuthInfo);
+ return ERROR_OUTOFMEMORY;
+ }
+
+ AuthInfo->nt_identity->Flags = SEC_WINNT_AUTH_IDENTITY_UNICODE;
+ if (nt_identity->Flags & SEC_WINNT_AUTH_IDENTITY_UNICODE)
+ AuthInfo->nt_identity->User = RPCRT4_strndupW(nt_identity->User, nt_identity->UserLength);
+ else
+ AuthInfo->nt_identity->User = RPCRT4_strndupAtoW((const char *)nt_identity->User, nt_identity->UserLength);
+ AuthInfo->nt_identity->UserLength = nt_identity->UserLength;
+ if (nt_identity->Flags & SEC_WINNT_AUTH_IDENTITY_UNICODE)
+ AuthInfo->nt_identity->Domain = RPCRT4_strndupW(nt_identity->Domain, nt_identity->DomainLength);
+ else
+ AuthInfo->nt_identity->Domain = RPCRT4_strndupAtoW((const char *)nt_identity->Domain, nt_identity->DomainLength);
+ AuthInfo->nt_identity->DomainLength = nt_identity->DomainLength;
+ if (nt_identity->Flags & SEC_WINNT_AUTH_IDENTITY_UNICODE)
+ AuthInfo->nt_identity->Password = RPCRT4_strndupW(nt_identity->Password, nt_identity->PasswordLength);
+ else
+ AuthInfo->nt_identity->Password = RPCRT4_strndupAtoW((const char *)nt_identity->Password, nt_identity->PasswordLength);
+ AuthInfo->nt_identity->PasswordLength = nt_identity->PasswordLength;
+
+ if (!AuthInfo->nt_identity->User ||
+ !AuthInfo->nt_identity->Domain ||
+ !AuthInfo->nt_identity->Password)
+ {
+ HeapFree(GetProcessHeap(), 0, AuthInfo->nt_identity->User);
+ HeapFree(GetProcessHeap(), 0, AuthInfo->nt_identity->Domain);
+ HeapFree(GetProcessHeap(), 0, AuthInfo->nt_identity->Password);
+ HeapFree(GetProcessHeap(), 0, AuthInfo->nt_identity);
+ HeapFree(GetProcessHeap(), 0, AuthInfo);
+ return ERROR_OUTOFMEMORY;
+ }
+ }
+ else
+ AuthInfo->nt_identity = NULL;
*ret = AuthInfo;
return RPC_S_OK;
}
@@ -997,12 +1068,60 @@
if (!refs)
{
FreeCredentialsHandle(&AuthInfo->cred);
+ if (AuthInfo->nt_identity)
+ {
+ HeapFree(GetProcessHeap(), 0, AuthInfo->nt_identity->User);
+ HeapFree(GetProcessHeap(), 0, AuthInfo->nt_identity->Domain);
+ HeapFree(GetProcessHeap(), 0, AuthInfo->nt_identity->User);
+ HeapFree(GetProcessHeap(), 0, AuthInfo->nt_identity);
+ }
HeapFree(GetProcessHeap(), 0, AuthInfo);
}
return refs;
}
+BOOL RpcAuthInfo_IsEqual(const RpcAuthInfo *AuthInfo1, const RpcAuthInfo *AuthInfo2)
+{
+ if (AuthInfo1 == AuthInfo2)
+ return TRUE;
+
+ if (!AuthInfo1 || !AuthInfo2)
+ return FALSE;
+
+ if ((AuthInfo1->AuthnLevel != AuthInfo2->AuthnLevel) ||
+ (AuthInfo1->AuthnSvc != AuthInfo2->AuthnSvc))
+ return FALSE;
+
+ if (AuthInfo1->identity == AuthInfo2->identity)
+ return TRUE;
+
+ if (!AuthInfo1->identity || !AuthInfo2->identity)
+ return FALSE;
+
+ if (has_nt_auth_identity(AuthInfo1->AuthnSvc))
+ {
+ const SEC_WINNT_AUTH_IDENTITY_W *identity1 = AuthInfo1->nt_identity;
+ const SEC_WINNT_AUTH_IDENTITY_W *identity2 = AuthInfo2->nt_identity;
+ /* compare user names */
+ if (identity1->UserLength != identity2->UserLength ||
+ memcmp(identity1->User, identity2->User, identity1->UserLength))
+ return FALSE;
+ /* compare domain names */
+ if (identity1->DomainLength != identity2->DomainLength ||
+ memcmp(identity1->Domain, identity2->Domain, identity1->DomainLength))
+ return FALSE;
+ /* compare passwords */
+ if (identity1->PasswordLength != identity2->PasswordLength ||
+ memcmp(identity1->Password, identity2->Password, identity1->PasswordLength))
+ return FALSE;
+ }
+ else
+ return FALSE;
+
+ return TRUE;
+}
+
static RPC_STATUS RpcQualityOfService_Create(const RPC_SECURITY_QOS *qos_src, BOOL unicode, RpcQualityOfService **qos_dst)
{
RpcQualityOfService *qos = HeapAlloc(GetProcessHeap(), 0, sizeof(*qos));
@@ -1143,6 +1262,65 @@
return refs;
}
+BOOL RpcQualityOfService_IsEqual(const RpcQualityOfService *qos1, const RpcQualityOfService *qos2)
+{
+ if (qos1 == qos2)
+ return TRUE;
+
+ if (!qos1 || !qos2)
+ return FALSE;
+
+ TRACE("qos1 = { %ld %ld %ld %ld }, qos2 = { %ld %ld %ld %ld }\n",
+ qos1->qos->Capabilities, qos1->qos->IdentityTracking,
+ qos1->qos->ImpersonationType, qos1->qos->AdditionalSecurityInfoType,
+ qos2->qos->Capabilities, qos2->qos->IdentityTracking,
+ qos2->qos->ImpersonationType, qos2->qos->AdditionalSecurityInfoType);
+
+ if ((qos1->qos->Capabilities != qos2->qos->Capabilities) ||
+ (qos1->qos->IdentityTracking != qos2->qos->IdentityTracking) ||
+ (qos1->qos->ImpersonationType != qos2->qos->ImpersonationType) ||
+ (qos1->qos->AdditionalSecurityInfoType != qos2->qos->AdditionalSecurityInfoType))
+ return FALSE;
+
+ if (qos1->qos->AdditionalSecurityInfoType == RPC_C_AUTHN_INFO_TYPE_HTTP)
+ {
+ const RPC_HTTP_TRANSPORT_CREDENTIALS_W *http_credentials1 = qos1->qos->u.HttpCredentials;
+ const RPC_HTTP_TRANSPORT_CREDENTIALS_W *http_credentials2 = qos2->qos->u.HttpCredentials;
+
+ if (http_credentials1->Flags != http_credentials2->Flags)
+ return FALSE;
+
+ if (http_credentials1->AuthenticationTarget != http_credentials2->AuthenticationTarget)
+ return FALSE;
+
+ /* authentication schemes and server certificate subject not currently used */
+
+ if (http_credentials1->TransportCredentials != http_credentials2->TransportCredentials)
+ {
+ const SEC_WINNT_AUTH_IDENTITY_W *identity1 = http_credentials1->TransportCredentials;
+ const SEC_WINNT_AUTH_IDENTITY_W *identity2 = http_credentials2->TransportCredentials;
+
+ if (!identity1 || !identity2)
+ return FALSE;
+
+ /* compare user names */
+ if (identity1->UserLength != identity2->UserLength ||
+ memcmp(identity1->User, identity2->User, identity1->UserLength))
+ return FALSE;
+ /* compare domain names */
+ if (identity1->DomainLength != identity2->DomainLength ||
+ memcmp(identity1->Domain, identity2->Domain, identity1->DomainLength))
+ return FALSE;
+ /* compare passwords */
+ if (identity1->PasswordLength != identity2->PasswordLength ||
+ memcmp(identity1->Password, identity2->Password, identity1->PasswordLength))
+ return FALSE;
+ }
+ }
+
+ return TRUE;
+}
+
/***********************************************************************
* RpcRevertToSelf (RPCRT4.@)
*/
@@ -1317,7 +1495,7 @@
if (bind->AuthInfo) RpcAuthInfo_Release(bind->AuthInfo);
bind->AuthInfo = NULL;
r = RpcAuthInfo_Create(AuthnLevel, AuthnSvc, cred, exp, cbMaxToken,
- &bind->AuthInfo);
+ AuthIdentity, &bind->AuthInfo);
if (r != RPC_S_OK)
FreeCredentialsHandle(&cred);
return RPC_S_OK;
@@ -1433,7 +1611,7 @@
if (bind->AuthInfo) RpcAuthInfo_Release(bind->AuthInfo);
bind->AuthInfo = NULL;
r = RpcAuthInfo_Create(AuthnLevel, AuthnSvc, cred, exp, cbMaxToken,
- &bind->AuthInfo);
+ AuthIdentity, &bind->AuthInfo);
if (r != RPC_S_OK)
FreeCredentialsHandle(&cred);
return RPC_S_OK;
diff --git a/dlls/rpcrt4/rpc_binding.h b/dlls/rpcrt4/rpc_binding.h
index 127a2cf..69b4001 100644
--- a/dlls/rpcrt4/rpc_binding.h
+++ b/dlls/rpcrt4/rpc_binding.h
@@ -35,6 +35,11 @@
CredHandle cred;
TimeStamp exp;
ULONG cbMaxToken;
+ /* the auth identity pointer that the application passed us (freed by application) */
+ RPC_AUTH_IDENTITY_HANDLE *identity;
+ /* our copy of NT auth identity structure, if the authentication service
+ * takes an NT auth identity */
+ SEC_WINNT_AUTH_IDENTITY_W *nt_identity;
} RpcAuthInfo;
typedef struct _RpcQualityOfService
@@ -137,8 +142,10 @@
ULONG RpcAuthInfo_AddRef(RpcAuthInfo *AuthInfo);
ULONG RpcAuthInfo_Release(RpcAuthInfo *AuthInfo);
+BOOL RpcAuthInfo_IsEqual(const RpcAuthInfo *AuthInfo1, const RpcAuthInfo *AuthInfo2);
ULONG RpcQualityOfService_AddRef(RpcQualityOfService *qos);
ULONG RpcQualityOfService_Release(RpcQualityOfService *qos);
+BOOL RpcQualityOfService_IsEqual(const RpcQualityOfService *qos1, const RpcQualityOfService *qos2);
RPC_STATUS RPCRT4_GetAssociation(LPCSTR Protseq, LPCSTR NetworkAddr, LPCSTR Endpoint, LPCWSTR NetworkOptions, RpcAssoc **assoc);
RpcConnection *RpcAssoc_GetIdleConnection(RpcAssoc *assoc, const RPC_SYNTAX_IDENTIFIER *InterfaceId, const RPC_SYNTAX_IDENTIFIER *TransferSyntax, const RpcAuthInfo *AuthInfo, const RpcQualityOfService *QOS);
diff --git a/dlls/rpcrt4/rpc_transport.c b/dlls/rpcrt4/rpc_transport.c
index 1ff9cfd..b03d3ea 100644
--- a/dlls/rpcrt4/rpc_transport.c
+++ b/dlls/rpcrt4/rpc_transport.c
@@ -1485,10 +1485,10 @@
/* try to find a compatible connection from the connection pool */
EnterCriticalSection(&assoc->cs);
LIST_FOR_EACH_ENTRY(Connection, &assoc->connection_pool, RpcConnection, conn_pool_entry)
- if ((Connection->AuthInfo == AuthInfo) &&
- (Connection->QOS == QOS) &&
- !memcmp(&Connection->ActiveInterface, InterfaceId,
- sizeof(RPC_SYNTAX_IDENTIFIER)))
+ if (!memcmp(&Connection->ActiveInterface, InterfaceId,
+ sizeof(RPC_SYNTAX_IDENTIFIER)) &&
+ RpcAuthInfo_IsEqual(Connection->AuthInfo, AuthInfo) &&
+ RpcQualityOfService_IsEqual(Connection->QOS, QOS))
{
list_remove(&Connection->conn_pool_entry);
LeaveCriticalSection(&assoc->cs);