# Copyright 2009 Canonical Ltd.
#
# This file is part of desktopcouch.
#
#  desktopcouch is free software: you can redistribute it and/or modify
# it under the terms of the GNU Lesser General Public License version 3
# as published by the Free Software Foundation.
#
# desktopcouch is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with desktopcouch.  If not, see <http://www.gnu.org/licenses/>.

"""All inter-tool communication."""

import logging
import hashlib

from twisted.internet import reactor
from twisted.internet.protocol import ServerFactory, ReconnectingClientFactory
from twisted.protocols import basic

import dbus

try:
    from desktopcouch.application.pair.couchdb_pairing.dbus_io import (
        get_remote_hostname)
except ImportError:
    logging.exception("Can't import dbus_io, because avahi not installed?")
    get_remote_hostname = lambda addr: None  # pylint: disable=C0103


hash_fn = hashlib.sha512  # pylint: disable=C0103,E1101


def dict_to_bytes(dictionary):
    """Convert a dictionary of string key/values into a string."""
    parts = list()

    for key, value in dictionary.iteritems():
        assert isinstance(key, str), key
        length = len(key)
        parts.append(chr(length >> 8))
        parts.append(chr(length & 255))
        parts.append(key)

        assert isinstance(value, str), value
        length = len(value)
        parts.append(chr(length >> 8))
        parts.append(chr(length & 255))
        parts.append(value)

    blob = "".join(parts)
    length = len(blob)
    blob_size = list()
    blob_size.append(chr(length >> 24))
    blob_size.append(chr(length >> 16 & 255))
    blob_size.append(chr(length >> 8 & 255))
    blob_size.append(chr(length & 255))

    return "CMbydi0" + "".join(blob_size) + blob


def bytes_to_dict(bytestring):
    """Convert a string from C{dict_to_bytes} back into a dictionary."""
    if bytestring[:7] != "CMbydi0":
        raise ValueError(
            "magic bytes missing.  Invalid string.  %r", bytestring[:10])
    bytestring = bytestring[7:]

    blob_size = 0
    for char in bytestring[:4]:
        blob_size = (blob_size << 8) + ord(char)

    blob = bytestring[4:]
    if blob_size != len(blob):
        raise ValueError("bytes are corrupt; expected %d, got %d" % (blob_size,
            len(blob)))

    dictionary = {}
    blob_cursor = 0

    while blob_cursor < blob_size:
        k_len = (ord(blob[blob_cursor + 0]) << 8) + ord(blob[blob_cursor + 1])
        key = blob[blob_cursor + 2:blob_cursor + 2 + k_len]
        blob_cursor += k_len + 2
        v_len = (ord(blob[blob_cursor + 0]) << 8) + ord(blob[blob_cursor + 1])
        value = blob[blob_cursor + 2:blob_cursor + 2 + v_len]
        blob_cursor += v_len + 2
        dictionary[key] = value
    return dictionary


class ListenForInvitations():
    """Narrative "Alice".

    This is the first half of a TCP listening socket.  We spawn off
    processors when we accept invitation-connections."""

    def __init__(self, get_secret_from_user, on_close, hostid, oauth_data):
        """Initialize."""
        self.logging = logging.getLogger(self.__class__.__name__)

        self.factory = ProcessAnInvitationFactory(get_secret_from_user,
                on_close, hostid, oauth_data)
        # pylint: disable=E1101
        self.listening_port = reactor.listenTCP(0, self.factory)
        # pylint: enable=E1101

    def get_local_port(self):
        """We created a socket, and the caller needs to know what our port
        number is, so it can advertise it."""

        port = self.listening_port.getHost().port
        self.logging.info("local port to receive invitations is %s", port)
        return port

    def close(self):
        """Called from the UI when a window is destroyed and we do not need
        this connection any more."""
        self.listening_port.stopListening()


# FIXME: it looks like this class does not implement the interface
# correctly, or pylint is completely stupid.
# pylint: disable=W0223
class ProcessAnInvitationProtocol(basic.LineReceiver):
    """Narrative "Alice".

    Listen for messages, and when we receive one, call the display callback
    function with the inviter details plus a key."""

    def __init__(self):
        """Initialize."""
        self.logging = logging.getLogger(self.__class__.__name__)
        self.expected_hash = None
        self.public_seed = None

    # FIXME: remove camel case
    def connectionMade(self):           # pylint: disable=C0103
        """Called when a connection is made.  No obligation here."""
        basic.LineReceiver.connectionMade(self)

    # FIXME: remove camel case
    def connectionLost(self, reason):   # pylint: disable=W0222,C0103
        """Called when a connection is lost."""
        self.logging.debug("connection lost")
        basic.LineReceiver.connectionLost(self, reason)

    # FIXME: remove camel case
    def lineReceived(self, rich_message):  # pylint: disable=C0103
        """Handler for receipt of a message from the Bob end."""
        d = bytes_to_dict(rich_message)

        self.expected_hash = d.pop("secret_message")
        self.public_seed = d.pop("public_seed")
        remote_hostid = d.pop("hostid")
        remote_oauth = d
        # pylint: disable=E1101
        self.factory.get_secret_from_user(self.transport.getPeer().host,
                self.check_secret_from_user,
                self.send_secret_to_remote,
                remote_hostid, remote_oauth)
        # pylint: enable=E1101

    def send_secret_to_remote(self, secret_message):
        """A callback for the invitation protocol to start a new phase
        involving the other end getting the hash-digest of the public
        seed and a secret we receive as a parameter."""
        hashed = hash_fn()
        hashed.update(self.public_seed)
        hashed.update(secret_message)
        all_dict = dict()
        all_dict.update(self.factory.oauth_info)  # pylint: disable=E1101
        all_dict["hostid"] = self.factory.hostid  # pylint: disable=E1101
        all_dict["secret_message"] = hashed.hexdigest()
        self.sendLine(dict_to_bytes(all_dict))

    def check_secret_from_user(self, secret_message):
        """A callback for the invitation protocol to verify the secret
        that the user gives, against the hash we received over the
        network."""

        hashed = hash_fn()
        hashed.update(secret_message)
        digest = hashed.hexdigest()

        if digest == self.expected_hash:
            hashed = hash_fn()
            hashed.update(self.public_seed)
            hashed.update(secret_message)
            all_dict = dict()
            all_dict.update(self.factory.oauth_info)  # pylint: disable=E1101
            all_dict["hostid"] = self.factory.hostid  # pylint: disable=E1101
            all_dict["secret_message"] = hashed.hexdigest()
            self.sendLine(dict_to_bytes(all_dict))

            self.logging.debug("User knew secret!")

            self.transport.loseConnection()
            return True

        self.logging.info("User secret %r is wrong.", secret_message)
        return False
