#include "pch.h"

#include "exploitfixes.h"
#include "exploitfixes_utf8parser.h"
#include "nsmem.h"
#include "cvar.h"
#include "gameutils.h"

ConVar* ns_exploitfixes_log;
#define SHOULD_LOG (ns_exploitfixes_log->m_Value.m_nValue > 0)
#define BLOCKED_INFO(s)                                                                                                                    \
	(                                                                                                                                      \
		[=]() -> bool                                                                                                                      \
		{                                                                                                                                  \
			if (SHOULD_LOG)                                                                                                                \
			{                                                                                                                              \
				std::stringstream stream;                                                                                                  \
				stream << "exploitfixes.cpp: " << BLOCK_PREFIX << s;                                                                       \
				spdlog::error(stream.str());                                                                                               \
			}                                                                                                                              \
			return false;                                                                                                                  \
		}())

struct Float3
{
	float vals[3];

	void MakeValid()
	{
		for (auto& val : vals)
			if (isnan(val))
				val = 0;
	}
};

#define BLOCK_NETMSG_FUNC(name, pattern)                                                                                                   \
	KHOOK(name, ("engine.dll", pattern), bool, __fastcall, (void* thisptr, void* buffer))                                                  \
	{                                                                                                                                      \
		return false;                                                                                                                      \
	}

// Servers can literally request a screenshot from any client, yeah no
BLOCK_NETMSG_FUNC(CLC_Screenshot_WriteToBuffer, "48 89 5C 24 ? 57 48 83 EC 20 8B 42 10");
BLOCK_NETMSG_FUNC(CLC_Screenshot_ReadFromBuffer, "48 89 5C 24 ? 48 89 6C 24 ? 48 89 74 24 ? 57 48 83 EC 20 48 8B DA 48 8B 52 38");

// This is unused ingame and a big exploit vector
BLOCK_NETMSG_FUNC(Base_CmdKeyValues_ReadFromBuffer, "40 55 48 81 EC ? ? ? ? 48 8D 6C 24 ? 48 89 5D 70");

KHOOK(CClient_ProcessSetConVar, ("engine.dll", "48 8B D1 48 8B 49 18 48 8B 01 48 FF 60 10"), bool, __fastcall, (void* pMsg))
{

	constexpr int ENTRY_STR_LEN = 260;
	struct SetConVarEntry
	{
		char name[ENTRY_STR_LEN];
		char val[ENTRY_STR_LEN];
	};

	struct NET_SetConVar
	{
		void* vtable;
		void* unk1;
		void* unk2;
		void* m_pMessageHandler;
		SetConVarEntry* m_ConVars; // convar entry array
		void* unk5; // these 2 unks are just vector capacity or whatever
		void* unk6;
		int m_ConVars_count; // amount of cvar entries in array (this will not be out of bounds)
	};

	auto msg = (NET_SetConVar*)pMsg;

	bool areWeServer;

	{
		// Figure out of we are the client or the server
		//	To do this, we utilize the msg's m_pMessageHandler pointer
		//	m_pMessageHandler points to a virtual class that handles all net messages
		//	The first virtual table function of our m_pMessageHandler will differ if it is IServerMessageHandler or IClientMessageHandler
		void* msgHandlerVTableFirstFunc = **(void****)(msg->m_pMessageHandler);
		static auto engineBaseAddress = (uintptr_t)GetModuleHandleA("engine.dll");
		auto offset = uintptr_t(msgHandlerVTableFirstFunc) - engineBaseAddress;

		constexpr uintptr_t CLIENTSTATE_FIRST_VFUNC_OFFSET = 0x8A15C;
		areWeServer = offset != CLIENTSTATE_FIRST_VFUNC_OFFSET;
	}

	std::string BLOCK_PREFIX = std::string {"NET_SetConVar ("} + (areWeServer ? "server" : "client") + "): Blocked dangerous/invalid msg: ";

	if (areWeServer)
	{
		constexpr int SETCONVAR_SANITY_AMOUNT_LIMIT = 69;
		if (msg->m_ConVars_count < 1 || msg->m_ConVars_count > SETCONVAR_SANITY_AMOUNT_LIMIT)
		{
			return BLOCKED_INFO("Invalid m_ConVars_count (" << msg->m_ConVars_count << ")");
		}
	}

	for (int i = 0; i < msg->m_ConVars_count; i++)
	{
		auto entry = msg->m_ConVars + i;

		// Safety check for memory access
		if (NSMem::IsMemoryReadable(entry, sizeof(*entry)))
		{

			// Find null terminators
			bool nameValid = false, valValid = false;
			for (int i = 0; i < ENTRY_STR_LEN; i++)
			{
				if (!entry->name[i])
					nameValid = true;
				if (!entry->val[i])
					valValid = true;
			}

			if (!nameValid || !valValid)
				return BLOCKED_INFO("Missing null terminators");

			auto realVar = g_pCVar->FindVar(entry->name);

			if (realVar)
				memcpy(
					entry->name,
					realVar->m_ConCommandBase.m_pszName,
					strlen(realVar->m_ConCommandBase.m_pszName) + 1); // Force name to match case

			bool isValidFlags = true;
			if (areWeServer)
			{
				if (realVar)
					isValidFlags = ConVar::IsFlagSet(realVar, FCVAR_USERINFO); // ConVar MUST be userinfo var
			}
			else
			{
				// TODO: Should probably have some sanity checks, but can't find any that are consistent
			}

			if (!isValidFlags)
			{
				if (!realVar)
				{
					return BLOCKED_INFO("Invalid flags on nonexistant cvar (how tho???)");
				}
				else
				{
					return BLOCKED_INFO(
						"Invalid flags (" << std::hex << "0x" << realVar->m_ConCommandBase.m_nFlags << "), var is " << entry->name);
				}
			}
		}
		else
		{
			return BLOCKED_INFO("Unreadable memory at " << (void*)entry); // Not risking that one, they all gotta be readable
		}
	}

	return oCClient_ProcessSetConVar(msg);
}

