/*
 * Copyright (c) 2015 Los Alamos National Security, LLC. All rights reserved.
 * Copyright (c) 2015 Cray Inc.  All rights reserved.
 *
 * This software is available to you under a choice of one of two
 * licenses.  You may choose to be licensed under the terms of the GNU
 * General Public License (GPL) Version 2, available from the file
 * COPYING in the main directory of this source tree, or the
 * BSD license below:
 *
 *     Redistribution and use in source and binary forms, with or
 *     without modification, are permitted provided that the following
 *     conditions are met:
 *
 *      - Redistributions of source code must retain the above
 *        copyright notice, this list of conditions and the following
 *        disclaimer.
 *
 *      - Redistributions in binary form must reproduce the above
 *        copyright notice, this list of conditions and the following
 *        disclaimer in the documentation and/or other materials
 *        provided with the distribution.
 *
 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
 * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
 * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
 * NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS
 * BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN
 * ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
 * CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
 * SOFTWARE.
 */

#include <stdio.h>
#include <stdlib.h>
#include <errno.h>
#include <getopt.h>
#include <poll.h>
#include <time.h>
#include <string.h>
#include <pthread.h>

#include <rdma/fabric.h>
#include <rdma/fi_domain.h>
#include <rdma/fi_errno.h>
#include <rdma/fi_endpoint.h>
#include <rdma/fi_cm.h>

#include <stdio.h>
#include <stdlib.h>
#include <inttypes.h>

#include "gnix_vc.h"
#include "gnix_cm_nic.h"
#include "gnix_hashtable.h"
#include "gnix_rma.h"
#include "gnix_ep.h"

#include <criterion/criterion.h>

static struct fid_fabric *fab;
static struct fid_domain *dom;
static struct fid_ep *ep[2];
static struct fid_av *av;
static struct fi_info *hints;
static struct fi_info *fi;
void *ep_name[2];
size_t gni_addr[2];
static struct fid_cq *msg_cq[2];
static struct fi_cq_attr cq_attr;

#define BUF_SZ (8*1024)
char *target;
char *source;
struct fid_mr *rem_mr, *loc_mr;
uint64_t mr_key;

