# Copyright 2005 James Bunton <james@delx.cjb.net>
# Licensed for distribution under the GPL version 2, check COPYING for details

from twisted.internet import reactor
from twisted.internet.defer import Deferred
from twisted.protocols.basic import LineReceiver
from twisted.internet.protocol import ClientFactory
from twisted.python import log

from msn import checkParamLen, MSNMessage

from debug import LogEvent, INFO, WARN, ERROR
import config

import random

MAXAUTHCOOKIE = 2**32-1


class MSNFTReceive_Base:
	# Public
	def __init__(self, filename, filesize, userHandle):
		self.consumer = None
		self.finished = False
		self.error = False
		self.buffer = []
		self.filename, self.filesize, self.userHandle = filename, filesize, userHandle
	
	def removeMe(self):
		self.consumer = None
	
	def accept(self, yes=True):
		pass
	
	def writeTo(self, obj):
		self.consumer = obj
		for data in self.buffer:
			self.consumer.write(data)
		self.buffer = []
		if self.finished:
			self.consumer.close()
		if self.error:
			self.consumer.error()


	# Private
	def write(self, data):
		if self.consumer:
			self.consumer.write(data)
		else:
			self.buffer.append(data)
	
	def close(self):
		self.removeMe()
		self.finished = True
		if self.consumer:
			self.consumer.close()
	
	def gotError(self, ignored=None):
		self.removeMe()
		self.error = True
		if self.consumer:
			self.consumer.error()


class MSNFTP_Ports:
	def __init__(self):
		try:
			lowPort = int(config.ftLowPort)
			highPort = int(config.ftHighPort)
		except ValueError:
			LogEvent(ERROR, "", "Invalid values for ftLowPort & ftHighPort. Using 6891 & 6899 respectively")
			lowPort = 6891
			highPort = 6899
		self.ports = [lowPort+x for x in xrange(highPort-lowPort)]
		self.portFree = [True] * len(self.ports)
	
	def requestPort(self):
		for i in xrange(len(self.ports)):
			if self.portFree[i]:
				self.portFree[i] = False
				LogEvent(INFO, "", "Reserved a port")
				return self.ports[i]
		LogEvent(INFO, "", "Out of ports")

	def freePort(self, port):
		if self.ports.count(port) > 0:
			self.portFree[self.ports.index(port)] = True;
			LogEvent(INFO)

msnports = MSNFTP_Ports()

class MSNFTP_Receive(ClientFactory, MSNFTReceive_Base):
	def __init__(self, filename, filesize, userHandle, iCookie, connectivity, switchboard):
		MSNFTReceive_Base.__init__(self, filename, filesize, userHandle)
		self.iCookie = iCookie
		self.switchboard = switchboard
		self.serverSocket = None
		self.timeout = None
		self.authCookie = str(random.randint(1, MAXAUTHCOOKIE))
		self.port = None
		self.d = None
		LogEvent(INFO, self.switchboard.userHandle)
	
	def removeMe(self):
		if self.serverSocket:
			self.serverSocket.stopListening()
		if self.timeout and not self.timeout.called:
			self.timeout.cancel()
		if self.port:
			global msnports
			msnports.freePort(self.port)
			self.port = None
		if self.d:
			self.d.errback()
			self.d = None
		LogEvent(INFO, self.switchboard.userHandle)

	def accept(self, yes=True):
		LogEvent(INFO, self.switchboard.userHandle)
		global msnports
		self.port = msnports.requestPort()
		if not self.port:
			yes = False
			self.gotError()
		LogEvent(INFO, self.switchboard.userHandle)
	
		from msn import MSNMessage
		m = MSNMessage()
		m.setHeader('Content-Type', 'text/x-msmsgsinvite; charset=UTF-8')
		if yes:
			m.message += 'IP-Address: %s\r\n' % str(config.ip)
			m.message += 'Port: %s\r\n' % str(self.port)
			m.message += 'AuthCookie: %s\r\n' % self.authCookie
			m.message += 'Sender-Connect: TRUE\r\n'
			m.message += 'Invitation-Command: ACCEPT\r\n'
			m.message += 'Invitation-Cookie: %s\r\n' % str(self.iCookie)
		else:
			m.message += 'Invitation-Command: CANCEL\r\n'
			m.message += 'Cancel-Code: REJECT\r\n'
		m.message += 'Launch-Application: FALSE\r\n'
		m.message += 'Request-Data: IP-Address:\r\n'
		m.message += '\r\n'
		m.ack = m.MESSAGE_ACK_NONE
		self.switchboard.sendMessage(m)

		if yes:
			self.d = Deferred()
			self.serverSocket = reactor.listenTCP(self.port, self)
			self.timeout = reactor.callLater(20, self.gotError)
			return self.d
	
	def buildProtocol(self, addr):
		LogEvent(INFO, self.switchboard.userHandle)
		self.serverSocket.stopListening()
		self.serverSocket = None
		self.timeout.cancel()
		self.timeout = None
		self.d.callback(None)
		self.d = None
		return MSNFTP_FileReceive(self.authCookie, self.switchboard.userHandle, self)
		

