#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <unistd.h>
#include <errno.h>
#include <signal.h>
#include <syslog.h>
#include <fcntl.h>
#include <time.h>
#include <inttypes.h>
#include <sys/ioctl.h>
#include <sys/time.h>
#include <sys/socket.h>
#include <sys/un.h>
#include <netinet/in.h>
#include <netinet/tcp.h>

#include <xen/xen.h>

#include "evtchnd.h"
#include "list.h"
#include "daemon.h"

/* ---------------------------------------------------------------------- */
/* state info                                                             */

struct port {
    struct conn      *conn;
    struct port      *peer;
    int              port;
};

struct domain {
    struct list_head list;
    int              domid;
    int              refcount;
    struct port      p[NR_EVENT_CHANNELS];
};
static LIST_HEAD(domains);

struct conn {
    struct list_head list;
    int              fd;
    int              ports;
    struct domain    *domain;
};
static LIST_HEAD(conns);

static int     termsig;
static int     slisten;

/* ---------------------------------------------------------------------- */
/* config info                                                            */

static int     timeout        = 60;

/* ---------------------------------------------------------------------- */

static struct domain *get_domain(int domid);
static void put_domain(struct domain *domain);

/* ---------------------------------------------------------------------- */

static void catchsig(int sig)
{
    termsig = sig;
}

static void usage(FILE *fp)
{
    fprintf(fp,
	    "\n"
	    "xen event channel emulation daemon\n"
	    "\n"
	    "usage: evtchnd [ options ]\n"
	    "\n"
	    "Options:\n"
	    "  -h       print this text\n"
	    "  -d       enable debug output                 [%s]\n"
	    "  -F       do not fork into background         [%s]\n"
	    "  -t sec   set network timeout                 [%i]\n"
	    "  -p file  set pidfile                         [%s]\n"
	    "  -l file  specify logfile\n"
	    "\n",
 	    debug      ?  "on" : "off",
 	    dontdetach ?  "on" : "off",
	    timeout,
	    pidfile ? pidfile : "-");
}

static void panic(char *msg)
{
    fprintf(stderr, "panic: %s\n", msg);
    exit(1);
}

/* ---------------------------------------------------------------------- */

static struct port *alloc_port(struct conn *conn, char *reason)
{
    struct port *p = NULL;
    int i;

    for (i = 1; i < NR_EVENT_CHANNELS; i++) {
	if (NULL != conn->domain->p[i].conn)
	    continue;
	p = conn->domain->p+i;
	p->port = i;
	p->conn = conn;
	conn->ports++;
	d1printf("%3d: alloc port %d, domain %d (%s)\n",
		 conn->fd, p->port, conn->domain->domid, reason);
	return p;
    }
    return NULL;
}

static void bind_port_peer(struct port *p, int domid, int port)
{
    struct domain *domain;
    struct port *o;
    char *msg = "ok";

    domain = get_domain(domid);
    o = domain->p+port;
    if (!o->conn) {
	msg = "peer not allocated";
    } else if (o->peer) {
	msg = "peer already bound";
    } else if (p->peer) {
	msg = "port already bound";
    } else {
	o->peer = p;
	p->peer = o;
    }
    d1printf("%3d: bind port %d domain %d  <->  port %d domain %d : %s\n",
	     p->conn->fd,
	     p->port, p->conn->domain->domid,
	     port, domid, msg);
    
    put_domain(domain);
}

static void unbind_port(struct port *p)
{
    struct port *o;

    o = p->peer;
    if (o) {
	d1printf("%3d: unbind port %d domain %d  <->  port %d domain %d\n",
		 p->conn->fd,
		 p->port, p->conn->domain->domid,
		 o->port, o->conn->domain->domid);
	o->peer = NULL;
	p->peer = NULL;
    }
}

static void notify_send_peer(struct port *peer)
{
    struct evtchn_ioctl_msg msg;
    struct evtchnd_port *n = (void*)(&msg.data);

    memset(&msg, 0, sizeof(msg));
    n->port = peer->port;
    msg.ioctl = EVTCHND_NOTIFY;
    write(peer->conn->fd, &msg, sizeof(msg));
}

static void notify_port(struct port *p)
{
    if (p->peer) {
	notify_send_peer(p->peer);
	d2printf("%3d: notify port %d domain %d  ->  port %d domain %d\n",
		 p->conn->fd, p->port, p->conn->domain->domid,
		 p->peer->port, p->peer->conn->domain->domid);
    } else {
	d1printf("%3d: notify port %d domain %d  ->  unconnected\n",
		 p->conn->fd, p->port, p->conn->domain->domid);
    }
}

static void unmask_port(struct port *p)
{
    /* nothing to do */
}

static void release_port(struct port *p)
{
    d1printf("%3d: release port %d, domain %d\n",
	     p->conn->fd, p->port, p->conn->domain->domid);
    unbind_port(p);
    p->conn->ports--;
    p->port = 0;
    p->conn = 0;
}

