/*
 * tls.c
 *
 * This file is part of msmtp, an SMTP client.
 *
 * Copyright (C) 2000, 2003, 2004
 * Martin Lambers <marlam@users.sourceforge.net>
 *
 *   This program is free software; you can redistribute it and/or modify
 *   it under the terms of the GNU General Public License as published by
 *   the Free Software Foundation; either version 2 of the License, or
 *   (at your option) any later version.
 *
 *   This program is distributed in the hope that it will be useful,
 *   but WITHOUT ANY WARRANTY; without even the implied warranty of
 *   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 *   GNU General Public License for more details.
 *
 *   You should have received a copy of the GNU General Public License
 *   along with this program; if not, write to the Free Software
 *   Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
 *
 *   msmtp is released under the GPL with the additional exemption that
 *   compiling, linking, and/or using OpenSSL is allowed.
 */

#ifdef HAVE_CONFIG_H
#include "config.h"
#endif

#include <string.h>
#include <strings.h>
#include <stdlib.h>
#include <limits.h>
#include <time.h>
#include <errno.h>
extern int errno;

#ifdef HAVE_GNUTLS
#include <gnutls/gnutls.h>
#include <gnutls/x509.h>
#endif /* HAVE_GNUTLS */
#ifdef HAVE_OPENSSL
#include <openssl/ssl.h>
#include <openssl/x509.h>
#include <openssl/err.h>
#include <openssl/rand.h>
#endif /* HAVE_OPENSSL */

#include "merror.h"
#include "tls.h"



/*
 * tls_clear()
 *
 * see tls.h
 */

void tls_clear(tls_t *tls)
{
    tls->have_trust_file = 0;
    tls->is_active = 0;
}


/*
 * seed_prng()
 *
 * Seeds the OpenSSL random number generator.
 * Used error codes: TLS_ESEED
 */

#ifdef HAVE_OPENSSL
merror_t seed_prng(void)
{
    char randfile[512];
    time_t t;
    int prn;
    int system_prn_max = 1024;
    
    /* Most systems have /dev/random or other sources of random numbers that
     * OpenSSL can use to seed itself.
     * The only system I know of where we must seed the PRNG is DOS.
     */
    if (!RAND_status())
    {
	if (!RAND_file_name(randfile, 512))
	{
	    return merror(TLS_ESEED, 
	    	    "no environment variables RANDFILE or HOME, "
	    	    "or filename of rand file too long");
	}
	if (RAND_load_file(randfile, -1) < 1)
	{
	    return merror(TLS_ESEED, "error reading %s", randfile);
	}
	/* Seed in time. I can't think of other "random" things on DOS
	 * systems. */
	if ((t = time(NULL)) < 0)
	{
	    return merror(TLS_ESEED, "cannot get system time: %s", 
	    	    strerror(errno));
	}
	RAND_seed((unsigned char *)&t, sizeof(time_t));
	/* If the RANDFILE + time is not enough, we fall back to the insecure
	 * and stupid method of seeding OpenSSLs PRNG with the systems PRNG. This
	 * is still better than always using the same RANDFILE without changes.
	 * The user has been warned in the documentation. */
	if (!RAND_status())
	{
	    srand((unsigned int)(t % UINT_MAX));
	    while (!RAND_status() && system_prn_max > 0)
	    {
		prn = rand();
		RAND_seed(&prn, sizeof(int));
		system_prn_max--;
	    }
	}
	/* Are we happy now? */
	if (!RAND_status())
	{
	    return merror(TLS_ESEED, 
		    "random file + time + pseudo randomness is not enough, "
		    "giving up");
	}
	/* Save a rand file for later usage. We ignore errors here as we can't
	 * do anything about them. */
	(void)RAND_write_file(randfile);
    }
    return merror(EOK, NULL);
}
#endif /* HAVE_OPENSSL */


/*
 * tls_lib_init()
 *
 * see tls.h
 */

