#include "core/convar/cvar.h"
#include "ns_limits.h"
#include "dedicated/dedicated.h"
#include "core/tier0.h"
#include "engine/r2engine.h"
#include "client/r2client.h"
#include "core/math/vector.h"
#include "core/vanilla.h"

AUTOHOOK_INIT()

ConVar* Cvar_ns_exploitfixes_log;
ConVar* Cvar_ns_should_log_all_clientcommands;

ConVar* Cvar_sv_cheats;

#define BLOCKED_INFO(s)                                                                                                                    \
	(                                                                                                                                      \
		[=]() -> bool                                                                                                                      \
		{                                                                                                                                  \
			if (Cvar_ns_exploitfixes_log->GetBool())                                                                                       \
			{                                                                                                                              \
				std::stringstream stream;                                                                                                  \
				stream << "ExploitFixes.cpp: " << BLOCK_PREFIX << s;                                                                       \
				spdlog::error(stream.str());                                                                                               \
			}                                                                                                                              \
			return false;                                                                                                                  \
		}())

// block bad netmessages
// Servers can literally request a screenshot from any client, yeah no
// clang-format off
AUTOHOOK(CLC_Screenshot_WriteToBuffer, engine.dll + 0x22AF20, 
bool, __fastcall, (void* thisptr, void* buffer)) // 48 89 5C 24 ? 57 48 83 EC 20 8B 42 10
// clang-format on
{
	if (g_pVanillaCompatibility->GetVanillaCompatibility())
		return CLC_Screenshot_WriteToBuffer(thisptr, buffer);
	return false;
}

// clang-format off
AUTOHOOK(CLC_Screenshot_ReadFromBuffer, engine.dll + 0x221F00, 
bool, __fastcall, (void* thisptr, void* buffer)) // 48 89 5C 24 ? 48 89 6C 24 ? 48 89 74 24 ? 57 48 83 EC 20 48 8B DA 48 8B 52 38
// clang-format on
{
	if (g_pVanillaCompatibility->GetVanillaCompatibility())
		return CLC_Screenshot_ReadFromBuffer(thisptr, buffer);
	return false;
}

// This is unused ingame and a big client=>server=>client exploit vector
// clang-format off
AUTOHOOK(Base_CmdKeyValues_ReadFromBuffer, engine.dll + 0x220040,
bool, __fastcall, (void* thisptr, void* buffer)) // 40 55 48 81 EC ? ? ? ? 48 8D 6C 24 ? 48 89 5D 70
// clang-format on
{
	NOTE_UNUSED(thisptr);
	NOTE_UNUSED(buffer);
	return false;
}

// clang-format off
AUTOHOOK(CClient_ProcessSetConVar, engine.dll + 0x75CF0,
bool, __fastcall, (void* pMsg)) // 48 8B D1 48 8B 49 18 48 8B 01 48 FF 60 10
// clang-format on
{

	constexpr int ENTRY_STR_LEN = 260;
	struct SetConVarEntry
	{
		char name[ENTRY_STR_LEN];
		char val[ENTRY_STR_LEN];
	};

	struct NET_SetConVar
	{
		void* vtable;
		void* unk1;
		void* unk2;
		void* m_pMessageHandler;
		SetConVarEntry* m_ConVars; // convar entry array
		void* unk5; // these 2 unks are just vector capacity or whatever
		void* unk6;
		int m_ConVars_count; // amount of cvar entries in array (this will not be out of bounds)
	};

	auto msg = (NET_SetConVar*)pMsg;
	bool bIsServerFrame = ThreadInServerFrameThread();

	std::string BLOCK_PREFIX =
		std::string {"NET_SetConVar ("} + (bIsServerFrame ? "server" : "client") + "): Blocked dangerous/invalid msg: ";

	if (bIsServerFrame)
	{
		constexpr int SETCONVAR_SANITY_AMOUNT_LIMIT = 69;
		if (msg->m_ConVars_count < 1 || msg->m_ConVars_count > SETCONVAR_SANITY_AMOUNT_LIMIT)
		{
			return BLOCKED_INFO("Invalid m_ConVars_count (" << msg->m_ConVars_count << ")");
		}
	}

	for (int i = 0; i < msg->m_ConVars_count; i++)
	{
		auto entry = msg->m_ConVars + i;

		// Safety check for memory access
		if (CMemory(entry).IsMemoryReadable(sizeof(*entry)))
		{
			// Find null terminators
			bool nameValid = false, valValid = false;
			for (int i = 0; i < ENTRY_STR_LEN; i++)
			{
				if (!entry->name[i])
					nameValid = true;
				if (!entry->val[i])
					valValid = true;
			}

			if (!nameValid || !valValid)
				return BLOCKED_INFO("Missing null terminators");

			ConVar* pVar = g_pCVar->FindVar(entry->name);

			if (pVar)
			{
				memcpy(
					entry->name,
					pVar->m_ConCommandBase.m_pszName,
					strlen(pVar->m_ConCommandBase.m_pszName) + 1); // Force name to match case

				int iFlags = bIsServerFrame ? FCVAR_USERINFO : FCVAR_REPLICATED;
				if (!pVar->IsFlagSet(iFlags))
					return BLOCKED_INFO(
						"Invalid flags (" << std::hex << "0x" << pVar->m_ConCommandBase.m_nFlags << "), var is " << entry->name);
			}
		}
		else
		{
			return BLOCKED_INFO("Unreadable memory at " << (void*)entry); // Not risking that one, they all gotta be readable
		}
	}

	return CClient_ProcessSetConVar(msg);
}