// Purpose: prevent invalid user CMDs
KHOOK(CClient_ProcessUsercmds, ("engine.dll", "40 55 56 48 83 EC 58"), bool, __fastcall, (void* thisptr, void* pMsg))
{
	struct CLC_Move
	{
		BYTE gap0[24];
		void* m_pMessageHandler;
		int m_nBackupCommands;
		int m_nNewCommands;
		int m_nLength;
		// bf_read m_DataIn;
		// bf_write m_DataOut;
	};

	auto msg = (CLC_Move*)pMsg;

	const char* BLOCK_PREFIX = "ProcessUserCmds: ";

	if (msg->m_nBackupCommands < 0)
	{
		return BLOCKED_INFO("Invalid m_nBackupCommands (" << msg->m_nBackupCommands << ")");
	}

	if (msg->m_nNewCommands < 0)
	{
		return BLOCKED_INFO("Invalid m_nNewCommands (" << msg->m_nNewCommands << ")");
	}

	constexpr int NUMCMD_SANITY_LIMIT = 16;
	if ((msg->m_nNewCommands + msg->m_nBackupCommands) > NUMCMD_SANITY_LIMIT)
	{
		return BLOCKED_INFO("Command count is too high (new: " << msg->m_nNewCommands << ", backup: " << msg->m_nBackupCommands << ")");
	}

	if (msg->m_nLength <= 0)
		return BLOCKED_INFO("Invalid message length (" << msg->m_nLength << ")");

	return oCClient_ProcessUsercmds(thisptr, pMsg);
}

