diff options
Diffstat (limited to 'NorthstarDLL/hooks.cpp')
-rw-r--r-- | NorthstarDLL/hooks.cpp | 79 |
1 files changed, 64 insertions, 15 deletions
diff --git a/NorthstarDLL/hooks.cpp b/NorthstarDLL/hooks.cpp index 3c5438cd..bb696cdb 100644 --- a/NorthstarDLL/hooks.cpp +++ b/NorthstarDLL/hooks.cpp @@ -18,23 +18,49 @@ AUTOHOOK_INIT() __dllLoadCallback::__dllLoadCallback(
eDllLoadCallbackSide side, const std::string dllName, DllLoadCallbackFuncType callback, std::string uniqueStr, std::string reliesOn)
{
+ // parse reliesOn array from string
+ std::vector<std::string> reliesOnArray;
+
+ if (reliesOn.length() && reliesOn[0] != '(')
+ {
+ reliesOnArray.push_back(reliesOn);
+ }
+ else
+ {
+ // follows the format (tag, tag, tag)
+ std::string sCurrentTag;
+ for (int i = 1; i < reliesOn.length(); i++)
+ {
+ if (!isspace(reliesOn[i]))
+ {
+ if (reliesOn[i] == ',' || reliesOn[i] == ')')
+ {
+ reliesOnArray.push_back(sCurrentTag);
+ sCurrentTag = "";
+ }
+ else
+ sCurrentTag += reliesOn[i];
+ }
+ }
+ }
+
switch (side)
{
case eDllLoadCallbackSide::UNSIDED:
{
- AddDllLoadCallback(dllName, callback, uniqueStr, reliesOn);
+ AddDllLoadCallback(dllName, callback, uniqueStr, reliesOnArray);
break;
}
case eDllLoadCallbackSide::CLIENT:
{
- AddDllLoadCallbackForClient(dllName, callback, uniqueStr, reliesOn);
+ AddDllLoadCallbackForClient(dllName, callback, uniqueStr, reliesOnArray);
break;
}
case eDllLoadCallbackSide::DEDICATED_SERVER:
{
- AddDllLoadCallbackForDedicatedServer(dllName, callback, uniqueStr, reliesOn);
+ AddDllLoadCallbackForDedicatedServer(dllName, callback, uniqueStr, reliesOnArray);
break;
}
}
@@ -100,7 +126,7 @@ struct DllLoadCallback std::string dll;
DllLoadCallbackFuncType callback;
std::string tag;
- std::string reliesOn;
+ std::vector<std::string> reliesOn;
bool called;
};
@@ -112,7 +138,7 @@ std::vector<DllLoadCallback>& GetDllLoadCallbacks() return vec;
}
-void AddDllLoadCallback(std::string dll, DllLoadCallbackFuncType callback, std::string tag, std::string reliesOn)
+void AddDllLoadCallback(std::string dll, DllLoadCallbackFuncType callback, std::string tag, std::vector<std::string> reliesOn)
{
DllLoadCallback& callbackStruct = GetDllLoadCallbacks().emplace_back();
@@ -124,7 +150,7 @@ void AddDllLoadCallback(std::string dll, DllLoadCallbackFuncType callback, std:: }
void AddDllLoadCallbackForDedicatedServer(
- std::string dll, DllLoadCallbackFuncType callback, std::string tag, std::string reliesOn)
+ std::string dll, DllLoadCallbackFuncType callback, std::string tag, std::vector<std::string> reliesOn)
{
if (!IsDedicatedServer())
return;
@@ -132,7 +158,7 @@ void AddDllLoadCallbackForDedicatedServer( AddDllLoadCallback(dll, callback, tag, reliesOn);
}
-void AddDllLoadCallbackForClient(std::string dll, DllLoadCallbackFuncType callback, std::string tag, std::string reliesOn)
+void AddDllLoadCallbackForClient(std::string dll, DllLoadCallbackFuncType callback, std::string tag, std::vector<std::string> reliesOn)
{
if (IsDedicatedServer())
return;
@@ -233,13 +259,25 @@ void CallLoadLibraryACallbacks(LPCSTR lpLibFileName, HMODULE moduleAddress) {
if (!callbackStruct.called && fs::path(lpLibFileName).filename() == fs::path(callbackStruct.dll).filename())
{
- if (callbackStruct.reliesOn != "" &&
- std::find(calledTags.begin(), calledTags.end(), callbackStruct.reliesOn) == calledTags.end())
+ bool bShouldContinue = false;
+
+ if (!callbackStruct.reliesOn.empty())
{
- bDoneCalling = false;
- continue;
+ for (std::string tag : callbackStruct.reliesOn)
+ {
+ if (std::find(calledTags.begin(), calledTags.end(), tag) == calledTags.end())
+ {
+ bDoneCalling = false;
+ bShouldContinue = true;
+ break;
+ }
+ }
}
+
+ if (bShouldContinue)
+ continue;
+
callbackStruct.callback(moduleAddress);
calledTags.push_back(callbackStruct.tag);
callbackStruct.called = true;
@@ -261,13 +299,24 @@ void CallLoadLibraryWCallbacks(LPCWSTR lpLibFileName, HMODULE moduleAddress) {
if (!callbackStruct.called && fs::path(lpLibFileName).filename() == fs::path(callbackStruct.dll).filename())
{
- if (callbackStruct.reliesOn != "" &&
- std::find(calledTags.begin(), calledTags.end(), callbackStruct.reliesOn) == calledTags.end())
+ bool bShouldContinue = false;
+
+ if (!callbackStruct.reliesOn.empty())
{
- bDoneCalling = false;
- continue;
+ for (std::string tag : callbackStruct.reliesOn)
+ {
+ if (std::find(calledTags.begin(), calledTags.end(), tag) == calledTags.end())
+ {
+ bDoneCalling = false;
+ bShouldContinue = true;
+ break;
+ }
+ }
}
+ if (bShouldContinue)
+ continue;
+
callbackStruct.callback(moduleAddress);
calledTags.push_back(callbackStruct.tag);
callbackStruct.called = true;
|