
/******************************************************************************
**
**  Copyright (C) 2005 Brian Wotring.
**
**  This program is free software; you can redistribute it and/or
**  modify it, however, you cannot sell it.
**
**  This program is distributed in the hope that it will be useful,
**  but WITHOUT ANY WARRANTY; without even the implied warranty of
**  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
**
**  You should have received a copy of the license attached to the
**  use of this software.  If not, view a current copy of the license
**  file here:
**
**      http://www.hostintegrity.com/osiris/LICENSE
**
******************************************************************************/

/*****************************************************************************
**
**  File:    message.c
**  Date:    February 17, 2002
**  
**  Author:  Brian Wotring
**  Purpose: message handling routines
**
******************************************************************************/

#include <stdlib.h>
#include <stdio.h>
#include <stdarg.h>
#include <string.h>

#include <fcntl.h>

#ifdef WIN32
#include <ws2tcpip.h>
#else
#include <netdb.h>
#include <sys/types.h>
#include <sys/socket.h>
#endif

#ifdef HAVE_CONFIG_H
#include "config.h"
#endif

#ifdef HAVE_NETINET_IN_H
#include <netinet/in.h>
#endif

#include <openssl/ssl.h>

#include "utilities.h"
#include "socketapi.h"
#include "status.h"
#include "error.h"

#include "message.h"


/******************************************************************************
**
**    Function: initialize_message
**
**    Purpose:  set all fields to zero, any further allocation or setup
**	            should go here and sets the type based on type passed in.
**   
**    Return:   pointer to passed in message that is now zeroed out and
**              has set return type.
**      
******************************************************************************/

MESSAGE * initialize_message( MESSAGE *message, osi_uint16 type )
{
    if( message != NULL )
    {
        memset( message, 0, sizeof( MESSAGE ) );
        SET_MESSAGE_TYPE( message, type );
        
        return message;
    }
    
    return NULL;
}


/******************************************************************************
**
**    Function: message_set_payload
**
**    Purpose:  sets payload of message to be passed in data and sets length
**		        accordingly.  payload must contain payload_length bytes
**              of data.
**   
**    Return:   MESSAGE_OK if data could fit into message, and message was
**              valid.
**      
******************************************************************************/

int message_set_payload( MESSAGE *message, void *payload,
                         int payload_length, int sequence )
{
    int result = OSI_ERROR_UNKNOWN;
    
    if( ( message != NULL ) && ( payload != NULL ) &&
        ( payload_length <= MAX_MESSAGE_DATA_SIZE ) )
    {
        SET_MESSAGE_LENGTH( message, payload_length );
        SET_MESSAGE_SEQUENCE( message, sequence );

        if( payload_length > 0 )
        {        
            memcpy( message->data, payload, payload_length );
        }

        result = MESSAGE_OK;
    }
    
    return result;
}

/******************************************************************************
**
**    Function: osi_write_error_message
**
**    Purpose:  create an error message with the passed in string and write it
**		        on the specified socket.  This is a convenience method.
**   
**    Return:   MESSAGE_OK if we could write this message, error code otherwise.
**     
******************************************************************************/

int osi_write_error_message( osi_uint64 type, const char *error_message,
                             int connection_socket )
{
    MESSAGE message;
    OSI_ERROR error;
    
    int result = OSI_ERROR_UNKNOWN;
    
    if( error_message != NULL )
    {
        memset( &error, 0, sizeof( OSI_ERROR ) );
        initialize_message( &message, MESSAGE_TYPE_ERROR );
        
        /* populate message structure, then push into a message. */
        
        error.type = (osi_uint64)type;
        osi_strlcpy( error.message, error_message, sizeof( error.message ) );
        
        wrap_error( &error );
        
        message_set_payload( &message, (void *)&error, sizeof( error ), 0 );
        result = osi_write_message( &message, connection_socket );
    }
    
    return result;
}

