aboutsummaryrefslogtreecommitdiff
path: root/primedev/windows/libsys.cpp
blob: 501eae687447b82d61faef86cdb325685e627bad (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
#include "libsys.h"
#include "plugins/pluginmanager.h"

#define XINPUT1_3_DLL "XInput1_3.dll"

typedef HMODULE (*WINAPI ILoadLibraryA)(LPCSTR lpLibFileName);
typedef HMODULE (*WINAPI ILoadLibraryExA)(LPCSTR lpLibFileName, HANDLE hFile, DWORD dwFlags);
typedef HMODULE (*WINAPI ILoadLibraryW)(LPCWSTR lpLibFileName);
typedef HMODULE (*WINAPI ILoadLibraryExW)(LPCWSTR lpLibFileName, HANDLE hFile, DWORD dwFlags);

ILoadLibraryA o_LoadLibraryA = nullptr;
ILoadLibraryExA o_LoadLibraryExA = nullptr;
ILoadLibraryW o_LoadLibraryW = nullptr;
ILoadLibraryExW o_LoadLibraryExW = nullptr;

//-----------------------------------------------------------------------------
// Purpose: Run detour callbacks for given HMODULE
//-----------------------------------------------------------------------------
void LibSys_RunModuleCallbacks(HMODULE hModule)
{
	if (!hModule)
	{
		return;
	}

	// Get module base name in ASCII as noone wants to deal with unicode
	CHAR szModuleName[MAX_PATH];
	GetModuleBaseNameA(GetCurrentProcess(), hModule, szModuleName, MAX_PATH);

	// DevMsg(eLog::NONE, "%s\n", szModuleName);

	// Call callbacks
	CallLoadLibraryACallbacks(szModuleName, hModule);
	g_pPluginManager->InformDllLoad(hModule, fs::path(szModuleName));
}

//-----------------------------------------------------------------------------
// Load library callbacks

HMODULE WINAPI WLoadLibraryA(LPCSTR lpLibFileName)
{
	HMODULE hModule = o_LoadLibraryA(lpLibFileName);

	LibSys_RunModuleCallbacks(hModule);

	return hModule;
}

HMODULE WINAPI WLoadLibraryExA(LPCSTR lpLibFileName, HANDLE hFile, DWORD dwFlags)
{
	HMODULE hModule;

	LPCSTR lpLibFileNameEnd = lpLibFileName + strlen(lpLibFileName);
	LPCSTR lpLibName = lpLibFileNameEnd - strlen(XINPUT1_3_DLL);

	// replace xinput dll with one that has ASLR
	if (lpLibFileName <= lpLibName && !strncmp(lpLibName, XINPUT1_3_DLL, strlen(XINPUT1_3_DLL) + 1))
	{
		const char* pszReplacementDll = "XInput1_4.dll";
		hModule = o_LoadLibraryExA(pszReplacementDll, hFile, dwFlags);

		if (!hModule)
		{
			pszReplacementDll = "XInput9_1_0.dll";
			spdlog::warn("Couldn't load XInput1_4.dll. Will try XInput9_1_0.dll. If on Windows 7 this is expected");
			hModule = o_LoadLibraryExA(pszReplacementDll, hFile, dwFlags);
		}

		if (!hModule)
		{
			spdlog::error("Couldn't load XInput9_1_0.dll");
			MessageBoxA(
				0, "Could not load a replacement for XInput1_3.dll\nTried: XInput1_4.dll and XInput9_1_0.dll", "Northstar", MB_ICONERROR);
			exit(EXIT_FAILURE);

			return nullptr;
		}

		spdlog::info("Successfully loaded {} as a replacement for XInput1_3.dll", pszReplacementDll);
	}
	else
	{
		hModule = o_LoadLibraryExA(lpLibFileName, hFile, dwFlags);
	}

	bool bShouldRunCallbacks =
		!(dwFlags & (LOAD_LIBRARY_AS_DATAFILE | LOAD_LIBRARY_AS_DATAFILE_EXCLUSIVE | LOAD_LIBRARY_AS_IMAGE_RESOURCE));
	if (bShouldRunCallbacks)
	{
		LibSys_RunModuleCallbacks(hModule);
	}

	return hModule;
}

HMODULE WINAPI WLoadLibraryW(LPCWSTR lpLibFileName)
{
	HMODULE hModule = o_LoadLibraryW(lpLibFileName);

	LibSys_RunModuleCallbacks(hModule);

	return hModule;
}

HMODULE WINAPI WLoadLibraryExW(LPCWSTR lpLibFileName, HANDLE hFile, DWORD dwFlags)
{
	HMODULE hModule = o_LoadLibraryExW(lpLibFileName, hFile, dwFlags);

	bool bShouldRunCallbacks =
		!(dwFlags & (LOAD_LIBRARY_AS_DATAFILE | LOAD_LIBRARY_AS_DATAFILE_EXCLUSIVE | LOAD_LIBRARY_AS_IMAGE_RESOURCE));
	if (bShouldRunCallbacks)
	{
		LibSys_RunModuleCallbacks(hModule);
	}

	return hModule;
}

//-----------------------------------------------------------------------------
// Purpose: Initilase dll load callbacks
//-----------------------------------------------------------------------------
void LibSys_Init()
{
	HMODULE hKernel = GetModuleHandleA("KERNEL32.DLL");

	o_LoadLibraryA = reinterpret_cast<ILoadLibraryA>(GetProcAddress(hKernel, "LoadLibraryA"));
	o_LoadLibraryExA = reinterpret_cast<ILoadLibraryExA>(GetProcAddress(hKernel, "LoadLibraryExA"));
	o_LoadLibraryW = reinterpret_cast<ILoadLibraryW>(GetProcAddress(hKernel, "LoadLibraryW"));
	o_LoadLibraryExW = reinterpret_cast<ILoadLibraryExW>(GetProcAddress(hKernel, "LoadLibraryExW"));

	HookAttach(&(PVOID&)o_LoadLibraryA, (PVOID)WLoadLibraryA);
	HookAttach(&(PVOID&)o_LoadLibraryExA, (PVOID)WLoadLibraryExA);
	HookAttach(&(PVOID&)o_LoadLibraryW, (PVOID)WLoadLibraryW);
	HookAttach(&(PVOID&)o_LoadLibraryExW, (PVOID)WLoadLibraryExW);
}