windows-nt/Source/XPSP1/NT/shell/ext/hnw/nconn32/protocol.cpp
2020-09-26 16:20:57 +08:00

537 lines
15 KiB
C++

//
// Protocol.cpp
//
// Install, uninstall, and related code for protocols such as TCP/IP.
//
// History:
//
// 2/02/1999 KenSh Created for JetNet
// 9/29/1999 KenSh Repurposed for Home Networking Wizard
//
#include "stdafx.h"
#include "NetConn.h"
#include "nconnwrap.h"
#include "ParseInf.h"
#include "TheApp.h"
#include "HookUI.h"
#include "NetCli.h"
// Local functions
//
VOID RemoveOrphanedProtocol(LPCSTR pszProtocolID);
HRESULT CreateNewProtocolBinding(LPCSTR pszProtocolDeviceID, LPSTR pszBuf, int cchBuf, LPCSTR pszClientBinding, LPCSTR pszServiceBinding);
// IsProtocolInstalled
//
// Returns TRUE if one or more instances of the given protocol
// (e.g. "MSTCP") is bound to a network adapter.
//
BOOL WINAPI IsProtocolInstalled(LPCTSTR pszProtocolDeviceID, BOOL bExhaustive)
{
if (!IsProtocolBoundToAnyAdapter(pszProtocolDeviceID))
return FALSE;
if (bExhaustive)
{
TCHAR szInfSection[50];
wsprintf(szInfSection, "%s.Install", pszProtocolDeviceID);
if (!CheckInfSectionInstallation("nettrans.inf", szInfSection))
return FALSE;
}
return TRUE;
}
// InstallProtocol (public)
//
// Installs the given protocol via NETDI, and binds it to all adapters.
// The standard progress UI is (mostly) suppressed, and instead the given
// callback function is called so a custom progress UI can be implemented.
//
// Returns a NETCONN_xxx result, defined in NetConn.h
//
// Parameters:
//
// hwndParent parent window helpful to use in NETDI call
// pfnProgress function to call with install progress reports
// pvProgressParam user-supplied parameter to pass to pfnProgress
//
// History:
//
// 2/23/1999 KenSh Created
// 3/26/1999 KenSh Check if already installed before reinstalling
//
HRESULT WINAPI InstallProtocol(LPCSTR pszProtocolID, HWND hwndParent, PROGRESS_CALLBACK pfnProgress, LPVOID pvProgressParam)
{
HRESULT hr = NETCONN_SUCCESS;
if (!IsProtocolBoundToAnyAdapter(pszProtocolID))
{
RemoveOrphanedProtocol(pszProtocolID);
BeginSuppressNetdiUI(hwndParent, pfnProgress, pvProgressParam);
DWORD dwResult = CallClassInstaller16(hwndParent, SZ_CLASS_PROTOCOL, pszProtocolID);
EndSuppressNetdiUI();
if (g_bUserAbort)
{
hr = NETCONN_USER_ABORT;
}
else if (SUCCEEDED(HresultFromCCI(dwResult)))
{
hr = NETCONN_NEED_RESTART;
}
// Total hack to work around JetNet bug 1193
// DoDummyDialog(hwndParent);
}
if (SUCCEEDED(hr))
{
// Ensure the protocol is bound exactly once to every NIC
HRESULT hr2 = BindProtocolToAllAdapters(pszProtocolID);
if (hr2 != NETCONN_SUCCESS)
hr = hr2;
}
return hr;
}
// InstallTCPIP (public)
//
// Installs TCP/IP via NETDI. See InstallProtocol for details.
//
HRESULT WINAPI InstallTCPIP(HWND hwndParent, PROGRESS_CALLBACK pfnProgress, LPVOID pvProgressParam)
{
return InstallProtocol(SZ_PROTOCOL_TCPIP, hwndParent, pfnProgress, pvProgressParam);
}
HRESULT WINAPI RemoveProtocol(LPCSTR pszProtocol)
{
HRESULT hr = NETCONN_SUCCESS;
// Remove all pointers to the protocol from net adapters
NETADAPTER* prgAdapters;
int cAdapters = EnumNetAdapters(&prgAdapters);
for (int iAdapter = 0; iAdapter < cAdapters; iAdapter++)
{
NETADAPTER* pAdapter = &prgAdapters[iAdapter];
LPTSTR* prgBindings;
int cBindings = EnumMatchingNetBindings(pAdapter->szEnumKey, pszProtocol, &prgBindings);
if (cBindings > 0)
{
CRegistry regBindings;
if (regBindings.OpenKey(HKEY_LOCAL_MACHINE, pAdapter->szEnumKey) &&
regBindings.OpenSubKey("Bindings"))
{
for (int iBinding = 0; iBinding < cBindings; iBinding++)
{
regBindings.DeleteValue(prgBindings[iBinding]);
hr = NETCONN_NEED_RESTART;
}
}
}
NetConnFree(prgBindings);
}
NetConnFree(prgAdapters);
// Remove the protocol's enum key
CRegistry reg;
if (reg.OpenKey(HKEY_LOCAL_MACHINE, "Enum\\Network"))
{
RegDeleteKeyAndSubKeys(reg.m_hKey, pszProtocol);
}
// Remove the protocol's class key(s)
DeleteClassKeyReferences(SZ_CLASS_PROTOCOL, pszProtocol);
return hr;
}
// IsProtocolBoundToAnyAdapter (public)
//
// Given a protocol ID, such as "MSTCP", returns TRUE if the protocol
// is bound to any adapter, or FALSE if not.
//
// History:
//
// 3/26/1999 KenSh Created
//
BOOL WINAPI IsProtocolBoundToAnyAdapter(LPCSTR pszProtocolID)
{
BOOL bResult = FALSE;
NETADAPTER* prgAdapters;
int cAdapters = EnumNetAdapters(&prgAdapters);
for (int i = 0; i < cAdapters; i++)
{
if (IsProtocolBoundToAdapter(pszProtocolID, &prgAdapters[i]))
{
bResult = TRUE;
goto done;
}
}
done:
NetConnFree(prgAdapters);
return bResult;
}
// Given a protocol ID, such as "MSTCP", and an adapter struct, determines
// whether the protocol is bound to the adapter.
BOOL WINAPI IsProtocolBoundToAdapter(LPCSTR pszProtocolID, const NETADAPTER* pAdapter)
{
LPSTR* prgBindings;
int cBindings = EnumMatchingNetBindings(pAdapter->szEnumKey, pszProtocolID, &prgBindings);
NetConnFree(prgBindings);
return (BOOL)cBindings;
}
// Checks to see whether any instances of this protocol in the Class branch of
// the registry are unreferenced, and deletes them if so.
// pszProtocolID is the generic device ID of the protocol, e.g. "MSTCP"
VOID RemoveOrphanedProtocol(LPCSTR pszProtocolID)
{
// REVIEW: Should we first delete references in Enum that are not in use?
CRegistry reg;
if (reg.OpenKey(HKEY_LOCAL_MACHINE, "System\\CurrentControlSet\\Services\\Class\\NetTrans", KEY_ALL_ACCESS))
{
//
// Enumerate the various protocols, e.g. "NetTrans\0000"
//
for (DWORD iKey = 0; ; iKey++)
{
CHAR szSubKey[MAX_PATH];
DWORD cbSubKey = _countof(szSubKey);
if (ERROR_SUCCESS != RegEnumKeyEx(reg.m_hKey, iKey, szSubKey, &cbSubKey, NULL, NULL, NULL, NULL))
break;
//
// Open the "Ndi" subkey so we can see what kind of protocol this is
//
lstrcpy(szSubKey + cbSubKey, "\\Ndi");
CRegistry regNode;
if (regNode.OpenKey(reg.m_hKey, szSubKey, KEY_ALL_ACCESS))
{
CHAR szDeviceID[40];
if (regNode.QueryStringValue("DeviceID", szDeviceID, _countof(szDeviceID)))
{
regNode.CloseKey(); // close the key before we try to delete it
if (0 == lstrcmpi(szDeviceID, pszProtocolID))
{
//
// Found the right protocol, now check if it's referenced
//
if (!IsNetClassKeyReferenced(szSubKey))
{
// Not referenced, so delete it
szSubKey[cbSubKey] = '\0'; // back up to just "NetTrans"
RegDeleteKeyAndSubKeys(reg.m_hKey, szSubKey);
}
}
}
}
}
}
}
// BindProtocolToAdapter
//
// pszProtocolEnumKey is the top-level Enum key for the protocol, e.g. "MSTCP"
//
// pszAdapterEnumKey is the first part of the top-level Enum key for the adapter,
// e.g. "PCI\\VEN_10B7&DEV_9050" or "Root\\Net\\0000" (dial-up adapter)
//
// bEnableSharing determines whether file and printer sharing will be bound
// to the protocol when running through the given adapter.
//
// History:
//
// 3/26/1999 KenSh Created
// 4/09/1999 KenSh Added bEnableSharing flag
//
HRESULT BindProtocolToAdapter(HKEY hkeyAdapterBindings, LPCSTR pszProtocolDeviceID, BOOL bEnableSharing)
{
HRESULT hr;
// NetTrans seems to always clone the Enum key, and also clone the Class key.
// (the new clone of the Enum key points to the new clone of the Class key;
// each new Enum's MasterCopy points to itself)
// NetService and NetClient seem to clone the Enum key, but not the Class key.
CHAR szClient[MAX_PATH];
CHAR szService1[MAX_PATH];
CHAR szService2[MAX_PATH];
CHAR szProtocol[MAX_PATH];
if (bEnableSharing)
{
if (FAILED(hr = CreateNewFilePrintSharing(szService1, _countof(szService1))))
return hr;
if (FAILED(hr = CreateNewFilePrintSharing(szService2, _countof(szService2))))
return hr;
}
else
{
szService1[0] = '\0';
szService2[0] = '\0';
}
if (FAILED(hr = CreateNewClientForMSNet(szClient, _countof(szClient), szService1)))
return hr;
if (FAILED(hr = CreateNewProtocolBinding(pszProtocolDeviceID, szProtocol, _countof(szProtocol), szClient, szService2)))
return hr;
// Bind the new protocol to the adapter
if (ERROR_SUCCESS != RegSetValueEx(hkeyAdapterBindings, szProtocol, 0, REG_SZ, (CONST BYTE*)"", 1))
return NETCONN_UNKNOWN_ERROR;
return NETCONN_NEED_RESTART;
}
HRESULT BindProtocolToAllAdapters_Helper(LPCSTR pszProtocolDeviceID, LPCSTR pszAdapterKey, BOOL bIgnoreVirtualNics)
{
HRESULT hr = NETCONN_SUCCESS;
// Get LowerRange interfaces for protocol
CHAR szProtocolLower[100];
GetDeviceLowerRange(SZ_CLASS_PROTOCOL, pszProtocolDeviceID, szProtocolLower, _countof(szProtocolLower));
// For each adapter, ensure the protocol is bound exactly once
//
NETADAPTER* prgAdapters;
int cAdapters = EnumNetAdapters(&prgAdapters);
// Pass 0: add new bindings
// Pass 1: delete inappropriate bindings
for (int iPass = 0; iPass <= 1; iPass++)
{
for (int iAdapter = 0; iAdapter < cAdapters; iAdapter++)
{
NETADAPTER* pAdapter = &prgAdapters[iAdapter];
CRegistry regAdapter;
// Get UpperRange interfaces for adapter
CHAR szAdapterUpper[100];
GetDeviceUpperRange(SZ_CLASS_ADAPTER, pAdapter->szDeviceID, szAdapterUpper, _countof(szAdapterUpper));
// Check for a match between the protocol and the adapter
BOOL bMatchingInterface = CheckMatchingInterface(szProtocolLower, szAdapterUpper);
CHAR szRegKey[MAX_PATH];
wsprintf(szRegKey, "%s\\Bindings", pAdapter->szEnumKey);
BOOL bCorrectNic = (NULL == pszAdapterKey) || (0 == lstrcmpi(pAdapter->szEnumKey, pszAdapterKey));
if (!lstrcmpi(pszProtocolDeviceID, SZ_PROTOCOL_IPXSPX))
{
// Bind IPX/SPX to all non-broadband NICs by default (usually a max of 1)
if (pszAdapterKey == NULL)
{
if (IsAdapterBroadband(pAdapter))
bCorrectNic = FALSE;
}
// Don't bind IPX/SPX to Dial-Up (or other virtual) adapters (bugs 1163, 1164)
if (pAdapter->bNicType == NIC_VIRTUAL)
bCorrectNic = FALSE;
}
//
// Check the bindings of the current adapter, looking for this protocol
//
if (regAdapter.OpenKey(HKEY_LOCAL_MACHINE, szRegKey, KEY_ALL_ACCESS))
{
TCHAR szValueName[60];
int cFound = 0;
DWORD iValue = 0;
for (;;)
{
DWORD cbValueName = _countof(szValueName);
if (ERROR_SUCCESS != RegEnumValue(regAdapter.m_hKey, iValue, szValueName, &cbValueName, NULL, NULL, NULL, NULL))
break;
LPSTR pchSlash = strchr(szValueName, '\\');
if (pchSlash == NULL)
break;
*pchSlash = '\0';
if (0 == lstrcmpi(szValueName, pszProtocolDeviceID))
{
*pchSlash = '\\';
// If this isn't the right NIC, or if there's not a matching
// interface, or (optionally) if the NIC is virtual, then
// unbind the protocol.
BOOL bUnbindFromNic = !bCorrectNic || !bMatchingInterface;
if (bIgnoreVirtualNics && (pAdapter->bNicType == NIC_VIRTUAL))
bUnbindFromNic = FALSE;
if (bUnbindFromNic || // bound to the wrong adapter!
cFound != 0) // bound more than once to this NIC!
{
if (iPass == 1) // unbind on second pass only
{
// Remove the binding, then restart our search
// for matching protocols
RemoveBindingFromParent(regAdapter.m_hKey, szValueName);
iValue = 0;
cFound = 0;
hr = NETCONN_NEED_RESTART;
continue;
}
}
cFound += 1;
}
iValue += 1;
}
if (bCorrectNic && iPass == 0)
{
if (cFound == 0)
// Protocol is not yet bound to the correct adapter
{
if (bMatchingInterface) // There's an interface in common
{
BOOL bExternalNic = IsAdapterBroadband(pAdapter);
// Enable file/printer sharing if:
// * Adapter is an ethernet or IRDA adapter
// * Adapter is not a broadband NIC
BOOL bEnableSharing = FALSE;
if ((pAdapter->bNetType == NETTYPE_LAN ||
pAdapter->bNetType == NETTYPE_IRDA) &&
!bExternalNic)
{
bEnableSharing = TRUE;
}
HRESULT hr2 = BindProtocolToAdapter(regAdapter.m_hKey, pszProtocolDeviceID, bEnableSharing);
if (hr2 != NETCONN_SUCCESS)
hr = hr2;
}
}
}
}
}
}
NetConnFree(prgAdapters);
return hr;
}
// BindProtocolToOnlyOneAdapter (public)
HRESULT WINAPI BindProtocolToOnlyOneAdapter(LPCSTR pszProtocolDeviceID, LPCSTR pszAdapterKey, BOOL bIgnoreVirtualNics)
{
return BindProtocolToAllAdapters_Helper(pszProtocolDeviceID, pszAdapterKey, bIgnoreVirtualNics);
}
// BindProtocolToAllAdapters (public)
//
// Given the device ID ("MSTCP") of a protocol which is already installed,
// binds that protocol to all adapters, as well as to client for Microsoft
// Networks and File and Printer sharing.
//
// History:
//
// 3/26/1999 KenSh Created
// 4/23/1999 KenSh Enable file sharing only on real NICs
//
HRESULT WINAPI BindProtocolToAllAdapters(LPCSTR pszProtocolDeviceID)
{
return BindProtocolToAllAdapters_Helper(pszProtocolDeviceID, NULL, FALSE);
}
// Given a protocol ID such as "MSTCP", an optional client binding string
// such as "VREDIR\0000", and an optional service binding string such as
// "VSERVER\0000", creates a new protocol binding, and copies the name of
// the new binding into the buffer provided (e.g. "MSTCP\0001").
HRESULT CreateNewProtocolBinding(LPCSTR pszProtocolDeviceID, LPSTR pszBuf, int cchBuf, LPCSTR pszClientBinding, LPCSTR pszServiceBinding)
{
HRESULT hr;
if (FAILED(hr = FindAndCloneNetEnumKey(SZ_CLASS_PROTOCOL, pszProtocolDeviceID, pszBuf, cchBuf)))
{
ASSERT(FALSE);
return hr;
}
// Now pszBuf contains a string of the form "MSTCP\0001"
CHAR szBindings[60];
CRegistry regBindings;
lstrcpy(szBindings, pszBuf); // "MSTCP\0001"
lstrcat(szBindings, "\\Bindings"); // "MSTCP\0001\Bindings"
if (FAILED(hr = OpenNetEnumKey(regBindings, szBindings, KEY_ALL_ACCESS)))
{
ASSERT(FALSE);
return hr;
}
// Delete existing bindings
regBindings.DeleteAllValues();
// Add the client and server bindings
if (pszClientBinding != NULL && *pszClientBinding != '\0')
regBindings.SetStringValue(pszClientBinding, "");
if (pszServiceBinding != NULL && *pszServiceBinding != '\0')
regBindings.SetStringValue(pszServiceBinding, "");
// Change the MasterCopy to point to correct place
CHAR szMasterCopy[MAX_PATH];
wsprintf(szMasterCopy, "Enum\\Network\\%s", pszBuf);
if (regBindings.OpenKey(HKEY_LOCAL_MACHINE, szMasterCopy, KEY_ALL_ACCESS))
{
regBindings.SetStringValue("MasterCopy", szMasterCopy);
}
// Create a clone of the driver (a.k.a. class key)
CHAR szExistingDriver[60];
CHAR szNewDriver[60];
regBindings.QueryStringValue("Driver", szExistingDriver, _countof(szExistingDriver));
CloneNetClassKey(szExistingDriver, szNewDriver, _countof(szNewDriver));
// Change the Driver to point to the new class key
CRegistry regEnumSubKey;
VERIFY(SUCCEEDED(OpenNetEnumKey(regEnumSubKey, pszBuf, KEY_ALL_ACCESS)));
regEnumSubKey.SetStringValue("Driver", szNewDriver);
// If this is a new TCP/IP binding, ensure we don't have a static IP address
if (0 == lstrcmpi(pszProtocolDeviceID, SZ_PROTOCOL_TCPIP))
{
CHAR szFullClassKey[100];
wsprintf(szFullClassKey, "System\\CurrentControlSet\\Services\\Class\\%s", szNewDriver);
CRegistry regClassKey;
VERIFY(regClassKey.OpenKey(HKEY_LOCAL_MACHINE, szFullClassKey));
if (regClassKey.QueryStringValue("IPAddress", szFullClassKey, _countof(szFullClassKey)))
regClassKey.SetStringValue("IPAddress", "0.0.0.0");
if (regClassKey.QueryStringValue("IPMask", szFullClassKey, _countof(szFullClassKey)))
regClassKey.SetStringValue("IPMask", "0.0.0.0");
}
return NETCONN_SUCCESS;
}