/******************************************************************************
**
**    Function: write_success_message_to_socket
**
**    Purpose:  create an success message and send it to the specified socket.
**              This is another convienience method.
**   
**    Return:   MESSAGE_OK if we could write this message, error code otherwise.
**      
******************************************************************************/

int osi_write_success_message( int connection_socket )
{
    MESSAGE message;
    int result = OSI_ERROR_UNKNOWN;

    initialize_message( &message, MESSAGE_TYPE_SUCCESS );
    result = osi_write_message( &message, connection_socket );
    
    return result;
}

/******************************************************************************
**
**    Function: osi_ssl_write_error_message
**
**    Purpose:  create an error message with the passed in string and write it
**		        on the specified ssl.  This is a convenience method.
**   
**    Return:   MESSAGE_OK if we could write this message, error code otherwise.
**      
******************************************************************************/

int osi_ssl_write_error_message( osi_uint64 type, const char *error_message,
                                 SSL *ssl )
{
    MESSAGE message;
    OSI_ERROR error;
    
    int result = OSI_ERROR_UNKNOWN;
    
    if( ( error_message == NULL ) || ( ssl == NULL ) )
    {
        return result;
    }

    memset( &error, 0, sizeof( OSI_ERROR ) );
    initialize_message( &message, MESSAGE_TYPE_ERROR );
        
    /* populate message structure, then push into a message. */
        
    error.type = (osi_uint64)type;
    osi_strlcpy( error.message, error_message, sizeof( error.message ) );
        
    wrap_error( &error );
        
    message_set_payload( &message, (void *)&error, sizeof( error ), 0 );
    result = osi_ssl_write_message( &message, ssl );
    
    return result;
}

/******************************************************************************
**
**    Function: osi_ssl_write_success_message
**
**    Purpose:  create an success message and send it to the specified ssl.
**              This is another convienience method.
**   
**    Return:   MESSAGE_OK if we could write this message, error code otherwise.
**      
******************************************************************************/

int osi_ssl_write_success_message( SSL *ssl )
{
    MESSAGE message;
    int result = OSI_ERROR_UNKNOWN;

    initialize_message( &message, MESSAGE_TYPE_SUCCESS );
    result = osi_ssl_write_message( &message, ssl );
    
    return result;
}

/******************************************************************************
**
**    Function: osi_read_message
**
**    Purpose:  read message payload based upon length specified in the
**              header.  If the payload is null or the length exceeds
**              the known max, don't read. the payload is null or the
**              length exceeds the known max, don't read. we allocate
**              and read according to length into the message buffer.
**   
**    Return:   MESSAGE_OK if we could read header and body.  value
**              returned from the failed method otherwise.
**      
******************************************************************************/

int osi_read_message( MESSAGE *message, int connection_socket )
{
    int result = OSI_ERROR_UNKNOWN;
    
    if( message != NULL )
    {
        initialize_message( message, 0 );
        result = read_message_header_from_socket( &message->header,
                                                  connection_socket );
        
        if( result == MESSAGE_OK )
        {        
            result = message_header_is_valid( &( message->header ) );
            
            if( result == MESSAGE_OK )
            {
                result = read_message_body_from_socket( message,
                                                        connection_socket );
            }
        }
    }
    
    return result;
}


/******************************************************************************
**
**    Function: read_message_header_from_socket
**
**    Purpose:  read specific number of bytes according to known size of
**              a message header.  message headers are sent in network
**              order so we convert header values back to host order. 
**   
**    Return:   MESSAGE_OK if read the correct amount of bytes, otherwise,
**              the appropriate error code is returned.
**      
******************************************************************************/

