#include <stdio.h>
#include <stdlib.h>
#include <unistd.h>
#include <string.h>
#include <inttypes.h>
#include <errno.h>
#include <assert.h>

#include "xenner.h"
#include "hexdump.h"
#include "mm.h"

/* ------------------------------------------------------------------ */

uint32_t *find_pgd_32(struct xenvcpu *vcpu, uint32_t va)
{
    struct xenvm *xen = vcpu->vm;
    pfn_t mfn;
    int slot;
    uint32_t *pgd;

    mfn  = addr_to_frame(vcpu->sregs.cr3);
    if (mfn >= xen->pg_total)
	return NULL;
    slot = PGD_INDEX_32(va);
    pgd  = xen->memory + frame_to_addr(mfn);
    d3printf("%s  : va 0x%08" PRIx32 " -> mfn 0x%lx +0x%03x  [%08" PRIx32"]\n",
	     __FUNCTION__, va, mfn, slot, pgd[slot]);
    return pgd + slot;
}

uint32_t *find_pte_32(struct xenvcpu *vcpu, uint32_t va)
{
    struct xenvm *xen = vcpu->vm;
    pfn_t mfn;
    int slot;
    uint32_t *pgd, *pte;

    pgd  = find_pgd_32(vcpu, va);
    if (!test_pgflag_32(*pgd, _PAGE_PRESENT))
	return NULL;
    mfn  = get_pgframe_32(*pgd);
    if (mfn >= xen->pg_total)
	return NULL;
    slot = PTE_INDEX_32(va);
    pte = xen->memory + frame_to_addr(mfn);
    d3printf("  %s: va 0x%08" PRIx32 " -> mfn 0x%lx +0x%03x  [%08" PRIx32"]\n",
	     __FUNCTION__, va, mfn, slot, pte[slot]);
    return pte + slot;
}

void *find_pa_32(struct xenvcpu *vcpu, uint32_t va)
{
    struct xenvm *xen = vcpu->vm;
    pfn_t mfn;
    int offset;
    uint32_t *pte;

    pte = find_pte_32(vcpu, va);
    if (NULL == pte)
	return NULL;
    if (!test_pgflag_32(*pte, _PAGE_PRESENT))
	return NULL;
    mfn = get_pgframe_32(*pte);
    if (mfn >= xen->pg_total)
	return NULL;
    offset = addr_offset(va);
    d3printf("   %s: va 0x%08" PRIx32 " -> mfn 0x%lx +0x%03x\n",
	     __FUNCTION__, va, mfn, offset);
    return xen->memory + frame_to_addr(mfn) + offset;
}

/* ------------------------------------------------------------------ */

uint64_t *find_pgd_pae(struct xenvcpu *vcpu, uint32_t va)
{
    struct xenvm *xen = vcpu->vm;
    pfn_t mfn;
    int slot;
    uint64_t *pgd;

    mfn  = addr_to_frame(vcpu->sregs.cr3);
    if (mfn >= xen->pg_total)
	return NULL;
    slot = PGD_INDEX_PAE(va);
    pgd  = xen->memory + frame_to_addr(mfn);
    d3printf("%s    : va 0x%08" PRIx32 " -> mfn 0x%lx +0x%03x  [%08" PRIx64 "]\n",
	     __FUNCTION__, va, mfn, slot, pgd[slot]);
    return pgd + slot;
}

uint64_t *find_pmd_pae(struct xenvcpu *vcpu, uint32_t va)
{
    struct xenvm *xen = vcpu->vm;
    pfn_t mfn;
    int slot;
    uint64_t *pgd, *pmd;

    pgd = find_pgd_pae(vcpu, va);
    if (!test_pgflag_pae(*pgd, _PAGE_PRESENT))
	return NULL;
    mfn = get_pgframe_pae(*pgd);
    if (mfn >= xen->pg_total)
	return NULL;
    slot = PMD_INDEX_PAE(va);
    pmd = xen->memory + frame_to_addr(mfn);
    d3printf("  %s  : va 0x%08" PRIx32 " -> mfn 0x%lx +0x%03x  [%08" PRIx64 "]\n",
	     __FUNCTION__, va, mfn, slot, pmd[slot]);
    return pmd + slot;
}

