/*****
*
* Copyright (C) 2002, 2003 Yoann Vandoorselaere <yoann@prelude-ids.org>
* All Rights Reserved
*
* This file is part of the Prelude program.
*
* This program is free software; you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by 
* the Free Software Foundation; either version 2, or (at your option)
* any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program; see the file COPYING.  If not, write to
* the Free Software Foundation, 675 Mass Ave, Cambridge, MA 02139, USA.
*
* Written by Yoann Vandoorselaere <yoann@prelude-ids.org>
*
*****/


/*
 * Many thanks goes to Michael Samuel <michael@miknet.net>
 * and Vincent Glaume <glaume@enseirb.fr> for helping/working on this.
 */

/*
 * This code is EXPERIMENTAL and should be considered as such.
 * The code is not in a finished state, and should be cleaned up.
 */

#include <stdio.h>
#include <string.h>
#include <stdlib.h>
#include <unistd.h>
#include <sys/types.h>
#include <sys/time.h>
#include <inttypes.h>
#include <netinet/in.h>
#include <assert.h>

#include "packet.h"

#include <libprelude/list.h>
#include <libprelude/timer.h>
#include <libprelude/prelude-log.h>
#include <libprelude/extract.h>
#include <libprelude/prelude-io.h>
#include <libprelude/prelude-message.h>
#include <libprelude/prelude-getopt.h>
#include <libprelude/plugin-common.h>

#include "rules.h"
#include "packet-decode.h"
#include "rules-default.h"
#include "tcp-stream.h"
#include "plugin-protocol.h" /* port list */

#define MIN(x, y) ( ((x) < (y)) ? (x) : (y) )
#define MAX(x, y) ( ((x) > (y)) ? (x) : (y) )

#define HASH_SIZE 1024


#define ACK_ACCEPT_WINDOW 0


#define LISTEN       0x00
#define SYN_SENT     0x01
#define SYN_RCVD     0x02
#define ACK_RCVD     0x04
#define CLOSE_WAIT   0x08
#define CLOSED       0x10
#define ALREADY_GOT_SACK     0x20


#define OPTION_CLIENT_SACK_PERMITTED 0x01
#define OPTION_SERVER_SACK_PERMITTED 0x02


#ifdef DEBUG
 #define dprint(args...) do {                                                                     \
         fprintf(stderr, "%s:%d:%s - ", __FILE__, __LINE__, __FUNCTION__); fprintf(stderr, args); \
 } while (0)
 
#else

 #define dprint(args...)
#endif



typedef struct {
        struct list_head list;
        uint16_t len;
        unsigned int offset;
        const unsigned char *data;
        packet_container_t *packet;
} data_chunk_t;



typedef struct {
        uint32_t isn;
        uint32_t fin_expected;
        uint32_t offset_acked;

        uint32_t resync_seq;
        int desynced;
        
        uint32_t ts_recent;
        uint32_t addr;
        uint16_t port;
        uint16_t win;
        uint8_t state;
        struct list_head datalist;
} stream_t;


typedef struct tcp_stream {
        stream_t client;
        stream_t server;

        uint8_t option;
        
        prelude_timer_t timer;

        unsigned int key_cache;
        struct tcp_stream *prev;
        struct tcp_stream *next;
} tcp_stream_t;


static int tcp_stream_enabled = 0;
static int session_expire_time = 120;
static int tcp_stream_reasm_from = STREAM_FROM_CLIENT;

static unsigned int flush_point_index;
static port_list_t *reasm_port = NULL;
static tcp_stream_t *host_hash[HASH_SIZE];



static int get_random_flush_point(void) 
{
        static uint8_t flush_point[] = {
                128, 217, 189, 130, 240, 221, 134, 129,
                250, 232, 141, 131, 144, 177, 201, 130,
                230, 190, 177, 142, 130, 200, 173, 129,
                250, 244, 174, 151, 201, 190, 180, 198,
                220, 201, 142, 185, 219, 129, 194, 140,
                145, 191, 197, 183, 199, 220, 231, 245,
                233, 135, 143, 158, 174, 194, 200, 180,
                201, 142, 153, 187, 173, 199, 143, 201
        };

        return flush_point[flush_point_index++ % sizeof(flush_point)];
}



#ifdef DEBUG

static void print_ascii(unsigned char *buf, int len) 
{
        int i;
        FILE *fd = stderr;
        
        for ( i = 0; i < len; i++ ) {
                if ( isprint(buf[i]) || buf[i] == '\n' || buf[i] == '\t')
                        fprintf(fd, "%c", buf[i]);
                else
                        fprintf(fd, "?(%d)", buf[i]);
        }
}

#endif




static tcp_stream_t *search(tcp_stream_t *bucket, const iphdr_t *ip, const tcphdr_t *tcp) 
{
        uint16_t sport, dport;
        uint32_t saddr, daddr;

        sport = extract_uint16(&tcp->th_sport);
        dport = extract_uint16(&tcp->th_dport);

        saddr = align_uint32(&ip->ip_src.s_addr);
        daddr = align_uint32(&ip->ip_dst.s_addr);
        
        while ( bucket != NULL ) {

                if ( ((saddr == bucket->client.addr && daddr == bucket->server.addr) ||
                      (daddr == bucket->client.addr && saddr == bucket->server.addr)) &&

                     ((sport == bucket->client.port && dport == bucket->server.port) ||
                      (dport == bucket->client.port && sport == bucket->server.port)) )
                        break;

                bucket = bucket->next;
        }
        
        return bucket;
}



