"This module provides functionality for compilation of strings as dolfin Expressions."

# Copyright (C) 2008-2008 Martin Sandve Alnes
#
# This file is part of DOLFIN.
#
# DOLFIN is free software: you can redistribute it and/or modify
# it under the terms of the GNU Lesser General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# DOLFIN is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with DOLFIN. If not, see <http://www.gnu.org/licenses/>.
#
# Modified by Johan Hake 2008-2009
#
# First added:  2008-06-04
# Last changed: 2011-04-18

from __future__ import print_function
import re
import types
import hashlib
import instant

# Import local compile_extension_module
from dolfin.compilemodules.compilemodule import (compile_extension_module,
                                                 expression_to_code_fragments,
                                                 math_header)

__all__ = ["compile_expressions"]

_expression_template = """class %(classname)s: public Expression
{
public:
%(members)s
  %(classname)s():Expression()
  {
%(value_shape)s
%(constructor)s
  }

  void eval(dolfin::Array<double>& values, const dolfin::Array<double>& x,
            const ufc::cell& cell) const
  {
%(evalcode_cell)s
  }

  void eval(dolfin::Array<double>& values, const dolfin::Array<double>& x) const
  {
%(evalcode)s
  }
};
"""

def flatten_and_check_expression(expr):
    # Convert expr to a flat tuple of strings
    # and return value_shape and geometrical dimensions
    if isinstance(expr, str):
        return (expr,), ()
    elif isinstance(expr, (tuple,list)):
        if all(isinstance(e,tuple) for e in expr):
            shape = (len(expr),len(expr[0]))
            expr = sum(expr, ())
        else:
            shape = (len(expr),)
        if all(isinstance(e,str) for e in expr):
            return expr, shape
    raise TypeError("Wrong type of expressions. Provide a 'str', a 'tuple' of 'str' or a 'tuple' of 'tuple' of 'str': %s" % str(expr))

def expression_to_dolfin_expression(expr, generic_function_members):
    "Generates code for a dolfin::Expression subclass for a single expression."

    # Check and flattern provided expression
    expr, shape = flatten_and_check_expression(expr)

    # Extract code fragments from the expr
    fragments, members = expression_to_code_fragments(\
        expr, ["values","x"], generic_function_members)

    # Generate code for value_rank
    value_shape_code = ["    _value_shape.push_back(%d);" % value_dim \
                        for value_dim in shape]

    evalcode = []

    # Generate code for constant members
    for name in generic_function_members:
        evalcode.append("    if (shared_%s->value_size()!=1)" % name)
        evalcode.append("      dolfin_error(\"generated code\",")
        evalcode.append("                   \"calling eval\", ")
        evalcode.append("                   \"Parameter \\'%s\\' is not scalar valued\");" % name)
        evalcode.append("    if (shared_%s.get()==this)" % name)
        evalcode.append("      dolfin_error(\"generated code\",")
        evalcode.append("                   \"calling eval\",")
        evalcode.append("                   \"Circular eval call detected. Cannot use itself as parameter \\'%s\\' within eval\");" % name)
        evalcode.append("    Array<double> %s__array_(1);" % name)
        evalcode.append("    shared_%s->eval(%s__array_, x);" % (name, name))
        evalcode.append("    const double %s = %s__array_[0];" % (name, name))

    # Generate code for the actual expression evaluation
    evalcode.extend("    values[%d] = %s;" % (i, c) for (i,c) in enumerate(expr))

    # Connect the code fragments using the expression template code
    fragments["evalcode"]  = "\n".join(evalcode)
    fragments["evalcode_cell"]  = fragments["evalcode"].replace(\
        "__array_, x", "__array_, x, cell")
    fragments["value_shape"] = "\n".join(value_shape_code)

    # Assign classname
    classname = "Expression_" + hashlib.sha1(fragments["evalcode"].\
                                             encode("utf-8")).hexdigest()
    fragments["classname"] = classname

    # Produce the C++ code for the expression class
    code = _expression_template % fragments
    return classname, code, members


def compile_expression_code(code, classnames=None, module_name=None, \
                            additional_declarations=None, mpi_comm=None):

    additional_declarations = additional_declarations or ""

    # Autodetect classnames:
    _classnames = re.findall(r"class[ ]+([\w]+).*", code)

    # Just a little assertion for safety:
    if classnames is None:
        classnames = _classnames
    else:
        assert all(a == b for (a,b) in zip(classnames, _classnames))

    # Complete the code
    code = "%s\n%s"%(math_header, code)

    # Compile the extension module
    compiled_module = compile_extension_module(\
        code, additional_declarations=additional_declarations, \
        mpi_comm=mpi_comm)

    # Get the compiled class
    expression_classes = [getattr(compiled_module, name) for name in classnames]
    return expression_classes

def compile_expressions(cppargs, generic_function_members=None,
                        mpi_comm=None):
    """
    Compiles a list of either C++ expressions of full subclasses of
    dolfin::Expression class.

    The expression can either be a str in which case it is
    interpreted as a scalar expression and a scalar Expression is generated.

    If the expression is a tuple consisting of more than one str it is
    interpreted as a vector expression, and a rank 1 Expression is generated.

    A tuple of tuples of str objects is interpreted as a matrix
    expression, and a rank 2 Expression is generated.

    If an expression string contains a name, it is assumed to be a scalar
    parameter name, and is added as a public member of the generated expression.
    The names of these parameters are then returned in a list together with the
    compiled expression class.

    If 'cppargs' include a class definition it is interpreted as c++ code with complete
    implementations of a subclasses of dolfin::Expression.

    kwargs propagates the default member values for any generated parameter.
    """
    #, which contains:
    #    %s
    #""" % "\n".join("        " + b for b in _builtins)
    # FIXME: Hook up this to a more general debug mechanism
    assert(isinstance(cppargs, list))

    generic_function_members_list = generic_function_members or \
                                    [[] for i in range(len(cppargs))]

    # Collect code and classnames
    code_snippets = []; classnames = []; all_members = []; additional_declarations = [];

    for cpparg, generic_function_members in zip(cppargs, generic_function_members_list):
        assert(isinstance(cpparg, (str, tuple, list)))
        # If the cpparg includes the word 'class' and 'Expression',
        # assume it is a c++ code snippet

        if isinstance(cpparg, str) and "class" in cpparg and "Expression" in cpparg:
            # Assume that a code snippet is passed as cpparg
            code = cpparg

            # Get the class name
            classname = re.findall(r"class[ ]+([\w]+).*", code)[0]
            members = []

            # FIXME: Check for passed dimension?
        else:
            classname, code, members = \
                       expression_to_dolfin_expression(cpparg, generic_function_members)

            additional_declarations.extend("%%rename(%s) dolfin::%s::shared_%s;" % \
                                           (name, classname, name) for  name in \
                                           generic_function_members)

        all_members.append(members)
        code_snippets.append(code)
        classnames.append(classname)

    expression_classes = compile_expression_code("\n\n".join(\
        code_snippets), classnames, additional_declarations="\n".join(\
                                                     additional_declarations),
                                                 mpi_comm=mpi_comm)

    return expression_classes, all_members

if __name__ == "__main__":
    cn1, code1 = expression_to_dolfin_expression("exp(alpha)",{'alpha':1.5})
    cn2, code2 = expression_to_dolfin_expression(("sin(x[0])", "cos(x[1])", "0.0"),{})
    cn3, code3 = expression_to_dolfin_expression((("sin(x[0])", "cos(x[1])"), ("0.0", "1.0")),{})

    print(code1)
    print(cn1)

    print(code2)
    print(cn2)

    print(code3)
    print(cn3)
