diff options
Diffstat (limited to 'NorthstarDLL/hooks.cpp')
-rw-r--r-- | NorthstarDLL/hooks.cpp | 390 |
1 files changed, 272 insertions, 118 deletions
diff --git a/NorthstarDLL/hooks.cpp b/NorthstarDLL/hooks.cpp index 72ae727a..cca1d986 100644 --- a/NorthstarDLL/hooks.cpp +++ b/NorthstarDLL/hooks.cpp @@ -1,61 +1,193 @@ #include "pch.h" -#include "hooks.h" -#include "hookutils.h" -#include "sigscanning.h" #include "dedicated.h" +#include <iostream> #include <wchar.h> #include <iostream> #include <vector> #include <fstream> #include <sstream> #include <filesystem> +#include <Psapi.h> -typedef LPSTR (*GetCommandLineAType)(); -LPSTR GetCommandLineAHook(); +AUTOHOOK_INIT() -typedef HMODULE (*LoadLibraryExAType)(LPCSTR lpLibFileName, HANDLE hFile, DWORD dwFlags); -HMODULE LoadLibraryExAHook(LPCSTR lpLibFileName, HANDLE hFile, DWORD dwFlags); +// called from the ON_DLL_LOAD macros +__dllLoadCallback::__dllLoadCallback( + eDllLoadCallbackSide side, const std::string dllName, DllLoadCallbackFuncType callback, std::string uniqueStr, std::string reliesOn) +{ + // parse reliesOn array from string + std::vector<std::string> reliesOnArray; -typedef HMODULE (*LoadLibraryAType)(LPCSTR lpLibFileName); -HMODULE LoadLibraryAHook(LPCSTR lpLibFileName); + if (reliesOn.length() && reliesOn[0] != '(') + { + reliesOnArray.push_back(reliesOn); + } + else + { + // follows the format (tag, tag, tag) + std::string sCurrentTag; + for (int i = 1; i < reliesOn.length(); i++) + { + if (!isspace(reliesOn[i])) + { + if (reliesOn[i] == ',' || reliesOn[i] == ')') + { + reliesOnArray.push_back(sCurrentTag); + sCurrentTag = ""; + } + else + sCurrentTag += reliesOn[i]; + } + } + } -typedef HMODULE (*LoadLibraryExWType)(LPCWSTR lpLibFileName, HANDLE hFile, DWORD dwFlags); -HMODULE LoadLibraryExWHook(LPCWSTR lpLibFileName, HANDLE hFile, DWORD dwFlags); + switch (side) + { + case eDllLoadCallbackSide::UNSIDED: + { + AddDllLoadCallback(dllName, callback, uniqueStr, reliesOnArray); + break; + } -typedef HMODULE (*LoadLibraryWType)(LPCWSTR lpLibFileName); -HMODULE LoadLibraryWHook(LPCWSTR lpLibFileName); + case eDllLoadCallbackSide::CLIENT: + { + AddDllLoadCallbackForClient(dllName, callback, uniqueStr, reliesOnArray); + break; + } -GetCommandLineAType GetCommandLineAOriginal; -LoadLibraryExAType LoadLibraryExAOriginal; -LoadLibraryAType LoadLibraryAOriginal; -LoadLibraryExWType LoadLibraryExWOriginal; -LoadLibraryWType LoadLibraryWOriginal; + case eDllLoadCallbackSide::DEDICATED_SERVER: + { + AddDllLoadCallbackForDedicatedServer(dllName, callback, uniqueStr, reliesOnArray); + break; + } + } +} -void InstallInitialHooks() +void __fileAutohook::Dispatch() { - if (MH_Initialize() != MH_OK) - spdlog::error("MH_Initialize (minhook initialization) failed"); + for (__autohook* hook : hooks) + hook->Dispatch(); +} + +void __fileAutohook::DispatchForModule(const char* pModuleName) +{ + const int iModuleNameLen = strlen(pModuleName); + + for (__autohook* hook : hooks) + if ((hook->iAddressResolutionMode == __autohook::OFFSET_STRING && !strncmp(pModuleName, hook->pAddrString, iModuleNameLen)) || + (hook->iAddressResolutionMode == __autohook::PROCADDRESS && !strcmp(pModuleName, hook->pModuleName))) + hook->Dispatch(); +} - HookEnabler hook; - ENABLER_CREATEHOOK( - hook, reinterpret_cast<void*>(&GetCommandLineA), &GetCommandLineAHook, reinterpret_cast<LPVOID*>(&GetCommandLineAOriginal)); - ENABLER_CREATEHOOK( - hook, reinterpret_cast<void*>(&LoadLibraryExA), &LoadLibraryExAHook, reinterpret_cast<LPVOID*>(&LoadLibraryExAOriginal)); - ENABLER_CREATEHOOK(hook, reinterpret_cast<void*>(&LoadLibraryA), &LoadLibraryAHook, reinterpret_cast<LPVOID*>(&LoadLibraryAOriginal)); - ENABLER_CREATEHOOK( - hook, reinterpret_cast<void*>(&LoadLibraryExW), &LoadLibraryExWHook, reinterpret_cast<LPVOID*>(&LoadLibraryExWOriginal)); - ENABLER_CREATEHOOK(hook, reinterpret_cast<void*>(&LoadLibraryW), &LoadLibraryWHook, reinterpret_cast<LPVOID*>(&LoadLibraryWOriginal)); +ManualHook::ManualHook(const char* funcName, LPVOID func) : pHookFunc(func), ppOrigFunc(nullptr) +{ + const int iFuncNameStrlen = strlen(funcName); + pFuncName = new char[iFuncNameStrlen]; + memcpy(pFuncName, funcName, iFuncNameStrlen); } -LPSTR GetCommandLineAHook() +ManualHook::ManualHook(const char* funcName, LPVOID* orig, LPVOID func) : pHookFunc(func), ppOrigFunc(orig) +{ + const int iFuncNameStrlen = strlen(funcName); + pFuncName = new char[iFuncNameStrlen]; + memcpy(pFuncName, funcName, iFuncNameStrlen); +} + +bool ManualHook::Dispatch(LPVOID addr, LPVOID* orig) +{ + if (orig) + ppOrigFunc = orig; + + if (MH_CreateHook(addr, pHookFunc, ppOrigFunc) == MH_OK) + { + if (MH_EnableHook(addr) == MH_OK) + { + spdlog::info("Enabling hook {}", pFuncName); + return true; + } + else + spdlog::error("MH_EnableHook failed for function {}", pFuncName); + } + else + spdlog::error("MH_CreateHook failed for function {}", pFuncName); + + return false; +} + +// dll load callback stuff +// this allows for code to register callbacks to be run as soon as a dll is loaded, mainly to allow for patches to be made on dll load +struct DllLoadCallback +{ + std::string dll; + DllLoadCallbackFuncType callback; + std::string tag; + std::vector<std::string> reliesOn; + bool called; +}; + +// HACK: declaring and initialising this vector at file scope crashes on debug builds due to static initialisation order +// using a static var like this ensures that the vector is initialised lazily when it's used +std::vector<DllLoadCallback>& GetDllLoadCallbacks() +{ + static std::vector<DllLoadCallback> vec = std::vector<DllLoadCallback>(); + return vec; +} + +void AddDllLoadCallback(std::string dll, DllLoadCallbackFuncType callback, std::string tag, std::vector<std::string> reliesOn) +{ + DllLoadCallback& callbackStruct = GetDllLoadCallbacks().emplace_back(); + + callbackStruct.dll = dll; + callbackStruct.callback = callback; + callbackStruct.tag = tag; + callbackStruct.reliesOn = reliesOn; + callbackStruct.called = false; +} + +void AddDllLoadCallbackForDedicatedServer( + std::string dll, DllLoadCallbackFuncType callback, std::string tag, std::vector<std::string> reliesOn) +{ + if (!IsDedicatedServer()) + return; + + AddDllLoadCallback(dll, callback, tag, reliesOn); +} + +void AddDllLoadCallbackForClient(std::string dll, DllLoadCallbackFuncType callback, std::string tag, std::vector<std::string> reliesOn) +{ + if (IsDedicatedServer()) + return; + + AddDllLoadCallback(dll, callback, tag, reliesOn); +} + +void MakeHook(LPVOID pTarget, LPVOID pDetour, void* ppOriginal, const char* pFuncName) +{ + char* pStrippedFuncName = (char*)pFuncName; + // strip & char from funcname + if (*pStrippedFuncName == '&') + pStrippedFuncName++; + + if (MH_CreateHook(pTarget, pDetour, (LPVOID*)ppOriginal) == MH_OK) + { + if (MH_EnableHook(pTarget) == MH_OK) + spdlog::info("Enabling hook {}", pStrippedFuncName); + else + spdlog::error("MH_EnableHook failed for function {}", pStrippedFuncName); + } + else + spdlog::error("MH_CreateHook failed for function {}", pStrippedFuncName); +} + +AUTOHOOK_ABSOLUTEADDR(_GetCommandLineA, GetCommandLineA, LPSTR, WINAPI, ()) { static char* cmdlineModified; static char* cmdlineOrg; if (cmdlineOrg == nullptr || cmdlineModified == nullptr) { - cmdlineOrg = GetCommandLineAOriginal(); + cmdlineOrg = _GetCommandLineA(); bool isDedi = strstr(cmdlineOrg, "-dedicated"); // well, this one has to be a real argument bool ignoreStartupArgs = strstr(cmdlineOrg, "-nostartupargs"); @@ -111,77 +243,86 @@ LPSTR GetCommandLineAHook() return cmdlineModified; } -// dll load callback stuff -// this allows for code to register callbacks to be run as soon as a dll is loaded, mainly to allow for patches to be made on dll load -struct DllLoadCallback -{ - std::string dll; - DllLoadCallbackFuncType callback; - bool called; -}; - -std::vector<DllLoadCallback*> dllLoadCallbacks; - -void AddDllLoadCallback(std::string dll, DllLoadCallbackFuncType callback) -{ - DllLoadCallback* callbackStruct = new DllLoadCallback; - callbackStruct->dll = dll; - callbackStruct->callback = callback; - callbackStruct->called = false; - - dllLoadCallbacks.push_back(callbackStruct); -} - -void AddDllLoadCallbackForDedicatedServer(std::string dll, DllLoadCallbackFuncType callback) -{ - if (!IsDedicatedServer()) - return; - - DllLoadCallback* callbackStruct = new DllLoadCallback; - callbackStruct->dll = dll; - callbackStruct->callback = callback; - callbackStruct->called = false; - - dllLoadCallbacks.push_back(callbackStruct); -} - -void AddDllLoadCallbackForClient(std::string dll, DllLoadCallbackFuncType callback) -{ - if (IsDedicatedServer()) - return; - - DllLoadCallback* callbackStruct = new DllLoadCallback; - callbackStruct->dll = dll; - callbackStruct->callback = callback; - callbackStruct->called = false; - - dllLoadCallbacks.push_back(callbackStruct); -} - +std::vector<std::string> calledTags; void CallLoadLibraryACallbacks(LPCSTR lpLibFileName, HMODULE moduleAddress) { - for (auto& callbackStruct : dllLoadCallbacks) + CModule cModule(moduleAddress); + + while (true) { - if (!callbackStruct->called && - strstr(lpLibFileName + (strlen(lpLibFileName) - callbackStruct->dll.length()), callbackStruct->dll.c_str()) != nullptr) + bool bDoneCalling = true; + + for (auto& callbackStruct : GetDllLoadCallbacks()) { - callbackStruct->callback(moduleAddress); - callbackStruct->called = true; + if (!callbackStruct.called && fs::path(lpLibFileName).filename() == fs::path(callbackStruct.dll).filename()) + { + bool bShouldContinue = false; + + if (!callbackStruct.reliesOn.empty()) + { + for (std::string tag : callbackStruct.reliesOn) + { + if (std::find(calledTags.begin(), calledTags.end(), tag) == calledTags.end()) + { + bDoneCalling = false; + bShouldContinue = true; + break; + } + } + } + + if (bShouldContinue) + continue; + + callbackStruct.callback(moduleAddress); + calledTags.push_back(callbackStruct.tag); + callbackStruct.called = true; + } } + + if (bDoneCalling) + break; } } void CallLoadLibraryWCallbacks(LPCWSTR lpLibFileName, HMODULE moduleAddress) { - for (auto& callbackStruct : dllLoadCallbacks) + CModule cModule(moduleAddress); + + while (true) { - std::wstring wcharStrDll = std::wstring(callbackStruct->dll.begin(), callbackStruct->dll.end()); - const wchar_t* callbackDll = wcharStrDll.c_str(); - if (!callbackStruct->called && wcsstr(lpLibFileName + (wcslen(lpLibFileName) - wcharStrDll.length()), callbackDll) != nullptr) + bool bDoneCalling = true; + + for (auto& callbackStruct : GetDllLoadCallbacks()) { - callbackStruct->callback(moduleAddress); - callbackStruct->called = true; + if (!callbackStruct.called && fs::path(lpLibFileName).filename() == fs::path(callbackStruct.dll).filename()) + { + bool bShouldContinue = false; + + if (!callbackStruct.reliesOn.empty()) + { + for (std::string tag : callbackStruct.reliesOn) + { + if (std::find(calledTags.begin(), calledTags.end(), tag) == calledTags.end()) + { + bDoneCalling = false; + bShouldContinue = true; + break; + } + } + } + + if (bShouldContinue) + continue; + + callbackStruct.callback(moduleAddress); + calledTags.push_back(callbackStruct.tag); + callbackStruct.called = true; + } } + + if (bDoneCalling) + break; } } @@ -208,65 +349,78 @@ void CallAllPendingDLLLoadCallbacks() } } -HMODULE LoadLibraryExAHook(LPCSTR lpLibFileName, HANDLE hFile, DWORD dwFlags) +// clang-format off +AUTOHOOK_ABSOLUTEADDR(_LoadLibraryExA, LoadLibraryExA, +HMODULE, WINAPI, (LPCSTR lpLibFileName, HANDLE hFile, DWORD dwFlags)) +// clang-format on { + HMODULE moduleAddress; + + // replace xinput dll with one that has ASLR if (!strncmp(lpLibFileName, "XInput1_3.dll", 14)) { - HMODULE moduleAddress = LoadLibraryExAOriginal("XInput9_1_0.dll", hFile, dwFlags); - if (moduleAddress) - { - CallLoadLibraryACallbacks(lpLibFileName, moduleAddress); - } - else + moduleAddress = _LoadLibraryExA("XInput9_1_0.dll", hFile, dwFlags); + + if (!moduleAddress) { MessageBoxA(0, "Could not find XInput9_1_0.dll", "Northstar", MB_ICONERROR); exit(-1); + + return nullptr; } - return moduleAddress; } else - { - HMODULE moduleAddress = LoadLibraryExAOriginal(lpLibFileName, hFile, dwFlags); - if (moduleAddress) - { - CallLoadLibraryACallbacks(lpLibFileName, moduleAddress); - } - return moduleAddress; - } + moduleAddress = _LoadLibraryExA(lpLibFileName, hFile, dwFlags); + + if (moduleAddress) + CallLoadLibraryACallbacks(lpLibFileName, moduleAddress); + + return moduleAddress; } -HMODULE LoadLibraryAHook(LPCSTR lpLibFileName) +// clang-format off +AUTOHOOK_ABSOLUTEADDR(_LoadLibraryA, LoadLibraryA, +HMODULE, WINAPI, (LPCSTR lpLibFileName)) +// clang-format on { - HMODULE moduleAddress = LoadLibraryAOriginal(lpLibFileName); + HMODULE moduleAddress = _LoadLibraryA(lpLibFileName); if (moduleAddress) - { CallLoadLibraryACallbacks(lpLibFileName, moduleAddress); - } return moduleAddress; } -HMODULE LoadLibraryExWHook(LPCWSTR lpLibFileName, HANDLE hFile, DWORD dwFlags) +// clang-format off +AUTOHOOK_ABSOLUTEADDR(_LoadLibraryExW, LoadLibraryExW, +HMODULE, WINAPI, (LPCWSTR lpLibFileName, HANDLE hFile, DWORD dwFlags)) +// clang-format on { - HMODULE moduleAddress = LoadLibraryExWOriginal(lpLibFileName, hFile, dwFlags); + HMODULE moduleAddress = _LoadLibraryExW(lpLibFileName, hFile, dwFlags); if (moduleAddress) - { CallLoadLibraryWCallbacks(lpLibFileName, moduleAddress); - } return moduleAddress; } -HMODULE LoadLibraryWHook(LPCWSTR lpLibFileName) +// clang-format off +AUTOHOOK_ABSOLUTEADDR(_LoadLibraryW, LoadLibraryW, +HMODULE, WINAPI, (LPCWSTR lpLibFileName)) +// clang-format on { - HMODULE moduleAddress = LoadLibraryWOriginal(lpLibFileName); + HMODULE moduleAddress = _LoadLibraryW(lpLibFileName); if (moduleAddress) - { CallLoadLibraryWCallbacks(lpLibFileName, moduleAddress); - } return moduleAddress; } + +void InstallInitialHooks() +{ + if (MH_Initialize() != MH_OK) + spdlog::error("MH_Initialize (minhook initialization) failed"); + + AUTOHOOK_DISPATCH() +} |