/*****
*
* Copyright (C) 2001 Jeremie Brebec <flagg@ifrance.com>
* All Rights Reserved
*
* This file is part of the Prelude program.
*
* This program is free software; you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by 
* the Free Software Foundation; either version 2, or (at your option)
* any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program; see the file COPYING.  If not, write to
* the Free Software Foundation, 675 Mass Ave, Cambridge, MA 02139, USA.
*
* Written by Jeremie Brebec <flagg@ifrance.com>
*
*****/


#include <stdlib.h>
#include <stdio.h>
#include <string.h>

#include <libprelude/prelude-log.h>

#include "rules.h"
#include "rules-type.h"
#include "rules-parsing.h"



/**********************************************************
 * IP type                                                *
 **********************************************************/



ip_t *copy_ip(ip_t *ip) 
{
	ip_t *new;
        
        new = malloc(sizeof(ip_t));
	if ( ! new ) {
                log(LOG_ERR, "memory exhausted.\n");
		return NULL;
	}

	new->ip = ip->ip;
	new->netmask = ip->netmask;

	return new;
}



void delete_ip(ip_t *ip) 
{
	free(ip);
}



int equal_ip(ip_t *ip1, ip_t *ip2) 
{
	return ip1->ip == ip2->ip && ip1->netmask == ip2->netmask;
}



int match_ip(ip_t *ip, ip_t *subip) 
{
	return (subip->ip & subip->netmask & ip->netmask) == (ip->ip & ip->netmask);
}



int intersection_ip(ip_t *ip1, ip_t *ip2) 
{
	return 0;
}



void split_ip(ip_t *ip1, ip_t *ip2, ip_t **inter ,ip_t **r_ip1, ip_t **r_ip2) 
{
	/* should never be called for ip */
}



void print_ip(ip_t *ip) 
{
	struct in_addr a;
        
	a.s_addr = ip->ip;
	fprintf(stderr, "Type ip ip = %s", inet_ntoa(a));
	a.s_addr = ip->netmask;
	fprintf(stderr, " ms = %s\n", inet_ntoa(a));
}



/**********************************************************
 * Segment type                                           *
 **********************************************************/


segment_t *copy_segment(segment_t *segment) 
{
	segment_t *new;

        new = malloc(sizeof(segment_t));
	if ( ! new ) {
		log(LOG_ERR, "memory exhausted.\n");
		return NULL;
	}

	new->lo = segment->lo;
	new->hi = segment->hi;

	return new;
}



void delete_segment(segment_t *segment) 
{
	free(segment);
}



int equal_segment(segment_t *segment1, segment_t *segment2) 
{
	return segment1->lo == segment2->lo && segment1->hi == segment2->hi;
}



int match_segment(segment_t *segment, segment_t *subsegment) 
{
	return (subsegment->lo >= segment->lo) && (subsegment->hi <= segment->hi);
}



int intersection_segment(segment_t *segment1, segment_t *segment2) 
{
	return ((segment2->lo >= segment1->lo && 
		 segment2->lo < segment1->hi && 
		 segment2->hi >= segment1->hi ) ||
	        (segment1->lo >= segment2->lo && 
		 segment1->lo < segment2->hi && 
		 segment1->hi >= segment2->hi ));
}



void split_segment(segment_t *segment1, segment_t *segment2, 
                   segment_t**inter, segment_t**r_segment1, segment_t**r_segment2) 
{
	*inter = malloc(sizeof(segment_t));
	*r_segment1 = copy_segment( segment1 );
	*r_segment2 = copy_segment( segment2 );

	if (! *inter || ! *r_segment1 || ! *r_segment2) {
		log(LOG_ERR, "memory exhausted.\n");
		return;
	}

	if ( segment1->lo < segment2->lo ) {
		(*r_segment2)->lo = segment1->hi;
		(*r_segment1)->hi = segment2->lo;
		(*inter)->lo = segment2->lo;
		(*inter)->hi = segment1->hi;
	} else {
		(*r_segment1)->lo = segment2->hi;
		(*r_segment2)->hi = segment1->lo;
		(*inter)->lo = segment1->lo;
		(*inter)->hi = segment2->hi;
	}
}



void print_segment(segment_t *segment) 
{
	fprintf(stderr, "Type segment [%u - %u]\n", segment->lo, segment->hi);
}




/**********************************************************
 * Flags type                                             *
 **********************************************************/



flags_t *copy_flags(flags_t *flags) 
{
	flags_t *new;

        new = malloc(sizeof(flags_t));
	if (! new ) {
		log(LOG_ERR, "memory exhausted.\n");
		return NULL;
	}

	new->flags = flags->flags;
	new->mask = flags->mask;

	return new;
}



void delete_flags(flags_t *flags) 
{
	free(flags);
}



int equal_flags(flags_t *flags1, flags_t *flags2) 
{
	return flags1->flags == flags2->flags && flags1->mask == flags2->mask;
}



int match_flags(flags_t *flags, flags_t *subflags) 
{
	int i_mask = flags->mask & subflags->mask;

	return ((flags->flags & i_mask) == (subflags->flags & i_mask) ) &&
                (i_mask != subflags->mask);
}



int intersection_flags(flags_t *flags1, flags_t *flags2) 
{
	return 0;
} 



void split_flags(flags_t *flags1, flags_t *flags2,
                 flags_t **inter, flags_t **r_flags1, flags_t **r_flags2) 
{
	/* nothing */
}



void print_flags(flags_t *flags) 
{
	fprintf(stderr, "Type flags flags:%u mask:%u\n", flags->flags, flags->mask);
}



/**********************************************************
 * Integer type
 **********************************************************/


