# GNU Enterprise Application Server - Per-Session Cache
#
# Copyright 2004 Free Software Foundation
#
# This file is part of GNU Enterprise.
#
# GNU Enterprise is free software; you can redistribute it
# and/or modify it under the terms of the GNU General Public
# License as published by the Free Software Foundation; either
# version 2, or (at your option) any later version.
#
# GNU Enterprise is distributed in the hope that it will be
# useful, but WITHOUT ANY WARRANTY; without even the implied
# warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR
# PURPOSE. See the GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public
# License along with program; see the file COPYING. If not,
# write to the Free Software Foundation, Inc., 59 Temple Place
# - Suite 330, Boston, MA 02111-1307, USA.
#
# $Id: data.py 5680 2004-04-08 18:05:36Z reinhard $

from types import *

import string
import whrandom

from gnue.common.datasources import GDataSource, GConditions, GConnections

# =============================================================================
# Cache class
# =============================================================================

class _cache:
  """
  This class is acutally not more than a 3-dimensional array with the
  dimensions called "table", "row", and "field". Any combination of these can
  either have a string value, a value of None, or not be available.

  For any data item, the cache remembers the current value as well as the
  original value.

  This class doesn't do database access. It gets the values to store via the
  "write" method.

  This class is only used internally.
  """

  # ---------------------------------------------------------------------------
  # Initalize
  # ---------------------------------------------------------------------------

  def __init__ (self):
    self.__old = {}                     # Original data
    self.__new = {}                     # Changed (dirty) data

  # ---------------------------------------------------------------------------
  # Store data in the cache
  # ---------------------------------------------------------------------------

  def write (self, table, row, field, value, dirty):
    """
    Write data to the cache. If "dirty" is false (0), the cache takes the given
    value as the original value for the field. If "dirty" is true (1), the
    value is taken as the modified value for the field, and the original value
    is remembered (if it was set before).

    It is possible to set a dirty value without having set an original value
    before.
    """
    checktype (table, UnicodeType)
    checktype (row, UnicodeType)
    checktype (field, UnicodeType)

    if dirty:
      tables = self.__new
    else:
      tables = self.__old

    if not tables.has_key (table):
      tables [table] = {}
    rows = tables [table]

    if not rows.has_key (row):
      rows [row] = {}
    fields = rows [row]

    fields [field] = value

  # ---------------------------------------------------------------------------
  # Return whether a certain value is stored in the clean/dirty cache
  # ---------------------------------------------------------------------------

  def __has (self, table, row, field, dirty):

    result = 0

    if dirty:
      tables = self.__new
    else:
      tables = self.__old

    if tables.has_key (table):
      rows = tables [table]
      if rows.has_key (row):
        fields = rows [row]
        if fields.has_key (field):
          result = 1

    return result

  # ---------------------------------------------------------------------------
  # Return whether a certain value is stored in the cache or not
  # ---------------------------------------------------------------------------

  def has (self, table, row, field):
    """
    Return true (1) if the given item is stored in the cache (either in a clean
    or in a dirty version). Return false (0) if it isn't.
    """
    checktype (table, UnicodeType)
    checktype (row, UnicodeType)
    checktype (field, UnicodeType)

    if self.__has (table, row, field, 1) or self.__has (table, row, field, 0):
      return 1
    else:
      return 0

  # ---------------------------------------------------------------------------
  # Read data from the cache
  # ---------------------------------------------------------------------------

  def read (self, table, row, field):
    """
    Read data from the cache. This always returns the current version, no
    matter if it's dirty or not.

    If the given item isn't available, an exception is raised.
    """
    checktype (table, UnicodeType)
    checktype (row, UnicodeType)
    checktype (field, UnicodeType)

    if self.__has (table, row, field, 1):
      tables = self.__new               # Data is in dirty cache
    else:
      tables = self.__old               # Data isn't in dirty cache, so search
                                        # in clean cache
    rows = tables [table]
    fields = rows [row]
    return fields [field]

  # ---------------------------------------------------------------------------
  # Get the status of a record
  # ---------------------------------------------------------------------------

  def status (self, table, row):
    """
    Returns the status of the given row. Returns one of the following results:

    'inserted': newly created record

    'changed': existing record with modifications

    'deleted': deleted record

    For this function to work, an original value for the 'gnue_id' field must
    be available for any record except for newly created ones, and setting
    'gnue_id' to None means deleting the record.
    """
    checktype (table, UnicodeType)
    checktype (row, UnicodeType)

    if not self.__has (table, row, u'gnue_id', 0):
      return ''                         # row is not in cache at all

    old_id = self.__old [table] [row] [u'gnue_id']

    if self.__has (table, row, u'gnue_id', 1):
      new_id = self.__new [table] [row] [u'gnue_id']
    else:
      new_id = old_id

    if old_id is None:
      if new_id is None:
        return ''                       # row was inserted and deleted
      else:
        return 'inserted'
    else:
      if new_id is None:
        return 'deleted'
      else:
        if self.__new.has_key (table):
          rows = self.__new [table]
          if rows.has_key (row):
            return 'changed'
        return ''                       # row has no dirty fields

  # ---------------------------------------------------------------------------
  # List all tables with dirty records
  # ---------------------------------------------------------------------------

  def dirtyTables (self):
    """
    Returns a dictionary of tables with dirty data (inserted, changed or
    deleted rows), where the key is the table name and the value is a
    dictionary of all dirty rows in the table, where the key is the row id and
    the value is a dictionary of all dirty fields in that row, where the key is
    the field name and the value is the current value of the field. Got it?
    """

    return self.__new

  # ---------------------------------------------------------------------------
  # Clear the whole cache
  # ---------------------------------------------------------------------------

  def clear (self):
    """
    Forget all data in the cache, original values as well as dirty values.
    """

    self.__old = {}
    self.__new = {}

