from main import USERHOME

import os
import random
import shelve
import shutil
import threading
import string


class StateSaverServer(object):
    """
      This class provides a singleton server object for saving program state
      in a database. The path of the database file contains random elements in
      order to not be attackable by "known-location attacks".
    """

    __OLD_STATES_FILE = os.path.join(USERHOME, "states.db")
    __STATES_FILE = os.path.join(USERHOME,
              "states" + hex(random.randrange(0xffffffL, 0xffffffffL)) + ".db")
    __SEPARATOR = ' '


    def __init__(self):

        self.__lock = threading.Lock()

        self.__db_file = self.__find_db()
        if (self.__db_file == self.__OLD_STATES_FILE):
            # move file to protect it
            shutil.copyfile(self.__OLD_STATES_FILE, self.__STATES_FILE)
            os.unlink(self.__OLD_STATES_FILE)
            self.__db_file = self.__STATES_FILE
        
        self.__db = shelve.open(self.__db_file, flag = 'c', writeback = True)

        # flush the cache every 5 seconds
        import gobject
        gobject.timeout_add(5 * 1000, self.__flush)


    def __flush(self):

        self.__db.sync()            
        return True


    def __find_db(self):

        files = os.listdir(USERHOME)
        dbs = [ f for f in files if f.endswith(".db") ]

        if (not dbs):
            return self.__STATES_FILE
        else:
            return os.path.join(USERHOME, dbs[0])
        
        

    def __check_key(self, key):

        assert key
        assert key[0] not in string.digits
        
        for c in key:
            assert c in (string.ascii_letters + string.digits + '_')



    def set_key(self, ident, key, value):

        self.__check_key(key)

        self.__lock.acquire()

        try:
            pk = ident + self.__SEPARATOR + key
            self.__db[pk] = value
        finally:
            self.__lock.release()



    def get_key(self, ident, key, default):

        self.__check_key(key)

        self.__lock.acquire()

        try:
            pk = ident + self.__SEPARATOR + key
            return self.__db.get(pk, default)
        finally:
            self.__lock.release()



    def remove(self, ident):

        self.__lock.acquire()

        try:
            for k in self.__db.keys():
                if k.startswith(ident + self.__SEPARATOR):
                    del self.__db[k]
        finally:
            self.__lock.release()




_SERVER = StateSaverServer()

class StateSaverClient(object):

    __slots__ = "__ident",


    def __init__(self, ident):
        self.__ident = ident

    def set_key(self, key, value):
        _SERVER.set_key(self.__ident, key, value)

    def get_key(self, key, default = None):
        return _SERVER.get_key(self.__ident, key, default)

    def remove(self):
        _SERVER.remove(self.__ident)



StateSaver = StateSaverClient

_singleton = StateSaver("__default__")
def DefaultStateSaver(): return _singleton