merror_t tls_lib_init(void)
{
#ifdef HAVE_GNUTLS
    int error_code;
    
    if ((error_code = gnutls_global_init()) != 0)
    {
	return merror(TLS_ELIBFAILED, "%s", gnutls_strerror(error_code));
    }
    return merror(EOK, NULL);
#endif /* HAVE_GNUTLS */

#ifdef HAVE_OPENSSL
    merror_t e;
    
    SSL_load_error_strings();
    SSL_library_init();
    if (!merror_ok(e = seed_prng()))
    {
	return e;
    }

    return merror(EOK, NULL);
#endif /* HAVE_OPENSSL */
}


/*
 * tls_is_active()
 *
 * see tls.h
 */

int tls_is_active(tls_t *tls)
{
    return tls->is_active;
}


/* 
 * tls_check_cert()
 *
 * If the 'verify' flag is set, perform a real verification of the peer's 
 * certificate. If this succeeds, the connection can be considered secure.
 * If the 'verify' flag is not set, perform only a few sanity checks of the
 * peer's certificate. You cannot trust the connection when this succeeds.
 * Used error codes: TLS_ECERT
 */

merror_t tls_check_cert(tls_t *tls, char *hostname, int verify)
{
#ifdef HAVE_GNUTLS
    char *error_msg;
    int status;
    const gnutls_datum *cert_list;
    unsigned int cert_list_size;
    unsigned int i;
    gnutls_x509_crt cert;
    time_t t1, t2;
    
    if (verify)
    {
	error_msg = "TLS certificate verification failed";
    }
    else
    {
	error_msg = "TLS certificate check failed";
    }
    
    /* If 'verify' is true, this function uses the trusted CAs in the credentials
     * structure. So you must have installed one or more CA certificates. 
     */
    status = gnutls_certificate_verify_peers(tls->session);
    if (status == GNUTLS_E_NO_CERTIFICATE_FOUND) 
    {
	return merror(TLS_ECERT, "%s: no certificate was sent", error_msg);
    }
    if (status & GNUTLS_CERT_REVOKED)
    {
	return merror(TLS_ECERT, 
		"%s: the certificate has been revoked", error_msg);
    }
    if (gnutls_certificate_type_get(tls->session) != GNUTLS_CRT_X509)
    {
	return merror(TLS_ECERT, 
		"%s: the certificate type is not X509", error_msg);
    }
    if (verify)
    {
	if (status & GNUTLS_CERT_INVALID)
	{
	    return merror(TLS_ECERT, 
		    "%s: the certificate is not trusted", error_msg);
	}
    }
    if (!(cert_list = gnutls_certificate_get_peers(tls->session, &cert_list_size)))
    {
	return merror(TLS_ECERT, "%s: no certificate was found", error_msg);
    }
    /* Needed to check times: */
    if ((t1 = time(NULL)) < 0)
    {
	return merror(TLS_ECERT, "%s: %s", error_msg, strerror(errno));
    }
    /* Check the certificate chain. All certificates in the chain must have
     * valid activation/expiration times. The first certificate in the chain is
     * the host's certificate; it must match the hostname. */
    for (i = 0; i < cert_list_size; i++)
    {
	if (gnutls_x509_crt_init(&cert) < 0)
	{
	    return merror(TLS_ECERT, 
		    "%s: cannot initialize certificate structure", error_msg);
	}
	if (gnutls_x509_crt_import(cert, &cert_list[i], GNUTLS_X509_FMT_DER) < 0) 
	{
	    return merror(TLS_ECERT,
		    "%s: error parsing certificate %u of %u", 
		    error_msg, i + 1, cert_list_size);
	}
	/* Check hostname */
	if (i == 0 && !gnutls_x509_crt_check_hostname(cert, hostname)) 
	{
    	    return merror(TLS_ECERT, 
		    "%s: the host certificate's owner does not match "
		    "hostname %s", error_msg, hostname);
	}
	/* Check certificate times */
	if ((t2 = gnutls_x509_crt_get_activation_time(cert)) < 0)
	{
	    return merror(TLS_ECERT, 
		    "%s: cannot get activation time for certificate "
		    "%u of %u", error_msg, i + 1, cert_list_size);
	}
	if (t2 > t1)
	{
	    return merror(TLS_ECERT,
		    "%s: certificate %u of %u is not yet activated", 
		    error_msg, i + 1, cert_list_size);
	}
	if ((t2 = gnutls_x509_crt_get_expiration_time(cert)) < 0)
	{
	    return merror(TLS_ECERT,
		    "%s: cannot get expiration time for certificate "
		    "%u of %u", error_msg, i + 1, cert_list_size);
	}
	if (t2 < t1)
	{
	    return merror(TLS_ECERT, "%s: certificate %u of %u has expired", 
		    error_msg, i + 1, cert_list_size);
	}
	gnutls_x509_crt_deinit(cert);
    }

    return merror(EOK, NULL);

#endif /* HAVE_GNUTLS */

#ifdef HAVE_OPENSSL
    X509 *x509cert;
    X509_NAME *x509_subject;
    long status;
    char buf[257];
    int cmp;
    char *error_msg;
    
    if (verify)
    {
	error_msg = "TLS certificate verification failed";
    }
    else
    {
	error_msg = "TLS certificate check failed";
    }

    /* Get certificate */
    if (!(x509cert = SSL_get_peer_certificate(tls->ssl)))
    {
	return merror(TLS_ECERT, "%s: no certificate was sent", error_msg);
    }
    
    /* Get result of OpenSSL's default verify function */
    if ((status = SSL_get_verify_result(tls->ssl)) != X509_V_OK)
    {
	if (!(!verify && status == X509_V_ERR_UNABLE_TO_GET_ISSUER_CERT_LOCALLY))
	{
	    X509_free(x509cert);
    	    return merror(TLS_ECERT, 
		    "%s: %s", error_msg, X509_verify_cert_error_string(status));
	}
    }

    /* Check if hostname matches certificate Common Name (CN) */
    if (!(x509_subject = X509_get_subject_name(x509cert)))
    {
	X509_free(x509cert);
	return merror(TLS_ECERT,
		"%s: cannot get certificate subject", error_msg);
    }
    if (X509_NAME_get_text_by_NID(
		x509_subject, NID_commonName, buf, sizeof(buf)) == -1)
    {
	X509_free(x509cert);
	return merror(TLS_ECERT, 
		"%s: cannot get certificate common name", error_msg);
    }
    /* left-most character may be wildcard */
    if (buf[0] == '*' && strlen(buf + 1) <= strlen(hostname))
    {
	cmp = strcasecmp(buf + 1, hostname + strlen(hostname) - strlen(buf + 1));
    }
    else
    {
	cmp = strcasecmp(buf, hostname);
    }
    if (cmp != 0)
    {
	X509_free(x509cert);
	return merror(TLS_ECERT, 
		"%s: the certificate's owner does not match hostname %s",
		error_msg, hostname);
    }
    X509_free(x509cert);
    return merror(EOK, NULL);
#endif /* HAVE_OPENSSL */
}


