/*
    BFilter - a smart ad-filtering web proxy
    Copyright (C) 2002-2006  Joseph Artsimovich <joseph_a@mail.ru>

    This program is free software; you can redistribute it and/or modify
    it under the terms of the GNU General Public License as published by
    the Free Software Foundation; either version 2 of the License, or
    (at your option) any later version.

    This program is distributed in the hope that it will be useful,
    but WITHOUT ANY WARRANTY; without even the implied warranty of
    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    GNU General Public License for more details.

    You should have received a copy of the GNU General Public License
    along with this program; if not, write to the Free Software
    Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
*/

#include "pch.h"

#ifdef HAVE_CONFIG_H
#include <config.h>
#endif

#include "AsyncConnector.h"
#include "AsyncConnectorListener.h"
#include "Reactor.h"
#include "InetAddr.h"
#include "SymbolicInetAddr.h"
#include "IntrusivePtr.h"
#include "DnsResolver.h"
#include "MonotonicTimer.h"
#include "TimeStamp.h"
#include "TimeDelta.h"
#include "ServerReachabilityDB.h"
#include "AutoClosingSAP.h"
#include <ace/config-lite.h>
#include <ace/SOCK_Stream.h>
#include <ace/OS_NS_sys_socket.h>
#include <boost/static_assert.hpp>
#include <cassert>
#include <stddef.h>
#include <errno.h>

using namespace std;

class AsyncConnector::IpRecord
{
public:
	InetAddr resolved_addr;
	ServerReachability reachability;
	ReactorHandlerId io_handler_id;
	ReactorTimerId timer_id;
	ACE_SOCK_Stream peer;
	TimeStamp conn_start_time;
	int remaining_conn_attempts;
	
	IpRecord(InetAddr const& resolved_addr,
		ServerReachability const& reachability);
	
	~IpRecord();
	
	bool isInProgress() const { return peer.get_handle() != ACE_INVALID_HANDLE; }
	
	bool isPaused() const { return !isInProgress() && timer_id; }
};


class AsyncConnector::IpRecordQualityCompare
{
public:
	bool operator()(IpRecord const& lhs, IpRecord const& rhs) const;
};


AsyncConnector::AsyncConnector()
:	m_isInProgress(false),
	m_pReactor(0),
	m_refCounter(0)
{
}

AsyncConnector::~AsyncConnector()
{
	abort();
}

void
AsyncConnector::initiate(Listener& listener, Reactor& reactor,
	SymbolicInetAddr const& addr, TimeDelta const* timeout)
{
	abort();
	m_observerLink.setObserver(&listener);
	m_isInProgress = true;
	
	vector<InetAddr> resolved_addrs(DnsResolver::resolve(addr));
	if (resolved_addrs.empty()) {
		handleConnFailure(AsyncConnectorError::NAME_RESOLUTION_FAILED, 0);
		return;
	}
	
	vector<ServerReachability> reachability(
		ServerReachabilityDB::instance()->get(resolved_addrs)
	);
	
	vector<IpRecord> ip_records;
	ip_records.reserve(resolved_addrs.size());
	for (size_t i = 0; i < resolved_addrs.size(); ++i) {
		ip_records.push_back(IpRecord(resolved_addrs[i], reachability[i]));
	}
	
	m_ipRecords.swap(ip_records);
	m_pReactor = &reactor;
	m_timeoutTimerId = m_pReactor->registerTimer(
		IntrusivePtr<EventHandlerBase>(this), timeout
	);
	
	nextConnectionAttempt();
}

void
AsyncConnector::abort()
{
	if (m_isInProgress) {
		m_observerLink.setObserver(0);
		clearIpRecords();
		clearMainTimeout();
		m_lastError.clear();
		m_isInProgress = false;
	}
	assert(m_refCounter == 0); // otherwise we are leaking reactor registrations
}

void
AsyncConnector::handleRead(ACE_HANDLE handle)
{
	handleCompletion(handle);
}

void
AsyncConnector::handleWrite(ACE_HANDLE handle)
{
	handleCompletion(handle);
}

void
AsyncConnector::handleExcept(ACE_HANDLE handle)
{
	handleCompletion(handle);
}

