/*
    MiddleMan filtering proxy server
    Copyright (C) 2002-2004  Jason McLaughlin

    This program is free software; you can redistribute it and/or modify
    it under the terms of the GNU General Public License as published by
    the Free Software Foundation; either version 2 of the License, or
    (at your option) any later version.

    This program is distributed in the hope that it will be useful,
    but WITHOUT ANY WARRANTY; without even the implied warranty of
    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    GNU General Public License for more details.

    You should have received a copy of the GNU General Public License
    along with this program; if not, write to the Free Software
    Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
*/

#include <stdio.h>
#include <string.h>
#include <unistd.h>
#include <errno.h>
#include <stdarg.h>
#include "proto.h"

extern GLOBAL *global;

SocketCallBack::SocketCallBack(int (*func) (void *, int, char *), void *arg)
{
	this->func = func;
	this->arg = arg;
}

int SocketCallBack::Call(int len, char *buf)
{
	return this->func(this->arg, len, buf);
}

Socket::Socket()
{
	cbid = 0;
	frozen = FALSE;
	inbuf_len = outbuf_len = 0;
	inbuf_reallen = outbuf_reallen = 0;
	inbuf = outbuf = NULL;
	bytesread = byteswritten = 0;
	this->fd = -1;
	this->type = Socket::SOCK_NORMAL;
}

Socket::Socket(int fd)
{
	cbid = 0;
	frozen = FALSE;
	inbuf_len = outbuf_len = 0;
	inbuf_reallen = outbuf_reallen = 0;
	inbuf = outbuf = NULL;
	bytesread = byteswritten = 0;
	this->fd = -1;
	this->type = Socket::SOCK_NORMAL;

	this->fd = fd;
}

Socket::~Socket()
{
	Flush();

	if (fd != -1)
		close(fd);

	FREE_AND_NULL(inbuf);
	FREE_AND_NULL(outbuf);

	#ifdef HAVE_SSL
	int tries = 5;
	if (type != Socket::SOCK_NORMAL) {
		while (tries--)
			if (SSL_shutdown(ssl) == 1) break;

		SSL_free(ssl);
		SSL_CTX_free(ctx);
	}
	#endif

	if (byteswritten)
		global->stats.Increment("network", "bytes out", byteswritten);
	if (bytesread)
		global->stats.Increment("network", "bytes in", bytesread);

}

bool Socket::Encrypt(int type)
{
#ifdef HAVE_SSL
	int x;

	ctx = SSL_CTX_new((type == Socket::SOCK_SSLCLIENT) ? SSLv23_client_method() : SSLv23_server_method());

	if (ctx == NULL)
		return FALSE;

	ssl = SSL_new(ctx);
	if (ssl == NULL) {
		SSL_CTX_free(ctx);

		return FALSE;
	}

	SSL_set_connect_state(ssl);
	SSL_set_fd(ssl, fd);

	if (type == Socket::SOCK_SSLCLIENT)
		x = SSL_connect(ssl);
	else {
		x = -1;
		/* not done
		   SSL_CTX_use_certificate_file( ctx, "cert", SSL_FILETYPE_PEM )
		   SSL_CTX_use_PrivateKey_file( ctx, "key", SSL_FILETYPE_PEM )
		   SSL_CTX_set_verify( ctx, ( SSL_VERIFY_NONE ), ssl_verify_callback );
		   SSL_CTX_set_verify_depth( ctx, 4 );
		   SSL_CTX_set_options( ctx, SSL_OP_ALL );
		 */
	}

	if (x < 0) {
		SSL_free(ssl);
		SSL_CTX_free(ctx);

		return FALSE;
	}

	this->type = type;

	return TRUE;

#endif
	return FALSE;
}

int Socket::CallBackAdd(int (*func) (void *, int, char *), void *arg, int events)
{
	SocketCallBack *scb = xnew SocketCallBack(func, arg);

	scb->id = cbid++;
	scb->events = events;

	callback_list.push_back(*scb);

	return scb->id;
}

bool Socket::CallBackRemove(int id)
{
	SocketCallBackList::iterator item;

	for (item = callback_list.begin(); item != callback_list.end(); item++) {
		if (item->id == id) {
			callback_list.erase(item);

			return TRUE;
		}
	}

	ASSERT(0);

	return FALSE;
}