uint64_t *find_pte_pae(struct xenvcpu *vcpu, uint32_t va)
{
    struct xenvm *xen = vcpu->vm;
    pfn_t mfn;
    int slot;
    uint64_t *pmd, *pte;

    pmd = find_pmd_pae(vcpu, va);
    if (NULL == pmd)
	return NULL;
    if (!test_pgflag_pae(*pmd, _PAGE_PRESENT))
	return NULL;
    mfn = get_pgframe_pae(*pmd);
    if (mfn >= xen->pg_total)
	return NULL;
    slot = PTE_INDEX_PAE(va);
    pte = xen->memory + frame_to_addr(mfn);
    d3printf("    %s: va 0x%08" PRIx32 " -> mfn 0x%lx +0x%03x  [%08" PRIx64 "]\n",
	     __FUNCTION__, va, mfn, slot, pte[slot]);
    return pte + slot;
}

void *find_pa_pae(struct xenvcpu *vcpu, uint32_t va)
{
    struct xenvm *xen = vcpu->vm;
    pfn_t mfn;
    int offset;
    uint64_t *pte;

    pte = find_pte_pae(vcpu, va);
    if (NULL == pte)
	return NULL;
    if (!test_pgflag_pae(*pte, _PAGE_PRESENT))
	return NULL;
    mfn = get_pgframe_pae(*pte);
    if (mfn >= xen->pg_total)
	return NULL;
    offset = addr_offset(va);
    d3printf("     %s: va 0x%08" PRIx32 " -> mfn 0x%lx +0x%03x\n",
	     __FUNCTION__, va, mfn, offset);
    return xen->memory + frame_to_addr(mfn) + offset;
}

int map_region_pse_pae(struct xenvcpu *vcpu, uint32_t va, uint32_t flags,
		       pfn_t start, pfn_t count)
{
    struct xenvm *xen = vcpu->vm;
    uint64_t *pmd;
    pfn_t mfn;

    d1printf("%s: mfns 0x%lx +0x%lx at 0x%" PRIx32 "\n",
	     __FUNCTION__, start, count, va);
    flags |= _PAGE_PSE;
    pmd = find_pmd_pae(vcpu, va);
    assert(pmd);
    for (mfn = start; mfn < start + count; mfn += PMD_COUNT_PAE, pmd++) {
	d1printf("%s:   old: %016" PRIx64 "\n", __FUNCTION__, *pmd);
	*pmd = get_pgentry_pae(mfn, flags);
	d1printf("%s:   new: %016" PRIx64 "\n", __FUNCTION__, *pmd);
    }
    return 0;
}

int map_linear_pgt_pae(struct xenvcpu *vcpu, uint32_t va, uint32_t flags)
{
    struct xenvm *xen = vcpu->vm;
    uint64_t *pgd;
    uint64_t *pmd;
    pfn_t mfn;
    int i;

    d1printf("%s: at 0x%" PRIx32 "\n", __FUNCTION__, va);
    pgd = find_pgd_pae(vcpu, 0);
    pmd = find_pmd_pae(vcpu, va);
    assert(pmd);

    for (i = 0; i < 4; i++) {
	if (test_pgflag_pae(pgd[i], _PAGE_PRESENT)) {
	    mfn = get_pgframe_pae(pgd[i]);
	    d1printf("%s:    %d mfn 0x%lx\n", __FUNCTION__, i, mfn);
	    pmd[i] = get_pgentry_pae(mfn, flags);
	} else {
	    d1printf("%s:    %d !present\n", __FUNCTION__, i);
	    pmd[i] = 0;
	}
    }
    return 0;
}