// prevent invalid user CMDs
// clang-format off
AUTOHOOK(CClient_ProcessUsercmds, engine.dll + 0x1040F0,
bool, __fastcall, (void* thisptr, void* pMsg)) // 40 55 56 48 83 EC 58
// clang-format on
{
	struct CLC_Move
	{
		BYTE gap0[24];
		void* m_pMessageHandler;
		int m_nBackupCommands;
		int m_nNewCommands;
		int m_nLength;
		// bf_read m_DataIn;
		// bf_write m_DataOut;
	};

	auto msg = (CLC_Move*)pMsg;

	const char* BLOCK_PREFIX = "ProcessUserCmds: ";

	if (msg->m_nBackupCommands < 0)
	{
		return BLOCKED_INFO("Invalid m_nBackupCommands (" << msg->m_nBackupCommands << ")");
	}

	if (msg->m_nNewCommands < 0)
	{
		return BLOCKED_INFO("Invalid m_nNewCommands (" << msg->m_nNewCommands << ")");
	}

	if (msg->m_nLength <= 0)
		return BLOCKED_INFO("Invalid message length (" << msg->m_nLength << ")");

	return CClient_ProcessUsercmds(thisptr, pMsg);
}

// clang-format off
AUTOHOOK(ReadUsercmd, server.dll + 0x2603F0,
void, __fastcall, (void* buf, void* pCmd_move, void* pCmd_from)) // 4C 89 44 24 ? 53 55 56 57
// clang-format on
{
	// Let normal usercmd read happen first, it's safe
	ReadUsercmd(buf, pCmd_move, pCmd_from);

	// Now let's make sure the CMD we read isnt messed up to prevent numerous exploits (including server crashing)
	struct alignas(4) SV_CUserCmd
	{
		DWORD command_number;
		DWORD tick_count;
		float command_time;
		Vector3 worldViewAngles;
		BYTE gap18[4];
		Vector3 localViewAngles;
		Vector3 attackangles;
		Vector3 move;
		DWORD buttons;
		BYTE impulse;
		short weaponselect;
		DWORD meleetarget;
		BYTE gap4C[24];
		char headoffset;
		BYTE gap65[11];
		Vector3 cameraPos;
		Vector3 cameraAngles;
		BYTE gap88[4];
		int tickSomething;
		DWORD dword90;
		DWORD predictedServerEventAck;
		DWORD dword98;
		float frameTime;
	};

	auto cmd = (SV_CUserCmd*)pCmd_move;
	auto fromCmd = (SV_CUserCmd*)pCmd_from;

	std::string BLOCK_PREFIX =
		"ReadUsercmd (command_number delta: " + std::to_string(cmd->command_number - fromCmd->command_number) + "): ";

	// fix invalid player angles
	if (!cmd->worldViewAngles.IsValid())
		cmd->worldViewAngles.Init();

	if (!cmd->attackangles.IsValid())
		cmd->attackangles.Init();

	if (!cmd->localViewAngles.IsValid())
		cmd->localViewAngles.Init();

	// Fix invalid camera angles
	if (!cmd->cameraPos.IsValid())
		cmd->cameraPos.Init();
	if (!cmd->cameraAngles.IsValid())
		cmd->cameraAngles.Init();

	// Fix invaid movement vector
	if (!cmd->move.IsValid())
		cmd->move.Init();

	if (cmd->frameTime <= 0 || cmd->tick_count == 0 || cmd->command_time <= 0)
	{
		BLOCKED_INFO(
			"Bogus cmd timing (tick_count: " << cmd->tick_count << ", frameTime: " << cmd->frameTime
											 << ", commandTime : " << cmd->command_time << ")");
		goto INVALID_CMD; // No simulation of bogus-timed cmds
	}

	return;

INVALID_CMD:

	// Fix any gameplay-affecting cmd properties
	// NOTE: Currently tickcount/frametime is set to 0, this ~shouldn't~ cause any problems
	cmd->worldViewAngles = cmd->localViewAngles = cmd->attackangles = cmd->cameraAngles = {0, 0, 0};
	cmd->tick_count = cmd->frameTime = 0;
	cmd->move = cmd->cameraPos = {0, 0, 0};
	cmd->buttons = 0;
	cmd->meleetarget = 0;
}

