aboutsummaryrefslogtreecommitdiff
path: root/primedev/windows/libsys.cpp
blob: 0aff820b72b5cff86101617b4139265c2a5952a0 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
#include "libsys.h"
#include "plugins/pluginmanager.h"

#define XINPUT1_3_DLL "XInput1_3.dll"

typedef HMODULE (*WINAPI ILoadLibraryA)(LPCSTR lpLibFileName);
typedef HMODULE (*WINAPI ILoadLibraryExA)(LPCSTR lpLibFileName, HANDLE hFile, DWORD dwFlags);
typedef HMODULE (*WINAPI ILoadLibraryW)(LPCWSTR lpLibFileName);
typedef HMODULE (*WINAPI ILoadLibraryExW)(LPCWSTR lpLibFileName, HANDLE hFile, DWORD dwFlags);

ILoadLibraryA o_LoadLibraryA = nullptr;
ILoadLibraryExA o_LoadLibraryExA = nullptr;
ILoadLibraryW o_LoadLibraryW = nullptr;
ILoadLibraryExW o_LoadLibraryExW = nullptr;

//-----------------------------------------------------------------------------
// Purpose: Run detour callbacks for given HMODULE
//-----------------------------------------------------------------------------
void LibSys_RunModuleCallbacks(HMODULE hModule)
{
	// Modules that we have already ran callbacks for.
	// Note: If we ever hook unloading modules, then this will need updating to handle removal etc.
	static std::vector<HMODULE> vCalledModules;

	if (!hModule)
	{
		return;
	}

	// If we have already ran callbacks for this module, don't run them again.
	if (std::find(vCalledModules.begin(), vCalledModules.end(), hModule) != vCalledModules.end())
	{
		return;
	}
	vCalledModules.push_back(hModule);

	// Get module base name in ASCII as noone wants to deal with unicode
	CHAR szModuleName[MAX_PATH];
	GetModuleBaseNameA(GetCurrentProcess(), hModule, szModuleName, MAX_PATH);

	// Run calllbacks for all imported modules
	CModule cModule(hModule);
	for (const std::string& svImport : cModule.GetImportedModules())
		LibSys_RunModuleCallbacks(GetModuleHandleA(svImport.c_str()));

	// DevMsg(eLog::NONE, "%s\n", szModuleName);

	// Call callbacks
	CallLoadLibraryACallbacks(szModuleName, hModule);
	g_pPluginManager->InformDllLoad(hModule, fs::path(szModuleName));
}

//-----------------------------------------------------------------------------
// Load library callbacks

HMODULE WINAPI WLoadLibraryA(LPCSTR lpLibFileName)
{
	HMODULE hModule = o_LoadLibraryA(lpLibFileName);

	LibSys_RunModuleCallbacks(hModule);

	return hModule;
}

HMODULE WINAPI WLoadLibraryExA(LPCSTR lpLibFileName, HANDLE hFile, DWORD dwFlags)
{
	HMODULE hModule;

	LPCSTR lpLibFileNameEnd = lpLibFileName + strlen(lpLibFileName);
	LPCSTR lpLibName = lpLibFileNameEnd - strlen(XINPUT1_3_DLL);

	// replace xinput dll with one that has ASLR
	if (lpLibFileName <= lpLibName && !strncmp(lpLibName, XINPUT1_3_DLL, strlen(XINPUT1_3_DLL) + 1))
	{
		const char* pszReplacementDll = "XInput1_4.dll";
		hModule = o_LoadLibraryExA(pszReplacementDll, hFile, dwFlags);

		if (!hModule)
		{
			pszReplacementDll = "XInput9_1_0.dll";
			spdlog::warn("Couldn't load XInput1_4.dll. Will try XInput9_1_0.dll. If on Windows 7 this is expected");
			hModule = o_LoadLibraryExA(pszReplacementDll, hFile, dwFlags);
		}

		if (!hModule)
		{
			spdlog::error("Couldn't load XInput9_1_0.dll");
			MessageBoxA(
				0, "Could not load a replacement for XInput1_3.dll\nTried: XInput1_4.dll and XInput9_1_0.dll", "Northstar", MB_ICONERROR);
			exit(EXIT_FAILURE);

			return nullptr;
		}

		spdlog::info("Successfully loaded {} as a replacement for XInput1_3.dll", pszReplacementDll);
	}
	else
	{
		hModule = o_LoadLibraryExA(lpLibFileName, hFile, dwFlags);
	}

	bool bShouldRunCallbacks =
		!(dwFlags & (LOAD_LIBRARY_AS_DATAFILE | LOAD_LIBRARY_AS_DATAFILE_EXCLUSIVE | LOAD_LIBRARY_AS_IMAGE_RESOURCE));
	if (bShouldRunCallbacks)
	{
		LibSys_RunModuleCallbacks(hModule);
	}

	return hModule;
}

HMODULE WINAPI WLoadLibraryW(LPCWSTR lpLibFileName)
{
	HMODULE hModule = o_LoadLibraryW(lpLibFileName);

	LibSys_RunModuleCallbacks(hModule);

	return hModule;
}

HMODULE WINAPI WLoadLibraryExW(LPCWSTR lpLibFileName, HANDLE hFile, DWORD dwFlags)
{
	HMODULE hModule = o_LoadLibraryExW(lpLibFileName, hFile, dwFlags);

	bool bShouldRunCallbacks =
		!(dwFlags & (LOAD_LIBRARY_AS_DATAFILE | LOAD_LIBRARY_AS_DATAFILE_EXCLUSIVE | LOAD_LIBRARY_AS_IMAGE_RESOURCE));
	if (bShouldRunCallbacks)
	{
		LibSys_RunModuleCallbacks(hModule);
	}

	return hModule;
}

//-----------------------------------------------------------------------------
// Purpose: Initilase dll load callbacks
//-----------------------------------------------------------------------------
void LibSys_Init()
{
	HMODULE hKernel = GetModuleHandleA("KERNEL32.DLL");

	o_LoadLibraryA = reinterpret_cast<ILoadLibraryA>(GetProcAddress(hKernel, "LoadLibraryA"));
	o_LoadLibraryExA = reinterpret_cast<ILoadLibraryExA>(GetProcAddress(hKernel, "LoadLibraryExA"));
	o_LoadLibraryW = reinterpret_cast<ILoadLibraryW>(GetProcAddress(hKernel, "LoadLibraryW"));
	o_LoadLibraryExW = reinterpret_cast<ILoadLibraryExW>(GetProcAddress(hKernel, "LoadLibraryExW"));

	HookAttach(&(PVOID&)o_LoadLibraryA, (PVOID)WLoadLibraryA);
	HookAttach(&(PVOID&)o_LoadLibraryExA, (PVOID)WLoadLibraryExA);
	HookAttach(&(PVOID&)o_LoadLibraryW, (PVOID)WLoadLibraryW);
	HookAttach(&(PVOID&)o_LoadLibraryExW, (PVOID)WLoadLibraryExW);
}