/*****************************************************************************
 *	Hoard: A Fast, Scalable, and Memory-Efficient Allocator for 
 *	Shared-Memory Multiprocessors.
 *
 *	Portions: Copyright (c) 1998-2001, The University of Texas at Austin.
 *
 *	Portions of this are free software; you can redistribute it and/or
 *	modify it under the terms of the GNU Library General Public License
 *	as published by the Free Software Foundation, http://www.fsf.org.
 *
 *	These portions are distributed in the hope that it will be useful,
 *	but WITHOUT ANY WARRANTY; without even the implied warranty of
 *	MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
 *	Library General Public License for more details.
 ****************************************************************************/

#define WIN32_LEAN_AND_MEAN
#include <windows.h>

#include "config.h"
#include "ntwrapper.h"


const char *RlsCRTLibraryName = "MSVCRT.DLL";
const char *DbgCRTLibraryName = "MSVCRTD.DLL";

#define IAX86_NEARJMP_OPCODE		0xe9
#define MakeIAX86Offset(to,from)	((unsigned)((char*)(to)-(char*)(from)) - 5)

typedef struct
{
	const char *import;				// import name of patch routine
	FARPROC replacement;				// pointer to replacement function
	FARPROC original;					// pointer to original function
	unsigned char codebytes[5];	// 5 bytes of original code storage
} PATCH;

/*---------------------------------------------------------------------------
*/
static PATCH rls_patches[] = 
{	// RELEASE CRT library routines supported by this memory manager.
	{"_expand",							(FARPROC)lh__expand,							0},
	{"_heapchk",						(FARPROC)lh__heapchk,						0},
	{"_heapmin",						(FARPROC)lh__heapmin,						0},
	{"_heapset",						(FARPROC)lh__heapset,						0},
	{"_heapwalk",						(FARPROC)lh__heapwalk,						0},

	// EDB: Added operator new & delete handling.

	// operator new, new[], delete, delete[].

	{"??2@YAPAXI@Z",                    (FARPROC)lh_malloc,                         0},
	{"??_U@YAPAXI@Z",                    (FARPROC)lh_malloc,                         0},
	{"??3@YAXPAX@Z",                    (FARPROC)lh_free,                         0},
	{"??_V@YAXPAX@Z",                    (FARPROC)lh_free,                         0},

	// the nothrow variants new, new[].

	{"??2@YAPAXIABUnothrow_t@std@@@Z",   (FARPROC)lh_new_nothrow,                         0},
	{"??_U@YAPAXIABUnothrow_t@std@@@Z",   (FARPROC)lh_new_nothrow,                         0},

#if 0 // def _DEBUG
	{"_msize",							(FARPROC)db__msize,							0},
	{"calloc",							(FARPROC)db_calloc,							0},
	{"malloc",							(FARPROC)db_malloc,							0},
	{"realloc",							(FARPROC)db_realloc,							0},
	{"free",								(FARPROC)db_free,								0},
#else
	{"_msize",							(FARPROC)lh__msize,							0},
	{"calloc",							(FARPROC)lh_calloc,							0},
	{"malloc",							(FARPROC)lh_malloc,							0},
	{"realloc",							(FARPROC)lh_realloc,							0},
	{"free",								(FARPROC)lh_free,								0},
#endif
};

