#include "pch.h"
#include "Memory.h"

extern HMODULE hTier0Module;
IMemAlloc** g_ppMemAllocSingleton;

void LoadTier0Handle()
{
	if (!hTier0Module)
		hTier0Module = GetModuleHandleA("tier0.dll");
	if (!hTier0Module)
		return;

	g_ppMemAllocSingleton = (IMemAlloc**)GetProcAddress(hTier0Module, "g_pMemAllocSingleton");
}

const int STATIC_ALLOC_SIZE = 4096;

size_t g_iStaticAllocated = 0;
void* g_pLastAllocated = nullptr;
char pStaticAllocBuf[STATIC_ALLOC_SIZE];

// they should never be used here, except in LibraryLoadError?

void* malloc(size_t n)
{
	// allocate into static buffer
	if (g_iStaticAllocated + n <= STATIC_ALLOC_SIZE)
	{
		void* ret = pStaticAllocBuf + g_iStaticAllocated;
		g_iStaticAllocated += n;
		return ret;
	}
	else
	{
		// try to fallback to g_pMemAllocSingleton
		if (!hTier0Module || !g_ppMemAllocSingleton)
			LoadTier0Handle();
		if (g_ppMemAllocSingleton && *g_ppMemAllocSingleton)
			return (*g_ppMemAllocSingleton)->m_vtable->Alloc(*g_ppMemAllocSingleton, n);
		else
			throw "Cannot allocate";
	}
}

void free(void* p)
{
	// if it was allocated into the static buffer, just do nothing, safest way to deal with it
	if (p >= pStaticAllocBuf && p <= pStaticAllocBuf + STATIC_ALLOC_SIZE)
		return;

	if (g_ppMemAllocSingleton && *g_ppMemAllocSingleton)
		(*g_ppMemAllocSingleton)->m_vtable->Free(*g_ppMemAllocSingleton, p);
}

void* realloc(void* old_ptr, size_t size)
{
	// it was allocated into the static buffer
	if (old_ptr >= pStaticAllocBuf && old_ptr <= pStaticAllocBuf + STATIC_ALLOC_SIZE)
	{
		if (g_pLastAllocated == old_ptr)
		{
			// nothing was allocated after this
			size_t old_size = g_iStaticAllocated - ((size_t)g_pLastAllocated - (size_t)pStaticAllocBuf);
			size_t diff = size - old_size;
			if (diff > 0)
				g_iStaticAllocated += diff;
			return old_ptr;
		}
		else
		{
			return malloc(size);
		}
	}

	if (g_ppMemAllocSingleton && *g_ppMemAllocSingleton)
		return (*g_ppMemAllocSingleton)->m_vtable->Realloc(*g_ppMemAllocSingleton, old_ptr, size);
	return nullptr;
}

void* operator new(size_t n)
{
	return malloc(n);
}

void operator delete(void* p) noexcept
{
	return free(p);
}