#
# Copyright 2001 by Object Craft P/L, Melbourne, Australia.
#
# LICENCE - see LICENCE file distributed with this software for details.
#

import socket
import base64

try:
    import Cookie
except ImportError:
    pass
try:
    import zlib
    have_zlib = 1
except ImportError:
    have_zlib = 0

from albatross.context import SessionBase
from albatross.common import *


# Simple session server
class SessionServerContextMixin(SessionBase):

    def __init__(self):
        SessionBase.__init__(self)
        self.__sesid = None

    def _get_sesid_from_cookie(self):
        hdr = self.request.get_header('Cookie')
        if hdr:
            c = Cookie.SimpleCookie(hdr)
            try:
                return c[self.app.ses_appid()].value
            except KeyError:
                pass
        return None

    def sesid(self):
        return self.__sesid

    def load_session(self):
        sesid = self._get_sesid_from_cookie()
        text = None
        if sesid:
            text = self.app.get_session(sesid)
        if not text:
            sesid = self.app.new_session()
        self.__sesid = sesid
        if text:
            text = base64.decodestring(text)
            try:
                if have_zlib:
                    text = zlib.decompress(text)
                self.decode_session(text)
            except:
                self.app.del_session(sesid)
                raise
        c = Cookie.SimpleCookie()
        c[self.app.ses_appid()] = self.__sesid
        prefix_len = len('Set-Cookie: ')
        self.set_header('Set-Cookie', str(c)[prefix_len:])

    def save_session(self):
        if self.should_save_session():
            text = self.encode_session()
            if have_zlib:
                text = zlib.compress(text)
            text = base64.encodestring(text)
            self.app.put_session(self.__sesid, text)

    def remove_session(self):
        SessionBase.remove_session(self)
        if self.__sesid is not None:
            self.app.del_session(self.__sesid)


class SessionServerAppMixin:

    def __init__(self, appid, server = 'localhost', port = 34343, age = 1800):
        self.__appid = appid
        self.__server = server
        self.__port = port
        self.__sock = None
        self.__age = age
        try:
            self._server_connect()
        except ServerError:
            pass

    def ses_appid(self):
        return self.__appid

    def _server_connect(self):
        if self.__sock:
            return
        try:
            sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
            sock.connect((self.__server, self.__port))
            self.__sock = sock
        except socket.error:
            raise ServerError('could not connect to session server')

    def _server_read(self):
        try:
            text = self.__sock.recv(16384)
            if text:
                return text
        except socket.error:
            pass
        self.__sock = None
        raise ServerError('session server disconnect')

    def _server_write(self, text):
        try:
            self.__sock.send(text)
            return
        except socket.error:
            pass
        self.__sock = None
        raise ServerError('session server disconnect')

    def _server_read_response(self):
        resp = self._server_read()
        if not resp.startswith('OK'):
            raise ServerError('Session server returned: %s' % resp)
        return resp[2:].strip()

    def get_session(self, sesid):
        self._server_connect()
        self._server_write('get %s %s\r\n' % (self.__appid, sesid))
        text = ''
        while 1:
            text = text + self._server_read()
            pos = text.find('\r\n')
            if pos >= 0:
                if not text.startswith('OK'):
                    return
                text = text[pos + 2:]
                break
        while 1:
            if text.endswith('\r\n\r\n'):
                return text
            text = text + self._server_read()

    def new_session(self):
        self._server_connect()
        self._server_write('new %s %s\r\n' % (self.__appid, self.__age))
        return self._server_read_response()

    def put_session(self, sesid, text):
        self._server_connect()
        self._server_write('put %s %s\r\n' % (self.__appid, sesid))
        self._server_read_response()
        self._server_write(text.replace('\n', '\r\n'))
        self._server_write('\r\n')
        self._server_read_response()

    def del_session(self, sesid):
        self._server_connect()
        self._server_write('del %s %s\r\n' % (self.__appid, sesid))
        self._server_read_response()