static void host_add(tcp_stream_t **bucket, tcp_stream_t *h) 
{
        h->next = *bucket;

        if ( h->next )
                h->next->prev = h;

        *bucket = h;
}



static void host_del(tcp_stream_t **bucket, tcp_stream_t *h) 
{        
        if ( h->next )
                h->next->prev = h->prev;

        if ( h->prev )
                h->prev->next = h->next;
        else
                *bucket = h->next;
}



/*
 * hash function done by Michael Samuel <michael@miknet.net>
 */
static unsigned int host_key(const iphdr_t *ip, const tcphdr_t *tcp) 
{
        uint32_t ippair;
        uint16_t portpair;
        
        portpair = align_uint16(&tcp->th_sport) ^ align_uint16(&tcp->th_dport);
        ippair = align_uint32(&ip->ip_src.s_addr) ^ align_uint32(&ip->ip_dst.s_addr);
        
        return (portpair ^ ippair ^ (ippair >> 16)) % HASH_SIZE;
}



static int get_stream_direction(tcp_stream_t *stream, iphdr_t *ip, tcphdr_t *tcp) 
{
        uint16_t sport;
        uint32_t saddr;
        
        if ( tcp->th_flags & TH_SYN && ! (tcp->th_flags & TH_ACK) )
                return STREAM_FROM_CLIENT;

        sport = extract_uint16(&tcp->th_sport);
        saddr = align_uint32(&ip->ip_src.s_addr);
        
        if ( sport == stream->client.port && saddr == stream->client.addr ) 
                return STREAM_FROM_CLIENT;
        else
                return STREAM_FROM_SERVER;
}




/*
 * free one side of a tcp_stream_t.
 */
static void free_simplex_stream(stream_t *stream)
{
        data_chunk_t *chunk;
        struct list_head *tmp, *bkp;

        list_for_each_safe(tmp, bkp, &stream->datalist) {
                chunk = list_entry(tmp, data_chunk_t, list);
                packet_release(chunk->packet);
                list_del(&chunk->list);
                free(chunk);
        }
}




/*
 * inject the reassembled packet for analysis.
 */
static int inject_packet(packet_container_t *packet, unsigned char *buf, unsigned int len) 
{
        iphdr_t *ip;
        tcphdr_t *tcp;
        uint16_t nlen;
        
        if ( packet->application_layer_depth != -1 ) 
                /*
                 * packet already have data, but we're going to
                 * re-add them, so decrease the packet depth before
                 * (in order to not get 2 data fields in the packet).
                 */
                packet->depth--;

        /*
         * setup some of the new packet field.
         */
        ip = packet->packet[packet->network_layer_depth].p.ip;
        tcp = packet->packet[packet->transport_layer_depth].p.tcp;
        
        nlen = htons(len + 40);
        memcpy(&ip->ip_len, &nlen, sizeof(ip->ip_len));
        tcp->th_offx2 = (sizeof(tcphdr_t) / 4) << 4;
        
        /*
         * set the allocated_data pointer,
         * so that the reassembled data are freed when the packet
         * is destroyed.
         */
        packet->tcp_allocated_data = buf;

#ifdef DEBUG
        print_ascii(buf, len);
#endif
        
        /*
         * here we go.
         */
        packet_lock(packet);
        
        dprint("Injecting reassembled data in packet.\n");
        capture_data(packet, buf, len);
        signature_engine_process_packet(signature_engine_get_tcp_root(), packet);

        packet_release(packet);

        return 0;
}




static int split_segment_if_needed(stream_t *dst, data_chunk_t *chunk, uint32_t *coff, uint32_t *clen)
{
	*clen = chunk->len;
	*coff = chunk->offset;
	
	if ( *coff + *clen <= dst->offset_acked )
		return 0;

       	if ( *coff < dst->offset_acked ) {
        	/*
                 * case where a segment is partly acked.
                 */
                 *clen -= ((chunk->offset + chunk->len) - dst->offset_acked);
                 chunk->len -= *clen;
                 chunk->offset = 0;
                 chunk->data = chunk->data + *clen;
         } else {
         	/*
                 * This segment isn't acknowledged yet. However, the offset is now changed.
                 */
                dprint("%p not acked yet (%u bytes at offset=%u).\n", chunk, chunk->len, chunk->offset);
                chunk->offset -= dst->offset_acked;
		return -1;
         }

	return 0;
}



static int is_segment_splited(data_chunk_t *chunk, const unsigned char *cdata) 
{
	return (cdata != chunk->data) ? 0 : -1;
}



static void store_last_packet(packet_container_t **last, data_chunk_t *chunk, const unsigned char *cdata)
{
	if ( *last )
		packet_release(*last);
	
	/*
         * make sure we are not going to free this one because the chunk will be kept around.
         */
	if ( is_segment_splited(chunk, cdata) == 0 )
		packet_lock(chunk->packet);

	*last = chunk->packet;
}




static void free_unsplited_chunk(data_chunk_t *chunk, const unsigned char *cdata) 
{
	if ( is_segment_splited(chunk, cdata) == 0 )
		/*
		 * This one is going to be re-used because a portion of it was not ack'd. keep it.
		 */ 
		return;

	list_del(&chunk->list);
	free(chunk);
}