void cancel_setup(void)
{
	int ret = 0;
	struct fi_av_attr attr;
	size_t addrlen = 0;

	hints = fi_allocinfo();
	cr_assert(hints, "fi_allocinfo");

	hints->domain_attr->cq_data_size = 4;
	hints->mode = ~0;

	hints->fabric_attr->name = strdup("gni");

	ret = fi_getinfo(FI_VERSION(1, 0), NULL, 0, 0, hints, &fi);
	cr_assert(!ret, "fi_getinfo");

	ret = fi_fabric(fi->fabric_attr, &fab, NULL);
	cr_assert(!ret, "fi_fabric");

	ret = fi_domain(fab, fi, &dom, NULL);
	cr_assert(!ret, "fi_domain");

	attr.type = FI_AV_MAP;
	attr.count = 16;

	ret = fi_av_open(dom, &attr, &av, NULL);
	cr_assert(!ret, "fi_av_open");

	ret = fi_endpoint(dom, fi, &ep[0], NULL);
	cr_assert(!ret, "fi_endpoint");

	cq_attr.format = FI_CQ_FORMAT_CONTEXT;
	cq_attr.size = 1024;
	cq_attr.wait_obj = 0;

	ret = fi_cq_open(dom, &cq_attr, &msg_cq[0], 0);
	cr_assert(!ret, "fi_cq_open");

	ret = fi_cq_open(dom, &cq_attr, &msg_cq[1], 0);
	cr_assert(!ret, "fi_cq_open");

	ret = fi_ep_bind(ep[0], &msg_cq[0]->fid, FI_SEND | FI_RECV);
	cr_assert(!ret, "fi_ep_bind");

	ret = fi_getname(&ep[0]->fid, NULL, &addrlen);
	cr_assert(addrlen > 0);

	ep_name[0] = malloc(addrlen);
	cr_assert(ep_name[0] != NULL);

	ret = fi_getname(&ep[0]->fid, ep_name[0], &addrlen);
	cr_assert(ret == FI_SUCCESS);

	ret = fi_endpoint(dom, fi, &ep[1], NULL);
	cr_assert(!ret, "fi_endpoint");

	ret = fi_ep_bind(ep[1], &msg_cq[1]->fid, FI_SEND | FI_RECV);
	cr_assert(!ret, "fi_ep_bind");

	ep_name[1] = malloc(addrlen);
	cr_assert(ep_name[1] != NULL);

	ret = fi_getname(&ep[1]->fid, ep_name[1], &addrlen);
	cr_assert(ret == FI_SUCCESS);

	ret = fi_av_insert(av, ep_name[0], 1, &gni_addr[0], 0,
				NULL);
	cr_assert(ret == 1);

	ret = fi_av_insert(av, ep_name[1], 1, &gni_addr[1], 0,
				NULL);
	cr_assert(ret == 1);

	ret = fi_ep_bind(ep[0], &av->fid, 0);
	cr_assert(!ret, "fi_ep_bind");

	ret = fi_ep_bind(ep[1], &av->fid, 0);
	cr_assert(!ret, "fi_ep_bind");

	ret = fi_enable(ep[0]);
	cr_assert(!ret, "fi_ep_enable");

	ret = fi_enable(ep[1]);
	cr_assert(!ret, "fi_ep_enable");

	target = malloc(BUF_SZ);
	assert(target);

	source = malloc(BUF_SZ);
	assert(source);

	ret = fi_mr_reg(dom, target, BUF_SZ,
			FI_REMOTE_WRITE, 0, 0, 0, &rem_mr, &target);
	cr_assert_eq(ret, 0);

	ret = fi_mr_reg(dom, source, BUF_SZ,
			FI_REMOTE_WRITE, 0, 0, 0, &loc_mr, &source);
	cr_assert_eq(ret, 0);

	mr_key = fi_mr_key(rem_mr);
}

void cancel_teardown(void)
{
	int ret = 0;

	fi_close(&loc_mr->fid);
	fi_close(&rem_mr->fid);

	free(target);
	free(source);

	ret = fi_close(&ep[0]->fid);
	cr_assert(!ret, "failure in closing ep.");

	ret = fi_close(&ep[1]->fid);
	cr_assert(!ret, "failure in closing ep.");

	ret = fi_close(&msg_cq[0]->fid);
	cr_assert(!ret, "failure in send cq.");

	ret = fi_close(&msg_cq[1]->fid);
	cr_assert(!ret, "failure in recv cq.");

	ret = fi_close(&av->fid);
	cr_assert(!ret, "failure in closing av.");

	ret = fi_close(&dom->fid);
	cr_assert(!ret, "failure in closing domain.");

	ret = fi_close(&fab->fid);
	cr_assert(!ret, "failure in closing fabric.");

	fi_freeinfo(fi);
	fi_freeinfo(hints);
	free(ep_name[0]);
	free(ep_name[1]);
}

void cancel_init_data(char *buf, int len, char seed)
{
	int i;

	for (i = 0; i < len; i++) {
		buf[i] = seed++;
	}
}

int cancel_check_data(char *buf1, char *buf2, int len)
{
	int i;

	for (i = 0; i < len; i++) {
		if (buf1[i] != buf2[i]) {
			printf("data mismatch, elem: %d, exp: %x, act: %x\n",
			       i, buf1[i], buf2[i]);
			return 0;
		}
	}

	return 1;
}

/*******************************************************************************
 * Test MSG functions
 ******************************************************************************/

TestSuite(gnix_cancel, .init = cancel_setup, .fini = cancel_teardown,
	  .disabled = false);

