#
# SchoolTool - common information systems platform for school administration
# Copyright (c) 2012 Shuttleworth Foundation
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation; either version 2 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software
# Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
#
"""
LDAP authentication plugin configuration.
"""
import os
import re
import ldap

from persistent import Persistent
from zope.annotation.interfaces import IAnnotations
from zope.component import provideUtility, adapter, getUtility
from zope.interface import implements, implementer

from schooltool.app.interfaces import ISchoolToolApplication
from schooltool.ldap.interfaces import ILDAPConfig, ILDAPPersonsConfig
from schooltool.ldap.interfaces import ILDAPAutoConfig


ST_APP_LDAP_KEY = 'schooltool.ldap.configuration'

LDAP_SCOPES = {
    'base': ldap.SCOPE_BASE, # scope of the entry
    'one': ldap.SCOPE_ONELEVEL, # search directly into entry
    'sub': ldap.SCOPE_SUBTREE, # search whole subtree
    }


def cleanup_ldap_filter(filter):
    if not filter:
        return filter
    filter = filter.strip()
    if not filter.startswith('('):
        filter = '(%s)' % filter
    return filter


def filter_op_filter(filter, other, op='&'):
    if not filter:
        return cleanup_ldap_filter(other)
    filter = filter.strip()
    if filter.startswith('('):
        template = '%s%s(%s)'
    else:
        template = '%s(%s)(%s)'
    result = template % (
        op, filter, ldap.filter.filter_format('%s=%s', other))
    return cleanup_ldap_filter(result)


def decode_ldap_query(query, lenient=False):
    """Decode params from ldap.conf query."""
    # RFC2307bis naming contexts
    # Syntax: base?scope?filter
    # where scope is {base,one,sub}
    base = scope = filter = None
    query = [s.strip() for s in query.split('?')]
    if query:
        base = query.pop(0)
    if query:
        token = query.pop(0).lower()
        if token not in LDAP_SCOPES and not lenient:
            raise ValueError(query, 'expected %r got %r' % (
                    tuple(LDAP_SCOPES), token))
        scope = token
    if query:
        filter = '?'.join(query)
    return base, scope, filter


class LDAPConfig(Persistent):
    implements(ILDAPConfig)

    uri = 'ldap://127.0.0.1:389'
    timeout = 10
    bind_dn = None
    bind_password = None

    def copy(self):
        copy = self.__class__()
        copy.uri = self.uri
        copy.timeout = self.timeout
        copy.bind_dn = self.bind_dn
        copy.bind_password = self.bind_password
        return copy


class LDAPPersonsConfig(LDAPConfig):
    implements(ILDAPPersonsConfig)

    queries = u''
    groupQueries = u''
    posixGroups = u''

    def copy(self):
        copy = LDAPConfig.copy(self)
        copy.queries = self.queries
        copy.groupQueries = self.groupQueries
        copy.posixGroups = self.posixGroups
        return copy


class LDAPPersonsAutoConfig(LDAPPersonsConfig):
    implements(ILDAPAutoConfig)

    enable_ttw = True


def iter_ldap_config(config_file):
    if isinstance(config_file, str):
        config_file = open(config_file)
    text = config_file.read()
    token_re = re.compile('\s')
    for line in text.splitlines():
        line = line.strip()
        if not line or line.startswith('#'):
            continue
        tokens = [t.strip() for t in token_re.split(line, 1)]
        if len(tokens) == 1:
            yield tokens[0], None
        else:
            yield tokens[0], tokens[1]