static int tcp_stream_reasm(stream_t *src, stream_t *dst)
{
	int ret;
        data_chunk_t *chunk;
	const unsigned char *cdata;
        unsigned char *bufptr, *buf;
        struct list_head *tmp, *bkp;
        packet_container_t *last_packet = NULL;
        uint32_t awaited_offset = 0, reasm_len = 0, clen, coff;
        
        /*
         * If the list is empty, we didn't see the acked packet.
         */
        if ( list_empty(&src->datalist) || dst->offset_acked == 0 ) {
                dprint("reassembly not triggered list_empty=%d, dst->off_acked=%u\n", list_empty(&src->datalist), dst->offset_acked);
                return 0;
        }
        
        dprint("dst->offset_acked=%u\n", dst->offset_acked);
        
        bufptr = buf = malloc(dst->offset_acked);
        if ( ! buf ) {
                log(LOG_ERR, "couldn't allocate %u bytes.\n", dst->offset_acked);
                return -1;
        }
        
        list_for_each_reversed_safe(tmp, bkp, &src->datalist) {
                chunk = list_entry(tmp, data_chunk_t, list);

		cdata = chunk->data;

		ret = split_segment_if_needed(dst, chunk, &coff, &clen);
		if ( ret < 0 )
			continue;
                
#ifdef DEBUG
                if ( awaited_offset && coff != awaited_offset )
                        dprint("%p is not the awaited packet (%u != %u)\n", chunk, coff, awaited_offset);
#endif                                
                dprint("blen=%d, copying %u bytes at offset %u from %p.\n", dst->offset_acked, coff, clen, chunk);
               
		store_last_packet(&last_packet, chunk, cdata); 

                assert(reasm_len + clen <= dst->offset_acked);
                memcpy(buf, cdata, clen);
                
                buf += clen;
                reasm_len += clen;
                awaited_offset = coff + clen;

		free_unsplited_chunk(chunk, cdata);
        }

        if ( ! reasm_len ) {
                /*
                 * this condition might occur if the list only contain not yet ACKed packet.
                 */
                free(bufptr);
                return 0;
        }
        
        inject_packet(last_packet, bufptr, reasm_len);
        packet_release(last_packet);

        return 0;
}



static int tcp_stream_expire(tcp_stream_t *stream)
{
        int ret;
                
        dprint("[%p] - Killing hdb for this connection.\n", stream);

        host_del(&host_hash[stream->key_cache], stream);
        timer_destroy(&stream->timer);
        
        dprint("[%p] - client state = %d, server state = %d\n",
               stream, stream->client.state, stream->server.state);

        if ( stream->server.state & ACK_RCVD && stream->client.state & SYN_RCVD && ! (stream->client.state & CLOSED) ) {
                dprint("[%p][client] - stream reassembly of %d bytes.\n", stream, stream->server.offset_acked);
                ret = tcp_stream_reasm(&stream->client, &stream->server);
        }
        
        free_simplex_stream(&stream->client);

        if ( stream->client.state & ACK_RCVD && stream->server.state & SYN_RCVD && ! (stream->server.state & CLOSED) ) {
                dprint("[%p][server] - stream reassembly of %d bytes.\n", stream, stream->client.offset_acked);
                ret = tcp_stream_reasm(&stream->server, &stream->client);
        }
        
        free_simplex_stream(&stream->server);

        free(stream);

        return 0;
}




static int tcp_stream_kill_one_side(tcp_stream_t *stream, stream_t *src, stream_t *dst, uint32_t ack) 
{
        int ret;
        
        /*
         * there could be chunk remaining in the reassembly list.
         */
        dprint("[%p][%p] - stream reassembly of %d bytes.\n", stream, dst, dst->offset_acked);
        
        tcp_stream_reasm(src, dst);
        free_simplex_stream(src);
        src->state |= CLOSED;
        src->isn = ack;
        dst->offset_acked = 0;

        if ( dst->state & CLOSED ) {

                ret = tcp_stream_reasm(dst, src);

                dprint("[%p][%p] - stream reassembly of %d bytes.\n", stream, dst, src->offset_acked);

                free_simplex_stream(src);
                free_simplex_stream(dst);
                timer_destroy(&stream->timer);
                host_del(&host_hash[stream->key_cache], stream);
                free(stream);
                dprint("Both connection side closed.\n");
        }
        
        return 0;
}




static void tcp_stream_expire_cb(void *data) 
{
        dprint("Timer expiring session %p\n", data);
        tcp_stream_expire(data);
}



static tcp_stream_t *tcp_stream_new(packet_container_t *pkt, const iphdr_t *ip, const tcphdr_t *tcp)
{
        tcp_stream_t *new;
        
        new = calloc(1, sizeof(tcp_stream_t));
        if (! new ) {
                log(LOG_ERR, "malloc(%d).\n", sizeof(tcp_stream_t));
                return NULL;
        }

        new->client.addr = align_uint32(&ip->ip_src.s_addr);
        new->server.addr = align_uint32(&ip->ip_dst.s_addr);
        
        new->key_cache = host_key(ip, tcp);
        
        INIT_LIST_HEAD(&new->server.datalist);
        INIT_LIST_HEAD(&new->client.datalist);

        timer_func(&new->timer) = tcp_stream_expire_cb;
        timer_data(&new->timer) = new;
        timer_expire(&new->timer) = session_expire_time;
        timer_init(&new->timer);
        
        host_add(&host_hash[new->key_cache], new);

        dprint("TCP stream new at key=%d, ip=%p, %p\n", new->key_cache, ip, host_hash[new->key_cache]);
        
        return new;
}



