/*
 * dsyslog - a dumb syslog (e.g. syslog for people who have a clue)
 * Copyright (c) 2008 William Pitcock <nenolod@dereferenced.org>
 *
 * Permission to use, copy, modify, and/or distribute this software for any
 * purpose with or without fee is hereby granted, provided that the above
 * copyright notice and this permission notice is present in all copies.
 *
 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
 * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
 * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
 * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
 * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
 * ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
 * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 */

#include <glib.h>
#include <sys/stat.h>
#include <sys/socket.h>
#include <sys/types.h>
#include <sys/poll.h>
#include <sys/stat.h>
#include <arpa/inet.h>
#include <netdb.h>
#include <netinet/in.h>

#include <gnutls/gnutls.h>

#include "dsyslog.h"

struct dsyslog_tls_socket {
	gint sock;
	struct sockaddr_in sa;
	gnutls_session_t gnutls_session;
	gnutls_certificate_credentials_t gnutls_x509cred;
	GQueue io_messages;
	GIOChannel *logchan;
	gint tag;
};

struct dsyslog_tls_message {
	gchar *message;
	gint len;
};

static void
dsyslog_tls_output_destructor(dsyslog_output_t *output)
{
	struct dsyslog_tls_socket *us;

	_ENTER;

	us = (struct dsyslog_tls_socket *) output->opaque;
	close(us->sock);
	g_slice_free(struct dsyslog_tls_socket, us);

	_LEAVE;
}

static gboolean dsyslog_tls_handshake_cb(GIOChannel *source, GIOCondition cond, gpointer data);

static gboolean
dsyslog_tls_io_cb(GIOChannel *source, GIOCondition cond, gpointer data)
{
	dsyslog_output_t *output = data;
	struct dsyslog_tls_socket *us = output->opaque;
	struct dsyslog_tls_message *msg;
	gint ret;

	_ENTER;

	g_return_val_if_fail(source != NULL, FALSE);
	g_return_val_if_fail(data != NULL, FALSE);

	while ((msg = g_queue_pop_head(&us->io_messages)) != NULL) {
		ret = gnutls_record_send(us->gnutls_session, msg->message, msg->len);

		/* ret == 0, connection died. reset and try again. --nenolod */
		/* ret < 0, ssl error. we could recover, but the code would be more complex than it's worth here. --nenolod */
		if (ret <= 0) {
			struct hostent *hp;
			struct in_addr *in;
			static const gint cert_type_priority[2] = { GNUTLS_CRT_X509, 0 };

			g_queue_push_head(&us->io_messages, msg);

			gnutls_bye(us->gnutls_session, GNUTLS_SHUT_WR);
			gnutls_deinit(us->gnutls_session);

			g_io_channel_unref(us->logchan);
			close(us->sock);

			us->sock = socket(PF_INET, SOCK_STREAM, 0);
			if (us->sock < 0) {
				perror("logsocket");
				g_slice_free(struct dsyslog_tls_socket, us);
				_LEAVE FALSE;
			}

			us->sa.sin_family = AF_INET;
			if ((hp = gethostbyname(output->host)) == NULL) {
				close(us->sock);
				g_slice_free(struct dsyslog_tls_socket, us);
				_LEAVE FALSE;
			}

			in = (struct in_addr *)(hp->h_addr_list[0]);
			us->sa.sin_addr.s_addr = in->s_addr;
			us->sa.sin_port = htons(output->port);

			if (connect(us->sock, (struct sockaddr *) &us->sa, sizeof(us->sa)) < 0) {
				perror("logsocket:connect");
				close(us->sock);
				g_slice_free(struct dsyslog_tls_socket, us);

				_LEAVE FALSE;
			}

			gnutls_init(&us->gnutls_session, GNUTLS_CLIENT);
			gnutls_set_default_priority(us->gnutls_session);
			gnutls_certificate_type_set_priority(us->gnutls_session, cert_type_priority);
			gnutls_certificate_allocate_credentials(&us->gnutls_x509cred);
			gnutls_credentials_set(us->gnutls_session, GNUTLS_CRD_CERTIFICATE, us->gnutls_x509cred);
			gnutls_transport_set_ptr(us->gnutls_session, GINT_TO_POINTER(us->sock));

			us->logchan = g_io_channel_unix_new(us->sock);
			us->tag = g_io_add_watch(us->logchan, G_IO_OUT, dsyslog_tls_handshake_cb, output);

			_LEAVE FALSE;
		}
	}

	_LEAVE TRUE;
}

