/***************************************************************************
 *
 * Copyright (c) 2000, 2001, 2002, 2003, 2004 BalaBit IT Ltd, Budapest, Hungary
 *
 * This program is free software; you can redistribute it and/or modify it
 * under the terms of the GNU General Public License version 2 as published
 * by the Free Software Foundation.
 *
 * Note that this permission is granted for only version 2 of the GPL.
 *
 * As an additional exemption you are allowed to compile & link against the
 * OpenSSL libraries as published by the OpenSSL project. See the file
 * COPYING for details.
 *
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with this program; if not, write to the Free Software
 * Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
 *
 * $Id: conntrack.c,v 1.46 2004/04/20 13:50:00 bazsi Exp $
 *
 * Author  : yeti, bazsi
 * Auditor : bazsi
 * Last audited version: 1.1
 * Notes:
 *
 ***************************************************************************/
#include <zorp/conntrack.h>

#include <zorp/proxy.h>
#include <zorp/sockaddr.h>
#include <zorp/zorp.h>
#include <zorp/zpython.h>
#include <zorp/policy.h>
#include <zorp/registry.h>
#include <zorp/thread.h>
#include <zorp/log.h>
#include <zorp/modules.h>
#include <zorp/socket.h>
#include <zorp/streamfd.h>
#include <zorp/source.h>

#include <assert.h>

#if ENABLE_CONNTRACK
#include <zorp/packet.h>
#include <zorp/packstream.h>
#include <zorp/packsock.h>

static void z_conntrack_add_stream(ZStream *stream);
static void z_conntrack_remove_stream(ZStream *stream);


GMutex *conntrack_poll_lock = NULL;
static ZPoll *conntrack_poll = NULL;
GMutex *conntrack_lock;
GCond *conntrack_started;

#define Z_CT_SHUT_FD	0x0001
#define Z_CT_SHUT_PROXY	0x0002

static void z_dgram_connection_shutdown(ZDGramConnection *self);


/**
 * z_dgram_connection_send:
 * @self: ZDGramConnection instance
 * @pack: ZPacket to send to the proxy
 * 
 * This function is used to send a packet to the proxy.  
 **/
static GIOStatus
z_dgram_connection_send(ZDGramConnection *self, ZPacket *pack)
{
  return z_stream_packet_send(self->ct_stream, pack, NULL);
}

/**
 * z_dgram_connection_recv:
 * @self: ZDGramConnection instance
 * @pack: returned packet
 * 
 * This function is used to fetch data from the proxy. 
 *
 * Returns a new packet in *pack, thus it must be freed by the caller.
 */
static GIOStatus
z_dgram_connection_recv(ZDGramConnection *self, ZPacket **pack)
{
  GIOStatus rc;

  z_enter();  
  rc = z_stream_packet_recv(self->ct_stream, pack, NULL);
  if (rc == G_IO_STATUS_EOF)
    g_assert(!*pack);
    
  z_leave();
  return rc;
}

/**
 * z_dgram_connection_packet_in:
 * @stream:
 * @cond:
 * @s: ZDGramConnection  instance
 *
 * Callback. Called when a packet can be read from the network socket that
 * must be forwarded to @s (at least, regarding its source and destination
 * addresses).
 *
 * Returns: TRUE, as a callback should.
 */
static gboolean
z_dgram_connection_packet_in(ZStream *stream, GIOCondition cond G_GNUC_UNUSED, gpointer s)
{
  ZDGramConnection *self = (ZDGramConnection *) s;
  ZPacket *pack = NULL;
  gint fd, rc;

  z_enter();
  fd = z_stream_get_fd(stream);
  
  /* FIXME: fetch all packets from the socket buffer and send them to the
   * appropriate proxy, as the kernel might have buffered additional packets. 
   */
  
  rc = z_packsock_read(fd, &pack, NULL);
  if (rc == G_IO_STATUS_AGAIN)
    {
      z_leave();
      return TRUE;
    }
  else if (rc != G_IO_STATUS_NORMAL)
    {
      /*LOG
	This message indicates that an error occurred during receiving an UDP packet from network.
       */ 
      z_log(self->fd_stream->name, CORE_ERROR, 3, "Error receiving raw UDP packet; rc='%d', error='%s'", rc, g_strerror(errno));
      z_dgram_connection_shutdown(self);
      z_leave();
      return FALSE;
    }

  /*LOG
    This message reports that an UDP packet received from network.
   */
  z_log(self->fd_stream->name, CORE_DEBUG, 7, "Receiving raw UDP packet; fd='%d', count='%d'", fd, pack->length);

  if (z_dgram_connection_send(self, pack) != G_IO_STATUS_NORMAL)
    {
      z_packet_free(pack);
      z_dgram_connection_shutdown(self);
      z_leave();
      return FALSE;
    }
  z_leave();
  return TRUE;
}