inline static tcp_stream_t *tcp_stream_search(const iphdr_t *ip, const tcphdr_t *tcp) 
{
        return search(host_hash[host_key(ip, tcp)], ip, tcp);
}




static int handle_overlap(data_chunk_t *old, data_chunk_t *new) 
{       
        /*
         * check if new segment begin before (or at the same place),
         * and that it end after (or at the same place) the old segment.
         */
        if ( new->offset <= old->offset && new->offset + new->len >= old->offset + old->len ) {
                /*
                 * new completly overlap old.
                 */
                                
                dprint("new completly overlap old (old %u - %u) (new %u - %u).\n",
                       old->offset, old->offset + old->len, new->offset, new->offset + new->len);
                
                list_del(&old->list);
                packet_release(old->packet);
                free(old);
                
                return -1;
        }

        /*
         * check if the new segment begin after begining of the old segment,
         * and end after or at the same place as the old segment.
         */
        else if ( new->offset > old->offset && new->offset < old->offset + old->len &&
                  new->offset + new->len >= old->offset + old->len ) {
                                
                dprint("new overlap end of old segment (old %u - %u) (new %u - %u).\n",
                       old->offset, old->offset + old->len, new->offset, new->offset + new->len);
                
                old->len = new->offset - old->offset;
        }

        /*
         * check if the new segment begin before or at the same place,
         * and end in the middle of the old segment.
         */
        else if ( new->offset <= old->offset && new->offset + new->len > old->offset ) {
                                
                dprint("new overlap begining of old segment (old %u - %u) (new %u - %u).\n",
                       old->offset, old->offset + old->len, new->offset, new->offset + new->len);
                
                old->data = old->data + old->len - ((old->offset + old->len) - (new->offset + new->len));
                old->len = (old->offset + old->len) - (new->offset + new->len);
                old->offset = new->offset + new->len;
        }
        
        /*
         * Check if the new segment begin after,
         * and that it end before the old segment.
         */
        else if ( new->offset > old->offset && new->offset + new->len < old->offset + old->len ) {
                data_chunk_t *end;

                /*
                 * new is in the middle of old.
                 */
                log(LOG_INFO, "new overlap in the middle of old (old %u - %u) (new %u - %u).\n",
                    old->offset, old->offset + old->len, new->offset, new->offset + new->len);
                
                end = malloc(sizeof(*end));
                if ( ! end ) {
                        log(LOG_ERR, "memory exhausted.\n");
                        return -1;
                }

                packet_lock(old->packet);
                
                end->packet = old->packet;
                end->offset = new->offset + new->len;
                end->len = (old->offset + old->len) - (new->offset + new->len);
                end->data = old->data + (end->offset - (new->offset + new->len));
                
                list_add(&end->list, &old->list);
                
                old->len = new->offset - old->offset;
        }

        return 0;
}





/*
 * return the previous chunk that should be followed by this new chunk
 * in the stream. If there is no previous chunk, return NULL.
 * As the chunk are stored in reverse order (bigger offset to smaller offset),
 * we almost always have one hop. 
 */
static struct list_head *search_previous_stream_chunk(stream_t *stream, data_chunk_t *new) 
{
        int ret;
        data_chunk_t *cur;
        struct list_head *tmp, *bkp, *prev = NULL;
        
        list_for_each_safe(tmp, bkp, &stream->datalist) {
                cur = list_entry(tmp, data_chunk_t, list);
                
                ret = handle_overlap(cur, new);
                if ( ret < 0 ) {
                        /*
                         * -1 mean that current entry got deleted.
                         */
                        continue;
                }
                
                if ( new->offset > cur->offset ) 
                        break;

                prev = tmp;
        }        
        
        return prev;
}




static int insert_stream_chunk(packet_container_t *packet, stream_t *src,
                               const unsigned char *data, uint16_t len, uint32_t off) 
{
        data_chunk_t *chunk, tmp;
        struct list_head *prev = NULL;
        
        tmp.len = len;
        tmp.offset = off;
        tmp.data = data;
        tmp.packet = packet;
                
        prev = search_previous_stream_chunk(src, &tmp);
        if ( ! prev )
                prev = &src->datalist;
        
        chunk = malloc(sizeof(data_chunk_t));
        if ( ! chunk ) {
                log(LOG_ERR, "memory exhausted.\n");
                return -1;
        }
        
        packet_lock(packet);
        chunk->len = tmp.len;
        chunk->packet = packet;
        chunk->offset = tmp.offset;
        chunk->data = data;
        
        list_add(&chunk->list, prev);

        return 0;
}





inline static int is_window_respected(uint32_t offset, stream_t *dst) 
{
        return (offset - dst->offset_acked <= dst->win) ? 0 : -1;
}




inline static int is_old_retransmission(uint32_t seq, uint16_t len, stream_t *src, stream_t *dst) 
{
        uint32_t offset;

        offset = (seq - src->isn);
        if ( (offset + len) < dst->offset_acked ) {
                dprint("old seq=%u isn=%u (%u + %u) < %u\n", seq, src->isn, offset, len, dst->offset_acked);
                goto resync;
        }

        if ( is_window_respected(offset, dst) < 0 ) {
                dprint("win seq=%u isn=%u(%u - %u) < %u\n", seq, src->isn, offset, dst->offset_acked, dst->win);
                goto resync;
        }
        
        src->desynced = 0;
        
        return -1;

 resync:
        dprint("we are desynchronized: offset=%u dst->isn=%u\n", offset, dst->isn);

        src->desynced = 1;
        src->resync_seq = seq;

        return 0;
}




