#
# This file is part of GNU Enterprise.
#
# GNU Enterprise is free software; you can redistribute it
# and/or modify it under the terms of the GNU General Public
# License as published by the Free Software Foundation; either
# version 2, or (at your option) any later version.
#
# GNU Enterprise is distributed in the hope that it will be
# useful, but WITHOUT ANY WARRANTY; without even the implied
# warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR
# PURPOSE. See the GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public
# License along with program; see the file COPYING. If not,
# write to the Free Software Foundation, Inc., 59 Temple Place
# - Suite 330, Boston, MA 02111-1307, USA.
#
# Copyright 2000-2005 Free Software Foundation
#
# FILE:
# odbc/DBdriver.py
#
# DESCRIPTION:
# Driver to provide access to data via the public domain win32all ODBC Driver
#
# NOTES:
# Only works under Win32... requires the win32all extensions.
# (http://aspn.activestate.com/ASPN/Downloads/ActivePython/Extensions/Win32all)
#
#   Supported attributes (via connections.conf or <database> tag)
#
#     service=   This is the ODBC DSN= string to use.
#
#


import sys, string, types
from gnue.common.datasources import GDataObjects, GConditions, GConnections
from gnue.common.apps import GDebug


try:
  import dbi, odbc
except ImportError, message:
  tmsg = u_("Driver not installed: win32all ODBC driver\n\n[%s") % message
  raise GConnections.AdapterNotInstalled, tmsg




class ODBC_RecordSet(GDataObjects.RecordSet):
  def _postChanges(self):
    if not self.isPending(): return
    if self._deleteFlag:
      statement = self._buildDeleteStatement()
    elif self._insertFlag:
      statement = self._buildInsertStatement()
    elif self._updateFlag:
      statement = self._buildUpdateStatement()

    GDebug.printMesg(9, "_postChanges: statement=%s" % statement)

    try:
      self._parent._update_cursor.execute(statement)

      # Set _initialData to be the just-now posted values
      if not self._deleteFlag:
        self._initialData = {}
        for key in self._fields.keys():
          self._initialData[key] = self._fields[key]

    except self._parent._dataObject._connection._DatabaseError, err:
      raise GDataObjects.ConnectionError, err

    self._updateFlag = 0
    self._insertFlag = 0
    self._deleteFlag = 0

    return 1
        

  # If a vendor can do any of these more efficiently (i.e., use a known
  # PRIMARY KEY or ROWID, then override these methods. Otherwise, leave
  # as default.  Note that these functions are specific to DB-SIG based
  # drivers (i.e., these functions are not in the base RecordSet class)

  def _buildDeleteStatement(self):
    if self._initialData.has_key(self._parent._primaryIdField):
      where = [self._parent._primaryIdFormat % \
          self._initialData[self._parent._primaryIdField]  ]
    else:
      where = []
      for field in self._initialData.keys():
        if self._parent.isFieldBound(field):
          if self._initialData[field] == None:
            where.append ("%s IS NULL" % field)
          else:
            where.append ("%s='%s'" % (field, self._initialData[field]))

    statement = "DELETE FROM %s WHERE %s" % \
       (self._parent._dataObject.table, string.join(where,' AND ') )
    return statement

  def _buildInsertStatement(self): 
    vals = []
    fields = []

    # TODO: This should actually only insert modified fields.
    # TODO: Unfortunately, self._modifiedFlags is not being 
    # TODO: set for new records (!@#$)
    #for field in self._modifiedFlags.keys():

    for field in self._fields.keys():
      if self._parent.isFieldBound(field):
        fields.append (field)
        if self._fields[field] == None or self._fields[field] == '':
          vals.append ("NULL") #  % (self._fields[field]))
        else:
          try: 
            if self._parent._fieldTypes[field] == 'number': 
              vals.append ("%s" % (self._fields[field]))
            else: 
              vals.append ("'%s'" % (self._fields[field]))
          except ValueError:
            vals.append ("%s" % (self._fields[field]))

    return "INSERT INTO %s (%s) VALUES (%s)" % \
       (self._parent._dataObject.table, string.join(fields,','), \
        string.join(vals,',') )


  def _buildUpdateStatement(self):
    updates = []
    for field in self._modifiedFlags.keys():
      try:
        if self._parent._fieldTypes[field] == 'number':
          updates.append ("%s=%s" % (field, self._fields[field]))
        else:
          updates.append ("%s='%s'" % (field, self._fields[field]))
      except KeyError:
        updates.append ("%s='%s'" % (field, self._fields[field]))

    if self._initialData.has_key(self._parent._primaryIdField):
      where = [self._parent._primaryIdFormat % \
          self._initialData[self._parent._primaryIdField]  ]
    else:
      where = []
      for field in self._initialData.keys():
        if self._initialData[field] == None:
          where.append ("%s IS NULL" % field)
        else:
          try:
            if self._parent._fieldTypes[field] == 'number':
              where.append ("%s=%s" % (field, self._initialData[field]))
            else:
              where.append ("%s='%s'" % (field, self._initialData[field]))
          except KeyError:
            where.append ("%s='%s'" % (field, self._initialData[field]))

    return "UPDATE %s SET %s WHERE %s" % \
       (self._parent._dataObject.table, string.join(updates,','), \
        string.join(where,' AND ') )