/**
 * z_dgram_connection_proxy_in:
 * @stream:
 * @cond:
 * @s: ZDGramConnection instance
 *
 * This callback is called by the stream subsystem when the proxy
 * wrote something it wants to be passed to the peer. Read a packet
 * from the proxy and forward it, outwards to the network, but only
 * after it has been approved by the protocol-specific tracker.
 *
 * Returns: TRUE (poll should not end).
 */
 
static gboolean
z_dgram_connection_proxy_in(ZStream *stream G_GNUC_UNUSED, GIOCondition cond G_GNUC_UNUSED,
			    gpointer user_data)
{
  ZDGramConnection *self = (ZDGramConnection *) user_data;
  ZPacket *pack = NULL;
  gint fd, rc;

  z_enter();
  if (z_dgram_connection_recv(self, &pack) != G_IO_STATUS_NORMAL)
    {
      /* EOF read, or error occurred */
      z_dgram_connection_shutdown(self);
      z_leave();
      return FALSE;
    }

  if (pack)
    {
      fd = z_stream_get_fd(self->fd_stream);
      /*LOG
	This message reports that an UDP packet is sent to the network.
       */
      z_log(self->fd_stream->name, CORE_DEBUG, 7, "Sending raw UDP packet; fd='%d', count='%d'", fd, pack->length);
      rc = z_packsock_write(fd, pack, NULL);
      z_packet_free(pack);

      /* NOTE: we may not get G_IO_STATUS_AGAIN as the fd is not in
       * nonblocking mode */
      if (rc != G_IO_STATUS_NORMAL)
        {
	  /*LOG
	    This message indicates that an error occurred during sending an UDP packet to the network.
	   */
          z_log(self->fd_stream->name, CORE_ERROR, 3, "Error sending raw UDP packet; fd='%d', rc='%d', error='%s'", fd, rc, g_strerror(errno));
          z_dgram_connection_shutdown(self);
          z_leave();
          return FALSE;
        }
    }
  z_leave();
  return TRUE;
}

/**
 * z_dgram_connection_start:
 * @self: ZDGramConnection instance
 *
 * This functions starts forwarding packets between a proxy and an UDP
 * session.
 **/
void
z_dgram_connection_start(ZDGramConnection *self)
{
  z_enter();
  z_dgram_connection_ref(self);
  
  z_dgram_connection_ref(self);
  z_stream_set_callback(self->fd_stream, Z_STREAM_FLAG_READ, z_dgram_connection_packet_in, self, (GDestroyNotify) z_dgram_connection_unref);
  z_stream_set_cond(self->fd_stream, Z_STREAM_FLAG_READ, TRUE);
  
  z_dgram_connection_ref(self);
  z_stream_set_callback(self->ct_stream, Z_STREAM_FLAG_READ, z_dgram_connection_proxy_in, self, (GDestroyNotify) z_dgram_connection_unref);
  z_stream_set_cond(self->ct_stream, Z_STREAM_FLAG_READ, TRUE);

  z_conntrack_add_stream(self->fd_stream);
  z_conntrack_add_stream(self->ct_stream);  
  z_leave();
}

/**
 * z_dgram_connection_shutdown:
 * @self: ZDGramConnection instance
 *
 * Finish forwarding packets between the proxy and an UDP session. This
 * function is called at teardown when the proxy exits.
 **/