static int update_stream_data(tcp_stream_t *stream, stream_t *src, stream_t *dst,
                              packet_container_t *packet, iphdr_t *ip, tcphdr_t *tcp,
                              int direction)
{
        uint16_t len;
        unsigned char *data;
        uint32_t offset, seq;
        
        seq = extract_uint32(&tcp->th_seq);
        len = extract_uint16(&ip->ip_len) - (IP_HL(ip) * 4) - (TH_OFF(tcp) * 4);

        if ( packet->application_layer_depth < 0 )
                /*
                 * no data in this packet.
                 */
                return 0;

        offset = seq - src->isn;
        data = packet->packet[packet->application_layer_depth].p.data;
        
        if ( dst->state & ACK_RCVD && offset + len <= dst->offset_acked ) {
                /*
                 * don't emit an alert here, it might happen that an host retransmit
                 * because it didn't see the ACK. Even thought the ACK is already on the wire.
                 */
                return 0;
        }        
                
        if ( len > 0 && tcp_stream_reasm_from & direction ) {
                dprint("[%p] - add %u byte at offset=%u (seq=%u).\n", stream, len, offset, seq);

                assert(offset + len > dst->offset_acked);
                return insert_stream_chunk(packet, src, data, len, offset);
        }
        
        return 0;
}




static tcp_stream_t *twh_got_client_syn(packet_container_t *packet, iphdr_t *ip, tcphdr_t *tcp) 
{
        tcp_stream_t *stream;
        
        stream = tcp_stream_new(packet, ip, tcp);
        if ( ! stream )
                return NULL;

        dprint("[%p] - isn=%u\n", stream, extract_uint32(&tcp->th_seq));

        /*
         * syn cost 1 sequence
         */
        stream->server.state = LISTEN;
        stream->client.state = SYN_SENT;
        
        stream->client.win = extract_uint16(&tcp->th_win);
        stream->client.isn = extract_uint32(&tcp->th_seq);
        stream->client.port = extract_uint16(&tcp->th_sport);
        stream->server.port = extract_uint16(&tcp->th_dport);

        if ( packet->tcp_sack ) 
                stream->option |= OPTION_CLIENT_SACK_PERMITTED;
        
        return stream;
}




static int twh_got_server_syn_ack(packet_container_t *packet, tcp_stream_t *stream, tcphdr_t *tcp) 
{
        uint32_t seq, ack;
        
        dprint("[%p]\n", stream);
        
        seq = extract_uint32(&tcp->th_seq);
        ack = extract_uint32(&tcp->th_ack);
        
        if ( ack != stream->client.isn + 1 ) {
                log(LOG_INFO, "Invalid acknowledgment (server ack != client isn).\n");
                return tcp_stream_unknown;
        }

        stream->server.win = extract_uint16(&tcp->th_win);
        stream->server.isn = seq;
        stream->server.offset_acked = ack - stream->client.isn;
        stream->server.state = SYN_RCVD|ACK_RCVD;
        stream->client.state |= SYN_RCVD;

        if ( packet->tcp_sack ) 
                stream->option |= OPTION_SERVER_SACK_PERMITTED;
        
        return 0;
}





static int status_got_fin(tcp_stream_t *stream, stream_t *src, stream_t *dst, iphdr_t *ip, tcphdr_t *tcp) 
{
        uint32_t offset;
        
        dprint("[%p]\n", stream);

        offset = extract_uint32(&tcp->th_seq) - src->isn;
        
        src->state |= CLOSE_WAIT;
        src->fin_expected = offset + extract_uint16(&ip->ip_len) - (IP_HL(ip) * 4) - (TH_OFF(tcp) * 4);
        
        return 0;
}




static int status_got_rst(tcp_stream_t *stream, stream_t *src, stream_t *dst, tcphdr_t *tcp) 
{
        int ret = 0;
        uint32_t ack;
        
        /*
         * In all states except SYN-SENT, all reset (RST) segments are validated
         * by checking their SEQ-fields.  A reset is valid if its sequence number
         * is in the window.  In the SYN-SENT state (a RST received in response
         * to an initial SYN), the RST is acceptable if the ACK field
         * acknowledges the SYN.
         */
        if ( dst->state == SYN_SENT ) {

                ack = extract_uint32(&tcp->th_ack);
                
                /*
                 * RST received in response to an initial SYN.
                 * check that the ACK field acknowledge the SYN.
                 */
                if ( ack == dst->isn + 1 ) {
                        dprint("RST acknowledge initial sequence.\n");
                        tcp_stream_expire(stream);
                } else {
                        ret = tcp_stream_unknown;
                        log(LOG_INFO, "evasive RST detection throught invalid ack (ack=%u isn=%u).\n",
                            ack, dst->isn);
                }
                
        }

        else
                tcp_stream_expire(stream);
        
        return ret;
}



#if 0
static int get_last_data_offset(stream_t *stream, uint32_t *off) 
{
        data_chunk_t *data;

        if ( list_empty(&stream->datalist) )
                return -1;
        
        data = list_entry(stream->datalist.next, data_chunk_t, list);
        *off = data->offset + data->len;
        
        return 0;
}
#endif