class LDAPConfigParser(object):

    config = None

    def __init__(self):
        self.clear()

    def clear(self):
        self.config = LDAPPersonsConfig()
        self.config.uri = None

    def parse(self, config_file, **defaults):
        cache = dict(defaults)
        entries = list(iter_ldap_config(config_file))
        for token, value in entries:
            method = getattr(self, 'pre_%s' % token.lower(), None)
            if method is not None:
                method(value, cache)
        for token, value in entries:
            method = getattr(self, 'on_%s' % token.lower(), None)
            if method is not None:
                method(value, cache)
        self.after_parse(config_file, cache)


    def after_parse(self, config_file, cache):
        self.set_default_uri(cache)
        self.append_pam_query(cache)
        self.set_fallback_query(cache)

    def set_default_uri(self, cache):
        if self.config.uri is None:
            self.config.uri = 'ldap://%(host)s:%(port)s' % {
                'host': cache.get('host', '127.0.0.1'),
                'port': cache.get('port', '389'),
                }

    def append_pam_query(self, cache):
        if ('base' in cache and
            ('pam_filter' in cache or
             'pam_login_attribute' in cache)):
            pam_filter = cleanup_ldap_filter(
                cache.get('pam_filter', 'objectClass=posixAccount'))
            pam_login_attr = cache.get('pam_login_attribute', 'uid')
            query = '%s %s?%s?%s' % (
                pam_login_attr,
                cache['base'], cache.get('scope', 'sub'), pam_filter
                )
            self.config.queries = ('%s\n%s' % (self.config.queries, query)).strip()

    def set_fallback_query(self, cache):
        if not self.config.queries:
            if ('fallback_login_attr' in cache and
                'fallback_login_filter' in cache and
                'base' in cache):
                self.config.queries = '%s %s?%s?%s' % (
                    cache['fallback_login_attr'],
                    cache['base'],
                    cache.get('scope', 'sub'),
                    cleanup_ldap_filter(cache['fallback_filter'])
                    )

    def pre_host(self, value, cache):
        cache['host'] = value

    def pre_port(self, value, cache):
        cache['port'] = value

    def pre_base(self, value, cache):
        cache['base'] = value

    def on_uri(self, value, cache):
        self.config.uri = value

    def on_binddn(self, value, cache):
        self.config.bind_dn = value

    def on_bindpw(self, value, cache):
        self.config.bind_password = value

    def pre_scope(self, value, cache):
        value = value.lower()
        if value in LDAP_SCOPES:
            cache['scope'] = value

    def on_timelimit(self, value, cache):
        pass # not implemented

    def on_bind_timelimit(self, value, cache):
        pass # not implemented

    def pre_pam_filter(self, value, cache):
        #A filter to AND with uid=%s
        #pam_filter objectclass=account
        cache['pam_filter'] = value

    def pre_pam_login_attribute(self, value, cache):
        # The user ID attribute (defaults to uid)
        #pam_login_attribute uid
        cache['pam_login_attribute'] = value

    def on_pam_password(self, value, cache):
        # Password type.
        # XXX: we support only md5 (I think)
        pass # not implemented

    def on_pam_groupdn(self, value, cache):
        # Group to enforce membership of
        #pam_groupdn cn=PAM,ou=Groups,dc=example,dc=com
        pass # not implemented

    def on_nss_base_passwd(self, value, cache):
        # account query
        base, scope, ldap_filter = decode_ldap_query(value, lenient=True)
        scope = scope.lower() if scope is not None else ''
        if scope not in LDAP_SCOPES:
            scope = cache.get('scope', 'sub')
        attr = cache.get('fallback_login_attr', 'uid')
        query = '%s %s' % (
            attr, '?'.join(filter(None, (base, scope, ldap_filter))))
        self.config.queries = ('%s\n%s' % (self.config.queries, query)).strip()

    def on_nss_base_group(self, value, cache):
        base, scope, ldap_filter = decode_ldap_query(value, lenient=True)
        scope = scope.lower() if scope is not None else ''
        if scope not in LDAP_SCOPES:
            scope = cache.get('scope', 'sub')
        if ldap_filter:
            ldap_filter = cleanup_ldap_filter(ldap_filter)
        query = '?'.join(filter(None, (base, scope, ldap_filter)))
        self.config.groupQueries = ('%s\n%s' % (self.config.groupQueries, query)).strip()


AUTODETECT_FROM = '/etc/ldap.conf'
FALLBACK_LOGIN_ATTR = 'uid'
FALLBACK_LOGIN_FILTER = 'objectClass=posixAccount'

def handle_configuration(options, context):
    parser = LDAPConfigParser()
    context.provideFeature('ldap_authentication')

    if getattr(options.config, 'ldap_authentication', None) is None:
        if (AUTODETECT_FROM and
            os.path.isfile(AUTODETECT_FROM)):
            parser.parse(
                AUTODETECT_FROM,
                fallback_login_attr=FALLBACK_LOGIN_ATTR,
                fallback_login_filter=FALLBACK_LOGIN_FILTER)
    else:
        config = options.config.ldap_authentication
        if (config.autodetect_from and
            os.path.isfile(config.autodetect_from)):
            parser.parse(
                config.autodetect_from,
                fallback_login_attr=config.default_login_attr,
                fallback_login_filter=config.default_login_filter)

        if config.uri:
            parser.config.uri = config.uri

        if config.query_users:
            queries = [
                entry if len(entry.split()) > 1 else '%s %s' % (
                    config.default_login_attr, entry)
                for entry in config.query_users
                if entry.strip()]
            parser.config.queries = '\n'.join(queries)

        if config.query_groups:
            parser.config.groupQueries = '\n'.join(filter(None, config.query_groups))

        if config.bind_dn:
            parser.config.bind_dn = config.bind_dn
            parser.config.bind_password = ''

        if config.bind_password:
            parser.config.bind_password = config.bind_password

        if config.bind_group:
            groups = []
            for entry in config.bind_group:
                if not entry.strip():
                    continue
                entry = entry.split()
                if len(entry) == 2:
                    st_year = ''
                    st_group, posix_id = entry
                    groups.append((st_year, st_group, posix_id))
                elif len(entry) == 3:
                    st_year, st_group, posix_id = entry
                    groups.append((st_year, st_group, posix_id))
            parser.config.posixGroups = '\n'.join([', '.join(e) for e in groups])

        parser.config.timeout = int(config.timeout)

    provideUtility(parser.config, ILDAPConfig)
    context.provideFeature('authentication')


@adapter(ISchoolToolApplication)
@implementer(ILDAPConfig)
def getAppLDAPConfig(app):
    annotations = IAnnotations(app)
    config = annotations.get(ST_APP_LDAP_KEY, None)
    if config is None:
        template = getUtility(ILDAPConfig)
        config = template.copy()
    return config


@adapter(ISchoolToolApplication)
@implementer(ILDAPPersonsConfig)
def getAppLDAPPersonsConfig(app):
    config = getAppLDAPConfig(app)
    if config is None:
        return None
    return ILDAPPersonsConfig(config, None)


def get_configuration():
    return ("""
<import package="schooltool.ldap" />

<section type="ldap_authentication" name="*" required="no" attribute="ldap_authentication">
  <description>
    Configuration for LDAP authentication plugin.
  </description>
  <example>
    &lt;ldap_authentication&gt;
      allow_web_config no
      uri ldap://127.0.0.1:389
    &lt;/ldap_authentication&gt;
  </example>
</section>

""", handle_configuration)