class MSNFTP_FileReceive(LineReceiver):
	"""
	This class provides support for receiving files from contacts.

	@ivar fileSize: the size of the receiving file. (you will have to set this)
	@ivar connected: true if a connection has been established.
	@ivar completed: true if the transfer is complete.
	@ivar bytesReceived: number of bytes (of the file) received.
						 This does not include header data.
	"""

	def __init__(self, auth, myUserHandle, file, directory="", overwrite=0):
		"""
		@param auth: auth string received in the file invitation.
		@param myUserHandle: your userhandle.
		@param file: A string or file object represnting the file
					 to save data to.
		@param directory: optional parameter specifiying the directory.
						  Defaults to the current directory.
		@param overwrite: if true and a file of the same name exists on
						  your system, it will be overwritten. (0 by default)
		"""
		self.auth = auth
		self.myUserHandle = myUserHandle
		self.fileSize = 0
		self.connected = 0
		self.completed = 0
		self.directory = directory
		self.bytesReceived = 0
		self.overwrite = overwrite

		# used for handling current received state
		self.state = 'CONNECTING'
		self.segmentLength = 0
		self.buffer = ''
		
		if isinstance(file, str):
			path = os.path.join(directory, file)
			if os.path.exists(path) and not self.overwrite:
				log.msg('File already exists...')
				raise IOError, "File Exists" # is this all we should do here?
			self.file = open(os.path.join(directory, file), 'wb')
		else:
			self.file = file

	def connectionMade(self):
		self.connected = 1
		self.state = 'INHEADER'
		self.sendLine('VER MSNFTP')

	def connectionLost(self, reason):
		self.connected = 0
		self.file.close()

	def parseHeader(self, header):
		""" parse the header of each 'message' to obtain the segment length """

		if ord(header[0]) != 0: # they requested that we close the connection
			self.transport.loseConnection()
			return
		try:
			extra, factor = header[1:]
		except ValueError:
			# munged header, ending transfer
			self.transport.loseConnection()
			raise
		extra  = ord(extra)
		factor = ord(factor)
		return factor * 256 + extra

	def sendLine(self, line):
		log.msg("SENDING LINE!!!    " + line)
		LineReceiver.sendLine(self, line)

	def lineReceived(self, line):
		temp = line.split(' ')
		if len(temp) == 1: params = []
		else: params = temp[1:]
		cmd = temp[0]
		log.msg("GOT A LINE!!!	  " + line)
		handler = getattr(self, "handle_%s" % cmd.upper(), None)
		if handler: handler(params) # try/except
		else: self.handle_UNKNOWN(cmd, params)

	def rawDataReceived(self, data):
		bufferLen = len(self.buffer)
		log.msg("RAW DATA: " + data)
		if self.state == 'INHEADER':
			delim = 3-bufferLen
			self.buffer += data[:delim]
			if len(self.buffer) == 3:
				self.segmentLength = self.parseHeader(self.buffer)
				if not self.segmentLength: return # hrm
				self.buffer = ""
				self.state = 'INSEGMENT'
			extra = data[delim:]
			if len(extra) > 0: self.rawDataReceived(extra)
			return

		elif self.state == 'INSEGMENT':
			dataSeg = data[:(self.segmentLength-bufferLen)]
			self.buffer += dataSeg
			self.bytesReceived += len(dataSeg)
			if len(self.buffer) == self.segmentLength:
				self.gotSegment(self.buffer)
				self.buffer = ""
				if self.bytesReceived == self.fileSize:
					self.completed = 1
					self.buffer = ""
					self.file.close()
					self.sendLine("BYE 16777989")
					return
				self.state = 'INHEADER'
				extra = data[(self.segmentLength-bufferLen):]
				if len(extra) > 0: self.rawDataReceived(extra)
				return

	def handle_VER(self, params):
		checkParamLen(len(params), 1, 'VER')
		if params[0].upper() == "MSNFTP":
			self.sendLine("USR %s %s" % (self.myUserHandle, self.auth))
		else:
			log.msg('they sent the wrong version, time to quit this transfer')
			self.transport.loseConnection()

	def handle_FIL(self, params):
		checkParamLen(len(params), 1, 'FIL')
		try:
			self.fileSize = int(params[0])
		except ValueError: # they sent the wrong file size - probably want to log this
			self.transport.loseConnection()
			return
		self.setRawMode()
		self.sendLine("TFR")

	def handle_UNKNOWN(self, cmd, params):
		log.msg('received unknown command (%s), params: %s' % (cmd, params))

	def gotSegment(self, data):
		""" called when a segment (block) of data arrives. """
		self.file.write(data)