/*
 * The only SACK insertion attack we can think of is if the attacker send
 * send a duplicated ACK containing data (for reinsertion), but no SACK option
 * even though SACK is enabled. In this case, we have to drop.
 */
static int check_for_valid_sack(packet_container_t *packet, tcp_stream_t *stream,
                                stream_t *src, stream_t *dst, uint32_t ack, uint32_t offset_acked) 
{
        int already_got_sack;

        if ( offset_acked > src->offset_acked )
                return 0;
        
        if ( ! (stream->option & (OPTION_CLIENT_SACK_PERMITTED|OPTION_SERVER_SACK_PERMITTED)) )
                return 0;
        
        already_got_sack = (src->state & ALREADY_GOT_SACK) ? 1 : 0;

        if ( ! already_got_sack ) {

                if ( packet->tcp_sack )
                        src->state |= ALREADY_GOT_SACK;
                
                return 0;
        }
        
        if ( already_got_sack && packet->tcp_sack )
                return 0;
                
        log(LOG_INFO, "SACK permitted set, but no SACK in ACK retransmission ack=%u, already_got_sack=%u, sack=%u.\n",
            ack, already_got_sack, packet->tcp_sack);
                
        return 0;
}




static int is_ack_window_valid(tcphdr_t *tcp, stream_t *src, stream_t *dst, uint32_t ack) 
{
        uint32_t off;
        
        off = ack - (dst->isn + src->offset_acked);
        
        if ( off <= src->win )
                return 0;

        dprint("Invalid ACK win (desynced=%d): %u - %u = %u <= %u\n",
               dst->desynced, ack, dst->isn, ack - dst->isn, src->win);
        
        if ( ! dst->desynced )
                return -1;

        /*
         * ack might be > resync_seq, as resync_seq isn't an acked
         * ISN. We rely on the window to tell me if it is valid.
         */
        off = MAX(ack, dst->resync_seq) - MIN(ack, dst->resync_seq);
        if ( off <= src->win ) {
                dprint("Resynchronised IDS (ack=%u).\n", ack);
                tcp_stream_reasm(dst, src);
                free_simplex_stream(dst);
                
                dst->isn = ack;
                src->offset_acked = 0;
                return 0;
        }

        dprint("Invalid ACK win %u - %u = %u <= %u\n",
               ack, dst->resync_seq, ack - dst->resync_seq, src->win);
        
        
        return -1;
}



static int status_got_ack(packet_container_t *packet, tcphdr_t *tcp, tcp_stream_t *stream, stream_t *src, stream_t *dst)
{
        int ret = 0;
        uint32_t ack, off;
        
        ack = extract_uint32(&tcp->th_ack);

        /* 
         * Check if we may accept this ack, being careful of wrap around
         */
        ret = is_ack_window_valid(tcp, src, dst, ack);
        if ( ret < 0  )
                return tcp_stream_unknown;
        
        off = ack - dst->isn;
        
        ret = check_for_valid_sack(packet, stream, src, dst, ack, off);
        if ( ret < 0 )
                return tcp_stream_unknown;
        
        if ( off < src->offset_acked ) {
                log(LOG_INFO, "Strange: got an ack for already acked stuff : offset_acked=%u, last=%u\n", off, src->offset_acked);
                return tcp_stream_unknown;
        }
        
        src->state |= ACK_RCVD;
        src->offset_acked = off;
        src->win = extract_uint16(&tcp->th_win);
        
        dprint("[%p][%p] - acked %u byte (ack=%u - isn=%u).\n", stream, src, src->offset_acked, ack, dst->isn);

        if ( off >= get_random_flush_point() ) {
                dprint("[%p][%p] - enough byte acked, triggering reassembly ack=%u.\n", stream, src, ack);
                
                ret = tcp_stream_reasm(dst, src);
                if ( ret < 0 )
                        return -1;
                
                dst->isn = ack;
                src->offset_acked = 0;
        }

        if ( dst->state & CLOSE_WAIT && off >= (dst->fin_expected + 1) ) {
                if ( off == (dst->fin_expected + 1) ) {
                        dprint("[%p][%p] - FIN acked, triggering reassembly.\n", stream, src);
                        tcp_stream_kill_one_side(stream, dst, src, ack);
                } else {
                        /*
                         * we got a FIN which wasn't ACKED, back to previous state.
                         */
                        log(LOG_INFO, "evasive FIN detection: FIN not acked. (src ack = %u, expected=%u)\n", off, dst->fin_expected + 1);
                        dst->state &= ~CLOSE_WAIT;
                }
        }
        
        return 0;
}



static const char *get_tcp_flags_string(tcphdr_t *tcp) 
{
        if ( tcp->th_flags & TH_RST )
                return "RST";

        else if ( tcp->th_flags & TH_FIN )
                return "FIN";

        else
                return NULL;
}