integer_t *copy_integer(integer_t *integer) 
{
	integer_t *new;

        new = malloc(sizeof(integer_t));
	if ( ! new ) {
		log(LOG_ERR, "memory exhausted.\n");
		return NULL;
	}

	new->num = integer->num;

	return new;
}



void delete_integer(integer_t *integer) 
{
	free(integer);
}



int equal_integer(integer_t *integer1, integer_t *integer2) 
{
	return integer1->num == integer2->num;
}



int match_integer(integer_t *integer, integer_t *subinteger) 
{
	return integer->num == subinteger->num;
}



int intersection_integer(integer_t *integer1, integer_t *integer2) 
{
	return 0; /* always false */
}



void split_integer(integer_t *integer1, integer_t *integer2,
                   integer_t **inter, integer_t **r_integer1, integer_t **r_integer2) 
{
	/* should never be called */
}



void print_integer(integer_t *integer) 
{
	fprintf(stderr, "Type integer num: %d\n", integer->num);
}



/**********************************************************
 * Parsing
 **********************************************************/



/*
 * ip
 */
ip_t *parse_ip(const char *str)
{
        ip_t *ip;
        char *s, *_str = strdup( str );
        static unsigned long netmasks[] = {
                0x0       , 0x80000000, 0xC0000000, 0xE0000000, 0xF0000000, 
                0xF8000000, 0xFC000000, 0xFE000000, 0xFF000000, 0xFF800000,
                0xFFC00000, 0xFFE00000, 0xFFF00000, 0xFFF80000, 0xFFFC0000,
                0xFFFE0000, 0xFFFF0000, 0xFFFF8000, 0xFFFFC000, 0xFFFFE000,
                0xFFFFF000, 0xFFFFF800, 0xFFFFFC00, 0xFFFFFE00, 0xFFFFFF00,
                0xFFFFFF80, 0xFFFFFFC0, 0xFFFFFFE0, 0xFFFFFFF0, 0xFFFFFFF8,
                0xFFFFFFFC, 0xFFFFFFFE, 0xFFFFFFFF
        };             
 	
	ip = malloc(sizeof(ip_t));
	if ( ! ip ) {
                log(LOG_ERR, "memory exhausted.\n");
		return NULL;
	}
	
	/* get ip from ip/netmask */
	s = strtok(_str, "/");
	
	if ( ! inet_aton(s, (struct in_addr *) &ip->ip) ) {
		signature_parser_set_error("Invalid IP %s", s);
                goto err;
	}
	
	/* get netmask */
	s = strtok(NULL, "/");
	if ( ! s )
		ip->netmask = INADDR_BROADCAST;
	else {
		if ( strchr(s, '.') ) {
			/* netmask a.b.c.d */			
			if (! inet_aton( s, (struct in_addr *) &ip->netmask) ) {
				signature_parser_set_error("Invalid Netmask %s", s);
				goto err;
			}
                        
		} else {

			/* numeric netmask */
			int i = atoi(s);
                        
			if (i < 0 || i > 32) {
				signature_parser_set_error("Netmask out of range (%d)", i);
                                goto err;
			}
			ip->netmask = htonl(netmasks[ i ]);
		}
	}

	ip->ip &= ip->netmask;
	free(_str);
        
	return ip;

 err:
        free(ip);
        free(_str);
        return NULL;
}



/* 
 * segment, x or < x or > x
 */
segment_t *parse_segment(const char *str) 
{
	int n;
        segment_t *segment;
	enum { min_segment, max_segment, exact_segment } mode;
        
	switch( *str ) {
	case '>': 
		mode = min_segment;
		str++;
		break;
                
	case '<':
		mode = max_segment;
		str++;
		break;
                
	case '0' ... '9':
		mode = exact_segment;
		break;
                
	default:
		return NULL;
	}

	if ( strlen(str) == 0 ) 
		return NULL;

	n = atoi(str);

	segment = malloc(sizeof(segment_t));
	if ( ! segment ) {
                log(LOG_ERR, "memory exhausted.\n");
		return NULL;
	}
	
	switch (mode) {
	case min_segment:
		segment->lo = n;
		segment->hi = ~0; /* +oo ? */
		break;
	case max_segment:
		segment->lo = 0;
		segment->hi = n;
		break;
	case exact_segment:
		segment->lo = n;
		segment->hi = n;
		break;
	}

	return segment;
}



integer_t *parse_integer(const char *str) 
{
	integer_t *integer;

        integer = malloc(sizeof(integer_t));
	if ( ! integer ) {
		log(LOG_ERR, "memory exhausted.\n");
		return NULL;
	}

	/* really basic.. */
	integer->num = atoi(str);

	return integer;
}




#define MAX_TYPE 30
static generic_type_t rules_types[MAX_TYPE];



/*
 * Return the type structure associated with the id
 */
generic_type_t *signature_engine_get_type_by_id(int id)
{
	if ( id < 0 || id > MAX_TYPE )
		return NULL;
	
	return &rules_types[id];
}




/*
 * Register a new type
 */
int signature_engine_register_type(
        int priority, copy_f_t copy, delete_f_t delete, equal_f_t equal,
        match_f_t match, intersection_f_t intersection, split_f_t split,
        print_f_t print, match_packet_f_t match_packet)
{
	generic_type_t *type;
	static int type_max = -1;

	if ( type_max >= MAX_TYPE )
		return -1;

	type_max++;

	type = &rules_types[type_max];

	type->id = type_max;
	type->priority = priority;

	type->copy = copy;
	type->delete = delete;
	type->equal = equal;
	type->match = match;
	type->intersection = intersection;
	type->split = split;
	type->match_packet = match_packet;
	type->print = print;

	return type_max;
}
