#
# This file is part of GNU Enterprise.
#
# GNU Enterprise is free software; you can redistribute it
# and/or modify it under the terms of the GNU General Public
# License as published by the Free Software Foundation; either
# version 2, or(at your option) any later version.
#
# GNU Enterprise is distributed in the hope that it will be
# useful, but WITHOUT ANY WARRANTY; without even the implied
# warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR
# PURPOSE. See the GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public
# License along with program; see the file COPYING. If not,
# write to the Free Software Foundation, Inc., 59 Temple Place
# - Suite 330, Boston, MA 02111-1307, USA.
#
# Copyright 2000-2004 Free Software Foundation
#
# FILE:
# Introspection.py
#
# DESCRIPTION:
#
# NOTES:
#

__all__ = ['Introspection']

import string
from string import lower, join, rstrip, upper
import sys

from gnue.common.apps import GDebug, GConfig
from gnue.common.datasources import GIntrospection

DESCRIPTION_NAME            = 0
DESCRIPTION_TYPE_CODE       = 1
DESCRIPTION_DISPLAY_SIZE    = 2
DESCRIPTION_INTERNAL_SIZE   = 3
DESCRIPTION_PRECISION       = 4
DESCRIPTION_SCALE           = 5
DESCRIPTION_NULL_OK         = 6

class Introspection(GIntrospection.Introspection):
  _primaryKeyFields = []

  # list of the types of Schema objects this driver provides
  types =[ ('table',_('Tables'),1),
           ('view',_('Views'),1)]

  #
  # TODO: This is a quick hack to get this class
  # TODO: into the new-style schema format.
  # TODO: getSchema* should be merged into find()
  #
  def find(self, name=None, type=None):
    if name is None:
      return self.getSchemaList(type)
    else:
      rs = self.getSchemaByName(name, type)
      if rs:
        return [rs]
      else:
        return None


  # TODO: Merge into find()
  # Return a list of Schema objects

  #
  # Schema (metadata) functions
  #

  # Return a list of Schema objects
  def getSchemaList(self, type=None):

  # This excludes any system tables and views.
    statement = "select rdb$relation_name, rdb$view_source "+\
			"from rdb$relations " + \
			"where rdb$system_flag=0 " + \
			"order by rdb$relation_name"

    cursor = self._connection.native.cursor()
    cursor.execute(statement)

  # TODO: rdb$view_source is null for table and rdb$view_source is not null for view
    list = []
    for rs in cursor.fetchall():
      list.append(GIntrospection.Schema(attrs={'id':rs[0], 'name':rstrip(rs[0]),
                         'type':'table',
			 'primarykey': self.__getPrimaryKey(rstrip(rs[0]))},
                         getChildSchema=self.__getFieldSchema))

    cursor.close()
    return list

  # Find a schema object with specified name
  def getSchemaByName(self, name, type=None):

    statement = "select rdb$relation_name, rdb$view_source "+\
			"from rdb$relations " + \
			"where rdb$relation_name = '%s'" % (name)

    cursor = self._connection.native.cursor()
    cursor.execute(statement)

    rs = cursor.fetchone()
    if rs:
      schema = GIntrospection.Schema(attrs={'id':rs[0], 'name':rstrip(rs[0]),
                           'type':'table',
			   'primarykey': self.__getPrimaryKey(rstrip(rs[0]))},
                           getChildSchema=self.__getFieldSchema)
    else:
      schema = None

    cursor.close()
    return schema

  # Return a list of fields (for _buildDeleteStatement and for _buildUpdateStatement)
  def __getPrimaryKey(self, relname):
    statement = "select rdb$relation_name, rdb$field_name, "+\
				   "rdb$constraint_name, rdb$field_position "+\
				    "from rdb$relation_constraints rc, rdb$index_segments ri "+\
				    "where ri.rdb$index_name = rc.rdb$index_name "+\
					    "and rc.rdb$constraint_type = 'PRIMARY KEY' "+\
					    "and rc.rdb$relation_name = '%s' " % (relname)+\
				    "order by ri.rdb$field_position"

    cursor = self._connection.native.cursor()
    cursor.execute(statement)

    list = []
    for rs in cursor.fetchall():
      list.append(lower(rstrip(rs[1])))

    cursor.close()
    return list

  # Get fields for a table
  def __getFieldSchema(self, parent):

    statement = "select * from %s"%(parent.name) + " where (0=1)"

    cursor = self._connection.native.cursor()
    cursor.execute(statement)

    list = []

    for d in cursor.description:
      try:
        nativetype = lower(d[DESCRIPTION_TYPE_CODE].__name__)
      except AttributeError:
        nativetype='unknown'
      
      attrs={'id':d[DESCRIPTION_NAME],
                 'name':lower(d[DESCRIPTION_NAME]),
                 'type':'field',
                 'nativetype': nativetype,
                 'required': d[DESCRIPTION_NULL_OK]==0,
                 'length': d[DESCRIPTION_DISPLAY_SIZE]}

      if nativetype in ('int','float','long'):
        attrs['datatype']='number'
        attrs['precision']=d[DESCRIPTION_SCALE]
      elif nativetype == 'tuple':
        attrs['datatype']='date'
      else:
        attrs['datatype']='text'

      cursor.execute("select rdb$default_source from rdb$relation_fields"+ \
                  " where rdb$relation_name = '%s' " % (parent.name)+ \
                  " and rdb$field_name = '%s'" % (upper(attrs['name'])))
      defrs = cursor.fetchone()
      if defrs[0]:
        dflt = defrs[0]
        if dflt[9:12] == "NOW":
          attrs['defaulttype'] = 'timestamp'
        else:
          attrs['defaulttype'] = 'constant'
          attrs['defaultval'] = dflt[8:]

      list.append(GIntrospection.Schema(attrs=attrs))

    cursor.close()
    return list