static int update_state_generic(packet_container_t *packet, iphdr_t *ip, tcphdr_t *tcp,
                                tcp_stream_t *stream, stream_t *src, stream_t *dst, int direction) 
{
        int ret;
        uint16_t len;
        uint32_t seq;

        seq = extract_uint32(&tcp->th_seq);
        len = extract_uint16(&ip->ip_len) - (IP_HL(ip) * 4) - (TH_OFF(tcp) * 4);
        
        ret = is_old_retransmission(seq, len, src, dst);
        if ( ret == 0 ) 
                return 0;
        
        if ( dst->state & ACK_RCVD && is_window_respected((seq - src->isn) + len, dst) < 0 ) {
                const char *kind = NULL;

                kind = get_tcp_flags_string(tcp);
                
                if ( kind ) 
                        log(LOG_INFO, "evasive %s detection throught out of window sequence (off=%u, win=%u).\n",
                            kind, seq - src->isn, dst->win);
                
                log(LOG_INFO, "Not respecting window size ( (((seq=%u - src->isn=%u) + len=%u)=%u) - dst->offset_acked=%u) = %u) > dst->win=%u).\n",
                    extract_uint32(&tcp->th_seq), src->isn, len, (seq - src->isn) + len,
                    dst->offset_acked, ((seq - src->isn) + len) - dst->offset_acked, dst->win);
                
                return tcp_stream_unknown;
        }

        if ( packet->paws_tsval < src->ts_recent ) {
                /*
                 * packet is an old retransmission.
                 */
                return -1;
        }
        
        if ( tcp->th_flags & TH_RST )
                return status_got_rst(stream, src, dst, tcp);
        
        if ( ! (stream->server.state & SYN_RCVD) )
                return -1;
        
        if ( tcp->th_flags & TH_FIN ) {
                ret = status_got_fin(stream, src, dst, ip, tcp);
                if ( ret < 0 )
                        return ret;
        }
        
        ret = update_stream_data(stream, src, dst, packet, ip, tcp, direction);
        if ( ret < 0 )
                return ret;

        /*
         * reset when we know there is no error.
         */
        timer_reset(&stream->timer);        
        src->ts_recent = packet->paws_tsval;
        
        /*
         * make sure this one run last, cause it might free the stream.
         */
        if ( tcp->th_flags & TH_ACK ) 
                ret = status_got_ack(packet, tcp, stream, src, dst);
        
        return ret;
}





static int update_server_state(packet_container_t *packet,
                               iphdr_t *ip, tcphdr_t *tcp, tcp_stream_t *stream)
{
        dprint("[%p] - client=%d, server=%d\n", stream, stream->client.state, stream->server.state);
        
        if ( tcp->th_flags & TH_SYN && stream->server.state & SYN_RCVD ) {
                uint32_t seq;
                
                /*
                 * syn ack in the middle of a connection !
                 * an attacker is probably trying to desynchronise us.
                 */
                seq = extract_uint32(&tcp->th_seq);

                if ( seq != stream->server.isn ) {
                        struct in_addr addr;
                        
                        log(LOG_INFO, "got SYN in the middle of a connection (seq=%u, isn=%u)!\n", seq, stream->server.isn);

                        addr.s_addr = stream->server.addr;
                        log(LOG_INFO, "strm %s:%d -> ", inet_ntoa(addr), stream->server.port);

                        addr.s_addr = stream->client.addr;
                        log(LOG_INFO, "%s:%d\n", inet_ntoa(addr), stream->client.port);
                        
                        addr.s_addr = align_uint32(&ip->ip_src.s_addr);
                        log(LOG_INFO, "pkts %s:%d -> ", inet_ntoa(addr), extract_uint16(&tcp->th_sport));

                        addr.s_addr = align_uint32(&ip->ip_dst.s_addr);
                        log(LOG_INFO, "%s:%d\n", inet_ntoa(addr), extract_uint16(&tcp->th_dport));
                }
                
                return tcp_stream_unknown;
        }

        
        if ( (tcp->th_flags & ~(TH_ECNECHO|TH_CWR)) == (TH_SYN|TH_ACK) ) 
                return twh_got_server_syn_ack(packet, stream, tcp);

        if ( ! (stream->server.state & (SYN_RCVD|ACK_RCVD)) ) 
                return tcp_stream_unknown;
        
        return update_state_generic(packet, ip, tcp, stream, &stream->server, &stream->client, STREAM_FROM_SERVER);
}




static int update_client_state(packet_container_t *packet,
                               iphdr_t *ip, tcphdr_t *tcp, tcp_stream_t *stream) 
{        
        /*
         * if we're there, we already have the client ISN, so we can start
         * gathering data.
         */
        dprint("[%p] - client=%d, server=%d\n", stream, stream->client.state, stream->server.state);

        if ( tcp->th_flags & TH_SYN ) {
                uint32_t seq;

                
                /*
                 * syn in the middle of a connection !
                 * an attacker is probably trying to desynchronise us.
                 */
                seq = extract_uint32(&tcp->th_seq);
                if ( seq != stream->client.isn )
                        log(LOG_INFO, "got SYN in the middle of a connection ! (seq=%u, isn=%u)\n", seq, stream->client.isn);
                
                return tcp_stream_unknown;
        }
                
        if ( ! (stream->client.state & SYN_RCVD) )
                return -1;
        
        return update_state_generic(packet, ip, tcp, stream, &stream->client, &stream->server, STREAM_FROM_CLIENT);
}



static int update_existing_stream(tcp_stream_t *stream, packet_container_t *packet, iphdr_t *ip, tcphdr_t *tcp) 
{
        int ret, direction;
        
        direction = get_stream_direction(stream, ip, tcp);
        
        if ( direction == STREAM_FROM_CLIENT )
                ret = update_client_state(packet, ip, tcp, stream);
        else 
                ret = update_server_state(packet, ip, tcp, stream);

        return ret;
}