// ensure that GetLocalBaseClient().m_bRestrictServerCommands is set correctly, which the return value of this function controls
// this is IsValveMod in source, but we're making it IsRespawnMod now since valve didn't make this one
// clang-format off
AUTOHOOK(IsRespawnMod, engine.dll + 0x1C6360,
bool, __fastcall, (const char* pModName)) // 48 83 EC 28 48 8B 0D ? ? ? ? 48 8D 15 ? ? ? ? E8 ? ? ? ? 85 C0 74 63
// clang-format on
{
	// somewhat temp, store the modname here, since we don't have a proper ptr in engine to it rn
	size_t iSize = strlen(pModName);
	g_pModName = new char[iSize + 1];
	strcpy(g_pModName, pModName);

	if (g_pVanillaCompatibility->GetVanillaCompatibility())
		return false;

	return (!strcmp("r2", pModName) || !strcmp("r1", pModName)) && !CommandLine()->CheckParm("-norestrictservercommands");
}

// ratelimit stringcmds, and prevent remote clients from calling commands that they shouldn't
// clang-format off
AUTOHOOK(CGameClient__ExecuteStringCommand, engine.dll + 0x1022E0,
bool, __fastcall, (CBaseClient* self, uint32_t unknown, const char* pCommandString))
// clang-format on
{
	if (Cvar_ns_should_log_all_clientcommands->GetBool())
		spdlog::info("player {} (UID: {}) sent command: \"{}\"", self->m_Name, self->m_UID, pCommandString);

	if (!g_pServerLimits->CheckStringCommandLimits(self))
	{
		CBaseClient__Disconnect(self, 1, "Sent too many stringcmd commands");
		return false;
	}

	// verify the command we're trying to execute is FCVAR_GAMEDLL_FOR_REMOTE_CLIENTS, if it's a concommand
	char* commandBuf[1040]; // assumedly this is the size of CCommand since we don't have an actual constructor
	memset(commandBuf, 0, sizeof(commandBuf));
	CCommand tempCommand = *(CCommand*)&commandBuf;

	if (!CCommand__Tokenize(tempCommand, pCommandString, cmd_source_t::kCommandSrcCode) || !tempCommand.ArgC())
		return false;

	ConCommand* command = g_pCVar->FindCommand(tempCommand.Arg(0));

	// if the command doesn't exist pass it on to ExecuteStringCommand for script clientcommands and stuff
	if (command && !command->IsFlagSet(FCVAR_GAMEDLL_FOR_REMOTE_CLIENTS))
	{
		// ensure FCVAR_GAMEDLL concommands without FCVAR_GAMEDLL_FOR_REMOTE_CLIENTS can't be executed by remote clients
		if (IsDedicatedServer())
			return false;

		if (strcmp(self->m_UID, g_pLocalPlayerUserID))
			return false;
	}

	// check for and block abusable legacy portal 2 commands
	// these aren't actually concommands weirdly enough, they seem to just be hardcoded
	if (!Cvar_sv_cheats->GetBool())
	{
		constexpr const char* blockedCommands[] = {
			"emit", // Sound-playing exploit (likely for Portal 2 coop devs testing splitscreen sound or something)

			// These both execute a command for every single entity for some reason, nice one valve
			"pre_go_to_hub",
			"pre_go_to_calibration",

			"end_movie", // Calls "__MovieFinished" script function, not sure exactly what this does but it certainly isn't needed
			"load_recent_checkpoint" // This is the instant-respawn exploit, literally just calls RespawnPlayer()
		};

		size_t iCmdLength = strlen(tempCommand.Arg(0));

		bool bIsBadCommand = false;
		for (auto& blockedCommand : blockedCommands)
		{
			if (iCmdLength != strlen(blockedCommand))
				continue;

			for (int i = 0; tempCommand.Arg(0)[i]; i++)
				if (tolower(tempCommand.Arg(0)[i]) != blockedCommand[i])
					goto NEXT_COMMAND; // break out of this loop, then go to next command

			// this is a command we need to block
			return false;
		NEXT_COMMAND:;
		}
	}

	return CGameClient__ExecuteStringCommand(self, unknown, pCommandString);
}