class MSNFTP_FileSend(LineReceiver):
    """
    This class provides support for sending files to other contacts.

    @ivar bytesSent: the number of bytes that have currently been sent.
    @ivar completed: true if the send has completed.
    @ivar connected: true if a connection has been established.
    @ivar targetUser: the target user (contact).
    @ivar segmentSize: the segment (block) size.
    @ivar auth: the auth cookie (number) to use when sending the
                transfer invitation
    """
    
    def __init__(self, file):
        """
        @param file: A string or file object represnting the file to send.
        """

        if isinstance(file, str) or isinstance(file, unicode):
            self.file = open(file, 'rb')
        else:
            self.file = file

        self.fileSize = 0
        self.bytesSent = 0
        self.completed = 0
        self.connected = 0
        self.targetUser = None
        self.segmentSize = 2045
        self.auth = random.randint(0, 2**30)
        self._pendingSend = None # :(

    def connectionMade(self):
        self.connected = 1

    def connectionLost(self, reason):
        if self._pendingSend:
            self._pendingSend.cancel()
            self._pendingSend = None
        self.connected = 0
        self.file.close()

    def lineReceived(self, line):
        temp = line.split(' ')
        if len(temp) == 1: params = []
        else: params = temp[1:]
        cmd = temp[0]
        handler = getattr(self, "handle_%s" % cmd.upper(), None)
        if handler: handler(params)
        else: self.handle_UNKNOWN(cmd, params)

    def handle_VER(self, params):
        checkParamLen(len(params), 1, 'VER')
        if params[0].upper() == "MSNFTP":
            self.sendLine("VER MSNFTP")
        else: # they sent some weird version during negotiation, i'm quitting.
            self.transport.loseConnection()

    def handle_USR(self, params):
        checkParamLen(len(params), 2, 'USR')
        self.targetUser = params[0]
        if self.auth == int(params[1]):
            self.sendLine("FIL %s" % (self.fileSize))
        else: # they failed the auth test, disconnecting.
            self.transport.loseConnection()

    def handle_TFR(self, params):
        checkParamLen(len(params), 0, 'TFR')
        # they are ready for me to start sending
        self.sendPart()

    def handle_BYE(self, params):
        self.completed = (self.bytesSent == self.fileSize)
        self.transport.loseConnection()

    def handle_CCL(self, params):
        self.completed = (self.bytesSent == self.fileSize)
        self.transport.loseConnection()

    def handle_UNKNOWN(self, cmd, params): log.msg('received unknown command (%s), params: %s' % (cmd, params))

    def makeHeader(self, size):
        """ make the appropriate header given a specific segment size. """
        quotient, remainder = divmod(size, 256)
        return chr(0) + chr(remainder) + chr(quotient)

    def sendPart(self):
        """ send a segment of data """
        if not self.connected:
            self._pendingSend = None
            return # may be buggy (if handle_CCL/BYE is called but self.connected is still 1)
        data = self.file.read(self.segmentSize)
        if data:
            dataSize = len(data)
            header = self.makeHeader(dataSize)
            self.transport.write(header + data)
            self.bytesSent += dataSize
            self._pendingSend = reactor.callLater(0, self.sendPart)
        else:
            self._pendingSend = None
            self.completed = 1