static void
z_dgram_connection_shutdown(ZDGramConnection *self)
{
  z_enter();
  if (self->fd_stream)
    {
      z_stream_set_cond(self->fd_stream, Z_STREAM_FLAG_READ, FALSE);
      z_stream_shutdown(self->fd_stream, SHUT_RD, NULL);
      z_stream_shutdown(self->fd_stream, SHUT_WR, NULL);
      z_conntrack_remove_stream(self->fd_stream);
      z_stream_close(self->fd_stream, NULL);
      z_stream_unref(self->fd_stream);
      self->fd_stream = NULL;
    }
  
  if (self->ct_stream)
    {
      z_stream_set_cond(self->ct_stream, Z_STREAM_FLAG_READ, FALSE);
      z_stream_shutdown(self->ct_stream, SHUT_RD, NULL);
      z_stream_shutdown(self->ct_stream, SHUT_WR, NULL);
      z_conntrack_remove_stream(self->ct_stream);
      z_stream_close(self->ct_stream, NULL);
      z_stream_unref(self->ct_stream);
      self->ct_stream = NULL;
    }
  z_dgram_connection_unref(self);
  z_leave();
}


/**
 * z_dgram_connection_new:
 * @session_id: session_id to be presented in log messages
 * @remote: remote address of the new UDP connection
 * @local: local address of the new UDP connection
 * @type: client or server side connection
 * @proxy_stream: ZStream to be returned to the proxy
 *
 * Create a new datagram connection. @remote and @local specify the source
 * and the destination address of packets to be captured.
 *
 * Returns: The ZDGramConnection socket, or NULL if something gone wrong.
 */
 
ZDGramConnection *
z_dgram_connection_new(gchar *session_id, 
                       ZSockAddr *remote, ZSockAddr *local, 
                       gint type, 
                       gint tos,
                       ZStream **proxy_stream)
{
  char buf[MAX_SOCKADDR_STRING], buf2[MAX_SOCKADDR_STRING];
  ZDGramConnection *self;
  int fd;
 
  z_enter();
  fd = z_packsock_open(ZPS_ESTABLISHED, remote, local, ZSF_MARK_TPROXY, tos, NULL);
  if (fd < 0)
    {
      z_leave();
      return NULL;
    }

  if (z_getsockname(fd, &local, 0) != G_IO_STATUS_NORMAL)
    {
      close(fd);
      z_leave();
      return NULL;
    }  

  /*LOG
    This message reports that a new DGram Connection is created.
   */
  z_log(NULL, CORE_DEBUG, 7, "Creating datagram connection; remote='%s', local='%s'",
	z_sockaddr_format(remote, buf, 128),
	(local == NULL ? "NULL" : z_sockaddr_format(local, buf2, 128)));

  self = g_new0(ZDGramConnection, 1);
  self->ref_cnt = 1;
  
  g_snprintf(self->session_id, sizeof(self->session_id), "%s/%s", session_id, type == ZCS_TO_CLIENT ? "client" : "server");
	     
  self->type = type;
  
  self->remote_addr = z_sockaddr_ref(remote);
  self->local_addr = local;
  
  g_snprintf(buf, sizeof(buf), "%s/pair", self->session_id);
  z_stream_packet_pair_new(buf, &self->ct_stream, proxy_stream);
  
  g_snprintf(buf, sizeof(buf), "%s/sock", self->session_id);
  self->fd_stream = z_stream_fd_new(fd, buf);

  z_leave();
  return self;
}

/**
 * z_dgram_connection_free:
 * @self: ZDGramConnection instance to destroy
 *
 * Destroy a ZDGramConnection. 
 */
static void
z_dgram_connection_free(ZDGramConnection *self)
{
  z_enter();
  g_assert(self->ct_stream == NULL && self->fd_stream == NULL);
  z_sockaddr_unref(self->local_addr);
  z_sockaddr_unref(self->remote_addr);
  g_free(self);
  z_leave();
}

/**
 * z_dgram_connection_ref:
 * @self: ZDGramConnection instance
 *
 * Increase reference count for ZDGramConnection instance. 
 **/
void
z_dgram_connection_ref(ZDGramConnection *self)
{
  z_enter();
  g_static_rec_mutex_lock(&self->ref_lock);
  g_assert(self->ref_cnt > 0);
  self->ref_cnt++;
  g_static_rec_mutex_unlock(&self->ref_lock);
  z_leave();
}