int read_message_header_from_socket( MESSAGE_HEADER *header,
                                     int connection_socket )
{
    int result = OSI_ERROR_UNKNOWN;
    int bytes_read   = 0;
    
    if( ( header != NULL ) && ( connection_socket > 0 ) )
    {    
        bytes_read = osi_read_bytes( connection_socket,
                                     (unsigned char *)header,
                                     MESSAGE_HEADER_SIZE, MESSAGE_HEADER_SIZE );
            
        if( bytes_read == MESSAGE_HEADER_SIZE )
        {
            /* convert back to host byte order. */
            
            header->type     = ntohs( header->type );
            header->length   = ntohs( header->length );
            header->sequence = ntohs( header->sequence );

            result = MESSAGE_OK;
        }
        
        else if( bytes_read == 0 )
        {
            result = OSI_ERROR_SOCKET_CLOSED;
        }
        
        /* otherwise, an error occured, let MESSAGE_ERROR fall through. */
    }
        
    return result;
}

/******************************************************************************
**
**    Function: read_message_body_from_socket
**
**    Purpose:  read message payload based upon length specified in the header. 
**		        If the payload is null or the length exceeds the known max,
**              don't read. we allocate and read according to length into
**              the message buffer.
**   
**    Return:   MESSAGE_OK if read the correct amount of bytes, otherwise,
**              the appropriate error code is returned.
**      
******************************************************************************/

int read_message_body_from_socket( MESSAGE *message, int connection_socket )
{    
    int result	       = OSI_ERROR_UNKNOWN;

    int bytes_read     = 0;
    int bytes_to_read  = 0;
    
    if( message != NULL && connection_socket > 0 )
    {
        bytes_to_read = message->header.length;
        
        /* do some sanity checking. this might be a null payload message */
        /* in which case we don't need to read anymore, or the header    */
        /* might be insane, verify before we allocate.                   */
                
        if( bytes_to_read == 0 )
        {
            return MESSAGE_OK;
        }
            
        if( bytes_to_read > MAX_MESSAGE_DATA_SIZE )
        {
            return OSI_ERROR_MESSAGE_PAYLOAD_TOO_LARGE;
        }
        
        bytes_read = osi_read_bytes( connection_socket,
                                    (unsigned char *)message->data,
                                    sizeof( message->data ), bytes_to_read );
            
        if( bytes_read == bytes_to_read )
        {
            result = MESSAGE_OK;
        }
        
        else if( bytes_read == 0 )
        {
            result = OSI_ERROR_SOCKET_CLOSED;
        }
        
        /* otherwise, an error occured, let MESSAGE_ERROR fall through. */
    }
    
    return result;
}

/******************************************************************************
**
**    Function: osi_write_message
**
**    Purpose:  send passed in message using the passed in socket descriptor.
**   
**    Return:   MESSAGE_OK if able to send entire message, MESSAGE_ERROR
**              otherwise.
**      
******************************************************************************/

int osi_write_message( MESSAGE *message, int connection_socket )
{
    int bytes_to_write = 0;
    int bytes_written  = 0;
    
    int result = OSI_ERROR_UNKNOWN;
    
    if( message != NULL )
    {
        bytes_to_write = (int)( GET_TOTAL_MESSAGE_SIZE( message ) );
        
        /* convert header to network order. */
        
        SET_MESSAGE_TYPE( message, htons( GET_MESSAGE_TYPE( message ) ) );
        SET_MESSAGE_LENGTH( message, htons( GET_MESSAGE_LENGTH( message ) ) );

        SET_MESSAGE_SEQUENCE( message,
                              htons( GET_MESSAGE_SEQUENCE( message ) ) );

        bytes_written = osi_write_bytes( connection_socket,
                                         (unsigned char *)message,
                                         bytes_to_write );
        
        if( bytes_written == bytes_to_write )
        {
            result = MESSAGE_OK;
        }
        
        if( bytes_written == 0 )
        {
            result = OSI_ERROR_SOCKET_CLOSED;
        }

        /* otherwise, an error occured, let MESSAGE_ERROR fall through. */
    }
    
    return result;
}