# =============================================================================
# Helper methods
# =============================================================================

# -----------------------------------------------------------------------------
# Create a result set
# -----------------------------------------------------------------------------

def _createDatasource (connections, database, table, fields, order = None):

  # prepare attributes of the datasource
  attributes = {}
  attributes ['name']     = ''
  attributes ['database'] = database
  attributes ['table']    = table

  if order is not None:
    if order != []:
      attributes ['order_by'] = string.joinfields (order, ',')

  # create the datasource
  datasource = GDataSource.DataSourceWrapper (
    connections = connections,
    attributes = attributes,
    fields = fields)

  # enable unicode mode for the datasource
  datasource._dataObject._unicodeMode = 1

  return datasource

# -----------------------------------------------------------------------------
# Create an empty result set
# -----------------------------------------------------------------------------

def _createEmptyResultSet (connections, database, table, fields):

  datasource = _createDatasource (connections, database, table, fields)
  return datasource.createEmptyResultSet ()

# -----------------------------------------------------------------------------
# Create a result set with data
# -----------------------------------------------------------------------------

def _createResultSet (connections, database, table, fields, conditions, order):

  datasource = _createDatasource (connections, database, table, fields, order)

  if isinstance (conditions, DictType):
    condition_tree = GConditions.buildConditionFromDict (conditions)

  elif isinstance (conditions, ListType):
    condition_tree = GConditions.buildTreeFromPrefix (conditions)

  elif conditions is None:
    condition_tree = None

  else:
    raise GConditions.ConditionError, _("Invalid condition format")

  return datasource.createResultSet (condition_tree)

# -----------------------------------------------------------------------------
# Create a result set containing only one row, identified by the gnue_id
# -----------------------------------------------------------------------------

def _find (connections, database, table, row, fields):

  conditions = [['eq', ''], ['field', u'gnue_id'], ['const', row]]
  resultSet = _createResultSet (connections, database, table, fields,
                                conditions, [])
  resultSet.firstRecord ()
  return resultSet

# =============================================================================
# Session class
# =============================================================================