/**
 * z_dgram_connection_unref:
 * @self: ZDGramConnection instance
 *
 * Decrement the reference count for @self and free it if the reference
 * goes down to zero.
 **/
void
z_dgram_connection_unref(ZDGramConnection *self)
{
  z_enter();
  g_static_rec_mutex_lock(&self->ref_lock);
  g_assert(self->ref_cnt > 0);
  if (--self->ref_cnt == 0)
    {
      g_static_rec_mutex_unlock(&self->ref_lock);
      z_dgram_connection_free(self);
    }
  else
    g_static_rec_mutex_unlock(&self->ref_lock);
  z_leave();
}

/**
 * z_io_receive_packet_in:
 * @stream:
 * @cond:
 * @s: CT instance
 *
 * This callback is registered as the read callback for the master UDP
 * socket listening for new connections.
 * When a packet is received a connection is created by calling the registered
 * callback function.
 *
 * Returns: TRUE, as a well-behaved callback should.
 */
static gboolean
z_io_receive_packet_in(ZStream *stream G_GNUC_UNUSED, GIOCondition cond G_GNUC_UNUSED,
		      gpointer s)
{
  ZIOReceive *self = (ZIOReceive *) s;
  ZDGramConnection *new_sock;
  ZPacket *pack = NULL;
  ZStream *proxy_stream = NULL;
  ZSockAddr *from = NULL, *to = NULL;
  gint fd;
  gint rc;
  struct 
  {
    ZStream *stream;
    ZDGramConnection *sock;
  } sessions[self->session_limit];
  gint num_sessions = 0, num_packets = 0, i;
  gint tos;
  
  z_enter();
  
  fd = z_stream_get_fd(stream);
  if (fd == -1)
    {
      z_log(self->session_id, CORE_ERROR, 1, "Internal error, master stream has no associated fd; stream='%p'", stream);
      z_leave();
      return FALSE;
    }
    
  while (num_sessions < self->session_limit)
    {
      rc = z_packsock_recv(fd, &pack, &from, &to, &tos, NULL);
      
      if (rc == G_IO_STATUS_AGAIN)
        {
          break;
        }
      if (rc != G_IO_STATUS_NORMAL)
        {
          z_log(self->session_id, CORE_ERROR, 1, "Error receiving datagram on listening stream; fd='%d'", fd);
          rc = FALSE;
          break;
        }
      num_packets++;
      for (i = 0; i < num_sessions; i++)
        {
          if (sessions[i].sock->remote_addr->salen == from->salen && 
              sessions[i].sock->local_addr->salen == to->salen &&
              memcmp(&sessions[i].sock->remote_addr->sa, &from->sa, from->salen) == 0 &&
              memcmp(&sessions[i].sock->local_addr->sa, &to->sa, to->salen))
            {
              if (z_dgram_connection_send(sessions[i].sock, pack) != G_IO_STATUS_NORMAL)
                {
                  /* FIXME: error */
                  z_packet_free(pack);
                }
              break;
            }
        }
      if (i == num_sessions)
        {
          /* not found */
          new_sock = z_dgram_connection_new(self->session_id, from, to, ZCS_TO_CLIENT, tos, &proxy_stream);
      
          if (new_sock)
            {
              if (z_dgram_connection_send(new_sock, pack) != G_IO_STATUS_NORMAL)
                {
                  /* FIXME: */
                  /* error sending to the proxy to be created */
                  z_stream_unref(proxy_stream);
                  z_packet_free(pack);
                  z_dgram_connection_unref(new_sock);
                }
              else
                {
                  sessions[i].stream = proxy_stream;
                  sessions[i].sock = new_sock;
                  num_sessions++;
                }
            }
          else
            {
              z_packet_free(pack);
              z_log(self->session_id, CORE_ERROR, 3, "Error creating session socket, dropping packet;");
            }
        }
      z_sockaddr_unref(from);
      z_sockaddr_unref(to);
    }
  
  if (num_sessions == self->session_limit)
    {
      z_log(self->session_id, CORE_ERROR, 3, "Conntrack session limit reached, increase session_limit; session_limit='%d'", self->session_limit);
    }
  z_log(self->session_id, CORE_DEBUG, 6, "Conntrack packet processing ended; num_sessions='%d', num_packets='%d'", num_sessions, num_packets);
  
  for (i = 0; i < num_sessions; i++)
    {
      self->callback(sessions[i].stream, sessions[i].sock->remote_addr, sessions[i].sock->local_addr, self->callback_data);
      z_stream_unref(sessions[i].stream);
      z_dgram_connection_start(sessions[i].sock);
      z_dgram_connection_unref(sessions[i].sock);
    }
  z_leave();
  return TRUE;
}