class ODBC_ResultSet(GDataObjects.ResultSet): 
  def __init__(self, dataObject, cursor=None, \
        defaultValues={}, masterRecordSet=None):
    GDataObjects.ResultSet.__init__(
           self,dataObject,cursor,defaultValues,masterRecordSet)
    self._recordSetClass = ODBC_RecordSet
    self._fieldNames = None
    self._fieldTypes = {}

#    self._recordCount = cursor.rowcount > 0 and cursor.rowcount or 0
    self._recordCount = 0

    # If a DB driver supports a unique identifier for rows,
    # list it here.  _primaryIdField is the field name (lower case)
    # that would appear in the recordset (note that this can be
    # a system generated format). If a primary id is supported,
    # _primaryIdFormat is the WHERE clause to be used. It will have
    # the string  % (fieldvalue) format applied to it.
    self._primaryIdField = None
    self._primaryIdFormat = "__gnue__ = '%s'"

    GDebug.printMesg(9, 'ResultSet created')

  def _loadNextRecord(self):
    if self._cursor:
      rs = None

      try:
        rs = self._cursor.fetchone()
      except self._dataObject._connection._DatabaseError, err:
        pass
# TODO: It seems that popy does what the other drivers don't
# TODO: and raises this error ALOT need to find out why
#        raise GDataObjects.ConnectionError, err

      if rs:
        if not self._fieldNames:
          self._fieldNames = []
          for t in (self._cursor.description):
            self._fieldNames.append (string.lower(t[0]))
            self._fieldTypes[string.lower(t[0])] = (string.lower(t[1]))
        i = 0
        dict = {}
        for f in (rs):
          dict[self._fieldNames[i]] = f
          i = i + 1
        self._cachedRecords.append (self._recordSetClass(parent=self, \
                                            initialData=dict))
        return 1
      else:
        return 0
    else:
      return 0