class connection:
  """
  This class encapsulates a connection to the database where data is cached on
  connection level. This means that if one query modifies data, another query
  using the same connection reads the new version even if the changes are not
  committed yet.
  """

  # ---------------------------------------------------------------------------
  # Initialize
  # ---------------------------------------------------------------------------

  def __init__ (self, connections, database):
    checktype (connections, GConnections.GConnections)
    checktype (database, StringType)

    self.__connections = connections
    self.__database = database
    self.__cache = _cache ()

  # ---------------------------------------------------------------------------
  # Create a recordset from a query
  # ---------------------------------------------------------------------------

  def query (self, table, fields, conditions, order):
    """
    Returns a recordset object. All fields given in 'fields' are fetched from
    the database and cached, so that subsequential access to those fields won't
    trigger another access to the db backend.

    Table and field names must be unicode strings.

    Field values in conditions must be in native Python type; in case of
    strings they must be Unicode.
    """
    checktype (table, UnicodeType)
    checktype (fields, ListType)
    for fields_element in fields: checktype (fields_element, UnicodeType)

    return recordset (self.__cache, self.__connections, self.__database, table,
                      fields, conditions, order)

  # ---------------------------------------------------------------------------
  # Generate a new object id
  # ---------------------------------------------------------------------------

  def __generateId (self):

    # TODO: need a better algorithm here
    result = u''
    for i in range (0, 32):
      result = result + str (int (whrandom.random () * 10))
    return result

  # ---------------------------------------------------------------------------
  # Create a new record
  # ---------------------------------------------------------------------------

  def insertRecord (self, table):
    """
    Inserts a new record. A 'gnue_id' is assigned automatically.

    Table must be a unicode string.
    """
    checktype (table, UnicodeType)

    id = self.__generateId ()
    r = record (self.__cache, self.__connections, self.__database, table, id)
    self.__cache.write (table, id, u'gnue_id', None, 0)  # old id is None
    self.__cache.write (table, id, u'gnue_id', id, 1)    # new id
    return r

  # ---------------------------------------------------------------------------
  # Delete a record
  # ---------------------------------------------------------------------------

  def deleteRecord (self, table, row):
    """
    Deletes the given record (acutally marks it for deletion on commit). All
    data of the record will stay available until commit, but the field
    'gnue_id' will seem to have a value of None.

    Table and row must be unicode strings.
    """
    checktype (table, UnicodeType)
    checktype (row, UnicodeType)

    if not self.__cache.has (table, row, u'gnue_id'):    # not yet in cache
      self.__cache.write (table, row, u'gnue_id', row, 0)
    self.__cache.write (table, row, u'gnue_id', None, 1)

  # ---------------------------------------------------------------------------
  # Find a record
  # ---------------------------------------------------------------------------

  def findRecord (self, table, row, fields):
    """
    Loads a record from the database.  All fields given in 'fields' are fetched
    from the database and cached, so that subsequential access to those fields
    won't trigger another access to the db backend.

    This method won't query the db backend for data which is already cached.

    Table and row must be unicode strings, fields must be a list of unicode
    strings.
    """
    checktype (table, UnicodeType)
    checktype (row, UnicodeType)
    checktype (fields, ListType)
    for fields_element in fields: checktype (fields_element, UnicodeType)

    uncachedFields = []
    for field in fields:
      if not self.__cache.has (table, row, field):
        uncachedFields.append(field)

    if uncachedFields == []:
      # already cached, no need to load from database
      r = record (self.__cache, self.__connections, self.__database, table, row)
    else:
      # not yet cached, need to load from database
      resultSet = _find (self.__connections, self.__database, table, row,
                         fields)
      if resultSet.current is None:
        return None
      r = record (self.__cache, self.__connections, self.__database, table, row)
      r._fill (fields, resultSet.current)
    return r

  # ---------------------------------------------------------------------------
  # Write all changes back to the database
  # ---------------------------------------------------------------------------

  def commit (self):
    """
    Write all dirty data to the database backend by a single transaction that
    is committed immediately. This operation invalidates the cache.
    """

    tables = self.__cache.dirtyTables ()
    for (table, rows) in tables.items ():
      for (row, fields) in rows.items ():
        status = self.__cache.status (table, row)

        if status == 'inserted':
          resultSet = _createEmptyResultSet (self.__connections,
                                             self.__database,
                                             table, fields.keys ())
          resultSet.insertRecord ()

          for (field, value) in fields.items ():
            resultSet.current.setField (field, value)

        elif status == 'changed':
          # TODO: gnue-common should provide a method for updating a record
          # without reading it first. Until that is done, we have to create a
          # temporary resultSet for every record we update
          resultSet = _find (self.__connections, self.__database, table, row,
                             [u'gnue_id'] + fields.keys ())

          for (field, value) in fields.items ():
            resultSet.current.setField (field, value)

        elif status == 'deleted':
          # TODO: gnue-common should provide a method for deleting a record
          # without reading it first. Until that is done, we have to create a
          # temporary resultSet for every record we delete
          resultSet = _find (self.__connections, self.__database, table, row,
                             [u'gnue_id'])
          resultSet.current.delete ()

        if status != '':
          resultSet.post ()

    # Commit the whole transaction
    self.__connections.commitAll ()

    # The transaction has ended. Changes from other transactions could become
    # valid in this moment, so we have to clear the whole cache.
    self.__cache.clear ()

  # ---------------------------------------------------------------------------
  # Undo all changes
  # ---------------------------------------------------------------------------

  def rollback (self):
    """
    Undo all uncommitted changes.
    """

    # Send the rollback to the database. Although we have (most probably) not
    # written anything yet, we have to tell the database that a new transaction
    # starts now, so that commits from other sessions become valid now for us
    # too.
    self.__connections.rollbackAll ()

    # The transaction has ended. Changes from other transactions could become
    # valid in this moment, so we have to clear the whole cache.
    self.__cache.clear ()

  # ---------------------------------------------------------------------------
  # Close the connection
  # ---------------------------------------------------------------------------

  def close (self):
    """
    Close the connection to the database backend.
    """

    self.__connections.closeAll ()