KHOOK(ReadUsercmd, ("server.dll", "4C 89 44 24 ? 53 55 56 57"), void, __fastcall, (void* buf, void* pCmd_move, void* pCmd_from))
{
	// Let normal usercmd read happen first, it's safe
	oReadUsercmd(buf, pCmd_move, pCmd_from);

	// Now let's make sure the CMD we read isnt messed up to prevent numerous exploits (including server crashing)
	struct alignas(4) SV_CUserCmd
	{
		DWORD command_number;
		DWORD tick_count;
		float command_time;
		Float3 worldViewAngles;
		BYTE gap18[4];
		Float3 localViewAngles;
		Float3 attackangles;
		Float3 move;
		DWORD buttons;
		BYTE impulse;
		short weaponselect;
		DWORD meleetarget;
		BYTE gap4C[24];
		char headoffset;
		BYTE gap65[11];
		Float3 cameraPos;
		Float3 cameraAngles;
		BYTE gap88[4];
		int tickSomething;
		DWORD dword90;
		DWORD predictedServerEventAck;
		DWORD dword98;
		float frameTime;
	};

	auto cmd = (SV_CUserCmd*)pCmd_move;
	auto fromCmd = (SV_CUserCmd*)pCmd_from;

	std::string BLOCK_PREFIX =
		"ReadUsercmd (command_number delta: " + std::to_string(cmd->command_number - fromCmd->command_number) + "): ";

	// Fix invalid player angles
	cmd->worldViewAngles.MakeValid();
	cmd->attackangles.MakeValid();
	cmd->localViewAngles.MakeValid();

	// Fix invalid camera angles
	cmd->cameraPos.MakeValid();
	cmd->cameraAngles.MakeValid();

	// Fix invaid movement vector
	cmd->move.MakeValid();

	if (cmd->tick_count == 0 || cmd->command_time <= 0)
	{
		BLOCKED_INFO(
			"Bogus cmd timing (tick_count: " << cmd->tick_count << ", frameTime: " << cmd->frameTime
											 << ", commandTime : " << cmd->command_time << ")");
		goto INVALID_CMD; // No simulation of bogus-timed cmds
	}

	return;

INVALID_CMD:
	// Fix any gameplay-affecting cmd properties
	// NOTE: Currently tickcount/frametime is set to 0, this ~shouldn't~ cause any problems
	cmd->worldViewAngles = cmd->localViewAngles = cmd->attackangles = cmd->cameraAngles = {0, 0, 0};
	cmd->tick_count = cmd->frameTime = 0;
	cmd->move = cmd->cameraPos = {0, 0, 0};
	cmd->buttons = 0;
	cmd->meleetarget = 0;
}

// basically: by default r2 isn't set as a valve mod, meaning that m_bRestrictServerCommands is false
// this is HORRIBLE for security, because it means servers can run arbitrary concommands on clients
// especially since we have script commands this could theoretically be awful
KHOOK(IsValveMod, ("engine.dll", "48 83 EC 28 48 8B 0D ? ? ? ? 48 8D 15 ? ? ? ? E8 ? ? ? ? 85 C0 74 63"), bool, __fastcall, ())
{
	bool result = !CommandLine()->CheckParm("-norestrictservercommands");
	spdlog::info("ExploitFixes: Overriding IsValveMod to {}...", result);
	return result;
}

// Fix respawn's crappy UTF8 parser so it doesn't crash -_-
// This also means you can launch multiplayer with "communities_enabled 1" and not crash, you're welcome
KHOOK(
	CrashFunc_ParseUTF8,
	("engine.dll", "48 89 5C 24 ? 48 89 6C 24 ? 48 89 74 24 ? 57 41 54 41 55 41 56 41 57 48 83 EC 20 8B 1A"),
	bool,
	__fastcall,
	(INT64 * a1, DWORD* a2, char* strData))
{

	static void* targetRetAddr = NSMem::PatternScan("engine.dll", "84 C0 75 2C 49 8B 16");

#ifdef _MSC_VER
	if (_ReturnAddress() == targetRetAddr)
#else
	if (__builtin_return_address(0) == targetRetAddr)
#endif
	{
		if (!ExploitFixes_UTF8Parser::CheckValid(a1, a2, strData))
		{
			const char* BLOCK_PREFIX = "ParseUTF8 Hook: ";
			BLOCKED_INFO("Ignoring potentially-crashing utf8 string");
			return false;
		}
	}

	return oCrashFunc_ParseUTF8(a1, a2, strData);
}

