windows-nt/Source/XPSP1/NT/sdktools/profiler/shim2.c
2020-09-26 16:20:57 +08:00

578 lines
17 KiB
C

////////////////////////////////////////////////////////////////////////////////////
//
// File: shim2.c
//
// History: May-99 clupu Created.
// Aug-99 v-johnwh Various bug fixes.
// 23-Nov-99 markder Support for multiple shim DLLs, chaining
// of hooks, DLL loads/unloads. General clean-up.
// Jan-00 markder Windows 9x support added.
// Mar-00 a-batjar Changed to support whistler format on w2k
// May-00 v-johnwh Modified to work in the profiler
// Desc: Contains all code to facilitate hooking of APIs by replacing entries
// in the import tables of loaded modules.
//
////////////////////////////////////////////////////////////////////////////////////
#include <windows.h>
#include <stdlib.h>
#include <psapi.h>
#include <tlhelp32.h>
#include <imagehlp.h>
#include <stdio.h>
#include "shimdb.h"
#include "shim2.h"
#define HAF_RESOLVED 0x0001
#define HAF_BOTTOM_OF_CHAIN 0x0002
typedef PHOOKAPI (*PFNNEWGETHOOKAPIS)(DWORD dwGetProcAddress, DWORD dwLoadLibraryA, DWORD dwFreeLibrary, DWORD* pdwHookAPICount);
typedef LPSTR (*PFNGETCOMMANDLINEA)(VOID);
typedef LPWSTR (*PFNGETCOMMANDLINEW)(VOID);
typedef PVOID (*PFNGETPROCADDRESS)(HMODULE hMod, char* pszProc);
typedef HINSTANCE (*PFNLOADLIBRARYA)(LPCSTR lpLibFileName);
typedef HINSTANCE (*PFNLOADLIBRARYW)(LPCWSTR lpLibFileName);
typedef HINSTANCE (*PFNLOADLIBRARYEXA)(LPCSTR lpLibFileName, HANDLE hFile, DWORD dwFlags);
typedef HINSTANCE (*PFNLOADLIBRARYEXW)(LPCWSTR lpLibFileName, HANDLE hFile, DWORD dwFlags);
typedef BOOL (*PFNFREELIBRARY)(HMODULE hLibModule);
// Global Variables
// Disable build warnings due to Print macro in free builds
#pragma warning( disable : 4002 )
#define MAX_MODULES 512
#define SHIM_GETHOOKAPIS "GetHookAPIs"
////////////////////////////////////////////////////////////////////////////////////
// API hook count & indices
////////////////////////////////////////////////////////////////////////////////////
enum
{
hookGetProcAddress,
hookLoadLibraryA,
hookLoadLibraryW,
hookLoadLibraryExA,
hookLoadLibraryExW,
hookFreeLibrary,
hookGetCommandLineA,
hookGetCommandLineW
};
////////////////////////////////////////////////////////////////////////////////////
// Global variables
////////////////////////////////////////////////////////////////////////////////////
// This array contains information used by the shim mechanism to describe
// what API to hook with a particular stub function.
LONG g_nShimDllCount;
HMODULE g_hShimDlls[MAX_MODULES];
PHOOKAPI g_rgpHookAPIs[MAX_MODULES];
LONG g_rgnHookAPICount[MAX_MODULES];
LPTSTR g_rgnHookDllList[MAX_MODULES];
HMODULE g_hHookedModules[MAX_MODULES];
LONG g_nHookedModuleCount;
extern BOOL g_bIsWin9X;
HANDLE g_hSnapshot = NULL;
HANDLE g_hValidationSnapshot = NULL;
////////////////////////////////////////////////////////////////////////////////////
//
// Func: ValidateAddress
//
// Params: pfnOld Original API function pointer to validate.
//
// Return: Potentially massaged pfnOld.
//
// Desc: Win9x thunks system API entry points for some reason. The
// shim mechanism has to work around this to get to the
// 'real' pointer so that it can make valid comparisons.
//
PVOID ValidateAddress( PVOID pfnOld )
{
MODULEENTRY32 ModuleEntry32;
BOOL bRet;
long i, j;
// Make sure the address isn't a shim thunk
for( i = g_nShimDllCount - 1; i >= 0; i-- )
{
for( j = 0; j < g_rgnHookAPICount[i]; j++ )
{
if( g_rgpHookAPIs[i][j].pfnOld == pfnOld )
{
if( pfnOld == g_rgpHookAPIs[i][j].pfnNew )
return pfnOld;
}
}
}
ModuleEntry32.dwSize = sizeof( ModuleEntry32 );
bRet = Module32First( g_hValidationSnapshot, &ModuleEntry32 );
while( bRet )
{
if( pfnOld >= (PVOID) ModuleEntry32.modBaseAddr &&
pfnOld <= (PVOID) ( ModuleEntry32.modBaseAddr + ModuleEntry32.modBaseSize ) )
{
return pfnOld;
}
bRet = Module32Next( g_hValidationSnapshot, &ModuleEntry32 );
}
// Hack for Win9x
return *(PVOID *)( ((PBYTE)pfnOld)+1);
}
////////////////////////////////////////////////////////////////////////////////////
//
// Func: ConstructChain
//
// Params: pfnOld Original API function pointer to resolve.
//
// Return: Top-of-chain PHOOKAPI structure.
//
// Desc: Scans HookAPI arrays for pfnOld and either constructs the
// chain or returns the top-of-chain PHOOKAPI if the chain
// already exists.
//
PHOOKAPI ConstructChain( PVOID pfnOld ,DWORD* DllListIndex)
{
LONG i, j;
PHOOKAPI pTopHookAPI;
PHOOKAPI pBottomHookAPI;
pTopHookAPI = NULL;
pBottomHookAPI = NULL;
*DllListIndex=0;
// Scan all HOOKAPI entries for corresponding function pointer
for( i = g_nShimDllCount - 1; i >= 0; i-- )
{
for( j = 0; j < g_rgnHookAPICount[i]; j++ )
{
if( g_rgpHookAPIs[i][j].pfnOld == pfnOld )
{
if( pTopHookAPI )
{
// Already hooked! Chain them together.
pBottomHookAPI->pfnOld = g_rgpHookAPIs[i][j].pfnNew;
pBottomHookAPI = &( g_rgpHookAPIs[i][j] );
pBottomHookAPI->pNextHook = pTopHookAPI;
pBottomHookAPI->dwFlags = HAF_RESOLVED;
}
else
{
if( g_rgpHookAPIs[i][j].pNextHook )
{
// Chaining has already been constructed.
pTopHookAPI = (PHOOKAPI) g_rgpHookAPIs[i][j].pNextHook;
*DllListIndex=i;
return pTopHookAPI;
}
// Not hooked yet. Set to top of chain.
pTopHookAPI = &( g_rgpHookAPIs[i][j] );
pTopHookAPI->pNextHook = pTopHookAPI;
pTopHookAPI->dwFlags = HAF_RESOLVED;
pBottomHookAPI = pTopHookAPI;
}
break;
}
}
}
if( pBottomHookAPI )
{
pBottomHookAPI->dwFlags = HAF_BOTTOM_OF_CHAIN;
}
*DllListIndex=i;
return pTopHookAPI;
} // ConstructChain
////////////////////////////////////////////////////////////////////////////////////
//
// Func: HookImports
//
// Params: dwBaseAddress Base address of module image to hook.
//
// szModName Name of module (for debug purposes only).
//
// Desc: This function is the workhorse of the shim: It scans the import
// table of a module (specified by dwBaseAddress) looking for
// function pointers that require hooking (according to HOOKAPI
// entries in g_rgpHookAPIs). It then overwrites hooked function
// pointers with the first stub function in the chain.
//
VOID HookImports(
DWORD dwBaseAddress,
LPTSTR szModName )
{
BOOL bAnyHooked = FALSE;
PIMAGE_DOS_HEADER pIDH = (PIMAGE_DOS_HEADER) dwBaseAddress;
PIMAGE_NT_HEADERS pINTH;
PIMAGE_IMPORT_DESCRIPTOR pIID;
PIMAGE_NT_HEADERS NtHeaders;
DWORD dwTemp;
DWORD dwImportTableOffset;
PHOOKAPI pTopHookAPI;
DWORD dwOldProtect;
LONG i, j;
PVOID pfnOld;
// Get the import table
pINTH = (PIMAGE_NT_HEADERS)(dwBaseAddress + pIDH->e_lfanew);
dwImportTableOffset = pINTH->OptionalHeader.DataDirectory[IMAGE_DIRECTORY_ENTRY_IMPORT].VirtualAddress;
if( dwImportTableOffset == 0 )
return;
pIID = (PIMAGE_IMPORT_DESCRIPTOR)(dwBaseAddress + dwImportTableOffset);
// Loop through the import table and search for the APIs that we want to patch
while( TRUE )
{
LPTSTR pszModule;
PIMAGE_THUNK_DATA pITDA;
// Return if no first thunk
if (pIID->FirstThunk == 0) // (terminating condition)
break;
pszModule = (LPTSTR) ( dwBaseAddress + pIID->Name );
// If we're not interested in this module jump to the next.
bAnyHooked = FALSE;
for( i = 0; i < g_nShimDllCount; i++ )
{
for( j = 0; j < g_rgnHookAPICount[i]; j++ )
{
if( lstrcmpi( g_rgpHookAPIs[i][j].pszModule, pszModule ) == 0 )
{
bAnyHooked = TRUE;
goto ScanDone;
}
}
}
ScanDone:
if( !bAnyHooked )
{
pIID++;
continue;
}
// We have APIs to hook for this module!
pITDA = (PIMAGE_THUNK_DATA)( dwBaseAddress + (DWORD)pIID->FirstThunk );
while( TRUE )
{
DWORD DllListIndex = 0;
pfnOld = (PVOID) pITDA->u1.Function;
// Done with all the imports from this module?
if( pITDA->u1.Ordinal == 0 ) // (terminating condition)
break;
if( g_bIsWin9X )
pfnOld = ValidateAddress( pfnOld );
pTopHookAPI = ConstructChain( (PVOID) pfnOld,&DllListIndex );
if( ! pTopHookAPI )
{
pITDA++;
continue;
}
/*
* Check if we want to patch this API for this particular loaded module
*/
if (NULL != g_rgnHookDllList[DllListIndex])
{
LPTSTR pszMod = g_rgnHookDllList[DllListIndex];
BOOL b = FALSE; //gets set to true if the list is an exclude list
while (*pszMod != 0) {
if (lstrcmpi(pszMod, szModName) == 0)
break;
if(lstrcmpi(pszMod,TEXT("%")) == 0)
b=TRUE;
if(b && lstrcmpi(pszMod,TEXT("*")) == 0)
{
//this means it is exclude all and we already checked include list
//skip this api
break;
}
pszMod = pszMod + lstrlen(pszMod) + 1;
}
if(b && *pszMod != 0)
{
pITDA++;
continue;
}
if (!b && *pszMod == 0)
{
pITDA++;
continue;
}
}
// Make the code page writable and overwrite new function pointer in import table
if ( VirtualProtect( &pITDA->u1.Function,
sizeof(DWORD),
PAGE_READWRITE,
&dwOldProtect) )
{
pITDA->u1.Function = (ULONG) pTopHookAPI->pfnNew;
VirtualProtect( &pITDA->u1.Function,
sizeof(DWORD),
dwOldProtect,
&dwTemp );
}
pITDA++;
}
pIID++;
}
} // HookImports
////////////////////////////////////////////////////////////////////////////////////
//
// Func: ResolveAPIs
//
// Desc: Each time a module is loaded, the pfnOld members of each HOOKAPI
// structure in g_rgpHookAPIs are resolved (by calling GetProcAddress).
//
VOID ResolveAPIs()
{
LONG i, j;
PVOID pfnOld = NULL;
PIMAGE_NT_HEADERS NtHeaders;
for (i = 0; i < g_nShimDllCount; i++)
{
for (j = 0; j < g_rgnHookAPICount[i]; j++ )
{
HMODULE hMod;
// We only care about HOOKAPIs at the bottom of a chain.
if( ( g_rgpHookAPIs[i][j].dwFlags & HAF_RESOLVED ) &&
! ( g_rgpHookAPIs[i][j].dwFlags & HAF_BOTTOM_OF_CHAIN ) )
continue;
if( ( hMod = GetModuleHandle(g_rgpHookAPIs[i][j].pszModule) ) != NULL)
{
pfnOld = GetProcAddress( hMod, g_rgpHookAPIs[i][j].pszFunctionName );
if( pfnOld == NULL )
{
// This is an ERROR. The hook DLL asked to patch a function
// that doesn't exist !!!
}
else
{
if( g_bIsWin9X )
pfnOld = ValidateAddress( pfnOld );
g_rgpHookAPIs[i][j].pfnOld = pfnOld;
}
}
}
}
} // ResolveAPIs
////////////////////////////////////////////////////////////////////////////////////
//
// Func: PatchNewModules
//
// Desc: This function is called at initialization and then each time a module
// is loaded. It enumerates all loaded processes and calls HookImports
// to overwrite appropriate function pointers.
//
void __stdcall Shim2PatchNewModules( VOID )
{
DWORD i;
LONG j;
BOOL bRet;
HMODULE hMod;
MODULEENTRY32 ModuleEntry32;
// Enumerate all the loaded modules and hook their import tables
g_hSnapshot = CreateToolhelp32Snapshot( TH32CS_SNAPMODULE, 0 );
g_hValidationSnapshot = CreateToolhelp32Snapshot( TH32CS_SNAPMODULE, 0 );
if( g_hSnapshot == NULL )
{
return;
}
// Resolve old APIs for loaded modules
ResolveAPIs();
ModuleEntry32.dwSize = sizeof( ModuleEntry32 );
bRet = Module32First( g_hSnapshot, &ModuleEntry32 );
while( bRet )
{
hMod = ModuleEntry32.hModule;
if( hMod >= (HMODULE) 0x80000000 )
{
bRet = Module32Next( g_hSnapshot, &ModuleEntry32 );
continue;
}
// we need to make sure we are not trying to shim ourselves
for (j = 0; j < g_nShimDllCount; j++ )
{
if( hMod == g_hShimDlls[j] )
{
hMod = NULL;
break;
}
}
for (j = 0; j < g_nHookedModuleCount; j++ )
{
if( hMod == g_hHookedModules[ j ] )
{
hMod = NULL;
break;
}
}
if( hMod )
{
HookImports( (DWORD) hMod, ModuleEntry32.szModule );
g_hHookedModules[ g_nHookedModuleCount++ ] = hMod;
}
bRet = Module32Next( g_hSnapshot, &ModuleEntry32 );
}
if( g_hSnapshot )
{
CloseHandle( g_hSnapshot );
g_hSnapshot = NULL;
}
if( g_hValidationSnapshot )
{
CloseHandle( g_hValidationSnapshot );
g_hValidationSnapshot = NULL;
}
} //PatchNewModules
////////////////////////////////////////////////////////////////////////////////////
//
// Func: AddHookAPIs
//
// Params: hShimDll Handle of new shim DLL.
//
// pHookAPIs Pointer to new HOOKAPI array.
//
// dwCount Number of entries in pHookAPIs.
//
// Desc: Stores away the pointer returned by a shim DLL's GetHookAPIs
// function in our global arrays.
//
void AddHookAPIs( HMODULE hShimDll, PHOOKAPI pHookAPIs, DWORD dwCount,LPTSTR szIncExclDllList)
{
DWORD i;
for( i = 0; i < dwCount; i++ )
{
pHookAPIs[i].dwFlags = 0;
pHookAPIs[i].pNextHook = NULL;
}
g_rgpHookAPIs[ g_nShimDllCount ] = pHookAPIs;
g_rgnHookAPICount[ g_nShimDllCount ] = dwCount;
g_hShimDlls[ g_nShimDllCount ] = hShimDll;
g_rgnHookDllList[g_nShimDllCount ] = szIncExclDllList;
g_nShimDllCount++;
} // AddHookAPIs
////////////////////////////////////////////////////////////////////////////////////
//
// Func: _LoadPatchDll
//
// Params: pwszPatchDll Name of shim DLL to be loaded.
//
// Return: TRUE if successful, FALSE if not.
//
// Desc: Loads a shim DLL and retrieves the hooking information via GetHookAPIs.
//
BOOL _LoadPatchDll(
LPWSTR szPatchDll,LPSTR szCmdLine,LPSTR szIncExclDllList)
{
PHOOKAPI pHookAPIs = NULL;
DWORD dwHookAPICount = 0;
HMODULE hModHookDll;
PFNGETHOOKAPIS pfnGetHookAPIs;
hModHookDll = LoadLibraryW(szPatchDll);
if (hModHookDll == NULL)
{
return FALSE;
}
pfnGetHookAPIs = (PFNGETHOOKAPIS) GetProcAddress( hModHookDll, SHIM_GETHOOKAPIS );
if( pfnGetHookAPIs == NULL )
{
FreeLibrary( hModHookDll );
return FALSE;
}
pHookAPIs = (*pfnGetHookAPIs)(szCmdLine, Shim2PatchNewModules, &dwHookAPICount );
if( dwHookAPICount == 0 || pHookAPIs == NULL )
{
FreeLibrary( hModHookDll );
return FALSE;
}
AddHookAPIs( hModHookDll, pHookAPIs, dwHookAPICount,szIncExclDllList);
return TRUE;
} // _LoadPatchDll
// Re-enable build warnings.
#pragma warning( default : 4002 )