/**
 * z_io_receive_start:
 * @self: CT instance
 *
 * Start receiving for new connections.
 *
 * Returns: whether CT is really started.
 */
gboolean 
z_io_receive_start(ZIOReceive *self)
{
  gchar buf[128];
  gint fd;
  
  z_enter();
  fd = z_packsock_open(ZPS_LISTEN, NULL, self->bind_addr, self->sock_flags, -1, NULL);
  if (fd == -1)
    {
      z_leave();
      return FALSE;
    }
  if (z_getsockname(fd, &self->bound_addr, 0) != G_IO_STATUS_NORMAL)
    {
      close(fd);
      z_leave();
      return FALSE;
    }
  g_snprintf(buf, sizeof(buf), "%s/ctlisten", self->session_id);
  self->fd_stream = z_stream_fd_new(fd, buf);
  z_stream_set_nonblock(self->fd_stream, TRUE);

  z_stream_set_callback(self->fd_stream,
			Z_STREAM_FLAG_READ,
			z_io_receive_packet_in, self, NULL);
  z_stream_set_cond(self->fd_stream, Z_STREAM_FLAG_READ, TRUE);
  z_conntrack_add_stream(self->fd_stream);

  z_leave();
  return TRUE;
}

/**
 * z_io_receive_new:
 * @session_id: session_id to be presented in log messages
 * @bind_addr: address to bind to.
 * @sock_flags: socket flags to be passed to z_bind and friends
 * @callback: function to call when a new connection is
 *            accepted. Will be called from the CT thread.
 * @user_data: data to pass to @callback.
 *
 * Create a new ZIOReceive instance. 
 *
 * Returns: the ZIOReceive instance or NULL if an error happened.
 */
ZIOReceive *
z_io_receive_new(gchar *session_id, 
		 ZSockAddr *bind_addr, 
		 gint session_limit,
		 guint32 sock_flags,
		 ZReceiveAcceptFunc callback, 
		 gpointer user_data)
{
  ZIOReceive *self;
  
  z_enter();
  self = g_new0(ZIOReceive, 1);
  
  g_snprintf(self->session_id, sizeof(self->session_id), "%s", session_id);
  
  self->bind_addr = z_sockaddr_ref(bind_addr);
  self->sock_flags = sock_flags;

  self->ref_cnt = 1;
  if (session_limit > 0)
    self->session_limit = session_limit;
  else
    self->session_limit = 10;
  self->callback = callback;
  self->callback_data = user_data;
 
  z_leave();
  return self;
}

/**
 * z_io_receive_free:
 * @self: ZIOReceive instance 
 * 
 * This function frees a ZIOReceive instance.
 **/
void
z_io_receive_free(ZIOReceive *self)
{
  z_enter();
  z_sockaddr_unref(self->bind_addr);
  g_free(self);
  z_leave();
}

/**
 * z_io_receive_ref:
 * @self: ZIOReceive instance
 *
 * Increment the reference count for @self. This function does no locking thus
 * it can only be called from a single thread.
 **/
void
z_io_receive_ref(ZIOReceive *self)
{
  z_enter();
  g_assert(self->ref_cnt > 0);
  self->ref_cnt++;
  z_leave();
}

/** 
 * z_io_receive_unref:
 * @self: ZIOReceive instance
 *
 * Decrement the reference count for @self and free it if the reference
 * count goes down to zero. This function does no locking thus it can only
 * be called from a single thread.
 **/