// GetEntByIndex (called by ScriptGetEntByIndex) doesn't check for the index being out of bounds when it's
// above the max entity count. This allows it to be used to crash servers.
typedef void*(__fastcall* GetEntByIndexType)(int idx);
GetEntByIndexType GetEntByIndex;

static void* GetEntByIndexHook(int idx)
{
	if (idx >= 0x4000)
	{
		spdlog::info("GetEntByIndex {} is out of bounds", idx);
		return nullptr;
	}
	return GetEntByIndex(idx);
}

// RELOCATED FROM https://github.com/R2Northstar/NorthstarLauncher/commit/25dbf729cfc75107a0fcf0186924b58ecc05214b
// Rewrite of CLZSS::SafeUncompress to fix a vulnerability where malicious compressed payloads could cause the decompressor to try to read
// out of the bounds of the output buffer.
KHOOK(
	LZSS_SafeUncompress,
	("engine.dll", "48 89 5C 24 ? 48 89 6C 24 ? 48 89 74 24 ? 48 89 7C 24 ? 33 ED 41 8B F9"),
	uint32_t,
	__fastcall,
	(void* self, const unsigned char* pInput, unsigned char* pOutput, unsigned int unBufSize))
{
	static constexpr int LZSS_LOOKSHIFT = 4;

	uint32_t totalBytes = 0;
	int getCmdByte = 0, cmdByte = 0;

	struct lzss_header_t
	{
		uint32_t id, actualSize;
	};

	lzss_header_t header = *(lzss_header_t*)pInput;

	if (pInput == NULL || header.id != 'SSZL' || header.actualSize == 0 || header.actualSize > unBufSize)
		return 0;

	pInput += sizeof(lzss_header_t);

	for (;;)
	{
		if (!getCmdByte)
			cmdByte = *pInput++;

		getCmdByte = (getCmdByte + 1) & 0x07;

		if (cmdByte & 0x01)
		{
			int position = *pInput++ << LZSS_LOOKSHIFT;
			position |= (*pInput >> LZSS_LOOKSHIFT);
			position += 1;

			int count = (*pInput++ & 0x0F) + 1;
			if (count == 1)
				break;

			// Ensure reference chunk exists entirely within our buffer
			if (position > totalBytes)
				return 0;

			totalBytes += count;
			if (totalBytes > unBufSize)
				return 0;

			unsigned char* pSource = pOutput - position;
			for (int i = 0; i < count; i++)
				*pOutput++ = *pSource++;
		}
		else
		{
			totalBytes++;
			if (totalBytes > unBufSize)
				return 0;

			*pOutput++ = *pInput++;
		}
		cmdByte = cmdByte >> 1;
	}

	if (totalBytes == header.actualSize)
	{
		return totalBytes;
	}
	else
	{
		return 0;
	}
}

//////////////////////////////////////////////////

void DoBytePatches()
{
	uintptr_t engineBase = (uintptr_t)GetModuleHandleA("engine.dll");
	uintptr_t serverBase = (uintptr_t)GetModuleHandleA("server.dll");

	// patches to make commands run from client/ui script still work
	// note: this is likely preventable in a nicer way? test prolly
	NSMem::BytePatch(engineBase + 0x4FB65, "EB 11");
	NSMem::BytePatch(engineBase + 0x4FBAC, "EB 16");

	// disconnect concommand
	{
		uintptr_t addr = engineBase + 0x5ADA2D;
		int val = *(int*)addr | FCVAR_SERVER_CAN_EXECUTE;
		NSMem::BytePatch(addr, (BYTE*)&val, sizeof(int));
	}

	{ // Dumb ANTITAMPER patches (they negatively impact performance and security)

		constexpr const char* ANTITAMPER_EXPORTS[] = {
			"ANTITAMPER_SPOTCHECK_CODEMARKER",
			"ANTITAMPER_TESTVALUE_CODEMARKER",
			"ANTITAMPER_TRIGGER_CODEMARKER",
		};

		// Prevent thesefrom actually doing anything
		for (auto exportName : ANTITAMPER_EXPORTS)
		{

			auto address = (uintptr_t)GetProcAddress(GetModuleHandleA("server.dll"), exportName);
			if (!address)
			{
				spdlog::warn("Failed to find AntiTamper function export \"{}\"", exportName);
			}
			else
			{
				// Just return, none of them have any args or are userpurge
				NSMem::BytePatch(address, "C3");

				spdlog::info("Patched AntiTamper function export \"{}\"", exportName);
			}
		}
	}
}