/*
 * openssl_io_error()
 *
 * Used only internally by the OpenSSL code.
 * 
 * Construct an error line according to 'error_code' (which originates from an
 * SSL_read(), SSL_write() or SSL_connect() operation) and 'error_code2' (which
 * originates from an SSL_get_error() call with 'error_code' as its argument).
 * The line will read: "error_string: error_reason". 'error_string' is given by
 * the calling function, this function finds out 'error_reason'.
 * Returns a pointer to a static string.
 */

#ifdef HAVE_OPENSSL

/* OpenSSL error strings are max 120 characters long according to
 * ERR_error_string(3). The rest is for our own error string. */
#define OPENSSL_ERRSTR_BUFSIZE 256
char _openssl_errstr_buf[OPENSSL_ERRSTR_BUFSIZE];

char *openssl_io_error(int error_code, int error_code2, char *error_string)
{
    unsigned long error_code3;
    char *error_reason;
    
    switch (error_code2)
    {
	case SSL_ERROR_SYSCALL:
	    error_code3 = ERR_get_error();
    	    if (error_code3 == 0)
	    {
		if (error_code == 0)
		{
		    error_reason = "a protocol violating EOF occured";
		}
		else if (error_code == -1)
		{
		    error_reason = strerror(errno);
		}
		else
		{
		    error_reason = "unknown error";
		}
	    }
	    else
	    {
		error_reason = ERR_error_string(error_code3, NULL);
	    }
	    break;
	    
	case SSL_ERROR_ZERO_RETURN:
	    error_reason = "the connection has been closed unexpectedly";
	    break;

	case SSL_ERROR_SSL:
	    error_reason = ERR_error_string(ERR_get_error(), NULL);
	    break;
	    
	default:
	    /* probably SSL_ERROR_NONE */
    	    error_reason = "unknown error";
	    break;
    }
    snprintf(_openssl_errstr_buf, OPENSSL_ERRSTR_BUFSIZE, 
	    "%s: %s", error_string, error_reason);
    return _openssl_errstr_buf;
}
#endif /* HAVE_OPENSSL */