int osi_ssl_read_http_message( MESSAGE *message, SSL *ssl )
{
    int result = OSI_ERROR_UNKNOWN;
    int bytes = 0;
   
    if( ( message == NULL ) || ( ssl == NULL ) )
    {
        return OSI_ERROR_NULL_ARGUMENT;
    }

    initialize_message( message, MESSAGE_TYPE_HTTP_DATA );
    bytes = osi_ssl_read( ssl, message->data, ( sizeof( message->data ) - 1 ) );

    if( bytes > 0 )
    {
        message->data[bytes] = '\0';
        result = MESSAGE_OK;
    }

    else if( bytes == 0 )
    {
        result = OSI_ERROR_SOCKET_CLOSED;
    }

    else
    {
        result = OSI_ERROR_MESSAGE_READ;
    }

    return result;
}

int osi_ssl_read_message( MESSAGE *message, SSL *ssl )
{
    int result = OSI_ERROR_UNKNOWN;
    
    if( ( message != NULL ) && ( ssl != NULL ) )
    {
        initialize_message( message, 0 );
        result = osi_ssl_read_message_header( &message->header, ssl );
        
        if( result == MESSAGE_OK )
        {        
            result = message_header_is_valid( &( message->header ) );
            
            if( result == MESSAGE_OK )
            {
                result = osi_ssl_read_message_body( message, ssl );
            }
        }
    }
    
    return result;
}

int osi_ssl_read_message_header( MESSAGE_HEADER *header, SSL *ssl )
{
    int result = OSI_ERROR_UNKNOWN;
    int bytes_read   = 0;
    
    if( ( header != NULL ) && ( ssl != NULL ) )
    {    
        bytes_read = osi_ssl_read_bytes( ssl, (unsigned char *)header,
                                         MESSAGE_HEADER_SIZE,
                                         MESSAGE_HEADER_SIZE );
            
        if( bytes_read == MESSAGE_HEADER_SIZE )
        {
            /* convert back to host byte order. */
            
            header->type     = ntohs( header->type );
            header->length   = ntohs( header->length );
            header->sequence = ntohs( header->sequence );

            result = MESSAGE_OK;
        }
        
        else if( bytes_read == 0 )
        {
            result = OSI_ERROR_SOCKET_CLOSED;
        }
        
        /* otherwise, an error occured, let MESSAGE_ERROR fall through. */
    }
        
    return result;
}

int osi_ssl_read_message_body( MESSAGE *message, SSL *ssl )
{
    int result	       = OSI_ERROR_UNKNOWN;

    int bytes_read     = 0;
    int bytes_to_read  = 0;
    
    if( ( message != NULL ) && ( ssl != NULL ) )
    {
        bytes_to_read = message->header.length;
        
        /* do some sanity checking. this might be a null payload message */
        /* in which case we don't need to read anymore, or the header    */
        /* might be insane, verify before we allocate.                   */
                
        if( bytes_to_read == 0 )
        {
            return MESSAGE_OK;
        }
            
        if( bytes_to_read > MAX_MESSAGE_DATA_SIZE )
        {
            return OSI_ERROR_MESSAGE_PAYLOAD_TOO_LARGE;
        }
        
        bytes_read = osi_ssl_read_bytes( ssl, (unsigned char *)message->data,
                                         sizeof( message->data ),
                                         bytes_to_read );
            
        if( bytes_read == bytes_to_read )
        {
            result = MESSAGE_OK;
        }
        
        else if( bytes_read == 0 )
        {
            result = OSI_ERROR_SOCKET_CLOSED;
        }
        
        /* otherwise, an error occured, let MESSAGE_ERROR fall through. */
    }
    
    return result;
}