# =============================================================================
# Recordset class
# =============================================================================

class recordset:
  """
  This class manages the result of a query. An instance of this class can be
  created via the connection.query() method.
  """

  # ---------------------------------------------------------------------------
  # Initialize
  # ---------------------------------------------------------------------------

  def __init__ (self, cache, connections, database, table, fields, conditions,
                order):
    self.__cache = cache
    self.__connections = connections
    self.__database = database
    self.__table = table
    self.__fields = [u'gnue_id'] + fields
    self.__resultSet = _createResultSet (self.__connections, self.__database,
                                         self.__table, self.__fields, 
                                         conditions, order)

  # ---------------------------------------------------------------------------
  # Return the number of records
  # ---------------------------------------------------------------------------

  def count (self):

    return self.__resultSet.getRecordCount ()

  # ---------------------------------------------------------------------------
  # Return the first record
  # ---------------------------------------------------------------------------

  def firstRecord (self):
    """
    Returns the first record or None if the set is empty.
    """

    if self.__resultSet.firstRecord () is None:
      return None
    else:
      id = self.__resultSet.current [u'gnue_id']
      r = record (self.__cache, self.__connections, self.__database,
                  self.__table, id)
      r._fill (self.__fields, self.__resultSet.current)
      return r

  # ---------------------------------------------------------------------------
  # Return the next record
  # ---------------------------------------------------------------------------

  def nextRecord (self):
    """
    Returns the next record or None if nothing is left.
    """

    if self.__resultSet.nextRecord () is None:
      return None
    else:
      id = self.__resultSet.current [u'gnue_id']
      r = record (self.__cache, self.__connections, self.__database,
                  self.__table, id)
      r._fill (self.__fields, self.__resultSet.current)
      return r

# =============================================================================
# Record class
# =============================================================================

class record:
  """
  This class stands for a record in a database table. An instance of this class
  can be created via the recordset.firstRecord() and recordset.nextRecord()
  methods.
  """

  # ---------------------------------------------------------------------------
  # Initialize
  # ---------------------------------------------------------------------------

  def __init__ (self, cache, connections, database, table, row):

    self.__cache = cache
    self.__connections = connections
    self.__database = database
    self.__table = table
    self.__row = row

  # ---------------------------------------------------------------------------
  # Fill the cache for this record with data from a (gnue-common) RecordSet
  # ---------------------------------------------------------------------------

  def _fill (self, fields, RecordSet):

    for field in fields:
      # Never ever override the cache with data from the backend
      if not self.__cache.has (self.__table, self.__row, field):
        self.__cache.write (self.__table, self.__row, field, RecordSet [field],
                            0)

  # ---------------------------------------------------------------------------
  # Get the value for a field
  # ---------------------------------------------------------------------------

  def getField (self, field):
    """
    Get the value for a field. If the value isn't cached, a new query to the
    database is issued to get the value.

    The field name must be given as a unicode string. The result will be
    returned as the native Python datatype, in case of a string field it will
    be Unicode.
    """
    checktype (field, UnicodeType)

    if self.__cache.has (self.__table, self.__row, field):
      # If we find the field in the cache, use it
      return self.__cache.read (self.__table, self.__row, field)
    else:
      # Not found in cache, so get it from the db
      resultSet = _find (self.__connections, self.__database, self.__table,
                         self.__row, [field])
      if resultSet.current is not None:
        value = resultSet.current [field]
        self.__cache.write (self.__table, self.__row, field, value, 0)
        return value
      else:
        return None

  # ---------------------------------------------------------------------------
  # Put the value for a field
  # ---------------------------------------------------------------------------

  def putField (self, field, value):
    """
    Put the value for a field.

    The field name must be given as a unicode string, value must be the native
    Python datatype of the field, in case of a string field it must be Unicode.
    """
    checktype (field, UnicodeType)

    self.__cache.write (self.__table, self.__row, field, value, 1)