/*
 * tls_init()
 *
 * see tls.h
 */

merror_t tls_init(tls_t *tls, char *key_file, char *cert_file, char *trust_file)
{
#ifdef HAVE_GNUTLS
    int error_code;
    
    if ((error_code = gnutls_init(&tls->session, GNUTLS_CLIENT)) != 0)
    {
	return merror(TLS_ELIBFAILED, "cannot initialize TLS Session: %s",
		gnutls_strerror(error_code));
    }
    if ((error_code = gnutls_set_default_priority(tls->session)) != 0)
    {
	gnutls_deinit(tls->session);
	return merror(TLS_ELIBFAILED, "cannot set priorities on TLS Session: %s",
		gnutls_strerror(error_code));
    }
    if ((error_code = gnutls_certificate_allocate_credentials(&tls->cred)) < 0)
    {
	gnutls_deinit(tls->session);
	return merror(TLS_ELIBFAILED, "cannot allocate certificate for TLS Session: %s",
		gnutls_strerror(error_code));
    }
    if (key_file && cert_file)
    {
	if ((error_code = gnutls_certificate_set_x509_key_file(tls->cred, 
		       	cert_file, key_file, GNUTLS_X509_FMT_PEM)) < 0)
	{
	    gnutls_deinit(tls->session);
	    gnutls_certificate_free_credentials(tls->cred);
	    return merror(TLS_EFILE, "cannot set X509 key file %s and/or "
		    "X509 cert file %s for TLS Session: %s",
	    	    key_file, cert_file, gnutls_strerror(error_code));
	}
    }
    if (trust_file)
    {
	if ((error_code = gnutls_certificate_set_x509_trust_file(tls->cred, 
			trust_file, GNUTLS_X509_FMT_PEM)) <= 0)
	{
	    gnutls_deinit(tls->session);
	    gnutls_certificate_free_credentials(tls->cred);
	    return merror(TLS_EFILE, 
		    "cannot set X509 trust file %s for TLS Session: %s",
	    	    trust_file, gnutls_strerror(error_code));
	}
	tls->have_trust_file = 1;
    }
    if ((error_code = gnutls_credentials_set(tls->session, 
		    GNUTLS_CRD_CERTIFICATE, tls->cred)) < 0)
    {
	gnutls_deinit(tls->session);
	gnutls_certificate_free_credentials(tls->cred);
	return merror(TLS_ELIBFAILED, 
		"cannot set credentials for TLS Session: %s",
		gnutls_strerror(error_code));
    }
    return merror(EOK, NULL);
    
#endif /* HAVE_GNUTLS */

#ifdef HAVE_OPENSSL
    
    SSL_METHOD *ssl_method = NULL;
    
    if (!(ssl_method = SSLv23_client_method()))
    {
	return merror(TLS_ELIBFAILED, "cannot set TLS method");
    }
    tls->ssl_ctx = SSL_CTX_new(ssl_method);
    if (!tls->ssl_ctx)
    {
	return merror(TLS_ELIBFAILED, "cannot create TLS context: %s", 
		ERR_error_string(ERR_get_error(), NULL));
    }
    if (key_file && cert_file)
    {
	if (SSL_CTX_use_PrivateKey_file(tls->ssl_ctx, key_file, SSL_FILETYPE_PEM) != 1)
	{
	    SSL_CTX_free(tls->ssl_ctx);
	    tls->ssl_ctx = NULL;
	    return merror(TLS_EFILE, "cannot load key file %s: %s", 
		    key_file, ERR_error_string(ERR_get_error(), NULL));
	}
	if (SSL_CTX_use_certificate_chain_file(tls->ssl_ctx, cert_file) != 1)
	{
	    SSL_CTX_free(tls->ssl_ctx);
	    tls->ssl_ctx = NULL;
	    return merror(TLS_EFILE, "cannot load certificate file %s: %s", 
		    cert_file, ERR_error_string(ERR_get_error(), NULL));
	}
    }
    if (trust_file)
    {
	if (SSL_CTX_load_verify_locations(tls->ssl_ctx, trust_file, NULL) != 1)
	{
	    SSL_CTX_free(tls->ssl_ctx);
	    tls->ssl_ctx = NULL;
	    return merror(TLS_EFILE, "cannot load CA file %s: %s", 
		    trust_file, ERR_error_string(ERR_get_error(), NULL));
	}
	tls->have_trust_file = 1;
    }
    if (!(tls->ssl = SSL_new(tls->ssl_ctx)))
    {
	SSL_CTX_free(tls->ssl_ctx);
	tls->ssl_ctx = NULL;
	return merror(TLS_ELIBFAILED, "cannot create a TLS structure: %s", 
		ERR_error_string(ERR_get_error(), NULL));
    }
    return merror(EOK, NULL);

#endif /* HAVE_OPENSSL */
}


