windows-nt/Source/XPSP1/NT/sdktools/debuggers/dbgrpc/schan.cpp

1790 lines
48 KiB
C++
Raw Normal View History

2020-09-26 03:20:57 -05:00
//----------------------------------------------------------------------------
//
// Secure channel support.
// Code lifted from the SDK sample security\ssl.
//
// Copyright (C) Microsoft Corporation, 2000.
//
//----------------------------------------------------------------------------
#include "pch.hpp"
HMODULE g_hSecurity;
SecurityFunctionTable g_SecurityFunc;
HCERTSTORE g_hMyCertStore;
enum
{
SEQ_INTERNAL = 0xffffff00
};
//----------------------------------------------------------------------------
//
// Basic schannel support functions.
//
//----------------------------------------------------------------------------
void
DbgDumpBuffers(PCSTR Name, SecBufferDesc* Desc)
{
#if 0
ULONG i;
g_NtDllCalls.DbgPrint("%s desc %p has %d buffers\n",
Name, Desc, Desc->cBuffers);
for (i = 0; i < Desc->cBuffers; i++)
{
g_NtDllCalls.DbgPrint(" type %d, %X bytes at %p\n",
Desc->pBuffers[i].BufferType,
Desc->pBuffers[i].cbBuffer,
Desc->pBuffers[i].pvBuffer);
}
#endif
}
#if 0
#define DSCHAN(Args) g_NtDllCalls.DbgPrint Args
#define DumpBuffers(Name, Desc) DbgDumpBuffers(Name, Desc)
#else
#define DSCHAN(Args)
#define DumpBuffers(Name, Desc)
#endif
#if 0
#define DSCHAN_IO(Args) g_NtDllCalls.DbgPrint Args
#define DumpBuffersIo(Name, Desc) DbgDumpBuffers(Name, Desc)
#else
#define DSCHAN_IO(Args)
#define DumpBuffersIo(Name, Desc)
#endif
HRESULT
LoadSecurityLibrary(void)
{
HRESULT Status;
if ((Status = InitDynamicCalls(&g_Crypt32CallsDesc)) != S_OK)
{
return Status;
}
PSecurityFunctionTable pSecurityFunc;
INIT_SECURITY_INTERFACE pInitSecurityInterface;
if (g_hSecurity != NULL)
{
// Already loaded.
return S_OK;
}
if (g_Crypt32Calls.CertOpenStore == NULL)
{
// Unable to load crypt32.dll.
return HRESULT_FROM_WIN32(ERROR_MOD_NOT_FOUND);
}
g_hSecurity = LoadLibrary("security.dll");
if (g_hSecurity == NULL)
{
goto EH_Fail;
}
pInitSecurityInterface = (INIT_SECURITY_INTERFACE)
GetProcAddress(g_hSecurity, "InitSecurityInterfaceA");
if (pInitSecurityInterface == NULL)
{
goto EH_Dll;
}
pSecurityFunc = pInitSecurityInterface();
if (pSecurityFunc == NULL)
{
goto EH_Dll;
}
memcpy(&g_SecurityFunc, pSecurityFunc, sizeof(g_SecurityFunc));
return S_OK;
EH_Dll:
FreeLibrary(g_hSecurity);
g_hSecurity = NULL;
EH_Fail:
return WIN32_LAST_STATUS();
}
HRESULT
CreateCredentials(LPSTR pszUserName,
BOOL fMachineStore,
BOOL Server,
ULONG dwProtocol,
SCHANNEL_CRED* ScCreds,
PCredHandle phCreds)
{
TimeStamp tsExpiry;
SECURITY_STATUS Status;
PCCERT_CONTEXT pCertContext = NULL;
// Open the "MY" certificate store.
if (g_hMyCertStore == NULL)
{
if (fMachineStore)
{
g_hMyCertStore = g_Crypt32Calls.
CertOpenStore(CERT_STORE_PROV_SYSTEM,
X509_ASN_ENCODING,
0,
CERT_SYSTEM_STORE_LOCAL_MACHINE,
L"MY");
}
else
{
g_hMyCertStore = g_Crypt32Calls.
CertOpenSystemStore(0, "MY");
}
if (!g_hMyCertStore)
{
Status = WIN32_LAST_STATUS();
goto Exit;
}
}
//
// If a user name is specified, then attempt to find a client
// certificate. Otherwise, just create a NULL credential.
//
if (pszUserName != NULL && *pszUserName)
{
// Find certificate. Note that this sample just searches for a
// certificate that contains the user name somewhere in the subject
// name. A real application should be a bit less casual.
pCertContext = g_Crypt32Calls.
CertFindCertificateInStore(g_hMyCertStore,
X509_ASN_ENCODING,
0,
CERT_FIND_SUBJECT_STR_A,
pszUserName,
NULL);
if (pCertContext == NULL)
{
Status = WIN32_LAST_STATUS();
goto Exit;
}
}
//
// Build Schannel credential structure. Currently, this sample only
// specifies the protocol to be used (and optionally the certificate,
// of course). Real applications may wish to specify other parameters
// as well.
//
ZeroMemory(ScCreds, sizeof(*ScCreds));
ScCreds->dwVersion = SCHANNEL_CRED_VERSION;
if (pCertContext != NULL)
{
ScCreds->cCreds = 1;
ScCreds->paCred = &pCertContext;
}
ScCreds->grbitEnabledProtocols = dwProtocol;
if (!Server)
{
if (pCertContext != NULL)
{
ScCreds->dwFlags |= SCH_CRED_NO_DEFAULT_CREDS;
}
else
{
ScCreds->dwFlags |= SCH_CRED_USE_DEFAULT_CREDS;
}
ScCreds->dwFlags |= SCH_CRED_MANUAL_CRED_VALIDATION;
}
//
// Create an SSPI credential.
//
//
// NOTE: In theory, an application could enumerate the security packages
// until it finds one with attributes it likes. Some applications
// (such as IIS) enumerate the packages and call AcquireCredentialsHandle
// on each until it finds one that accepts the SCHANNEL_CRED structure.
// If an application has its heart set on using SSL, like this sample
// does, then just hardcoding the UNISP_NAME package name when calling
// AcquireCredentialsHandle is not a bad thing.
//
Status = g_SecurityFunc.AcquireCredentialsHandle(
NULL, // Name of principal
UNISP_NAME_A, // Name of package
Server ? // Flags indicating use
SECPKG_CRED_INBOUND :
SECPKG_CRED_OUTBOUND,
NULL, // Pointer to logon ID
ScCreds, // Package specific data
NULL, // Pointer to GetKey() func
NULL, // Value to pass to GetKey()
phCreds, // (out) Cred Handle
&tsExpiry); // (out) Lifetime (optional)
//
// Free the certificate context. Schannel has already made its own copy.
//
if (pCertContext)
{
g_Crypt32Calls.CertFreeCertificateContext(pCertContext);
}
Exit:
DSCHAN(("CreateCredentials returns %X\n", Status));
return Status;
}
HRESULT
VerifyRemoteCertificate(PCtxtHandle Context,
PSTR pszServerName,
DWORD dwCertFlags)
{
SSL_EXTRA_CERT_CHAIN_POLICY_PARA SslPara;
CERT_CHAIN_POLICY_PARA PolicyPara;
CERT_CHAIN_POLICY_STATUS PolicyStatus;
CERT_CHAIN_PARA ChainPara;
PCCERT_CHAIN_CONTEXT pChain = NULL;
PCCERT_CONTEXT pCert = NULL;
HRESULT Status;
PWSTR pwszServerName;
DWORD cchServerName;
// Read the remote certificate.
if ((Status = g_SecurityFunc.
QueryContextAttributes(Context,
SECPKG_ATTR_REMOTE_CERT_CONTEXT,
&pCert)) != S_OK)
{
goto Exit;
}
if (pCert == NULL)
{
Status = SEC_E_WRONG_PRINCIPAL;
goto EH_Cert;
}
if (pszServerName != NULL && *pszServerName)
{
//
// Convert server name to unicode.
//
cchServerName = MultiByteToWideChar(CP_ACP, 0, pszServerName,
-1, NULL, 0);
pwszServerName = (PWSTR)
LocalAlloc(LMEM_FIXED, cchServerName * sizeof(WCHAR));
if (pwszServerName == NULL)
{
Status = SEC_E_INSUFFICIENT_MEMORY;
goto EH_Cert;
}
cchServerName = MultiByteToWideChar(CP_ACP, 0, pszServerName,
-1, pwszServerName, cchServerName);
if (cchServerName == 0)
{
Status = SEC_E_WRONG_PRINCIPAL;
goto EH_Name;
}
}
else
{
pwszServerName = NULL;
}
//
// Build certificate chain.
//
ZeroMemory(&ChainPara, sizeof(ChainPara));
ChainPara.cbSize = sizeof(ChainPara);
if (!g_Crypt32Calls.CertGetCertificateChain(NULL,
pCert,
NULL,
pCert->hCertStore,
&ChainPara,
0,
NULL,
&pChain))
{
Status = WIN32_LAST_STATUS();
goto EH_Name;
}
//
// Validate certificate chain.
//
ZeroMemory(&SslPara, sizeof(SslPara));
SslPara.cbStruct = sizeof(SslPara);
SslPara.dwAuthType = pwszServerName == NULL ?
AUTHTYPE_CLIENT : AUTHTYPE_SERVER;
SslPara.fdwChecks = dwCertFlags;
SslPara.pwszServerName = pwszServerName;
ZeroMemory(&PolicyPara, sizeof(PolicyPara));
PolicyPara.cbSize = sizeof(PolicyPara);
PolicyPara.pvExtraPolicyPara = &SslPara;
ZeroMemory(&PolicyStatus, sizeof(PolicyStatus));
PolicyStatus.cbSize = sizeof(PolicyStatus);
if (!g_Crypt32Calls.CertVerifyCertificateChainPolicy(CERT_CHAIN_POLICY_SSL,
pChain,
&PolicyPara,
&PolicyStatus))
{
Status = WIN32_LAST_STATUS();
goto EH_Chain;
}
if (PolicyStatus.dwError)
{
Status = PolicyStatus.dwError;
}
else
{
Status = S_OK;
}
EH_Chain:
g_Crypt32Calls.CertFreeCertificateChain(pChain);
EH_Name:
if (pwszServerName != NULL)
{
LocalFree(pwszServerName);
}
EH_Cert:
g_Crypt32Calls.CertFreeCertificateContext(pCert);
Exit:
DSCHAN(("VerifyRemoteCertificate returns %X\n", Status));
return Status;
}
//----------------------------------------------------------------------------
//
// Schannel wrapper transport.
//
//----------------------------------------------------------------------------
#define SecHandleIsValid(Handle) \
((Handle)->dwLower != -1 || (Handle)->dwUpper != -1)
DbgRpcSecureChannelTransport::
DbgRpcSecureChannelTransport(ULONG ThisTransport,
ULONG BaseTransport)
{
m_Name = g_DbgRpcTransportNames[ThisTransport];
m_ThisTransport = ThisTransport;
m_BaseTransport = BaseTransport;
m_Stream = NULL;
SecInvalidateHandle(&m_Creds);
m_OwnCreds = FALSE;
SecInvalidateHandle(&m_Context);
m_OwnContext = FALSE;
m_BufferUsed = 0;
m_Server = FALSE;
}
DbgRpcSecureChannelTransport::~DbgRpcSecureChannelTransport(void)
{
if (SecHandleIsValid(&m_Context))
{
if (m_Server)
{
DisconnectFromClient();
}
else
{
DisconnectFromServer();
}
}
delete m_Stream;
if (m_OwnContext && SecHandleIsValid(&m_Context))
{
g_SecurityFunc.DeleteSecurityContext(&m_Context);
}
if (m_OwnCreds && SecHandleIsValid(&m_Creds))
{
g_SecurityFunc.FreeCredentialsHandle(&m_Creds);
}
}
ULONG
DbgRpcSecureChannelTransport::GetNumberParameters(void)
{
return 2 + (m_Stream != NULL ? m_Stream->GetNumberParameters() : 0);
}
void
DbgRpcSecureChannelTransport::GetParameter(ULONG Index, PSTR Name, PSTR Value)
{
switch(Index)
{
case 0:
if (m_Protocol)
{
strcpy(Name, "Proto");
switch(m_Protocol)
{
case SP_PROT_PCT1:
strcpy(Name, "PCT1");
break;
case SP_PROT_SSL2:
strcpy(Name, "SSL2");
break;
case SP_PROT_SSL3:
strcpy(Name, "SSL3");
break;
case SP_PROT_TLS1:
strcpy(Name, "TLS1");
break;
}
}
break;
case 1:
if (m_User[0])
{
strcpy(Name, m_MachineStore ? "MachUser" : "CertUser");
strcpy(Value, m_User);
}
break;
default:
if (m_Stream != NULL)
{
m_Stream->GetParameter(Index - 2, Name, Value);
}
break;
}
}
void
DbgRpcSecureChannelTransport::ResetParameters(void)
{
m_Protocol = 0;
m_User[0] = 0;
m_MachineStore = FALSE;
if (m_Stream == NULL)
{
m_Stream = DbgRpcNewTransport(m_BaseTransport);
}
if (m_Stream != NULL)
{
m_Stream->ResetParameters();
}
}
BOOL
DbgRpcSecureChannelTransport::SetParameter(PCSTR Name, PCSTR Value)
{
if (m_Stream == NULL)
{
// Force all initialization to fail.
return FALSE;
}
if (!_stricmp(Name, "proto"))
{
if (Value == NULL)
{
DbgRpcError("%s parameters: "
"the protocol name was not specified correctly\n",
m_Name);
return FALSE;
}
if (!_stricmp(Value, "pct1"))
{
m_Protocol = SP_PROT_PCT1;
}
else if (!_stricmp(Value, "ssl2"))
{
m_Protocol = SP_PROT_SSL2;
}
else if (!_stricmp(Value, "ssl3"))
{
m_Protocol = SP_PROT_SSL3;
}
else if (!_stricmp(Value, "tls1"))
{
m_Protocol = SP_PROT_TLS1;
}
else
{
DbgRpcError("%s parameters: unknown protocol '%s'\n", Value,
m_Name);
return FALSE;
}
}
else if (!_stricmp(Name, "machuser"))
{
if (Value == NULL)
{
DbgRpcError("%s parameters: "
"the user name was not specified correctly\n",
m_Name);
return FALSE;
}
m_User[0] = 0;
strncat(m_User, Value, sizeof(m_User) - 1);
m_MachineStore = TRUE;
}
else if (!_stricmp(Name, "certuser"))
{
if (Value == NULL)
{
DbgRpcError("%s parameters: "
"the user name was not specified correctly\n",
m_Name);
return FALSE;
}
m_User[0] = 0;
strncat(m_User, Value, sizeof(m_User) - 1);
m_MachineStore = FALSE;
}
else
{
if (!m_Stream->SetParameter(Name, Value))
{
return FALSE;
}
}
return TRUE;
}
DbgRpcTransport*
DbgRpcSecureChannelTransport::Clone(void)
{
DbgRpcTransport* Stream = m_Stream->Clone();
if (Stream == NULL)
{
return NULL;
}
DbgRpcSecureChannelTransport* Trans =
new DbgRpcSecureChannelTransport(m_ThisTransport, m_BaseTransport);
if (Trans != NULL)
{
Trans->m_Stream = Stream;
Trans->m_Creds = m_Creds;
Trans->m_OwnCreds = FALSE;
Trans->m_Context = m_Context;
Trans->m_OwnContext = FALSE;
Trans->m_Protocol = m_Protocol;
strcpy(Trans->m_User, m_User);
Trans->m_MachineStore = m_MachineStore;
Trans->m_Sizes = m_Sizes;
Trans->m_MaxChunk = m_MaxChunk;
Trans->m_Server = m_Server;
}
else
{
delete Stream;
}
return Trans;
}
HRESULT
DbgRpcSecureChannelTransport::CreateServer(void)
{
HRESULT Status;
if ((Status = LoadSecurityLibrary()) != S_OK)
{
return Status;
}
if ((Status = CreateCredentials(m_User, m_MachineStore, TRUE,
m_Protocol, &m_ScCreds, &m_Creds)) != S_OK)
{
return Status;
}
m_OwnCreds = TRUE;
if ((Status = m_Stream->CreateServer()) != S_OK)
{
return Status;
}
m_Server = TRUE;
return S_OK;
}
HRESULT
DbgRpcSecureChannelTransport::AcceptConnection(DbgRpcTransport** ClientTrans,
PSTR Identity)
{
HRESULT Status;
DbgRpcTransport* Stream;
if ((Status = m_Stream->AcceptConnection(&Stream, Identity)) != S_OK)
{
return Status;
}
DbgRpcSecureChannelTransport* Trans =
new DbgRpcSecureChannelTransport(m_ThisTransport, m_BaseTransport);
if (Trans == NULL)
{
delete Stream;
return E_OUTOFMEMORY;
}
Trans->m_Stream = Stream;
Trans->m_Creds = m_Creds;
Trans->m_OwnCreds = FALSE;
Trans->m_Server = TRUE;
if ((Status = Trans->AuthenticateClientConnection()) != S_OK)
{
goto EH_Trans;
}
if ((Status = Trans->GetSizes()) != S_OK)
{
goto EH_Trans;
}
// Attempt to validate client certificate.
if ((Status = VerifyRemoteCertificate(&Trans->m_Context, NULL, 0)) != S_OK)
{
goto EH_Trans;
}
*ClientTrans = Trans;
return S_OK;
EH_Trans:
delete Trans;
return Status;
}
HRESULT
DbgRpcSecureChannelTransport::ConnectServer(void)
{
HRESULT Status = m_Stream->ConnectServer();
if (Status != S_OK)
{
return Status;
}
if ((Status = LoadSecurityLibrary()) != S_OK)
{
return Status;
}
if ((Status = CreateCredentials(m_User, m_MachineStore, FALSE,
m_Protocol, &m_ScCreds, &m_Creds)) != S_OK)
{
return Status;
}
m_OwnCreds = TRUE;
if ((Status = InitiateServerConnection(m_Stream->m_ServerName)) != S_OK)
{
return Status;
}
if ((Status = AuthenticateServerConnection()) != S_OK)
{
return Status;
}
if ((Status = GetSizes()) != S_OK)
{
return Status;
}
// Attempt to validate server certificate.
if ((Status = VerifyRemoteCertificate(&m_Context,
m_Stream->m_ServerName, 0)) != S_OK)
{
// If this fails with CERT_E_CN_NO_MATCH it's most
// likely that the server name wasn't given as a fully
// qualified machine name. We may just want to ignore that error.
return Status;
}
return S_OK;
}
ULONG
DbgRpcSecureChannelTransport::Read(ULONG Seq, PVOID Buffer, ULONG Len)
{
SecBufferDesc Message;
SecBuffer Buffers[4];
DWORD Status;
ULONG Complete;
DSCHAN_IO(("Start read(%X) with %X bytes cached\n",
Len, m_BufferUsed));
//
// Initialize security buffer structs
//
Message.ulVersion = SECBUFFER_VERSION;
Message.cBuffers = 4;
Message.pBuffers = Buffers;
//
// Receive the data from the client.
//
Complete = 0;
while (Complete < Len)
{
do
{
// Pass in the data we have so far.
Buffers[0].pvBuffer = m_Buffer;
Buffers[0].cbBuffer = m_BufferUsed;
Buffers[0].BufferType = SECBUFFER_DATA;
// Provide extra buffers for header, trailer
// and possibly extra data.
Buffers[1].BufferType = SECBUFFER_EMPTY;
Buffers[2].BufferType = SECBUFFER_EMPTY;
Buffers[3].BufferType = SECBUFFER_EMPTY;
Status = g_SecurityFunc.DecryptMessage(&m_Context, &Message,
Seq, NULL);
DSCHAN_IO(("Read DecryptMessage on %X bytes returns %X\n",
m_BufferUsed, Status));
DumpBuffersIo("Read", &Message);
if (Status == SEC_E_INCOMPLETE_MESSAGE)
{
DSCHAN_IO((" Missing %X bytes\n", Buffers[1].cbBuffer));
ULONG Read = StreamRead(Seq, m_Buffer + m_BufferUsed,
sizeof(m_Buffer) - m_BufferUsed);
if (Read == 0)
{
return Complete;
}
m_BufferUsed += Read;
}
else if (Status == SEC_I_RENEGOTIATE)
{
// The server wants to perform another handshake
// sequence.
if ((Status = AuthenticateServerConnection()) != S_OK)
{
break;
}
}
}
while (Status == SEC_E_INCOMPLETE_MESSAGE);
if (Status != S_OK)
{
break;
}
// Buffers 0,1,2 should be header, data, trailer.
DBG_ASSERT(Buffers[1].BufferType == SECBUFFER_DATA);
DSCHAN_IO((" %X bytes of %X read\n",
Buffers[1].cbBuffer, Len));
memcpy((PUCHAR)Buffer + Complete,
Buffers[1].pvBuffer, Buffers[1].cbBuffer);
Complete += Buffers[1].cbBuffer;
// Check for extra data in buffer 3.
if (Buffers[3].BufferType == SECBUFFER_EXTRA)
{
DSCHAN_IO((" %X bytes extra\n"));
memmove(m_Buffer, Buffers[3].pvBuffer, Buffers[3].cbBuffer);
m_BufferUsed = Buffers[3].cbBuffer;
}
else
{
m_BufferUsed = 0;
}
}
DSCHAN_IO((" Read returns %X bytes\n", Complete));
return Complete;
}
ULONG
DbgRpcSecureChannelTransport::Write(ULONG Seq, PVOID Buffer, ULONG Len)
{
SecBufferDesc Message;
SecBuffer Buffers[3];
DWORD Status;
ULONG Complete;
DSCHAN_IO(("Start write(%X) with %X bytes cached\n",
Len, m_BufferUsed));
Message.ulVersion = SECBUFFER_VERSION;
Message.cBuffers = 3;
Message.pBuffers = Buffers;
Complete = 0;
while (Complete < Len)
{
ULONG Chunk;
//
// Set up header, data and trailer buffers so
// that EncryptMessage has room for everything
// in one contiguous buffer.
//
Buffers[0].pvBuffer = m_Buffer + m_BufferUsed;
Buffers[0].cbBuffer = m_Sizes.cbHeader;
Buffers[0].BufferType = SECBUFFER_STREAM_HEADER;
//
// Data is encrypted in-place so copy data
// from the user's buffer into the working buffer.
// Part of the working buffer may be taken up
// by queued data so work with what's left.
//
if (Len > m_MaxChunk - m_BufferUsed)
{
Chunk = m_MaxChunk - m_BufferUsed;
}
else
{
Chunk = Len;
}
DSCHAN_IO((" write %X bytes of %X\n", Chunk, Len));
Buffers[1].pvBuffer =
(PUCHAR)Buffers[0].pvBuffer + Buffers[0].cbBuffer;
Buffers[1].cbBuffer = Chunk;
Buffers[1].BufferType = SECBUFFER_DATA;
memcpy(Buffers[1].pvBuffer, (PUCHAR)Buffer + Complete, Chunk);
Buffers[2].pvBuffer =
(PUCHAR)Buffers[1].pvBuffer + Buffers[1].cbBuffer;
Buffers[2].cbBuffer = m_Sizes.cbTrailer;
Buffers[2].BufferType = SECBUFFER_STREAM_TRAILER;
Status = g_SecurityFunc.EncryptMessage(&m_Context, 0, &Message, Seq);
if (Status != S_OK)
{
break;
}
DumpBuffersIo("Write encrypt", &Message);
ULONG Total, Written;
Total = Buffers[0].cbBuffer + Buffers[1].cbBuffer +
Buffers[2].cbBuffer;
Written = StreamWrite(Seq, Buffers[0].pvBuffer, Total);
if (Written != Total)
{
break;
}
Complete += Chunk;
}
DSCHAN_IO((" Write returns %X bytes\n", Complete));
return Complete;
}
HRESULT
DbgRpcSecureChannelTransport::GetSizes(void)
{
HRESULT Status;
//
// Find out how big the header will be:
//
if ((Status = g_SecurityFunc.
QueryContextAttributes(&m_Context, SECPKG_ATTR_STREAM_SIZES,
&m_Sizes)) != S_OK)
{
return Status;
}
// Compute the largest chunk that can be encrypted at
// once in the transport's data buffer.
m_MaxChunk = sizeof(m_Buffer) - (m_Sizes.cbHeader + m_Sizes.cbTrailer);
if (m_MaxChunk > m_Sizes.cbMaximumMessage)
{
m_MaxChunk = m_Sizes.cbMaximumMessage;
}
return S_OK;
}
HRESULT
DbgRpcSecureChannelTransport::AuthenticateClientConnection(void)
{
TimeStamp tsExpiry;
SECURITY_STATUS Status;
SecBufferDesc InBuffer;
SecBufferDesc OutBuffer;
SecBuffer InBuffers[2];
SecBuffer OutBuffers[1];
BOOL fInitContext = TRUE;
DWORD dwSSPIFlags, dwSSPIOutFlags;
ULONG Seq;
Status = SEC_E_SECPKG_NOT_FOUND; //default error if we run out of packages
dwSSPIFlags = ASC_REQ_SEQUENCE_DETECT |
ASC_REQ_REPLAY_DETECT |
ASC_REQ_CONFIDENTIALITY |
ASC_REQ_EXTENDED_ERROR |
ASC_REQ_ALLOCATE_MEMORY |
ASC_REQ_STREAM |
ASC_REQ_MUTUAL_AUTH;
//
// Set buffers for AcceptSecurityContext call
//
InBuffer.cBuffers = 2;
InBuffer.pBuffers = InBuffers;
InBuffer.ulVersion = SECBUFFER_VERSION;
OutBuffer.cBuffers = 1;
OutBuffer.pBuffers = OutBuffers;
OutBuffer.ulVersion = SECBUFFER_VERSION;
Status = SEC_I_CONTINUE_NEEDED;
m_BufferUsed = 0;
while ( Status == SEC_I_CONTINUE_NEEDED ||
Status == SEC_E_INCOMPLETE_MESSAGE ||
Status == SEC_I_INCOMPLETE_CREDENTIALS)
{
if (0 == m_BufferUsed || Status == SEC_E_INCOMPLETE_MESSAGE)
{
ULONG Read = StreamRead(SEQ_INTERNAL, m_Buffer + m_BufferUsed,
sizeof(m_Buffer) - m_BufferUsed);
if (Read == 0)
{
Status = HRESULT_FROM_WIN32(ERROR_READ_FAULT);
goto Exit;
}
else
{
m_BufferUsed += Read;
}
}
//
// InBuffers[1] is for getting extra data that
// SSPI/SCHANNEL doesn't proccess on this
// run around the loop.
//
InBuffers[0].pvBuffer = m_Buffer;
InBuffers[0].cbBuffer = m_BufferUsed;
InBuffers[0].BufferType = SECBUFFER_TOKEN;
InBuffers[1].pvBuffer = NULL;
InBuffers[1].cbBuffer = 0;
InBuffers[1].BufferType = SECBUFFER_EMPTY;
//
// Initialize these so if we fail, pvBuffer contains NULL,
// so we don't try to free random garbage at the quit
//
OutBuffers[0].pvBuffer = NULL;
OutBuffers[0].cbBuffer = 0;
OutBuffers[0].BufferType = SECBUFFER_TOKEN;
Status = g_SecurityFunc.AcceptSecurityContext(
&m_Creds,
(fInitContext ? NULL : &m_Context),
&InBuffer,
dwSSPIFlags,
SECURITY_NATIVE_DREP,
(fInitContext ? &m_Context : NULL),
&OutBuffer,
&dwSSPIOutFlags,
&tsExpiry);
DSCHAN(("ASC on %X bytes returns %X\n",
m_BufferUsed, Status));
DumpBuffers("ASC in", &InBuffer);
DumpBuffers("ASC out", &OutBuffer);
if (SUCCEEDED(Status))
{
fInitContext = FALSE;
m_OwnContext = TRUE;
}
if ( Status == SEC_E_OK ||
Status == SEC_I_CONTINUE_NEEDED ||
(FAILED(Status) &&
(0 != (dwSSPIOutFlags & ISC_RET_EXTENDED_ERROR))))
{
if (OutBuffers[0].cbBuffer != 0 &&
OutBuffers[0].pvBuffer != NULL )
{
ULONG Written;
DSCHAN((" write back %X bytes\n", OutBuffers[0].cbBuffer));
//
// Send response to server if there is one
//
Written = StreamWrite(SEQ_INTERNAL, OutBuffers[0].pvBuffer,
OutBuffers[0].cbBuffer);
g_SecurityFunc.FreeContextBuffer( OutBuffers[0].pvBuffer );
OutBuffers[0].pvBuffer = NULL;
if (Written != OutBuffers[0].cbBuffer)
{
Status = HRESULT_FROM_WIN32(ERROR_WRITE_FAULT);
goto Exit;
}
}
}
if ( Status == SEC_E_OK )
{
if ( InBuffers[1].BufferType == SECBUFFER_EXTRA )
{
DSCHAN_IO((" ASC returns with %X extra bytes\n",
InBuffers[1].cbBuffer));
memmove(m_Buffer,
m_Buffer + (m_BufferUsed - InBuffers[1].cbBuffer),
InBuffers[1].cbBuffer);
m_BufferUsed = InBuffers[1].cbBuffer;
}
else
{
m_BufferUsed = 0;
}
goto Exit;
}
else if (FAILED(Status) && (Status != SEC_E_INCOMPLETE_MESSAGE))
{
goto Exit;
}
if ( Status != SEC_E_INCOMPLETE_MESSAGE &&
Status != SEC_I_INCOMPLETE_CREDENTIALS)
{
if ( InBuffers[1].BufferType == SECBUFFER_EXTRA )
{
DSCHAN_IO((" ASC loops with %X extra bytes\n",
InBuffers[1].cbBuffer));
memmove(m_Buffer,
m_Buffer + (m_BufferUsed - InBuffers[1].cbBuffer),
InBuffers[1].cbBuffer);
m_BufferUsed = InBuffers[1].cbBuffer;
}
else
{
//
// prepare for next receive
//
m_BufferUsed = 0;
}
}
}
Exit:
DSCHAN(("AuthClient returns %X\n", Status));
return Status;
}
HRESULT
DbgRpcSecureChannelTransport::InitiateServerConnection(LPSTR pszServerName)
{
SecBufferDesc OutBuffer;
SecBuffer OutBuffers[1];
DWORD dwSSPIFlags;
DWORD dwSSPIOutFlags;
TimeStamp tsExpiry;
SECURITY_STATUS Status;
DWORD cbData;
dwSSPIFlags = ISC_REQ_SEQUENCE_DETECT |
ISC_REQ_REPLAY_DETECT |
ISC_REQ_CONFIDENTIALITY |
ISC_REQ_EXTENDED_ERROR |
ISC_REQ_ALLOCATE_MEMORY |
ISC_REQ_STREAM |
ISC_REQ_MUTUAL_AUTH;
//
// Initiate a ClientHello message and generate a token.
//
OutBuffers[0].pvBuffer = NULL;
OutBuffers[0].BufferType = SECBUFFER_TOKEN;
OutBuffers[0].cbBuffer = 0;
OutBuffer.cBuffers = 1;
OutBuffer.pBuffers = OutBuffers;
OutBuffer.ulVersion = SECBUFFER_VERSION;
Status = g_SecurityFunc.InitializeSecurityContextA(
&m_Creds,
NULL,
pszServerName,
dwSSPIFlags,
0,
SECURITY_NATIVE_DREP,
NULL,
0,
&m_Context,
&OutBuffer,
&dwSSPIOutFlags,
&tsExpiry);
DSCHAN(("First ISC returns %X\n", Status));
DumpBuffers("First ISC out", &OutBuffer);
if (Status != SEC_I_CONTINUE_NEEDED)
{
goto Exit;
}
m_OwnContext = TRUE;
// Send response to server if there is one.
if (OutBuffers[0].cbBuffer != 0 && OutBuffers[0].pvBuffer != NULL)
{
DSCHAN((" write back %X bytes\n", OutBuffers[0].cbBuffer));
cbData = StreamWrite(SEQ_INTERNAL, OutBuffers[0].pvBuffer,
OutBuffers[0].cbBuffer);
if(cbData == 0)
{
g_SecurityFunc.FreeContextBuffer(OutBuffers[0].pvBuffer);
if (m_OwnContext)
{
g_SecurityFunc.DeleteSecurityContext(&m_Context);
SecInvalidateHandle(&m_Context);
}
Status = HRESULT_FROM_WIN32(ERROR_WRITE_FAULT);
goto Exit;
}
// Free output buffer.
g_SecurityFunc.FreeContextBuffer(OutBuffers[0].pvBuffer);
OutBuffers[0].pvBuffer = NULL;
}
Status = S_OK;
Exit:
DSCHAN(("InitServer returns %X\n", Status));
return Status;
}
HRESULT
DbgRpcSecureChannelTransport::AuthenticateServerConnection(void)
{
SecBufferDesc InBuffer;
SecBuffer InBuffers[2];
SecBufferDesc OutBuffer;
SecBuffer OutBuffers[1];
DWORD dwSSPIFlags;
DWORD dwSSPIOutFlags;
TimeStamp tsExpiry;
SECURITY_STATUS Status;
DWORD cbData;
ULONG ReadNeeded;
dwSSPIFlags = ISC_REQ_SEQUENCE_DETECT |
ISC_REQ_REPLAY_DETECT |
ISC_REQ_CONFIDENTIALITY |
ISC_REQ_EXTENDED_ERROR |
ISC_REQ_ALLOCATE_MEMORY |
ISC_REQ_STREAM;
m_BufferUsed = 0;
ReadNeeded = 1;
//
// Loop until the handshake is finished or an error occurs.
//
Status = SEC_I_CONTINUE_NEEDED;
while(Status == SEC_I_CONTINUE_NEEDED ||
Status == SEC_E_INCOMPLETE_MESSAGE ||
Status == SEC_I_INCOMPLETE_CREDENTIALS)
{
//
// Read data from server.
//
if (0 == m_BufferUsed || Status == SEC_E_INCOMPLETE_MESSAGE)
{
if (ReadNeeded > 0)
{
cbData = StreamRead(SEQ_INTERNAL, m_Buffer + m_BufferUsed,
sizeof(m_Buffer) - m_BufferUsed);
if(cbData == 0)
{
Status = HRESULT_FROM_WIN32(ERROR_READ_FAULT);
break;
}
m_BufferUsed += cbData;
}
else
{
ReadNeeded = 1;
}
}
//
// Set up the input buffers. Buffer 0 is used to pass in data
// received from the server. Schannel will consume some or all
// of this. Leftover data (if any) will be placed in buffer 1 and
// given a buffer type of SECBUFFER_EXTRA.
//
InBuffers[0].pvBuffer = m_Buffer;
InBuffers[0].cbBuffer = m_BufferUsed;
InBuffers[0].BufferType = SECBUFFER_TOKEN;
InBuffers[1].pvBuffer = NULL;
InBuffers[1].cbBuffer = 0;
InBuffers[1].BufferType = SECBUFFER_EMPTY;
InBuffer.cBuffers = 2;
InBuffer.pBuffers = InBuffers;
InBuffer.ulVersion = SECBUFFER_VERSION;
//
// Set up the output buffers. These are initialized to NULL
// so as to make it less likely we'll attempt to free random
// garbage later.
//
OutBuffers[0].pvBuffer = NULL;
OutBuffers[0].BufferType= SECBUFFER_TOKEN;
OutBuffers[0].cbBuffer = 0;
OutBuffer.cBuffers = 1;
OutBuffer.pBuffers = OutBuffers;
OutBuffer.ulVersion = SECBUFFER_VERSION;
//
// Call InitializeSecurityContext.
//
Status = g_SecurityFunc.InitializeSecurityContextA(
&m_Creds,
&m_Context,
NULL,
dwSSPIFlags,
0,
SECURITY_NATIVE_DREP,
&InBuffer,
0,
NULL,
&OutBuffer,
&dwSSPIOutFlags,
&tsExpiry);
DSCHAN(("ISC on %X bytes returns %X\n",
m_BufferUsed, Status));
DumpBuffers("ISC in", &InBuffer);
DumpBuffers("ISC out", &OutBuffer);
//
// If InitializeSecurityContext was successful (or if the error was
// one of the special extended ones), send the contends of the output
// buffer to the server.
//
if(Status == SEC_E_OK ||
Status == SEC_I_CONTINUE_NEEDED ||
FAILED(Status) && (dwSSPIOutFlags & ISC_RET_EXTENDED_ERROR))
{
if(OutBuffers[0].cbBuffer != 0 && OutBuffers[0].pvBuffer != NULL)
{
DSCHAN((" write back %X bytes\n", OutBuffers[0].cbBuffer));
cbData = StreamWrite(SEQ_INTERNAL, OutBuffers[0].pvBuffer,
OutBuffers[0].cbBuffer);
if(cbData == 0)
{
g_SecurityFunc.FreeContextBuffer(OutBuffers[0].pvBuffer);
if (m_OwnContext)
{
g_SecurityFunc.DeleteSecurityContext(&m_Context);
SecInvalidateHandle(&m_Context);
}
Status = HRESULT_FROM_WIN32(ERROR_WRITE_FAULT);
goto Exit;
}
// Free output buffer.
g_SecurityFunc.FreeContextBuffer(OutBuffers[0].pvBuffer);
OutBuffers[0].pvBuffer = NULL;
}
}
//
// If InitializeSecurityContext returned SEC_E_INCOMPLETE_MESSAGE,
// then we need to read more data from the server and try again.
//
if(Status == SEC_E_INCOMPLETE_MESSAGE)
{
continue;
}
//
// If InitializeSecurityContext returned SEC_E_OK, then the
// handshake completed successfully.
//
if(Status == SEC_E_OK)
{
//
// If the "extra" buffer contains data, this is encrypted application
// protocol layer stuff. It needs to be saved. The application layer
// will later decrypt it with DecryptMessage.
//
if(InBuffers[1].BufferType == SECBUFFER_EXTRA)
{
DSCHAN_IO((" ISC returns with %X extra bytes\n",
InBuffers[1].cbBuffer));
memmove(m_Buffer,
m_Buffer + (m_BufferUsed - InBuffers[1].cbBuffer),
InBuffers[1].cbBuffer);
m_BufferUsed = InBuffers[1].cbBuffer;
}
else
{
m_BufferUsed = 0;
}
//
// Bail out to quit
//
break;
}
//
// Check for fatal error.
//
if(FAILED(Status))
{
break;
}
//
// If InitializeSecurityContext returned SEC_I_INCOMPLETE_CREDENTIALS,
// then the server just requested client authentication.
//
if(Status == SEC_I_INCOMPLETE_CREDENTIALS)
{
DSCHAN(("Get new client credentials\n"));
//
// Display trusted issuers info.
//
GetNewClientCredentials();
//
// Now would be a good time perhaps to prompt the user to select
// a client certificate and obtain a new credential handle,
// but I don't have the energy nor inclination.
//
// As this is currently written, Schannel will send a "no
// certificate" alert to the server in place of a certificate.
// The server might be cool with this, or it might drop the
// connection.
//
// Go around again.
ReadNeeded = 0;
Status = SEC_I_CONTINUE_NEEDED;
continue;
}
//
// Copy any leftover data from the "extra" buffer, and go around
// again.
//
if ( InBuffers[1].BufferType == SECBUFFER_EXTRA )
{
DSCHAN((" ISC loops with %X extra bytes\n",
InBuffers[1].cbBuffer));
memmove(m_Buffer,
m_Buffer + (m_BufferUsed - InBuffers[1].cbBuffer),
InBuffers[1].cbBuffer);
m_BufferUsed = InBuffers[1].cbBuffer;
}
else
{
m_BufferUsed = 0;
}
}
// Delete the security context in the case of a fatal error.
if(FAILED(Status))
{
if (m_OwnContext)
{
g_SecurityFunc.DeleteSecurityContext(&m_Context);
SecInvalidateHandle(&m_Context);
}
}
Exit:
DSCHAN(("AuthServer returns %X\n", Status));
return Status;
}
void
DbgRpcSecureChannelTransport::GetNewClientCredentials(void)
{
CredHandle hCreds;
SecPkgContext_IssuerListInfoEx IssuerListInfo;
PCCERT_CHAIN_CONTEXT pChainContext;
CERT_CHAIN_FIND_BY_ISSUER_PARA FindByIssuerPara;
PCCERT_CONTEXT pCertContext;
TimeStamp tsExpiry;
SECURITY_STATUS Status;
//
// Read list of trusted issuers from schannel.
//
Status = g_SecurityFunc.QueryContextAttributes(&m_Context,
SECPKG_ATTR_ISSUER_LIST_EX,
(PVOID)&IssuerListInfo);
if (Status != SEC_E_OK)
{
goto Exit;
}
//
// Enumerate the client certificates.
//
ZeroMemory(&FindByIssuerPara, sizeof(FindByIssuerPara));
FindByIssuerPara.cbSize = sizeof(FindByIssuerPara);
FindByIssuerPara.pszUsageIdentifier = szOID_PKIX_KP_CLIENT_AUTH;
FindByIssuerPara.dwKeySpec = 0;
FindByIssuerPara.cIssuer = IssuerListInfo.cIssuers;
FindByIssuerPara.rgIssuer = IssuerListInfo.aIssuers;
pChainContext = NULL;
while(TRUE)
{
// Find a certificate chain.
pChainContext = g_Crypt32Calls.
CertFindChainInStore(g_hMyCertStore,
X509_ASN_ENCODING,
0,
CERT_CHAIN_FIND_BY_ISSUER,
&FindByIssuerPara,
pChainContext);
if(pChainContext == NULL)
{
break;
}
// Get pointer to leaf certificate context.
pCertContext = pChainContext->rgpChain[0]->rgpElement[0]->pCertContext;
// Create schannel credential.
m_ScCreds.cCreds = 1;
m_ScCreds.paCred = &pCertContext;
Status = g_SecurityFunc.AcquireCredentialsHandleA(
NULL, // Name of principal
UNISP_NAME_A, // Name of package
SECPKG_CRED_OUTBOUND, // Flags indicating use
NULL, // Pointer to logon ID
&m_ScCreds, // Package specific data
NULL, // Pointer to GetKey() func
NULL, // Value to pass to GetKey()
&hCreds, // (out) Cred Handle
&tsExpiry); // (out) Lifetime (optional)
if(Status != SEC_E_OK)
{
continue;
}
// Destroy the old credentials.
if (m_OwnCreds)
{
g_SecurityFunc.FreeCredentialsHandle(&m_Creds);
}
// XXX drewb - This doesn't really work if this
// isn't the credential owner.
m_Creds = hCreds;
break;
}
Exit:
DSCHAN(("GetNewClientCredentials returns %X\n", Status));
}
void
DbgRpcSecureChannelTransport::DisconnectFromClient(void)
{
DWORD dwType;
PBYTE pbMessage;
DWORD cbMessage;
DWORD cbData;
SecBufferDesc OutBuffer;
SecBuffer OutBuffers[1];
DWORD dwSSPIFlags;
DWORD dwSSPIOutFlags;
TimeStamp tsExpiry;
DWORD Status;
//
// Notify schannel that we are about to close the connection.
//
dwType = SCHANNEL_SHUTDOWN;
OutBuffers[0].pvBuffer = &dwType;
OutBuffers[0].BufferType = SECBUFFER_TOKEN;
OutBuffers[0].cbBuffer = sizeof(dwType);
OutBuffer.cBuffers = 1;
OutBuffer.pBuffers = OutBuffers;
OutBuffer.ulVersion = SECBUFFER_VERSION;
Status = g_SecurityFunc.ApplyControlToken(&m_Context, &OutBuffer);
if(FAILED(Status))
{
goto cleanup;
}
//
// Build an SSL close notify message.
//
dwSSPIFlags = ASC_REQ_SEQUENCE_DETECT |
ASC_REQ_REPLAY_DETECT |
ASC_REQ_CONFIDENTIALITY |
ASC_REQ_EXTENDED_ERROR |
ASC_REQ_ALLOCATE_MEMORY |
ASC_REQ_STREAM;
OutBuffers[0].pvBuffer = NULL;
OutBuffers[0].BufferType = SECBUFFER_TOKEN;
OutBuffers[0].cbBuffer = 0;
OutBuffer.cBuffers = 1;
OutBuffer.pBuffers = OutBuffers;
OutBuffer.ulVersion = SECBUFFER_VERSION;
Status = g_SecurityFunc.AcceptSecurityContext(
&m_Creds,
&m_Context,
NULL,
dwSSPIFlags,
SECURITY_NATIVE_DREP,
NULL,
&OutBuffer,
&dwSSPIOutFlags,
&tsExpiry);
DSCHAN(("DisASC returns %X\n", Status));
DumpBuffers("DisASC out", &OutBuffer);
if(FAILED(Status))
{
goto cleanup;
}
pbMessage = (PBYTE)OutBuffers[0].pvBuffer;
cbMessage = OutBuffers[0].cbBuffer;
//
// Send the close notify message to the client.
//
if (pbMessage != NULL && cbMessage != 0)
{
DSCHAN((" write back %X bytes\n", cbMessage));
cbData = StreamWrite(SEQ_INTERNAL, pbMessage, cbMessage);
if (cbData == 0)
{
Status = HRESULT_FROM_WIN32(ERROR_WRITE_FAULT);
goto cleanup;
}
// Free output buffer.
g_SecurityFunc.FreeContextBuffer(pbMessage);
}
cleanup:
DSCHAN(("DisconnectFromClient returns %X\n", Status));
}
void
DbgRpcSecureChannelTransport::DisconnectFromServer(void)
{
DWORD dwType;
PBYTE pbMessage;
DWORD cbMessage;
DWORD cbData;
SecBufferDesc OutBuffer;
SecBuffer OutBuffers[1];
DWORD dwSSPIFlags;
DWORD dwSSPIOutFlags;
TimeStamp tsExpiry;
DWORD Status;
//
// Notify schannel that we are about to close the connection.
//
dwType = SCHANNEL_SHUTDOWN;
OutBuffers[0].pvBuffer = &dwType;
OutBuffers[0].BufferType = SECBUFFER_TOKEN;
OutBuffers[0].cbBuffer = sizeof(dwType);
OutBuffer.cBuffers = 1;
OutBuffer.pBuffers = OutBuffers;
OutBuffer.ulVersion = SECBUFFER_VERSION;
Status = g_SecurityFunc.ApplyControlToken(&m_Context, &OutBuffer);
if(FAILED(Status))
{
goto cleanup;
}
//
// Build an SSL close notify message.
//
dwSSPIFlags = ISC_REQ_SEQUENCE_DETECT |
ISC_REQ_REPLAY_DETECT |
ISC_REQ_CONFIDENTIALITY |
ISC_REQ_EXTENDED_ERROR |
ISC_REQ_ALLOCATE_MEMORY |
ISC_REQ_STREAM;
OutBuffers[0].pvBuffer = NULL;
OutBuffers[0].BufferType = SECBUFFER_TOKEN;
OutBuffers[0].cbBuffer = 0;
OutBuffer.cBuffers = 1;
OutBuffer.pBuffers = OutBuffers;
OutBuffer.ulVersion = SECBUFFER_VERSION;
Status = g_SecurityFunc.InitializeSecurityContextA(
&m_Creds,
&m_Context,
NULL,
dwSSPIFlags,
0,
SECURITY_NATIVE_DREP,
NULL,
0,
&m_Context,
&OutBuffer,
&dwSSPIOutFlags,
&tsExpiry);
DSCHAN(("DisISC returns %X\n", Status));
DumpBuffers("DisISC out", &OutBuffer);
if(FAILED(Status))
{
goto cleanup;
}
pbMessage = (PBYTE)OutBuffers[0].pvBuffer;
cbMessage = OutBuffers[0].cbBuffer;
//
// Send the close notify message to the server.
//
if(pbMessage != NULL && cbMessage != 0)
{
DSCHAN((" write back %X bytes\n", cbMessage));
cbData = StreamWrite(SEQ_INTERNAL, pbMessage, cbMessage);
if (cbData == 0)
{
Status = HRESULT_FROM_WIN32(ERROR_WRITE_FAULT);
goto cleanup;
}
// Free output buffer.
g_SecurityFunc.FreeContextBuffer(pbMessage);
}
cleanup:
DSCHAN(("DisconnectFromServer returns %X\n", Status));
}