static struct domain *get_domain(int domid)
{
    struct list_head *item;
    struct domain *domain;
    
    list_for_each(item, &domains) {
	domain = list_entry(item, struct domain, list);
	if (domain->domid == domid)
	    goto done;
    }

    domain = malloc(sizeof(*domain));
    if (NULL == domain)
	panic("oom");
    memset(domain,0,sizeof(*domain));
    if (domid)
	domain->domid = domid;
    list_add_tail(&domain->list, &domains);
    d1printf("  ?: new domain id %d\n", domain->domid);

done:
    domain->refcount++;
    return domain;
}

static void put_domain(struct domain *domain)
{
    domain->refcount--;
    if (domain->refcount)
	return;
    d1printf("  ?: del domain id %d\n", domain->domid);
    list_del(&domain->list);
    free(domain);
}

static struct conn *new_conn(int fd)
{
    struct conn *conn;

    conn = malloc(sizeof(*conn));
    if (NULL == conn)
	panic("oom");
    d1printf("%3d: new\n", fd);
    memset(conn,0,sizeof(*conn));
    conn->fd = fd;
    conn->domain = get_domain(0);
    fcntl(conn->fd,F_SETFL,O_NONBLOCK);
    list_add_tail(&conn->list, &conns);
    return conn;
}

static void del_conn(struct conn *conn)
{
    struct port *p = NULL;
    int i;

    for (i = 1; i < NR_EVENT_CHANNELS; i++) {
	p = conn->domain->p+i;
	if (conn != p->conn)
	    continue;
	release_port(p);
    }
    put_domain(conn->domain);

    d1printf("%3d: del\n", conn->fd);
    close(conn->fd);
    list_del(&conn->list);
    free(conn);
}

/* ---------------------------------------------------------------------- */
/* main loop                                                              */

static void sock_data(struct conn *conn)
{
    struct evtchn_ioctl_msg req;
    struct evtchn_ioctl_msg rsp;
    int rc;

    rc = read(conn->fd, &req, sizeof(req));
    switch (rc) {
    case -1:
	d1printf("%3d: read: %s\n", conn->fd, strerror(errno));
	del_conn(conn);
	return;
    case 0:
	d1printf("%3d: EOF\n", conn->fd);
	del_conn(conn);
	return;
    default:
	/* don't return */
	break;
    }

    memset(&rsp, 0, sizeof(rsp));
    rsp.ioctl = req.ioctl;
    switch (req.ioctl) {
    case IOCTL_EVTCHN_BIND_UNBOUND_PORT:
    {
//	struct ioctl_evtchn_bind_unbound_port *io = (void*)req.data;
	struct port *p = alloc_port(conn, "unbound");
	rsp.retval = p->port;
	break;
    }
    case IOCTL_EVTCHN_BIND_INTERDOMAIN:
    {
	struct ioctl_evtchn_bind_interdomain *io = (void*)req.data;
	struct port *p = alloc_port(conn, "interdomain");
	if (io->remote_port >= NR_EVENT_CHANNELS) {
	    rsp.retval = -1;
	    rsp.error = EINVAL;
	} else {
	    bind_port_peer(p, io->remote_domain, io->remote_port);
	    rsp.retval = p->port;
	}
	break;
    }
    case IOCTL_EVTCHN_BIND_VIRQ:
    {
//	struct ioctl_evtchn_bind_virq *io = (void*)req.data;
	struct port *me = alloc_port(conn, "virq");
	rsp.retval = me->port;
	/* FIXME: link virq */
	break;
    }
    case IOCTL_EVTCHN_UNBIND:
    {
	struct ioctl_evtchn_unbind *io = (void*)req.data;
	struct port *p = conn->domain->p + io->port;
	if (io->port >= NR_EVENT_CHANNELS) {
	    rsp.retval = -1;
	    rsp.error = EINVAL;
	} else if (p->conn == conn) {
	    unbind_port(p);
	    release_port(p);
	}
	rsp.retval = 0;
	break;
    }
    case IOCTL_EVTCHN_NOTIFY:
    {
	struct ioctl_evtchn_notify *io = (void*)req.data;
	struct port *p = conn->domain->p + io->port;
	if (io->port >= NR_EVENT_CHANNELS) {
	    rsp.retval = -1;
	    rsp.error = EINVAL;
	} else if (p->conn == conn)
	    notify_port(p);
	return; /* no reply */
    }

    case EVTCHND_DOMID:
    {
	struct evtchnd_domid *io = (void*)req.data;
	if (0 == conn->ports) {
	    put_domain(conn->domain);
	    conn->domain = get_domain(io->domid);
	    rsp.retval = conn->domain->domid;
	} else {
	    rsp.retval = -1;
	    rsp.error = EINVAL;
	}
	break;
    }
    case EVTCHND_UNMASK:
    {
	struct evtchnd_port *io = (void*)req.data;
	struct port *p = conn->domain->p + io->port;
	if (io->port >= NR_EVENT_CHANNELS) {
	    rsp.retval = -1;
	    rsp.error = EINVAL;
	} else if (p->conn == conn)
	    unmask_port(p);
	return; /* no reply */
    }
    case EVTCHND_NOTIFY:
	/* send it back (for queue reordering) */
	memcpy(&rsp, &req, sizeof(rsp));
	break;

    default:
	d1printf("%3d: unknown request 0x%x\n", conn->fd, req.ioctl);
	rsp.retval = -1;
	rsp.error = ENOSYS;
	break;
    }
    write(conn->fd, &rsp, sizeof(rsp));
}