/*
 * tls_start()
 *
 * see tls.h
 */

merror_t tls_start(tls_t *tls, int fd, char *hostname, int no_certcheck)
{
#ifdef HAVE_GNUTLS
    int error_code;
    merror_t e;
    
    gnutls_transport_set_ptr(tls->session, (gnutls_transport_ptr)fd);
    if ((error_code = gnutls_handshake(tls->session)) < 0)
    {
	gnutls_deinit(tls->session);
	gnutls_certificate_free_credentials(tls->cred);
	return merror(TLS_EHANDSHAKE, 
		"TLS handshake failed: %s", gnutls_strerror(error_code));
    }
    if (!no_certcheck)
    {
	if (!merror_ok(e = tls_check_cert(tls, hostname, tls->have_trust_file)))
	{
	    gnutls_deinit(tls->session);
	    gnutls_certificate_free_credentials(tls->cred);
	    return e;
	}
    }    
    tls->is_active = 1;
    return merror(EOK, NULL);
#endif /* HAVE_GNUTLS */

#ifdef HAVE_OPENSSL
    int error_code;
    merror_t e;
    
    if (!SSL_set_fd(tls->ssl, fd))
    {
	SSL_free(tls->ssl);
	SSL_CTX_free(tls->ssl_ctx);
	return merror(TLS_ELIBFAILED, "cannot set the file descriptor for TLS: %s", 
		ERR_error_string(ERR_get_error(), NULL));
    }
    if ((error_code = SSL_connect(tls->ssl)) < 1)
    {
	SSL_free(tls->ssl);
	SSL_CTX_free(tls->ssl_ctx);
	return merror(TLS_EIO, "%s", openssl_io_error(error_code, 
		    SSL_get_error(tls->ssl, error_code),
		    "TLS handshake failed"));
    }
    if (!no_certcheck)
    {
	if (!merror_ok(e = tls_check_cert(tls, hostname, tls->have_trust_file)))
	{
	    SSL_free(tls->ssl);
	    SSL_CTX_free(tls->ssl_ctx);
	    return e;
	}
    }
    tls->is_active = 1;
    return merror(EOK, NULL);
#endif /* HAVE_OPENSSL */
}


/*
 * tls_getchar()
 *
 * see tls.h
 */