static gboolean
dsyslog_tls_handshake_cb(GIOChannel *source, GIOCondition cond, gpointer data)
{
	dsyslog_output_t *output = data;
	struct dsyslog_tls_socket *us = output->opaque;
	gint ret;

	_ENTER;

	g_return_val_if_fail(source != NULL, FALSE);
	g_return_val_if_fail(data != NULL, FALSE);

	ret = gnutls_handshake(us->gnutls_session);

	if (ret < 0) {
		if (ret == GNUTLS_E_AGAIN || ret == GNUTLS_E_INTERRUPTED) {
			if (gnutls_record_get_direction(us->gnutls_session) == 0) {
				g_io_add_watch(source, G_IO_IN, dsyslog_tls_handshake_cb, output);
				_LEAVE FALSE;
			} else {
				g_io_add_watch(source, G_IO_OUT, dsyslog_tls_handshake_cb, output);
				_LEAVE FALSE;
			}
		}
	}

	g_io_add_watch(source, G_IO_OUT, dsyslog_tls_io_cb, output);
	_LEAVE FALSE;
}

static void
dsyslog_tls_output_handler(dsyslog_event_t *event, dsyslog_output_t *output)
{
	struct dsyslog_tls_socket *us;
	struct dsyslog_tls_message *msg;
	gchar msgbuf[1024];

	_ENTER;

	if (!output->opaque) {
		struct hostent *hp;
		struct in_addr *in;
		static const gint cert_type_priority[2] = { GNUTLS_CRT_X509, 0 };

		us = g_slice_new0(struct dsyslog_tls_socket);

		if (!output->host) {
			_LEAVE;
		}

		if (!output->port) {
			output->port = 514;
		}

		us->sock = socket(PF_INET, SOCK_STREAM, 0);
		if (us->sock < 0) {
			perror("logsocket");
			g_slice_free(struct dsyslog_tls_socket, us);
			_LEAVE;
		}

		us->sa.sin_family = AF_INET;
		if ((hp = gethostbyname(output->host)) == NULL) {
			close(us->sock);
			g_slice_free(struct dsyslog_tls_socket, us);
			_LEAVE;
		}

		in = (struct in_addr *)(hp->h_addr_list[0]);
		us->sa.sin_addr.s_addr = in->s_addr;
		us->sa.sin_port = htons(output->port);

		if (connect(us->sock, (struct sockaddr *) &us->sa, sizeof(us->sa)) < 0) {
			perror("logsocket:connect");
			close(us->sock);
			g_slice_free(struct dsyslog_tls_socket, us);

			_LEAVE;
		}

		gnutls_init(&us->gnutls_session, GNUTLS_CLIENT);
		gnutls_set_default_priority(us->gnutls_session);
		gnutls_certificate_type_set_priority(us->gnutls_session, cert_type_priority);
		gnutls_certificate_allocate_credentials(&us->gnutls_x509cred);
		gnutls_credentials_set(us->gnutls_session, GNUTLS_CRD_CERTIFICATE, us->gnutls_x509cred);
		gnutls_transport_set_ptr(us->gnutls_session, GINT_TO_POINTER(us->sock));

		us->logchan = g_io_channel_unix_new(us->sock);
		us->tag = g_io_add_watch(us->logchan, G_IO_OUT, dsyslog_tls_handshake_cb, output);

		output->opaque = us;
		output->destructor = dsyslog_tls_output_destructor;
	}

	us = (struct dsyslog_tls_socket *) output->opaque;
	snprintf(msgbuf, sizeof(msgbuf), "<%d>%s: %s", event->logcode, event->program, event->message);

	msg = g_slice_new(struct dsyslog_tls_message);
	msg->message = g_strdup(msgbuf);
	msg->len = strlen(msgbuf);

	g_queue_push_tail(&us->io_messages, msg);

	_LEAVE;
}

void
_modinit(void)
{
	_ENTER;

	/* it's refcounted, so this is ok. --nenolod */
	gnutls_global_init();
	dsyslog_output_type_register("tls", dsyslog_tls_output_handler);

	_LEAVE;
}

void
_modfini(void)
{
	_ENTER;

	dsyslog_output_type_unregister("tls");
	gnutls_global_deinit();

	_LEAVE;
}