int Socket::CallBackEvent(int event, int len, char *buf)
{
	int ret;
	SocketCallBackList::iterator item;

	for (item = callback_list.begin(); item != callback_list.end(); item++) {
		if (item->events & event) {
			ret = item->Call(len, buf);
			if (ret == -1)
				return ret;
		}
	}

	return 0;
}


int Socket::CallBackMaskGet(int id)
{
	SocketCallBackList::iterator item;

	for (item = callback_list.begin(); item != callback_list.end(); item++)
		if (item->id == id)
			return item->events;

	ASSERT(0);

	return 0;
}


int Socket::CallBackMaskSet(int id, int mask)
{
	SocketCallBackList::iterator item;

	for (item = callback_list.begin(); item != callback_list.end(); item++) {
		if (item->id == id) {
			item->events = mask;

			return mask;
		}
	}

	ASSERT(0);

	return 0;
}

void Socket::Flush()
{
	frozen = FALSE;

	while (outbuf_len > 0 && Write(NULL, -1) != -1);

	Resize(0, OUTBUF);
	Resize(0, INBUF);
}

void Socket::Resize(size_t newsize, int which)
{
	if (which == INBUF) {
		if (inbuf != NULL && inbuf_reallen >= newsize && inbuf_reallen < newsize + FILEBUF_ALIGNMENT)
			inbuf_len = newsize;
		else if (newsize == 0)
			Clear(INBUF);
		else {
#ifdef _DEBUG
			inbuf_reallen = newsize;
#else
			inbuf_reallen = ALIGN2(newsize, FILEBUF_ALIGNMENT);
#endif
			inbuf = (char *) xrealloc(inbuf, inbuf_reallen);
			inbuf_len = newsize;
		}
	} else {
		if (outbuf != NULL && outbuf_reallen >= newsize && outbuf_reallen < newsize + FILEBUF_ALIGNMENT)
			outbuf_len = newsize;
		else if (newsize == 0)
			Clear(OUTBUF);
		else {
#ifdef _DEBUG
			outbuf_reallen = newsize;
#else
			outbuf_reallen = ALIGN2(newsize, FILEBUF_ALIGNMENT);
#endif
			outbuf = (char *) xrealloc(outbuf, outbuf_reallen);
			outbuf_len = newsize;
		}
	}
}


void Socket::Clear(int which)
{
	if (which == INBUF) {
		if (inbuf_reallen > FILEBUF_ALIGNMENT) {
			FREE_AND_NULL(inbuf);
			inbuf_reallen = 0;
		}

		inbuf_len = 0;
	} else {
		if (outbuf_reallen > FILEBUF_ALIGNMENT) {
			FREE_AND_NULL(outbuf);
			outbuf_reallen = 0;
		}

		outbuf_len = 0;
	}
}

int Socket::Read(char *buf, int len)
{
	return Read(buf, len, -1);
}

int Socket::Read(char *buf, int len, int timeout)
{
	int pos = 0, x;
	char b[BLOCKSIZE];
	struct pollfd pfd;

	if (timeout != -1 && (len == -1 || inbuf_len < len)) {
		pfd.fd = fd;
		pfd.events = POLLIN;

		x = p_poll(&pfd, 1, timeout * 1000);
		if (x != 1 || pfd.revents & (POLLHUP | POLLERR))
			return -1;
	}

	if (inbuf_len != 0 && len != -1) {
		memcpy(buf, inbuf, (inbuf_len < len) ? inbuf_len : len);
		if (inbuf_len <= len) {
			/* all buffered data used up, clear it */
			pos = inbuf_len;

			Clear(INBUF);

			/* buffer is exact size of requested read size, no need to do any network IO */
			if (pos == len) {
				CallBackEvent(READEVENT, pos, buf);
				return len;
			}
		} else {
			/* buffer larger than requested read size, no need to do any network IO */
			/* chop off used part of buffer */
			memmove(inbuf, inbuf + len, inbuf_len - len);
			Resize(inbuf_len - len, INBUF);

			return len;
		}
	}

	if (type != Socket::SOCK_NORMAL) {
		#ifdef HAVE_SSL
		x = SSL_read(ssl, b, sizeof(b));
		#endif
	} else
		x = read(fd, b, sizeof(b));
	if (x > 0) {
		bytesread += x;

		if (len == -1) {
			/* just buffer */
			Resize(inbuf_len + x, INBUF);
			memcpy(inbuf + inbuf_len - x, b, x);
		} else {
			memcpy(buf + pos, b, (x > len - pos) ? len - pos : x);
			if (x > len - pos) {
				/* buffer the rest */
				Resize(x - (len - pos), INBUF);
				memcpy(inbuf, &b[len - pos], x - (len - pos));
			}

			pos += (x > len - pos) ? len - pos : x;
		}
	}

	if (pos != 0)
		if (CallBackEvent(READEVENT, pos, buf) == -1)
			return -1;
		else if (x == -1)
			CallBackEvent(ERROREVENT, -1, NULL);

	return (pos != 0) ? pos : x;
}