// prevent clients from crashing servers through overflowing CNetworkStringTableContainer::WriteBaselines
bool bWasWritingStringTableSuccessful;

// clang-format off
AUTOHOOK(CBaseClient__SendServerInfo, engine.dll + 0x104FB0,
void, __fastcall, (void* self))
// clang-format on
{
	bWasWritingStringTableSuccessful = true;
	CBaseClient__SendServerInfo(self);
	if (!bWasWritingStringTableSuccessful)
		CBaseClient__Disconnect(
			self, 1, "Overflowed CNetworkStringTableContainer::WriteBaselines, try restarting your client and reconnecting");
}

// return null when GetEntByIndex is passed an index >= 0x4000
// this is called from exactly 1 script clientcommand that can be given an arbitrary index, and going above 0x4000 crashes
// clang-format off
AUTOHOOK(GetEntByIndex, server.dll + 0x2A8A50,
void*, __fastcall, (int i))
// clang-format on
{
	const int MAX_ENT_IDX = 0x4000;

	if (i >= MAX_ENT_IDX)
	{
		spdlog::warn("GetEntByIndex {} is out of bounds (max {})", i, MAX_ENT_IDX);
		return nullptr;
	}

	return GetEntByIndex(i);
}
// clang-format off
AUTOHOOK(CL_CopyExistingEntity, engine.dll + 0x6F940,
bool, __fastcall, (void* a1))
// clang-format on
{
	struct CEntityReadInfo
	{
		BYTE gap[40];
		int nNewEntity;
	};

	CEntityReadInfo* pReadInfo = (CEntityReadInfo*)a1;
	if (pReadInfo->nNewEntity >= 0x1000 || pReadInfo->nNewEntity < 0)
	{
		// Value isn't sanitized in release builds for
		// every game powered by the Source Engine 1
		// causing read/write outside of array bounds.
		// This defect has let to the achievement of a
		// full-chain RCE exploit. We hook and perform
		// sanity checks for the value of m_nNewEntity
		// here to prevent this behavior from happening.
		return false;
	}

	return CL_CopyExistingEntity(a1);
}

ON_DLL_LOAD("engine.dll", EngineExploitFixes, (CModule module))
{
	AUTOHOOK_DISPATCH_MODULE(engine.dll)

	// allow client/ui to run clientcommands despite restricting servercommands
	module.Offset(0x4FB65).Patch("EB 11");
	module.Offset(0x4FBAC).Patch("EB 16");

	// patch to set bWasWritingStringTableSuccessful in CNetworkStringTableContainer::WriteBaselines if it fails
	{
		CMemory writeAddress(&bWasWritingStringTableSuccessful - module.Offset(0x234EDC).GetPtr());

		CMemory addr = module.Offset(0x234ED2);
		addr.Patch("C7 05");
		addr.Offset(2).Patch((BYTE*)&writeAddress, sizeof(writeAddress));

		addr.Offset(6).Patch("00 00 00 00");

		addr.Offset(10).NOP(5);
	}
}

ON_DLL_LOAD_RELIESON("server.dll", ServerExploitFixes, ConVar, (CModule module))
{
	AUTOHOOK_DISPATCH_MODULE(server.dll)

	// ret at the start of CServerGameClients::ClientCommandKeyValues as it has no benefit and is forwarded to client (i.e. security issue)
	// this prevents the attack vector of client=>server=>client, however server=>client also has clientside patches
	module.Offset(0x153920).Patch("C3");

	// Dumb ANTITAMPER patches (they negatively impact performance and security)
	constexpr const char* ANTITAMPER_EXPORTS[] = {
		"ANTITAMPER_SPOTCHECK_CODEMARKER",
		"ANTITAMPER_TESTVALUE_CODEMARKER",
		"ANTITAMPER_TRIGGER_CODEMARKER",
	};

	// Prevent these from actually doing anything
	for (auto exportName : ANTITAMPER_EXPORTS)
	{
		CMemory exportAddr = module.GetExportedFunction(exportName);
		if (exportAddr)
		{
			// Just return, none of them have any args or are userpurge
			exportAddr.Patch("C3");
			spdlog::info("Patched AntiTamper function export \"{}\"", exportName);
		}
	}

	Cvar_ns_exploitfixes_log =
		new ConVar("ns_exploitfixes_log", "1", FCVAR_GAMEDLL, "Whether to log whenever ExploitFixes.cpp blocks/corrects something");
	Cvar_ns_should_log_all_clientcommands =
		new ConVar("ns_should_log_all_clientcommands", "0", FCVAR_NONE, "Whether to log all clientcommands");

	Cvar_sv_cheats = g_pCVar->FindVar("sv_cheats");
}