int osi_ssl_write_message( MESSAGE *message, SSL *ssl )
{
    int bytes_to_write = 0;
    int bytes_written  = 0;
    
    int result = OSI_ERROR_UNKNOWN;
    
    if( ( message != NULL ) && ( ssl != NULL ) )
    {
        bytes_to_write = (int)( GET_TOTAL_MESSAGE_SIZE( message ) );
        
        /* convert header to network order. */
        
        SET_MESSAGE_TYPE( message, htons( GET_MESSAGE_TYPE( message ) ) );
        SET_MESSAGE_LENGTH( message, htons( GET_MESSAGE_LENGTH( message ) ) );
        SET_MESSAGE_SEQUENCE( message,
                              htons( GET_MESSAGE_SEQUENCE( message ) ) );

        bytes_written = osi_ssl_write_bytes( ssl,
                                             (unsigned char *)message,
                                             bytes_to_write );
        
        if( bytes_written == bytes_to_write )
        {
            result = MESSAGE_OK;
        }
        
        if( bytes_written == 0 )
        {
            result = OSI_ERROR_SOCKET_CLOSED;
        }

        /* otherwise, an error occured, let MESSAGE_ERROR fall through. */
    }
    
    return result;
}

/******************************************************************************
**
**    Function: message_header_is_valid
**
**    Purpose:  verify message header based upon some basic criteria, certain
**		        message types must have a payload, others should not have a
**		        payload.
**   
**    Return:   MESSAGE_OK if header is valid, MESSAGE_ERROR otherwise.
**      
******************************************************************************/

int message_header_is_valid( MESSAGE_HEADER *header )
{
    int result = OSI_ERROR_UNKNOWN;
    
    if( header != NULL )
    {
        switch( (int)header->type )
        {
            /* messages of any size. */
            
            case MESSAGE_TYPE_SESSION_KEY:
            case MESSAGE_TYPE_SUCCESS:
            case MESSAGE_TYPE_CONTROL_DATA_LAST:

            
                result = MESSAGE_OK;
                break;
            
            /* empty messages. */
            
            case MESSAGE_TYPE_START_SCAN:            
            case MESSAGE_TYPE_STOP_SCAN:            
            case MESSAGE_TYPE_STATUS_REQUEST:
            case MESSAGE_TYPE_DROP_CONFIG:
            
                if( header->length == 0 )
                {
                    result = MESSAGE_OK;
                }
                
                break;

            /* non-empty messages. */
           
            case MESSAGE_TYPE_STATUS_RESPONSE:
            case MESSAGE_TYPE_CONTROL_DATA:
            case MESSAGE_TYPE_CONTROL_DATA_FIRST:
            case MESSAGE_TYPE_SCAN_DATA_FIRST:
            case MESSAGE_TYPE_SCAN_DATA:
            case MESSAGE_TYPE_SCAN_DATA_LAST:
            case MESSAGE_TYPE_CONFIG_DATA:
            case MESSAGE_TYPE_CONFIG_DATA_FIRST:
            case MESSAGE_TYPE_CONFIG_DATA_LAST:
            case MESSAGE_TYPE_DB_DATA:
            case MESSAGE_TYPE_DB_DATA_FIRST:
            case MESSAGE_TYPE_DB_DATA_LAST:
            case MESSAGE_TYPE_CONTROL_REQUEST:
            case MESSAGE_TYPE_ERROR:

                if( header->length > 0 )
                {
                    result = MESSAGE_OK;
                }
                
                break;
            
            default:
            
                result = OSI_ERROR_UNKNOWN;
                break;
        }
    }
    
    return result;
}

/******************************************************************************
**
**    Function: dump_message
**
**    Purpose:  dump message header and information to current logging
**              facility. For now we simply print the header information.
**   
**    Return:   none.
**      
******************************************************************************/

void dump_message( MESSAGE *message )
{
    char *type    = MESSAGE_STRING_UNKNOWN;
    char *payload = "<PAYLOAD IS NULL>";
    
    int length   = 0;
    int sequence = 0;

    if( message != NULL )
    {
        type = get_name_for_message_type( (int)GET_MESSAGE_TYPE( message ) );

        length   = message->header.length;
        sequence = message->header.sequence;
        
        if( length > 0 )
        {
            payload = "<PAYLOAD EXISTS>";
        }
        
        osi_print_stdout(
                "[ message dump ] TYPE=%s ; LENGTH=%d ; SEQUENCE=%d ; %s",
                type, length, sequence, payload );
    }
}