static void mainloop(void)
{
    struct list_head *item, *safe;
    struct conn *conn;
    struct timeval tv;
    fd_set rd;
    int max, fd;

    for (;!termsig;) {
	FD_ZERO(&rd);
	FD_SET(slisten,&rd);
	max = slisten;
	list_for_each(item, &conns) {
	    conn = list_entry(item, struct conn, list);
	    FD_SET(conn->fd,&rd);
	    if (max < conn->fd)
		max = conn->fd;
	}

	tv.tv_sec  = timeout;
	tv.tv_usec = 0;
	if (-1 == select(max+1, &rd, NULL, NULL, timeout ? &tv : NULL)) {
	    if (debug)
		perror("select");
	    continue;
	}

	if (FD_ISSET(slisten,&rd)) {
	    /* new connection */
	    fd = accept(slisten,NULL,NULL);
	    if (-1 != fd)
		conn = new_conn(fd);
	}

	list_for_each_safe(item, safe, &conns) {
	    conn = list_entry(item, struct conn, list);
	    if (FD_ISSET(conn->fd,&rd))
		sock_data(conn);
	}
    }
}

/* ---------------------------------------------------------------------- */

static int socket_tcp(void)
{
#ifdef EVTCHND_PORT
    struct sockaddr_in in;
    int slisten, opt=1;

    if (-1 == (slisten = socket(PF_INET, SOCK_STREAM, IPPROTO_TCP))) {
	perror("socket(tcp)");
	return -1;
    }
    setsockopt(slisten, SOL_SOCKET, SO_REUSEADDR, &opt, sizeof(opt));
    setsockopt(slisten, IPPROTO_TCP, TCP_NODELAY, &opt, sizeof(opt));

    in.sin_family      = AF_INET;
    in.sin_port        = htons(EVTCHND_PORT);
    in.sin_addr.s_addr = htonl(INADDR_LOOPBACK);
    if (-1 == bind(slisten, (struct sockaddr*) &in, sizeof(in))) {
	perror("bind(tcp)");
	return -1;
    }

    d1printf("bound to tcp port %d\n", EVTCHND_PORT);
    return slisten;
#else
    return -1;
#endif
}

static int socket_unix(void)
{
#ifdef EVTCHND_PATH
    struct sockaddr_un un;
    int slisten;

    if (-1 == (slisten = socket(PF_UNIX, SOCK_STREAM, 0))) {
	perror("socket(unix)");
	return -1;
    }

    un.sun_family      = AF_UNIX;
    strncpy(un.sun_path, EVTCHND_PATH, sizeof(un.sun_path));
    if (-1 == bind(slisten, (struct sockaddr*) &un, sizeof(un))) {
	perror("bind(unix)");
	return -1;
    }
    chmod(EVTCHND_PATH, 0666);

    d1printf("bound to unix socket %s\n", EVTCHND_PATH);
    return slisten;
#else
    return -1;
#endif
}

int
main(int argc, char *argv[])
{
    struct sigaction act,old;
    int c;
    
    /* parse options */
    for (;;) {
	if (-1 == (c = getopt(argc,argv,"hdFt:p:l:")))
	    break;
	switch (c) {
	case 'd':
	    debug++;
	    break;
	case 'F':
	    dontdetach++;
	    break;
	case 't':
	    timeout = atoi(optarg);
	    break;
	case 'p':
	    pidfile = optarg;
	    break;
	case 'l':
	    log_setfile(optarg);
	    break;
	case 'h':
	    usage(stdout);
	    exit(0);
	default:
	    usage(stderr);
	    exit(1);
	}
    }

    slisten = -1;
    if (-1 == slisten)
	slisten = socket_unix();
    if (-1 == slisten)
	slisten = socket_tcp();
    if (-1 == slisten)
	exit(1);

    if (-1 == listen(slisten, 8)) {
	perror("listen");
        exit(1);
    }

    /* fork into background, handle pidfile */
    daemonize();

    /* setup signal handler */
    memset(&act,0,sizeof(act));
    sigemptyset(&act.sa_mask);
    act.sa_handler = SIG_IGN;
    sigaction(SIGPIPE,&act,&old);
    sigaction(SIGCHLD,&act,&old);
    act.sa_handler = catchsig;
    sigaction(SIGTERM,&act,&old);
    if (debug)
	sigaction(SIGINT,&act,&old);

    /* go! */
    mainloop();

    /* cleanup */
    close(slisten);
    unlink(EVTCHND_PATH);
    exit(0);
}