# pylint: enable=W0223


class ProcessAnInvitationFactory(ServerFactory):
    """Hold configuration values for all the connections, and fire off a
    protocol to handle the data sent and received."""

    protocol = ProcessAnInvitationProtocol

    def __init__(self, get_secret_from_user, on_close, hostid, oauth_info):
        self.logging = logging.getLogger(self.__class__.__name__)
        self.get_secret_from_user = get_secret_from_user
        self.on_close = on_close
        self.hostid = hostid
        self.oauth_info = oauth_info


class SendInvitationProtocol(basic.LineReceiver):  # pylint: disable=W0223
    """Narrative "Bob"."""

    def __init__(self):
        """Initialize."""
        self.logging = logging.getLogger(self.__class__.__name__)
        self.logging.debug("initialized")
        self.expected_hash_of_secret = None

    # FIXME: remove camel case
    def connectionMade(self):           # pylint: disable=C0103
        """Fire when a connection is made to the listener.  No obligation
        here."""
        self.logging.debug("connection made")

        hashed = hash_fn()
        hashed.update(self.factory.secret_message)  # pylint: disable=E1101
        d = dict(secret_message=hashed.hexdigest(),
                public_seed=self.factory.public_seed,  # pylint: disable=E1101
                hostid=self.factory.local_hostid)      # pylint: disable=E1101
        d.update(self.factory.local_oauth_info)        # pylint: disable=E1101
        self.sendLine(dict_to_bytes(d))

        hashed = hash_fn()
        hashed.update(self.factory.public_seed)  # pylint: disable=E1101
        hashed.update(self.factory.secret_message)  # pylint: disable=E1101
        self.expected_hash_of_secret = hashed.hexdigest()

    # FIXME: remove camel case
    def lineReceived(self, rich_message):  # pylint: disable=C0103
        """Handler for receipt of a message from the Alice end."""
        d = bytes_to_dict(rich_message)
        message = d.pop("secret_message")

        if message == self.expected_hash_of_secret:
            remote_host = self.transport.getPeer().host
            try:
                remote_hostname = get_remote_hostname(remote_host)
            except dbus.exceptions.DBusException:
                remote_hostname = None
            remote_hostid = d.pop("hostid")
            self.factory.auth_complete_cb(  # pylint: disable=E1101
                remote_hostname, remote_hostid, d)
            self.transport.loseConnection()
        else:
            self.logging.warn("Expected %r from invitation.",
                    self.expected_hash_of_secret)

    # FIXME: remove camel case
    def connectionLost(self, reason):   # pylint: disable=W0222,C0103
        """When a connected socked is broken, this is fired."""
        self.logging.info("connection lost.")
        basic.LineReceiver.connectionLost(self, reason)


class SendInvitationFactory(ReconnectingClientFactory):
    """Hold configuration values for all the connections, and fire off a
    protocol to handle the data sent and received."""

    protocol = SendInvitationProtocol

    def __init__(self, auth_complete_cb, secret_message, public_seed,
            on_close, local_hostid, local_oauth_info):
        self.logging = logging.getLogger(self.__class__.__name__)
        self.auth_complete_cb = auth_complete_cb
        self.secret_message = secret_message
        self.public_seed = public_seed
        self.on_close = on_close
        self.local_hostid = local_hostid
        self.local_oauth_info = local_oauth_info
        self.logging.debug("initialized")

    def close(self):
        """Called from the UI when a window is destroyed and we do not need
        this connection any more."""
        self.logging.warn("close not handled properly")  # FIXME

    # FIXME: remove camel case
    # pylint: disable=C0103
    def clientConnectionFailed(self, connector, reason):
        """When we fail to connect to the listener, this is fired."""
        self.logging.warn("connect failed. %s", reason)
        ReconnectingClientFactory.clientConnectionFailed(self, connector,
                reason)
    # pylint: enable=C0103

    # FIXME: remove camel case
    def clientConnectionLost(self, connector, reason):  # pylint: disable=C0103
        """When a connected socked is broken, this is fired."""
        self.logging.info("connection lost. %s", reason)
        ReconnectingClientFactory.clientConnectionLost(self, connector, reason)


def start_send_invitation(host, port, auth_complete_cb, secret_message,
        public_seed, on_close, local_hostid, local_oauth):
    """Instantiate the factory to hold configuration data about sending an
    invitation and let the reactor add it to its event-handling loop by way of
    starting a TCP connection."""
    factory = SendInvitationFactory(auth_complete_cb, secret_message,
            public_seed, on_close, local_hostid, local_oauth)
    reactor.connectTCP(host, port, factory)  # pylint: disable=E1101

    return factory