#ifdef _DEBUG
static PATCH dbg_patches[] = 
{	// DEBUG CRT library routines supported by this memory manager.
	{"_calloc_dbg",					(FARPROC)lh__calloc_dbg,					0},
	{"_CrtCheckMemory",				(FARPROC)lh__CrtCheckMemory,				0},
	{"_CrtDoForAllClientObjects",	(FARPROC)lh__CrtDoForAllClientObjects,	0},
	{"_CrtDumpMemoryLeaks",			(FARPROC)lh__CrtDumpMemoryLeaks,			0},
	{"_CrtIsMemoryBlock",			(FARPROC)lh__CrtIsMemoryBlock,			0},
	{"_CrtIsValidHeapPointer",		(FARPROC)lh__CrtIsValidHeapPointer,		0},
	{"_CrtMemCheckpoint",			(FARPROC)lh__CrtMemCheckpoint,			0},
	{"_CrtMemDifference",			(FARPROC)lh__CrtMemDifference,			0},
	{"_CrtMemDumpAllObjectsSince",(FARPROC)lh__CrtMemDumpAllObjectsSince,0},
	{"_CrtMemDumpStatistics",		(FARPROC)lh__CrtMemDumpStatistics,		0},
	{"_CrtSetAllocHook",				(FARPROC)lh__CrtSetAllocHook,				0},
	{"_CrtSetBreakAlloc",			(FARPROC)lh__CrtSetBreakAlloc,			0},
	{"_CrtSetDbgFlag",				(FARPROC)lh__CrtSetDbgFlag,				0},
	{"_CrtSetDumpClient",			(FARPROC)lh__CrtSetDumpClient,			0},
	{"_expand",							(FARPROC)lh__expand,							0},
	{"_expand_dbg",					(FARPROC)lh__expand_dbg,					0},
	{"_free_dbg",						(FARPROC)lh__free_dbg,						0},
	{"_malloc_dbg",					(FARPROC)lh__malloc_dbg,					0},
	{"_msize",							(FARPROC)lh__msize,							0},
	{"_msize_dbg",						(FARPROC)lh__msize_dbg,						0},
	{"_realloc_dbg",					(FARPROC)lh__realloc_dbg,					0},
	{"_heapchk",						(FARPROC)lh__heapchk,						0},
	{"_heapmin",						(FARPROC)lh__heapmin,						0},
	{"_heapset",						(FARPROC)lh__heapset,						0},
	{"_heapwalk",						(FARPROC)lh__heapwalk,						0},
#if 0
	{"_msize",							(FARPROC)db__msize,							0},
	{"calloc",							(FARPROC)db_calloc,							0},
	{"free",								(FARPROC)db_free,								0},
	{"malloc",							(FARPROC)db_malloc,							0},
	{"realloc",							(FARPROC)db_realloc,							0},
#else
	{"_msize",							(FARPROC)lh__msize,							0},
	{"calloc",							(FARPROC)lh_calloc,							0},
	{"malloc",							(FARPROC)lh_malloc,							0},
	{"realloc",							(FARPROC)lh_realloc,							0},
	{"free",								(FARPROC)lh_free,								0},
#endif


	// operator new, new[], delete, delete[].

	{"??2@YAPAXI@Z",                    (FARPROC)lh_malloc,                         0},
	{"??_U@YAPAXI@Z",                    (FARPROC)lh_malloc,                         0},
	{"??3@YAXPAX@Z",                    (FARPROC)lh_free,                         0},
	{"??_V@YAXPAX@Z",                    (FARPROC)lh_free,                         0},

	// the nothrow variants new, new[].

	{"??2@YAPAXIABUnothrow_t@std@@@Z",   (FARPROC)lh_new_nothrow,                         0},
	{"??_U@YAPAXIABUnothrow_t@std@@@Z",   (FARPROC)lh_new_nothrow,                         0},

	// EDB: The debug versions of operator new & delete.

	{"??2@YAPAXIHPBDH@Z",               (FARPROC)lh_debug_operator_new,             0},
	{"??3@YAXPAXHPBDH@Z",               (FARPROC)lh_debug_operator_delete,          0},

	// EDB: And the nh_malloc_foo.

	{"_nh_malloc_dbg",                  (FARPROC)lh_nh_malloc_dbg,                  0},

};
#endif

/*---------------------------------------------------------------------------
*/
static void PatchIt(PATCH *patch)
{
	// change rights on CRT Library module to execute/read/write
	MEMORY_BASIC_INFORMATION mbi_thunk;
	VirtualQuery((void*)patch->original, &mbi_thunk, 
			sizeof(MEMORY_BASIC_INFORMATION));
	VirtualProtect(mbi_thunk.BaseAddress, mbi_thunk.RegionSize, 
			PAGE_EXECUTE_READWRITE, &mbi_thunk.Protect);

	// patch CRT library original routine:
	// 	save original 5 code bytes for exit restoration
	//		write jmp <patch_routine> (5 bytes long) to original
	memcpy(patch->codebytes, patch->original, sizeof(patch->codebytes));
	unsigned char *patchloc = (unsigned char*)patch->original;
	*patchloc++ = IAX86_NEARJMP_OPCODE;
	*(unsigned*)patchloc = MakeIAX86Offset(patch->replacement, patch->original);
	
	// reset CRT library code to original page protection.
	VirtualProtect(mbi_thunk.BaseAddress, mbi_thunk.RegionSize, 
			mbi_thunk.Protect, &mbi_thunk.Protect);
}