void 
z_io_receive_unref(ZIOReceive *self)
{
  z_enter();
  g_assert(self->ref_cnt > 0);
  if (--self->ref_cnt == 0)
    {
      z_io_receive_free(self);
    }
  z_leave();
}

/**
 * z_io_receive_cancel:
 * @self: ZIOReceive instance
 * 
 * Cancel the specified ZIOReceive instance, it is guaranteed that no callbacks
 * will be called after this function returns.
 **/
void 
z_io_receive_cancel(ZIOReceive *self)
{
  /* FIXME: check if del_stream guarantees that no callbacks will be called.
   * If it doesn't we might free self while a callback is pending, otherwise
   * we can free it without problems
   */
  z_enter();
  z_conntrack_remove_stream(self->fd_stream);
  z_stream_set_cond(self->fd_stream, Z_STREAM_FLAG_READ, FALSE);
  z_stream_close(self->fd_stream, NULL);
  z_stream_unref(self->fd_stream);
  self->fd_stream = NULL;
  z_leave();
}

/* global conntrack code */

/**
 * z_conntrack_add_stream:
 * @stream: stream to add
 *
 * Add @stream to the global conntrack poll.
 */
void 
z_conntrack_add_stream(ZStream *stream)
{
  z_enter();
  g_mutex_lock(conntrack_poll_lock);
  if (conntrack_poll)
    z_poll_add_stream(conntrack_poll, stream);
  g_mutex_unlock(conntrack_poll_lock);
  z_leave();
}

/**
 * z_conntrack_remove_stream:
 * @stream: stream to remove
 *
 * Remove @stream from the global conntrack poll.
 */
void 
z_conntrack_remove_stream(ZStream *stream)
{
  z_enter();
  g_mutex_lock(conntrack_poll_lock);
  if (conntrack_poll)
    z_poll_remove_stream(conntrack_poll, stream);
  g_mutex_unlock(conntrack_poll_lock);
  z_leave();
}


/**
 * z_conntrack_thread:
 * @s: CT instance
 *
 * Main thread function for conntrack. Basically it polls using the global
 * conntrack ZPoll, e.g. this cares about 1) accepting new incoming
 * sessions, 2) forwards packets between proxies and its corresponding fds
 *
 * Returns: NULL.
 */
gpointer 
z_conntrack_thread(gpointer s G_GNUC_UNUSED)
{
  ZPoll *poll;
  
  z_enter();
  poll = conntrack_poll = z_poll_new();
  z_poll_ref(poll);

  g_mutex_lock(conntrack_lock);
  g_cond_signal(conntrack_started);
  g_mutex_unlock(conntrack_lock);
  
  while (conntrack_poll && z_poll_is_running(poll))
    z_poll_iter_timeout(poll, -1);

  z_poll_unref(poll);
  z_leave();
  return NULL;
}

/**
 * z_conntrack_init:
 * 
 * Global conntrack initialization called during Zorp startup.
 **/
gboolean
z_conntrack_init(void)
{
  z_enter();
  conntrack_started = g_cond_new();
  conntrack_lock = g_mutex_new();
  conntrack_poll_lock = g_mutex_new();
  if (!z_thread_new("conntrack/thread", z_conntrack_thread, NULL))
    {
      z_log(NULL, CORE_ERROR, 2, "Error creating conntrack thread, initialization failed;");
      return FALSE;
    }

  g_mutex_lock(conntrack_lock);
  while (!conntrack_poll)
    g_cond_wait(conntrack_started, conntrack_lock);
  g_mutex_unlock(conntrack_lock);
  z_leave();
  return TRUE;
}

/**
 * z_conntrack_destroy:
 *
 * Global conntrack deinitialization function called at Zorp shutdown.
 **/
void
z_conntrack_destroy(void)
{
  ZPoll *poll;
  
  z_enter();
  if (conntrack_poll)
    {
      g_mutex_lock(conntrack_poll_lock);
      poll = conntrack_poll;
      conntrack_poll = NULL;
      z_poll_wakeup(poll);
      z_poll_unref(poll);
      g_mutex_unlock(conntrack_poll_lock);
    }
  z_leave();
}

#endif