int map_normal_pgt_pae(struct xenvcpu *vcpu,
		       uint32_t va, uint32_t flags, pfn_t table)
{
    struct xenvm *xen = vcpu->vm;
    uint64_t *pmd;

    d1printf("%s: mfn 0x%lx at 0x%" PRIx32 "\n",
	     __FUNCTION__, table, va);
    pmd = find_pmd_pae(vcpu, va);
    assert(pmd);
    *pmd = get_pgentry_pae(table, flags);
    return 0;
}

void dump_pgtables_pae(struct xenvcpu *vcpu, uint32_t vaddr)
{
    void *ptr;

    ptr = find_pgd_pae(vcpu, vaddr);
    hexdump("pgd", ptr, 8);
    ptr = find_pmd_pae(vcpu, vaddr);
    hexdump("pmd", ptr, 8);
    ptr = find_pte_pae(vcpu, vaddr);
    hexdump("pte", ptr, 8);
}

/* ------------------------------------------------------------------ */

uint64_t *find_pgd_64(struct xenvcpu *vcpu, uint64_t va)
{
    struct xenvm *xen = vcpu->vm;
    pfn_t mfn;
    int slot;
    uint64_t *pgd;

    mfn  = addr_to_frame(vcpu->sregs.cr3);
    if (mfn >= xen->pg_total)
	return NULL;
    slot = PGD_INDEX_64(va);
    pgd  = xen->memory + frame_to_addr(mfn);
    d3printf("%s      : va 0x%08" PRIx64 " -> mfn 0x%lx +0x%03x  [%08" PRIx64 "]\n",
	     __FUNCTION__, va, mfn, slot, pgd[slot]);
    return pgd + slot;
}

uint64_t *find_pud_64(struct xenvcpu *vcpu, uint64_t va)
{
    struct xenvm *xen = vcpu->vm;
    pfn_t mfn;
    int slot;
    uint64_t *pgd, *pud;

    pgd = find_pgd_64(vcpu, va);
    if (!test_pgflag_64(*pgd, _PAGE_PRESENT))
	return NULL;
    mfn = get_pgframe_64(*pgd);
    if (mfn >= xen->pg_total)
	return NULL;
    slot = PUD_INDEX_64(va);
    pud = xen->memory + frame_to_addr(mfn);
    d3printf("  %s    : va 0x%08" PRIx64 " -> mfn 0x%lx +0x%03x  [%08" PRIx64 "]\n",
	     __FUNCTION__, va, mfn, slot, pud[slot]);
    return pud + slot;
}

uint64_t *find_pmd_64(struct xenvcpu *vcpu, uint64_t va)
{
    struct xenvm *xen = vcpu->vm;
    pfn_t mfn;
    int slot;
    uint64_t *pud, *pmd;

    pud = find_pud_64(vcpu, va);
    if (NULL == pud)
	return NULL;
    if (!test_pgflag_64(*pud, _PAGE_PRESENT))
	return NULL;
    mfn = get_pgframe_64(*pud);
    if (mfn >= xen->pg_total)
	return NULL;
    slot = PMD_INDEX_64(va);
    pmd = xen->memory + frame_to_addr(mfn);
    d3printf("    %s  : va 0x%08" PRIx64 " -> mfn 0x%lx +0x%03x  [%08" PRIx64 "]\n",
	     __FUNCTION__, va, mfn, slot, pmd[slot]);
    return pmd + slot;
}

uint64_t *find_pte_64(struct xenvcpu *vcpu, uint64_t va)
{
    struct xenvm *xen = vcpu->vm;
    pfn_t mfn;
    int slot;
    uint64_t *pmd, *pte;

    pmd = find_pmd_64(vcpu, va);
    if (NULL == pmd)
	return NULL;
    if (!test_pgflag_64(*pmd, _PAGE_PRESENT))
	return NULL;
    mfn = get_pgframe_64(*pmd);
    if (mfn >= xen->pg_total)
	return NULL;
    slot = PTE_INDEX_64(va);
    pte = xen->memory + frame_to_addr(mfn);
    d3printf("      %s: va 0x%08" PRIx64 " -> mfn 0x%lx +0x%03x  [%08" PRIx64 "]\n",
	     __FUNCTION__, va, mfn, slot, pte[slot]);
    return pte + slot;
}