merror_t tls_getchar(tls_t *tls, char *c, int *eof)
{
#ifdef HAVE_GNUTLS
    ssize_t error_code;
    
    error_code = gnutls_record_recv(tls->session, c, 1);
    if (error_code == 1)
    {
	*eof = 0;
	return merror(EOK, NULL);
    }
    else if (error_code == 0)
    {
	*eof = 1;
	return merror(EOK, NULL);
    }
    else /* error_code < 0 */
    {
	return merror(TLS_EIO, "cannot read from TLS connection: %s", 
		gnutls_strerror(error_code));
    }
    
#endif /* HAVE_GNUTLS */

#ifdef HAVE_OPENSSL
    
    int error_code;
    int error_code2;
    
    if ((error_code = SSL_read(tls->ssl, c, 1)) < 1)
    {
	if ((error_code2 = SSL_get_error(tls->ssl, error_code)) 
		== SSL_ERROR_NONE)
	{
	    *eof = 1;
	    return merror(EOK, NULL);
	}
	else
	{
	    return merror(TLS_EIO, "%s", openssl_io_error(error_code, error_code2, 
			"cannot read from TLS connection"));
	}
    }
    else
    {
	*eof = 0;
	return merror(EOK, NULL);
    }
#endif /* HAVE_OPENSSL */
}


/*
 * tls_gets()
 *
 * see tls.h
 */

merror_t tls_gets(tls_t *tls, char *line, int size)
{
    char c;
    int i;
    int eof;
    merror_t e;

    i = 0;
    while (merror_ok(e = tls_getchar(tls, &c, &eof)) && !eof)
    {
	line[i++] = c;
	if (c == '\n' || i == size - 1)
	{
	    break;
	}
    }
    line[i] = '\0';

    return e;
}


/*
 * tls_puts()
 *
 * see tls.h
 */

merror_t tls_puts(tls_t *tls, char *s)
{
#ifdef HAVE_GNUTLS
    ssize_t error_code;
    size_t count;

    count = strlen(s);    
    if (count < 1)
    {
	/* nothing to be done */
	return merror(EOK, NULL);
    }

    error_code = gnutls_record_send(tls->session, s, count);
    if (error_code < 0)
    {
	return merror(TLS_EIO, "cannot write to TLS connection: %s", 
		gnutls_strerror(error_code));
    }
    else if ((size_t)error_code == count)
    {
	return merror(EOK, NULL);
    }
    else /* 0 <= error_code < count */
    {
	return merror(TLS_EIO, "cannot write to TLS connection: unknown error");
    }    

#endif /* HAVE_GNUTLS */

#ifdef HAVE_OPENSSL
    
    int error_code;
    size_t count;

    count = strlen(s);
    if (count < 1)
    {
	/* nothing to be done */
	return merror(EOK, NULL);
    }
    
    if ((error_code = SSL_write(tls->ssl, s, (int)count)) != (int)count)
    {
	return merror(TLS_EIO, "%s", openssl_io_error(error_code, 
		    SSL_get_error(tls->ssl, error_code),
		    "cannot write to TLS connection"));
    }

    return merror(EOK, NULL);
#endif /* HAVE_OPENSSL */
}


/*
 * tls_close()
 *
 * see tls.h
 */

void tls_close(tls_t *tls)
{
    if (tls->is_active)
    {
#ifdef HAVE_GNUTLS
	gnutls_bye(tls->session, GNUTLS_SHUT_RDWR);
	gnutls_deinit(tls->session);
	gnutls_certificate_free_credentials(tls->cred);
#endif /* HAVE_GNUTLS */
#ifdef HAVE_OPENSSL
 	SSL_shutdown(tls->ssl);
	SSL_free(tls->ssl);
	SSL_CTX_free(tls->ssl_ctx);
#endif /* HAVE_OPENSSL */
    }
    tls_clear(tls);
}


/*
 * tls_lib_deinit()
 *
 * see tls.h
 */

void tls_lib_deinit(void)
{
#ifdef HAVE_GNUTLS
    gnutls_global_deinit();
#endif /* HAVE_GNUTLS */
}