void
AsyncConnector::handleTimeout(ReactorTimerId const& timer_id)
{
	if (timer_id == m_timeoutTimerId) {
		Error err = m_lastError;
		markInProgressConnectionsAsUndefined();
		if (!err.isSet()) {
			err.set(AsyncConnectorError::CONNECTION_TIMEOUT, ETIMEDOUT);
		}
		handleConnFailure(err.code, err.errnum);
	} else {
		IpRecord* ip_rec = findIpRecordByTimerId(timer_id);
		assert(ip_rec);
		if (m_pReactor) {
			m_pReactor->unregisterTimer(ip_rec->timer_id);
			ip_rec->timer_id = ReactorTimerId();
		}
		nextConnectionAttempt();
	}
}

void
AsyncConnector::ref()
{
	++m_refCounter;
}

void
AsyncConnector::unref()
{
	if (--m_refCounter == 0) {
		m_pReactor = 0;
	}
}

void
AsyncConnector::nextConnectionAttempt()
{
	while (true) {
		int potential_candidates = 0;
		IpRecord* ip_rec = findCandidateForConnection(potential_candidates);
		if (!ip_rec) {
			if (potential_candidates == 0) {
				handleConnFailure(m_lastError.code, m_lastError.errnum);
			}
			return;
		}
		
		if (initiateConnection(*ip_rec)) {
			return;
		}
	}
}

bool
AsyncConnector::initiateConnection(IpRecord& ip_rec)
{
	assert(ip_rec.peer.get_handle() == ACE_INVALID_HANDLE);
	assert(!ip_rec.io_handler_id);
	assert(!ip_rec.timer_id);
	assert(ip_rec.remaining_conn_attempts > 0);
	
	if (!m_pReactor) {
		m_lastError.set(AsyncConnectorError::GENERIC_ERROR, 0);
		return false;
	}
	
	ip_rec.conn_start_time = MonotonicTimer::getTimestamp();
	--ip_rec.remaining_conn_attempts;
	
	
	if (ip_rec.peer.open(SOCK_STREAM, ip_rec.resolved_addr.get_type(), 0, 0) == -1
	    || ip_rec.peer.enable(ACE_NONBLOCK) == -1) {
		ip_rec.peer.close();
		m_lastError.set(AsyncConnectorError::GENERIC_ERROR, errno);
		return false;
	}
	
	try {
		IntrusivePtr<EventHandlerBase> handler(this);
		ip_rec.io_handler_id = m_pReactor->registerHandler(
			ip_rec.peer.get_handle(), handler, Reactor::ALL_EVENTS
		);
		TimeDelta timeout = TimeDelta::fromMsec(NEXT_ATTEMPT_TIMEOUT);
		ip_rec.timer_id = m_pReactor->registerTimer(handler, &timeout);
	} catch (Reactor::Exception const&) {
		abortConnection(ip_rec, AsyncConnectorError::GENERIC_ERROR, 0);
		return false;
	}
	
	sockaddr* sa = reinterpret_cast<sockaddr*>(ip_rec.resolved_addr.get_addr());
	if (ACE_OS::connect(ip_rec.peer.get_handle(), sa, ip_rec.resolved_addr.get_size()) == -1) {
		int& errno_ref = errno;
		if (errno_ref != EINPROGRESS && errno_ref != EWOULDBLOCK) {
			abortConnection(ip_rec, AsyncConnectorError::GENERIC_ERROR, errno_ref);
			return false;
		}
	}
	
	return true;
}