Test(gnix_cancel, cancel_ep_send)
{
	int ret;
	struct gnix_fid_ep *gnix_ep;
	struct gnix_fab_req *req;
	struct fi_cq_err_entry buf;
	struct gnix_vc *vc;
	void *foobar_ptr = NULL;
	gnix_ht_key_t *key;

	/* simulate a posted request */
	gnix_ep = container_of(ep[0], struct gnix_fid_ep, ep_fid);
	req = _gnix_fr_alloc(gnix_ep);

	req->msg.send_addr = 0xdeadbeef;
	req->msg.send_len = 128;
	req->user_context = foobar_ptr;
	req->type = GNIX_FAB_RQ_SEND;

	/* allocate, store vc */
	ret = _gnix_vc_alloc(gnix_ep, NULL, &vc);
	cr_assert(ret == FI_SUCCESS, "_gnix_vc_alloc failed");

	key = (gnix_ht_key_t *)&gnix_ep->my_name.gnix_addr;
	ret = _gnix_ht_insert(gnix_ep->vc_ht, *key, vc);
	cr_assert(!ret);

	/* make a dummy request */
	fastlock_acquire(&vc->tx_queue_lock);
	slist_insert_head(&req->slist, &vc->tx_queue);
	fastlock_release(&vc->tx_queue_lock);

	/* cancel simulated request */
	ret = fi_cancel(&ep[0]->fid, foobar_ptr);
	cr_assert(ret == FI_SUCCESS, "fi_cancel failed");

	/* check for event */
	ret = fi_cq_readerr(msg_cq[0], &buf, FI_SEND);
	cr_assert(ret == 1, "did not find one error event");

	cr_assert(buf.buf == (void *) 0xdeadbeef, "buffer mismatch");
	cr_assert(buf.data == 0, "data mismatch");
	cr_assert(buf.err == FI_ECANCELED, "error code mismatch");
	cr_assert(buf.prov_errno == FI_ECANCELED, "prov error code mismatch");
	cr_assert(buf.len == 128, "length mismatch");
}

Test(gnix_cancel, cancel_ep_recv)
{
	int ret;
	struct fi_cq_err_entry buf;

	/* simulate a posted request */
	ret = fi_recv(ep[0], (void *) 0xdeadbeef, 128, 0, FI_ADDR_UNSPEC,
			(void *) 0xcafebabe);
	cr_assert(ret == FI_SUCCESS, "fi_recv failed");

	/* cancel simulated request */
	ret = fi_cancel(&ep[0]->fid, (void *) 0xcafebabe);
	cr_assert(ret == FI_SUCCESS, "fi_cancel failed");

	/* check for event */
	ret = fi_cq_readerr(msg_cq[0], &buf, FI_RECV);
	cr_assert(ret == 1, "did not find one error event");

	cr_assert(buf.buf == (void *) 0xdeadbeef, "buffer mismatch");
	cr_assert(buf.data == 0, "data mismatch");
	cr_assert(buf.err == FI_ECANCELED, "error code mismatch");
	cr_assert(buf.prov_errno == FI_ECANCELED, "prov error code mismatch");
	cr_assert(buf.len == 128, "length mismatch");
}

Test(gnix_cancel, cancel_ep_no_event)
{
	int ret;

	ret = fi_cancel(&ep[0]->fid, NULL);
	cr_assert(ret == -FI_ENOENT, "fi_cancel failed");
}

Test(gnix_cancel, cancel_ep_no_domain)
{
	int ret;
	struct gnix_fid_ep *gnix_ep;
	struct gnix_fid_domain *gnix_dom;

	/* simulate a disconnected endpoint */
	gnix_ep = container_of(ep[0], struct gnix_fid_ep, ep_fid);
	gnix_dom = gnix_ep->domain;
	gnix_ep->domain = NULL;

	/* run test */
	ret = fi_cancel(&ep[0]->fid, NULL);
	cr_assert(ret == -FI_EDOMAIN, "fi_cancel failed");

	/* reconnect */
	gnix_ep->domain = gnix_dom;
}