static int create_new_stream_if_needed(packet_container_t *packet, iphdr_t *ip, tcphdr_t *tcp) 
{
        int flags;

        flags = tcp->th_flags & ~(TH_ECNECHO|TH_CWR);
        
	if ( flags == TH_SYN ) {
                twh_got_client_syn(packet, ip, tcp);
                
                /*
                 * return -1 so that the caller know it have to
                 * analyze the packet (SYN won't be part of the reassembled packet).
                 */
                return -1;
        }
                
        return tcp_stream_unknown;
}



static int is_port_reassembled(tcphdr_t *tcp) 
{
        int ret;
        
        if ( ! reasm_port )
                return 0;

        ret = protocol_plugin_is_port_ok(reasm_port, extract_uint16(&tcp->th_sport));
        if ( ret == 0 )
                return 0;
        
        return protocol_plugin_is_port_ok(reasm_port, extract_uint16(&tcp->th_dport));
}



int tcp_stream_store(packet_container_t *packet, iphdr_t *ip, tcphdr_t *tcp) 
{
        int ret;
        tcp_stream_t *stream;
      
        /*
         * check that packet is coming from or going to a port specified
         * in the port list. If it is not the case, return -1,
         * so that the packet is analyzed, even thought it won't be reassembled.
         */
        ret = is_port_reassembled(tcp);
        if ( ret < 0 )
                return -1;
        
        timer_lock_critical_region();
        
        stream = tcp_stream_search(ip, tcp);
        if ( stream )
                ret = update_existing_stream(stream, packet, ip, tcp);
        else 
                ret = create_new_stream_if_needed(packet, ip, tcp);
        
        timer_unlock_critical_region();
        
        return ret;
}



int tcp_stream_is_enabled(void) 
{
        return tcp_stream_enabled;
}




int tcp_stream_get_state(packet_container_t *pc) 
{
        int ret;
        iphdr_t *ip;
        tcphdr_t *tcp;
        tcp_stream_t *stream;
        
        ip = pc->packet[pc->network_layer_depth].p.ip;
        tcp = pc->packet[pc->transport_layer_depth].p.tcp;
                
        stream = tcp_stream_search(ip, tcp);
        if ( ! stream ) {
                timer_unlock_critical_region();
                return STREAM_STATELESS;
        }
        
        ret = get_stream_direction(stream, ip, tcp);
        
        if ( stream->client.state & ACK_RCVD && stream->server.state & ACK_RCVD )
                ret |= STREAM_ESTABLISHED;

        if ( pc->tcp_allocated_data )
                ret |= STREAM_PACKET_REASSEMBLED;
        else
                ret |= STREAM_PACKET_NOT_REASSEMBLED;
                
        return ret;
}



static int tcp_stream_enable(prelude_option_t *opt, const char *arg) 
{        
        tcp_stream_enabled = 1;
        return prelude_option_success;
}



static int set_reasm_from_client_only(prelude_option_t *opt, const char *arg) 
{
        tcp_stream_reasm_from = STREAM_FROM_CLIENT;
        return prelude_option_success;
}



static int set_reasm_from_server_only(prelude_option_t *opt, const char *arg) 
{
        tcp_stream_reasm_from = STREAM_FROM_SERVER;
        return prelude_option_success;
}



static int set_reasm_from_both_direction(prelude_option_t *opt, const char *arg) 
{
        tcp_stream_reasm_from = STREAM_FROM_CLIENT|STREAM_FROM_SERVER;
        return prelude_option_success;
}




static int set_reasm_port_list(prelude_option_t *opt, const char *arg) 
{
        int ret;
        
        reasm_port = protocol_plugin_port_list_new();
        if ( ! reasm_port )
                return prelude_option_error;

        ret = protocol_plugin_add_string_port_to_list(reasm_port, arg);
        if ( ret < 0 )
                return prelude_option_error;

        return prelude_option_success;
}




static int set_reasm_expire_time(prelude_option_t *opt, const char *arg) 
{
        session_expire_time = atoi(arg);
        return prelude_option_success;
}




void tcp_stream_init_config(void) 
{
        prelude_option_t *opt;
        
        opt = prelude_option_add(NULL, CLI_HOOK|CFG_HOOK|WIDE_HOOK, 't', "tcp-reasm",
                                 "Enable TCP stream reassembly (EXPERIMENTAL)", no_argument,
                                 tcp_stream_enable, NULL);

        prelude_option_set_priority(opt, option_run_first);
        
        prelude_option_add(opt, CLI_HOOK|CFG_HOOK, 'b', "both",
                           "Reassemble data in both direction", no_argument,
                           set_reasm_from_both_direction, NULL);
        
        prelude_option_add(opt, CLI_HOOK|CFG_HOOK, 'c', "client-only",
                           "Reassemble data from client only (default)", no_argument,
                           set_reasm_from_client_only, NULL);

        prelude_option_add(opt, CLI_HOOK|CFG_HOOK, 's', "server-only",
                           "Reassemble data from server only", no_argument,
                           set_reasm_from_server_only, NULL);

        prelude_option_add(opt, CLI_HOOK|CFG_HOOK, 'p', "port-list",
                           "Reassemble data on specified port only", required_argument,
                           set_reasm_port_list, NULL);
        
        prelude_option_add(opt, CLI_HOOK|CFG_HOOK, 'e', "expire",
                           "How much time an inactive session is kept (default is 120 seconds)",
                           required_argument, set_reasm_expire_time, NULL);

        srand(getpid());
        flush_point_index = rand();
}


