void
AsyncConnector::abortConnection(IpRecord& ip_rec, Error::Code err_code, int err_num)
{
	assert(m_pReactor);
	
	m_lastError.set(err_code, err_num);
	
	bool const can_retry = (
		ip_rec.remaining_conn_attempts > 0 &&
		ip_rec.reachability.getConnectResult() != ServerReachability::FAILED &&
		(err_code == AsyncConnectorError::CONNECTION_REFUSED ||
		 err_code == AsyncConnectorError::DESTINATION_UNREACHABLE)
	);
	
	updateReachabilityInfo(ip_rec, ServerReachability::FAILED);
	
	if (ip_rec.io_handler_id) {
		m_pReactor->unregisterHandler(ip_rec.io_handler_id);
		ip_rec.io_handler_id = ReactorHandlerId();
	}
	ip_rec.peer.close();
	
	if (can_retry) {
		TimeDelta timeout(TimeDelta::fromMsec(PAUSE_BEFORE_RETRY));
		if (ip_rec.timer_id) {
			m_pReactor->rescheduleTimer(ip_rec.timer_id, &timeout);
		} else  {
			IntrusivePtr<EventHandlerBase> handler(this);
			ip_rec.timer_id = m_pReactor->registerTimer(handler, &timeout);
		}
	} else {
		if (ip_rec.timer_id) {
			m_pReactor->unregisterTimer(ip_rec.timer_id);
			ip_rec.timer_id = ReactorTimerId();
		}
		ip_rec.remaining_conn_attempts = 0;
	}
}

AsyncConnector::IpRecord*
AsyncConnector::findCandidateForConnection(int& potential_candidates)
{
	IpRecordQualityCompare compare;
	IpRecord* best_candidate = 0;
	int num_potential = 0;
	vector<IpRecord>::iterator it = m_ipRecords.begin();
	vector<IpRecord>::iterator const end = m_ipRecords.end();
	for (; it != end; ++it) {
		if (it->isInProgress() || it->timer_id) {
			++num_potential;
			continue;
		}
		if (it->remaining_conn_attempts <= 0) {
			continue;
		}
		if (!best_candidate || compare(*it, *best_candidate)) {
			best_candidate = &*it;
		}
	}
	potential_candidates = num_potential;
	return best_candidate;
}

AsyncConnector::IpRecord*
AsyncConnector::findIpRecordByHandle(ACE_HANDLE handle)
{
	vector<IpRecord>::iterator it = m_ipRecords.begin();
	vector<IpRecord>::iterator const end = m_ipRecords.end();
	for (; it != end; ++it) {
		if (it->peer.get_handle() == handle) {
			return &*it;
		}
	}
	return 0;
}

AsyncConnector::IpRecord*
AsyncConnector::findIpRecordByTimerId(ReactorTimerId const& timer_id)
{
	vector<IpRecord>::iterator it = m_ipRecords.begin();
	vector<IpRecord>::iterator const end = m_ipRecords.end();
	for (; it != end; ++it) {
		if (it->timer_id == timer_id) {
			return &*it;
		}
	}
	return 0;
}

void
AsyncConnector::clearIpRecords()
{
	vector<IpRecord>::iterator it = m_ipRecords.begin();
	vector<IpRecord>::iterator const end = m_ipRecords.end();
	for (; it != end; ++it) {
		if (m_pReactor) {
			if (it->io_handler_id) {
				m_pReactor->unregisterHandler(it->io_handler_id);
			}
			if (it->timer_id) {
				m_pReactor->unregisterTimer(it->timer_id);
			}
		}
		it->peer.close();
	}
	m_ipRecords.clear();
}

void
AsyncConnector::clearMainTimeout()
{
	if (m_pReactor && m_timeoutTimerId) {
		m_pReactor->unregisterTimer(m_timeoutTimerId);
		m_timeoutTimerId = ReactorTimerId();
	}
}

void
AsyncConnector::handleCompletion(ACE_HANDLE handle)
{
	IpRecord* ip_rec = findIpRecordByHandle(handle);
	assert(ip_rec);

	AsyncConnectorError::Code errcode = AsyncConnectorError::GENERIC_ERROR;
	int err = 0;
	int errlen = sizeof(err);
	if (ip_rec->peer.get_option(SOL_SOCKET, SO_ERROR, &err, &errlen) == 0) {
		switch (err) {
			case 0:
				onConnEstablished(*ip_rec);
				return;
			case ETIMEDOUT:
				errcode = AsyncConnectorError::CONNECTION_TIMEOUT;
				break;
			case ECONNREFUSED:
				errcode = AsyncConnectorError::CONNECTION_REFUSED;
				break;
			case ENETUNREACH:
			case EHOSTUNREACH:
			case EHOSTDOWN:
				errcode = AsyncConnectorError::DESTINATION_UNREACHABLE;
				break;
		}
	} else {
		err = errno;
	}
	
	abortConnection(*ip_rec, errcode, err);
	
	nextConnectionAttempt();
}