/******************************************************************************
**
**    Function: dump_message_header
**
**    Purpose:  dump message header to standard output.
**   
**    Return:   none.
**      
******************************************************************************/

void dump_message_header( MESSAGE *message )
{
    if( message != NULL )
    {
        int index;
        char *type    = MESSAGE_STRING_UNKNOWN;
    
        int length   = 0;
        int sequence = 0;
        
        unsigned char header[MESSAGE_HEADER_SIZE];
    
        type = get_name_for_message_type( (int)GET_MESSAGE_TYPE( message ) );

        length   = message->header.length;
        sequence = message->header.sequence;
        
        memcpy( header, message, sizeof( header ) );
        
        osi_print_stdout( " [ message header ]\n" );
        osi_print_stdout( "     type: %s", type );
        osi_print_stdout( "   length: %d", length );
        osi_print_stdout( " sequence: %d", sequence );
        fprintf( stdout, "  raw: 0x" );
        
        for( index = 0; index < MESSAGE_HEADER_SIZE; index++ )
        {
            fprintf( stdout, "%02X ", header[index] );
        }
        
        fprintf( stdout, "\n" );
    }
}


/******************************************************************************
**
**    Function: get_name_for_message_type
**
**    Purpose:  used to get a string description of this message type. 
**   
**    Return:   if found, the string associated with this type, unknown string
**		        otherwise.
**      
******************************************************************************/

char * get_name_for_message_type( int type )
{
    switch( type )
    {
        case MESSAGE_TYPE_SUCCESS:
            return MESSAGE_STRING_SUCCESS;
            break;
        
        case MESSAGE_TYPE_SESSION_KEY:
            return MESSAGE_STRING_SESSION_KEY;
            break;
            
        case MESSAGE_TYPE_START_SCAN:
            return MESSAGE_STRING_START_SCAN;
            break;
            
        case MESSAGE_TYPE_STOP_SCAN:
            return MESSAGE_STRING_STOP_SCAN;
            break;
            
        case MESSAGE_TYPE_STATUS_REQUEST:
            return MESSAGE_STRING_STATUS_REQUEST;
            break;

        case MESSAGE_TYPE_STATUS_RESPONSE:
            return MESSAGE_STRING_STATUS_RESPONSE;
            break;
            
        case MESSAGE_TYPE_CONFIG_DATA:
            return MESSAGE_STRING_CONFIG_DATA;
            break;
            
        case MESSAGE_TYPE_SCAN_DATA:
            return MESSAGE_STRING_SCAN_DATA;
            break;

        case MESSAGE_TYPE_CONFIG_DATA_FIRST:
            return MESSAGE_STRING_CONFIG_DATA_FIRST;
            break;
            
        case MESSAGE_TYPE_CONFIG_DATA_LAST:
            return MESSAGE_STRING_CONFIG_DATA_LAST;
            break;
            
        case MESSAGE_TYPE_SCAN_DATA_FIRST:
            return MESSAGE_STRING_SCAN_DATA_FIRST;
            break;
            
        case MESSAGE_TYPE_SCAN_DATA_LAST:
            return MESSAGE_STRING_SCAN_DATA_LAST;
            break;
            
        case MESSAGE_TYPE_CONTROL_REQUEST:
            return MESSAGE_STRING_CONTROL_REQUEST;
            break;
            
        case MESSAGE_TYPE_CONTROL_DATA:
            return MESSAGE_STRING_CONTROL_DATA;
            break;

        case MESSAGE_TYPE_CONTROL_DATA_FIRST:
            return MESSAGE_STRING_CONTROL_DATA_FIRST;
            break;
    
        case MESSAGE_TYPE_CONTROL_DATA_LAST:
            return MESSAGE_STRING_CONTROL_DATA_LAST;
            break;
            
        case MESSAGE_TYPE_ERROR:
            return MESSAGE_STRING_ERROR;
            break;
            
        default:
            return MESSAGE_STRING_UNKNOWN;
            break;
    }
}