int Socket::Write(char *buf, int len)
{
	int x = 0, tmp;

	if (outbuf_len != 0 && frozen == FALSE) {
		if (type != Socket::SOCK_NORMAL) {
			#ifdef HAVE_SSL
			x = SSL_write(ssl,outbuf, outbuf_len);
			#endif
		} else
			x = write(fd, outbuf, outbuf_len);
		if (x > 0) {
			byteswritten += x;

			if (x == outbuf_len) {
				/* successfully wrote everything */
				Clear(OUTBUF);
			} else {
				/* partial write */
				memmove(outbuf, outbuf + x, outbuf_len - x);
				Resize(outbuf_len - x, OUTBUF);
			}
		}
	}

	if (len != -1) {
		x = 0;

		if (outbuf_len == 0 && frozen == FALSE) {
			if (type != Socket::SOCK_NORMAL) {
				#ifdef HAVE_SSL
				x = SSL_write(ssl, buf, len);
				#endif
			} else
				x = write(fd, buf, len);
		}

		if (x != len) {
			/* buffer data that couldn't be written immediately */
			tmp = outbuf_len;
			Resize(outbuf_len + len - ((x != -1) ? x : 0), OUTBUF);
			memcpy(outbuf + tmp, buf + ((x != -1) ? x : 0), len - ((x != -1) ? x : 0));
		}

		if (x != -1)
			byteswritten += x;
	}

	if (x != -1)
		CallBackEvent(WRITEEVENT, len, buf);
	else
		CallBackEvent(ERROREVENT, -1, NULL);

	return x;
}

int Socket::GetLine(char *buf, int len) {
	return GetLine(buf, len, -1);
}

int Socket::GetLine(char *buf, int len, int timeout)
{
	int x = 0, pos = 0;
	char *ptr = NULL;

	do {
		if (inbuf_len != 0 && ((ptr = (char *) memchr(inbuf + pos, '\n', inbuf_len - pos)))) {
			/* if len is -1, just discard the line */
			if (len != -1) {
				x = ((ptr - inbuf) + 1 < len) ? (ptr - inbuf) + 1 : len;

				memcpy(buf, inbuf, x);
				buf[(x != len) ? x : x - 1] = '\0';
			}

			if (ptr - inbuf == inbuf_len || ptr - inbuf + 1 == inbuf_len) {
				/* used up all buffered data */
				Clear(INBUF);
			} else {
				/* remove entire line from buffer, even if it's larger than requested read size */
				memmove(inbuf, ptr + 1, inbuf_len - (ptr - inbuf) - 1);
				Resize(inbuf_len - (ptr - inbuf) - 1, INBUF);
			}

			break;
		} else
			pos = inbuf_len;
	} while ((x = Read(NULL, -1, timeout)) > 0);

	return x;
}

int Socket::PushFront(char *buf, int len)
{
	Resize(inbuf_len + len, INBUF);
	memcpy(inbuf + len, inbuf, inbuf_len - len);
	memcpy(inbuf, buf, len);

	return inbuf_len - len;
}

int Socket::PutSock(char *format, ...)
{
	int ret;
	char buf[8192];
	va_list va;

	va_start(va, format);
	ret = vsnprintf(buf, sizeof(buf), format, va);
	va_end(va);

	return Write(buf, (ret > sizeof(buf) || ret == -1) ? sizeof(buf) : ret);
}
		

int Socket::Connect(const char *host, int port) {
	return (this->fd = net_connect(host, port, -1));
}

bool Socket::Connected() {
	return (this->fd == -1) ? FALSE : TRUE;
}

int Socket::WriteFilebuf(const Filebuf *filebuf)
{
	return Write(filebuf->data, filebuf->size);
}