KHOOK(
	SpecialClientCommand,
	("server.dll", "48 89 5C 24 ? 48 89 74 24 ? 55 57 41 56 48 8D 6C 24 ? 48 81 EC ? ? ? ? 83 3A 00"),
	bool,
	__fastcall,
	(void* player, CCommand* command))
{

	static ConVar* sv_cheats = g_pCVar->FindVar("sv_cheats");

	if (sv_cheats->GetBool())
		return oSpecialClientCommand(player, command); // Don't block anything if sv_cheats is on

	// These are mostly from Portal 2 (sigh)
	constexpr const char* blockedCommands[] = {
		"emit", // Sound-playing exploit (likely for Portal 2 coop devs testing splitscreen sound or something)

		// These both execute a command for every single entity for some reason, nice one valve
		"pre_go_to_hub",
		"pre_go_to_calibration",

		"end_movie", // Calls "__MovieFinished" script function, not sure exactly what this does but it certainly isn't needed
		"load_recent_checkpoint" // This is the instant-respawn exploit, literally just calls RespawnPlayer()
	};

	if (command->ArgC() > 0)
	{
		std::string cmdStr = command->Arg(0);
		for (char& c : cmdStr)
			c = tolower(c);

		for (const char* blockedCommand : blockedCommands)
		{
			if (cmdStr.find(blockedCommand) != std::string::npos)
			{
				// Block this command
				spdlog::warn("Blocked exploititive client command \"{}\".", cmdStr);
				return true;
			}
		}
	}

	return oSpecialClientCommand(player, command);
}

void SetupKHook(KHook* hook)
{
	if (hook->Setup())
	{
		spdlog::debug("KHook::Setup(): Hooked at {}", hook->targetFuncAddr);
	}
	else
	{
		spdlog::critical("\tFAILED to initialize all exploit patches.");

		// Force exit
		MessageBoxA(0, "FAILED to initialize all exploit patches.", "Northstar", MB_ICONERROR);
		exit(0);
	}
}

void ExploitFixes::LoadCallback_MultiModule(HMODULE baseAddress)
{

	spdlog::info("ExploitFixes::LoadCallback_MultiModule({}) ...", (void*)baseAddress);

	int hooksEnabled = 0;
	for (auto itr = KHook::_allHooks.begin(); itr != KHook::_allHooks.end(); itr++)
	{
		auto curHook = *itr;
		if (GetModuleHandleA(curHook->targetFunc.moduleName) == baseAddress)
		{
			SetupKHook(curHook);
			itr = KHook::_allHooks.erase(itr); // Prevent repeated initialization

			hooksEnabled++;

			if (itr == KHook::_allHooks.end())
				break;
		}
	}

	spdlog::info("\tEnabled {} hooks.", hooksEnabled);
}

void ExploitFixes::LoadCallback_Full(HMODULE baseAddress)
{
	spdlog::info("ExploitFixes::LoadCallback_Full ...");

	spdlog::info("\tByte patching...");
	DoBytePatches();

	for (KHook* hook : KHook::_allHooks)
		SetupKHook(hook);

	spdlog::info("\tInitialized " + std::to_string(KHook::_allHooks.size()) + " late exploit-patch hooks.");
	KHook::_allHooks.clear();

	ns_exploitfixes_log =
		new ConVar("ns_exploitfixes_log", "1", FCVAR_GAMEDLL, "Whether to log whenever exploitfixes.cpp blocks/corrects something");

	HookEnabler hook;
	ENABLER_CREATEHOOK(hook, (char*)baseAddress + 0x2a8a50, &GetEntByIndexHook, reinterpret_cast<LPVOID*>(&GetEntByIndex));
}