/*---------------------------------------------------------------------------
*/
static bool PatchMeIn(void)
{
	// acquire the module handles for the CRT libraries (release and debug)
	HMODULE RlsCRTLibrary = GetModuleHandle(RlsCRTLibraryName);

#ifdef _DEBUG
	HMODULE DbgCRTLibrary = GetModuleHandle(DbgCRTLibraryName);
#endif

	HMODULE DefCRTLibrary = 
#ifdef _DEBUG
			DbgCRTLibrary? DbgCRTLibrary: 
#endif	
			RlsCRTLibrary;

	// assign function pointers for required CRT support functions
#if 0
	if (DefCRTLibrary)
	{
		lh_memcpy_ptr = (void(*)(void*,const void*,size_t))
				GetProcAddress(DefCRTLibrary, "memcpy");
		lh_memset_ptr = (void(*)(void*,int,size_t))
				GetProcAddress(DefCRTLibrary, "memset");
	}
#endif

	// patch all relevant Release CRT Library entry points
	unsigned i;
	bool patchedRls = false;
	if (RlsCRTLibrary)
		for (i = 0; i < sizeof(rls_patches) / sizeof(*rls_patches); i++)
			if (rls_patches[i].original = GetProcAddress(RlsCRTLibrary, rls_patches[i].import))
			{
				PatchIt(&rls_patches[i]);
				patchedRls = true;
			}

#ifdef _DEBUG
	// patch all relevant Debug CRT Library entry points
	bool patchedDbg = false;
	if (DbgCRTLibrary)
		for (i = 0; i < sizeof(dbg_patches) / sizeof(*dbg_patches); i++)
			if (dbg_patches[i].original = GetProcAddress(DbgCRTLibrary, dbg_patches[i].import))
			{
				PatchIt(&dbg_patches[i]);
				patchedDbg = true;
			}

	// no point in staying loaded if we didn't patch anything...
	return patchedRls || patchedDbg;
#else
	return patchedRls;
#endif
}

#if PATCH_ME_OUT
/*---------------------------------------------------------------------------
**	UnPatchIt - this code is not actually executed, but I couldn't resist the
**	urge to write - I hate only doing half a job! :)
*/
static void UnPatchIt(PATCH *patch)
{
	// change rights on CRT Library module to execute/read/write
	MEMORY_BASIC_INFORMATION mbi_thunk;
	VirtualQuery((void*)patch->original, &mbi_thunk, 
			sizeof(MEMORY_BASIC_INFORMATION));
	VirtualProtect(mbi_thunk.BaseAddress, mbi_thunk.RegionSize, 
			PAGE_EXECUTE_READWRITE, &mbi_thunk.Protect);

	// write original code bytes back to original CRT library routine
	memcpy(patch->original, patch->codebytes, sizeof(patch->codebytes));
	
	// reset CRT library code to original page protection.
	VirtualProtect(mbi_thunk.BaseAddress, mbi_thunk.RegionSize, 
			mbi_thunk.Protect, &mbi_thunk.Protect);
}
#endif

/*---------------------------------------------------------------------------
**	PatchMeOut - unpatch the CRT(s).
*/
static void PatchMeOut(void)
{
#if PATCH_ME_OUT
	// unpatch all previously patched Release CRT Library entry points
	for (unsigned i = 0; i < sizeof(rls_patches) / sizeof(*rls_patches); i++)
		if (rls_patches[i].original)
			UnPatchIt(&rls_patches[i]);

	// unpatch all previously patched Debug CRT Library entry points
	for (i = 0; i < sizeof(dbg_patches) / sizeof(*dbg_patches); i++)
		if (dbg_patches[i].original)
			UnPatchIt(&dbg_patches[i]);
#endif
}


extern "C" 
{
	// this global data item is used by the app-linked obj to reference libhoard.dll
	// asap in the executable - we want libhoard.dll's DllMain to run first...
	__declspec(dllexport) int ReferenceMe;

	/*---------------------------------------------------------------------------
	**	This is actually the entry point for this Win32 module - specified in the
	**	makefile using a linker option: /entry:LibHoardMain. Do NOT link with the
	**	CRT to create this DLL. 
	*/
	BOOL WINAPI LibHoardMain(HANDLE hinstDLL, DWORD fdwReason, LPVOID lpreserved)
	{
		switch (fdwReason)
		{
			case DLL_PROCESS_ATTACH:
				DisableThreadLibraryCalls((HMODULE)hinstDLL);
				return PatchMeIn();

			case DLL_PROCESS_DETACH:
				PatchMeOut();
				return TRUE;
		}
		return FALSE;
	}
}	// extern "C"