void
AsyncConnector::updateReachabilityInfo(
	IpRecord& ip_rec, ServerReachability::ConnectResult res)
{
	updateReachabilityInfo(ip_rec, res, MonotonicTimer::getTimestamp());
}

void
AsyncConnector::updateReachabilityInfo(
	IpRecord& ip_rec, ServerReachability::ConnectResult res,
	TimeStamp const& timestamp)
{
	TimeDelta elapsed(timestamp - ip_rec.conn_start_time);
	unsigned long msec = 0;
	if (elapsed > TimeDelta::zero()) {
		msec = elapsed.toMsec();
	}
	
	if (res == ServerReachability::UNDEFINED && msec < 1000) {
		// Such data would be useless.
		return;
	}
	
	ip_rec.reachability = ServerReachability(msec, res);
	ServerReachabilityDB::instance()->put(
		ip_rec.resolved_addr, ip_rec.reachability
	);
}

void
AsyncConnector::markInProgressConnectionsAsUndefined()
{
	TimeStamp timestamp = MonotonicTimer::getTimestamp();
	vector<IpRecord>::iterator it = m_ipRecords.begin();
	vector<IpRecord>::iterator const end = m_ipRecords.end();
	for (; it != end; ++it) {
		if (it->isInProgress()) {
			updateReachabilityInfo(*it, ServerReachability::UNDEFINED, timestamp);
		}
	}
}

void
AsyncConnector::onConnEstablished(IpRecord& ip_rec)
{
	AutoClosingSAP<ACE_SOCK_Stream> stream(ip_rec.peer);
	assert(ip_rec.peer.get_handle() == ACE_INVALID_HANDLE);
	
	markInProgressConnectionsAsUndefined();
	updateReachabilityInfo(ip_rec, ServerReachability::CONNECTED);
	
	Listener* listener = m_observerLink.getObserver();
	abort(); // this will detach the observer
	
	if (listener) {
		listener->onConnectionEstablished(stream);
	}
}

void
AsyncConnector::handleConnFailure(AsyncConnectorError::Code errcode, int)
{
	Listener* listener = m_observerLink.getObserver();
	abort(); // this will detach the listener
	if (listener) {
		listener->onConnectionFailed(errcode);
	}
}


/*======================= AsyncConnector::IpRecord ========================*/

AsyncConnector::IpRecord::IpRecord(
	InetAddr const& resolved_addr, ServerReachability const& reachability)
:	resolved_addr(resolved_addr),
	reachability(reachability),
	remaining_conn_attempts(2)
{
}

AsyncConnector::IpRecord::~IpRecord()
{
}

/*================= AsyncConnector::IpRecordQualityCompare ================*/

bool
AsyncConnector::IpRecordQualityCompare::operator()(
	IpRecord const& lhs, IpRecord const& rhs) const
{
	// returns true if lhs is better than rhs
	
	if (lhs.remaining_conn_attempts < rhs.remaining_conn_attempts) {
		return false;
	} else if (lhs.remaining_conn_attempts > rhs.remaining_conn_attempts) {
		return true;
	}
	
	typedef ServerReachability SR;
	SR::ConnectResult lhs_conn_res = lhs.reachability.getConnectResult();
	SR::ConnectResult rhs_conn_res = rhs.reachability.getConnectResult();
	long lhs_timeout = lhs.reachability.getConnectTime();
	long rhs_timeout = rhs.reachability.getConnectTime();
	
	static long const handicap_matrix[3][3] = {
		/* UNDEF  CONN  FAILED */
		{     0,  -5000,  8000 }, // UNDEFINED
		{  5000,      0, 30000 }, // CONNECTED
		{ -8000, -30000,     0 }  // FAILED
	};
	BOOST_STATIC_ASSERT(SR::UNDEFINED == 0);
	BOOST_STATIC_ASSERT(SR::CONNECTED == 1);
	BOOST_STATIC_ASSERT(SR::FAILED == 2);
	
	long handicap = handicap_matrix[lhs_conn_res][rhs_conn_res];
	return lhs_timeout < rhs_timeout + handicap;
}