# =============================================================================
# Self test code
# =============================================================================

if __name__ == '__main__':

  from gnue.common.apps import GClientApp

  app = GClientApp.GClientApp ()

  print 'create connection object ...',
  c = connection (app.connections, 'gnue')
  print 'Ok'

  print 'connection.query for existing records ...'
  rs = c.query (u'address_person', [u'address_name'], None, [u'address_name'])
  print 'Ok'

  print 'recordset.firstRecord ...',
  r = rs.firstRecord ()
  print 'Ok'

  print 'record.getField with prefetched data ...',
  print repr (r.getField (u'address_name'))

  print 'record.getField with non-prefetched data ...',
  print repr (r.getField (u'address_city'))

  print 'recordset.nextRecord ...',
  r = rs.nextRecord ()
  print 'Ok'

  print 'record.getField with prefetched data ...',
  print repr (r.getField (u'address_name'))

  print 'connection.insertRecord ...',
  r = c.insertRecord (u'address_person')
  print 'Ok'

  print 'record.getField ...',
  id = r.getField (u'gnue_id')
  print repr (id)

  print 'record.putField ...',
  r.putField (u'address_name', u'New Person')
  print 'Ok'

  print 'record.getField of inserted data ...',
  print repr (r.getField (u'address_name'))

  print 'connection.commit with an inserted record ...',
  c.commit ()
  print 'Ok'

  print 'connection.query for previously inserted record ...',
  rs = c.query (u'address_person', [u'address_name'],
                [['eq', ''], ['field', u'address_name'],
                             ['const', u'New Person']], None)
  print 'Ok'

  print 'recordset.firstRecord ...',
  r = rs.firstRecord ()
  print 'Ok'

  print 'record.getField with prefetched data ...',
  print repr (r.getField (u'address_name'))

  print 'record.putField of prefetched data ...',
  r.putField (u'address_name', u'New Name')
  print 'Ok'

  print 'record.putField of non-prefetched data ...',
  r.putField (u'address_city', u'New City')
  print 'Ok'

  print 'record.getField of changed data ...',
  print repr (r.getField (u'address_name'))

  print 'connection.findRecord for previously changed record ...',
  r = c.findRecord (u'address_person', id, [u'address_name'])
  print 'Ok'

  print 'record.getField of changed data, independent query ...',
  print repr (r.getField (u'address_city'))

  print 'connection.commit with a changed record ...',
  c.commit ()
  print 'Ok'

  print 'record.getField of prefetched data ...',
  print repr (r.getField (u'address_name'))

  print 'connection.deleteRecord ...',
  c.deleteRecord (u'address_person', id)
  print 'Ok'

  print 'record.getField of deleted uncommitted record, prefetched ...',
  print repr (r.getField (u'address_name'))

  print 'record.getField of deleted uncommitted record, non-prefetched ...',
  print repr (r.getField (u'address_city'))

  print 'connection.commit with a deleted record ...',
  c.commit ()
  print 'Ok'

  print 'check if the record is really gone now ...',
  rs = c.query (u'address_person', [u'address_name'],
                [['eq', ''], ['field', u'address_city'],
                             ['const', u'New City']],
                None)
  if rs.firstRecord () != None:
    raise Exception
  print 'Ok'

  print 'connection.close ...',
  c.close ()
  print 'Ok'

  print 'Thank you for playing!'