void *find_pa_64(struct xenvcpu *vcpu, uint64_t va)
{
    struct xenvm *xen = vcpu->vm;
    pfn_t mfn;
    int offset;
    uint64_t *pte;

    pte = find_pte_64(vcpu, va);
    if (NULL == pte)
	return NULL;
    if (!test_pgflag_64(*pte, _PAGE_PRESENT))
	return NULL;
    mfn = get_pgframe_64(*pte);
    if (mfn >= xen->pg_total)
	return NULL;
    offset = addr_offset(va);
    d3printf("       %s: va 0x%08" PRIx64 " -> mfn 0x%lx +0x%03x\n",
	     __FUNCTION__, va, mfn, offset);
    return xen->memory + frame_to_addr(mfn) + offset;
}

/* ------------------------------------------------------------------ */

void *mfn_to_ptr(struct xenvm *xen, uint64_t mfn)
{
    if (mfn >= xen->pg_total)
	return NULL;
    return xen->memory + frame_to_addr(mfn);
}

void *guest_paddr_to_ptr(struct xenvm *xen, uint64_t addr)
{
    uint64_t mfn = xen->mfn_guest + addr_to_frame(addr);
    void *ptr = mfn_to_ptr(xen, mfn);
    return ptr ? ptr + addr_offset(addr) : NULL;
}

void *guest_vaddr_to_ptr(struct xenvcpu *vcpu, uint64_t addr)
{
    switch (vcpu->vm->mode) {
    case XENMODE_32:
	return find_pa_32(vcpu, addr);
    case XENMODE_PAE:
	return find_pa_pae(vcpu, addr);
    default:
	return find_pa_64(vcpu, addr);
    }
}

void *emu_paddr_to_ptr(struct xenvm *xen, uint64_t addr)
{
    if (addr > frame_to_addr(xen->pg_emu))
	return NULL;
    return mfn_to_ptr(xen, xen->mfn_emu) + addr;
}

void *emu_vaddr_to_ptr(struct xenvm *xen, uint64_t addr)
{
    return emu_paddr_to_ptr(xen, addr - xen->emu_vs);
}

/* ------------------------------------------------------------------ */

int copy_from_guest(struct xenvcpu *vcpu, void *dest, uint64_t vaddr, size_t size)
{
    int offset;
    int bytes;
    void *src;
    
    while (size > 0) {
	offset = addr_offset(vaddr);
	bytes  = PAGE_SIZE - offset;
	if (bytes > size)
	    bytes = size;
	src = guest_vaddr_to_ptr(vcpu, vaddr);
	if (NULL == src) {
	    struct xenvm *xen = vcpu->vm;
	    d1printf("%s: vaddr 0x%08" PRIx64 " => -EFAULT\n",
		     __FUNCTION__, vaddr);
#if 0
	    /* for debugging */
	    vcpu->vm->debug += 3;
	    src = guest_vaddr_to_ptr(vcpu, vaddr);
	    vcpu->vm->debug -= 3;
#endif
	    return -EFAULT;
	}
	memcpy(dest, src, bytes);
	dest  += bytes;
	vaddr += bytes;
	size  -= bytes;
    }
    return 0;
}

int copy_to_guest(struct xenvcpu *vcpu, uint64_t vaddr, void *src, size_t size)
{
    int offset;
    int bytes;
    void *dest;
    
    while (size > 0) {
	offset = addr_offset(vaddr);
	bytes  = PAGE_SIZE - offset;
	if (bytes > size)
	    bytes = size;
	dest = guest_vaddr_to_ptr(vcpu, vaddr);
	if (NULL == dest) {
	    struct xenvm *xen = vcpu->vm;
	    d1printf("%s: vaddr 0x%08" PRIx64 " => -EFAULT\n",
		     __FUNCTION__, vaddr);
	    return -EFAULT;
	}
	memcpy(dest, src, bytes);
	src   += bytes;
	vaddr += bytes;
	size  -= bytes;
    }
    return 0;
}