class ODBC_DataObject(GDataObjects.DataObject):

  conditionElements = {
       'add':             (2, 999, '(%s)',                   '+'      ),
       'sub':             (2, 999, '(%s)',                   '-'      ),
       'mul':             (2, 999, '(%s)',                   '*'      ),
       'div':             (2, 999, '(%s)',                   '/'      ),
       'and':             (1, 999, '(%s)',                   ' AND '  ),
       'or':              (2, 999, '(%s)',                   ' OR '   ),
       'not':             (1,   1, '(NOT %s)',               None     ),
       'negate':          (1,   1, '-%s',                    None     ),
       'eq':              (2,   2, '(%s = %s)',              None     ),
       'ne':              (2,   2, '(%s != %s)',             None     ),
       'gt':              (2,   2, '(%s > %s)',              None     ),
       'ge':              (2,   2, '(%s >= %s)',             None     ),
       'lt':              (2,   2, '(%s < %s)',              None     ),
       'le':              (2,   2, '(%s <= %s)',             None     ),
       'like':            (2,   2, '%s LIKE %s',             None     ),
       'notlike':         (2,   2, '%s NOT LIKE %s',         None     ),
       'between':         (3,   3, '%s BETWEEN %s AND %s',   None     ) }

  def __init__(self, strictQueryCount=1):
    GDataObjects.DataObject.__init__(self)

    GDebug.printMesg (9,"DB-SIG database driver backend initializing")

    self._resultSetClass = ODBC_ResultSet
    self._DatabaseError = None
    self._strictQueryCount = strictQueryCount


  # This should be over-ridden only if driver needs more than user/pass
  def getLoginFields(self):
    return [['_username', 'User Name',0],['_password', 'Password',1]]


  def connect(self, connectData={}):

    GDebug.printMesg(9,"ODBC database driver initializing")
    self._DatabaseError = odbc.error

    try:
      service = connectData['service']
    except KeyError:
      service = ""

    try:
      self._dataConnection = odbc.odbc( "%s/%s/%s" % (
                   service,
                   connectData['_username'],
                   connectData['_password']))

    except dbi.opError, value:
      raise GDataObjects.LoginError, value

    except self._DatabaseError, value:
      raise GDataObjects.LoginError, value

    self._postConnect()


  #
  # Schema (metadata) functions
  #

  # TODO: See postgresql for an example of what these functions do.

  # Return a list of the types of Schema objects this driver provides
  def getSchemaTypes(self):
    return None # [('table',_('Tables'),1)]

  # Return a list of Schema objects
  def getSchemaList(self, type=None):
    return []

  # Find a schema object with specified name
  def getSchemaByName(self, name, type=None):
    return None

  def _postConnect(self):
    self.triggerExtensions = TriggerExtensions(self._dataConnection)

  def _createResultSet(self, conditions={}, readOnly=0, masterRecordSet=None,sql=""):
    try:
      cursor = self._dataConnection.cursor()
      cursor.execute(self._buildQuery(conditions))

    except dbi.progError, err:
      raise GDataObjects.ConnectionError, err

    except self._DatabaseError, err:
      raise GDataObjects.ConnectionError, err
    rs = self._resultSetClass(self, cursor=cursor, masterRecordSet=masterRecordSet)

    # pull a record count for the upcomming query
    if self._strictQueryCount:
      rs._recordCount = self._getQueryCount(conditions)

    if readOnly:
      rs._readonly = readOnly
    return rs


  def _getQueryCount(self,conditions={}):
    cursor = self._dataConnection.cursor()

    cursor.execute(self._buildQueryCount(conditions))
    rs = cursor.fetchone()
    return int(rs[0])

    
  def _buildQueryCount(self, conditions={}):
    q = "SELECT count(*) FROM %s%s" % (self.table, self._conditionToSQL(conditions))

    GDebug.printMesg(9,q)

    return q

  def commit(self):
    GDebug.printMesg (9,"DB-SIG database driver: commit()")

    try: 
      self._dataConnection.commit()
    except self._DatabaseError, value:
      raise GDataObjects.ConnectionError, value
    
    self._beginTransaction()

  def rollback(self): 
    GDebug.printMesg (9,"DB-SIG database driver: rollback()")

    try: 
      self._dataConnection.rollback()
    except: 
      pass	# I'm SURE this isn't right (jcater)
                # But not all db's support transactions

    self._beginTransaction()


  def _buildQuery(self, conditions={},forDetail=None,additionalSQL=""):
    return None


  # Used to convert a condition tree to an sql where clause
  def _conditionToSQL (self, condition): 
    if condition == {} or condition == None: 
      return ""
    elif type(condition) == types.DictType: 
      cond = GConditions.buildConditionFromDict(condition)
    else:
      cond = condition
  
    if not len(cond._children): 
      return ""
    elif len(cond._children) > 1: 
      chillun = cond._children
      cond._children = []
      _and = GConditions.GCand(cond)
      _and._children = chillun
  

    where = " WHERE (%s)" % (self.__conditionToSQL (cond._children[0]))
    GDebug.printMesg(9, where)
    return where
  
  # Used internally by _conditionToSQL
  def __conditionToSQL (self, element): 
    if type(element) != types.InstanceType: 
      return "%s" % element
    else: 
      otype = string.lower(element._type[2:])
      if otype == 'cfield': 
        return "%s" % element.name
      elif otype == 'cconst':
        if element.value == None:
          return "NULL"
        elif element.type == 'number':
          return "%s" % element.value
        else:
          return "'%s'" % element.value
      elif otype == 'param':
        v = element.getValue()
        return (v == None and "NULL") or ("'%s'" % element.getValue())
      elif self.conditionElements.has_key(otype):
        for i in range(0, len(element._children)): 
          element._children[i] = self.__conditionToSQL(element._children[i])
        if len(element._children) < self.conditionElements[otype][0]:
          tmsg = u_('Condition element "%(element)s" expects at least '
                    '%(expected)s arguments; found %(found)s') \
                 % {'element' : otype,
                    'expected': self.conditionElements[otype][0],
                    'found'   : len (element._children)}
          raise GConditions.ConditionError, tmsg
        if len(element._children) > self.conditionElements[otype][1]:
          tmsg = u_('Condition element "%(element)s" expects at most '
                    '%(expected)s arguments; found %(found)s') \
                 % {'element' : otype,
                    'expected': self.conditionElements [otype][1],
                    'found'   : len (element._children)}
          raise GConditions.ConditionError, tmsg
        if self.conditionElements[otype][3] == None: 
          return self.conditionElements[otype][2] % tuple(element._children)
        else: 
          return self.conditionElements[otype][2] % \
            (string.join(element._children, self.conditionElements[otype][3]))
      else: 
        tmsg = u_('Condition clause "%s" is not supported by this db driver.') % otype
        raise GConditions.ConditionNotSupported, tmsg

  # Code necessary to force the connection into transaction mode... 
  # this is usually not necessary (MySQL is one of few DBs that must force)
  def _beginTransaction(self): 
    pass      


class ODBC_DataObject_Object(ODBC_DataObject): 
  def _buildQuery(self, conditions={}): 
    GDebug.printMesg(9,'Implicit Fields: %s' % self._fieldReferences)
    if len(self._fieldReferences):
      q = "SELECT %s FROM %s%s" % \
           (string.join(self._fieldReferences.keys(),","), self.table, 
            self._conditionToSQL(conditions))
    else: 
      q = "SELECT * FROM %s%s" % (self.table, self._conditionToSQL(conditions))

    if hasattr(self,'order_by'):
     q = "%s ORDER BY %s " % (q, self.order_by)

    GDebug.printMesg(9,q)

    return q



#
#  Extensions to Trigger Namespaces
#
class TriggerExtensions:

  def __init__(self, connection):
    self.__connection = connection




######################################
#
#  The following hashes describe
#  this driver's characteristings.
#
######################################

#
#  All datasouce "types" and corresponding DataObject class
#
supportedDataObjects = {
  'object': ODBC_DataObject_Object,
#  'sql':    ODBC_DataObject_SQL
}


