//                                               -*- C++ -*-
/**
 *  @file  ComposedNumericalMathFunction.cxx
 *  @brief Abstract top-level class for all distributions
 *
 *  (C) Copyright 2005-2007 EDF-EADS-Phimeca
 *
 *  This library is free software; you can redistribute it and/or
 *  modify it under the terms of the GNU Lesser General Public
 *  License as published by the Free Software Foundation; either
 *  version 2.1 of the License.
 *
 *  This library is distributed in the hope that it will be useful
 *  but WITHOUT ANY WARRANTY; without even the implied warranty of
 *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
 *  Lesser General Public License for more details.
 *
 *  You should have received a copy of the GNU Lesser General Public
 *  License along with this library; if not, write to the Free Software
 *  Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307 USA
 *
 *  @author: $LastChangedBy: dutka $
 *  @date:   $LastChangedDate: 2008-10-31 11:52:04 +0100 (ven 31 oct 2008) $
 *  Id:      $Id: ComposedNumericalMathFunction.cxx 995 2008-10-31 10:52:04Z dutka $
 */
#include "ComposedNumericalMathFunction.hxx"
#include "NoNumericalMathGradientImplementation.hxx"
#include "NoNumericalMathHessianImplementation.hxx"
#include "ComposedNumericalMathEvaluationImplementation.hxx"
#include "ComposedNumericalMathGradientImplementation.hxx"
#include "ComposedNumericalMathHessianImplementation.hxx"
#include "PersistentObjectFactory.hxx"

namespace OpenTURNS {

  namespace Base {

    namespace Func {

      typedef NumericalMathFunctionImplementation::GradientImplementation GradientImplementation;
      typedef NumericalMathFunctionImplementation::HessianImplementation  HessianImplementation;

      CLASSNAMEINIT(ComposedNumericalMathFunction);

      static Common::Factory<ComposedNumericalMathFunction> RegisteredFactory("ComposedNumericalMathFunction");

      /* Default constructor */
      //       ComposedNumericalMathFunction::ComposedNumericalMathFunction()
      // 	: NumericalMathFunctionImplementation(),
      // 	  p_leftFunction_(),
      // 	  p_rightFunction_()
      //       {
      // 	// Nothing to do
      //       }

      /* Composition constructor */
      ComposedNumericalMathFunction::ComposedNumericalMathFunction(const Implementation & p_left,
                                                                   const Implementation & p_right)
	: NumericalMathFunctionImplementation(new ComposedNumericalMathEvaluationImplementation(p_left->getEvaluationImplementation(), p_right->getEvaluationImplementation()),
					      new NoNumericalMathGradientImplementation(),
					      new NoNumericalMathHessianImplementation()),
	  p_leftFunction_(p_left),
	  p_rightFunction_(p_right)
      {
	try{
	  GradientImplementation p_gradientImplementation(new ComposedNumericalMathGradientImplementation(p_left->getGradientImplementation(), p_right->getEvaluationImplementation(), p_right->getGradientImplementation()));
	  setInitialGradientImplementation(p_gradientImplementation);
	  setGradientImplementation(p_gradientImplementation);
	}
	catch(InvalidArgumentException & ex) {
	  // Nothing to do
	}
	try{
	  HessianImplementation p_hessianImplementation(new ComposedNumericalMathHessianImplementation(p_left->getGradientImplementation(), p_left->getHessianImplementation(), p_right->getEvaluationImplementation(), p_right->getGradientImplementation(), p_right->getHessianImplementation()));
	  setInitialHessianImplementation(p_hessianImplementation);
	  setHessianImplementation(p_hessianImplementation);
	}
	catch(InvalidArgumentException & ex) {
	  // Nothing to do
	}
      }

      /* Virtual constructor */
      ComposedNumericalMathFunction * ComposedNumericalMathFunction::clone() const
      {
	return new ComposedNumericalMathFunction(*this);
      }

      /* Comparison operator */
      Bool ComposedNumericalMathFunction::operator ==(const ComposedNumericalMathFunction & other) const
      {
	return true;
      }
  
      /* String converter */
      String ComposedNumericalMathFunction::str() const {
	OSS oss;
	oss << "class=" << ComposedNumericalMathFunction::GetClassName()
	    << " name=" << getName()
            << " description=" << getDescription()
	    << " left function=" << p_leftFunction_->str()
	    << " right function=" << p_rightFunction_->str();
	return oss;
      }
  
      /*
       * Gradient according to the marginal parameters
       * H(x, p) = F(G(x, pg), pf)
       * dH/dp = dF/dy(G(x, pg), pf) . dG/dp(x, pg) + dF/dp(x, pf)
       * with
       * p = [pg, pf], dG/dp = [dG/dpg, 0], dF/dp = [0, dF/dpf]
       * thus
       * dH/dp = [dF/dy(G(x, pg), pf) . dG/dpg(x, pg), dF/dpf(x, pf)]
       * and the needed gradient is (dH/dp)^t
       */
      ComposedNumericalMathFunction::Matrix ComposedNumericalMathFunction::parametersGradient(const NumericalPoint & in) const
      {
	NumericalPoint y(p_rightFunction_->operator()(in));
	UnsignedLong inputDimension(getInputNumericalPointDimension());
	Matrix leftGradientY(p_leftFunction_->gradient(y));
	Matrix rightGradientP(p_rightFunction_->parametersGradient(in));
	Matrix upper(rightGradientP.getNbRows(), leftGradientY.getNbColumns());
	// Check if the right function has parameters, if not the matrix product will failed
	if (rightGradientP.getNbRows() > 0) upper = rightGradientP * leftGradientY;
	UnsignedLong rightParametersDimension(upper.getNbRows());
	Matrix lower(p_leftFunction_->parametersGradient(y));
	UnsignedLong leftParametersDimension(lower.getNbRows());
	Matrix gradient(rightParametersDimension + leftParametersDimension, inputDimension);
	UnsignedLong rowIndex(0);
	// Gradient according to left parameters
	for (UnsignedLong i = 0; i < rightParametersDimension; i++)
	  {
	    for (UnsignedLong j = 0; j < inputDimension; j++)
	      {
		gradient(rowIndex, j) = upper(i, j);
	      }
	    rowIndex++;
	  }
	// Gradient accroding to right parameters
	for (UnsignedLong i = 0; i < leftParametersDimension; i++)
	  {
	    for (UnsignedLong j = 0; j < inputDimension; j++)
	      {
		gradient(rowIndex, j) = lower(i, j);
	      }
	    rowIndex++;
	  }
	return gradient;
      }

      /* Method save() stores the object through the StorageManager */
      void ComposedNumericalMathFunction::save(const StorageManager::Advocate & adv) const
      {
	NumericalMathFunctionImplementation::save(adv);
	adv.writeValue(*p_leftFunction_, StorageManager::MemberNameAttribute, "leftFunction_");
	adv.writeValue(*p_rightFunction_, StorageManager::MemberNameAttribute, "rightFunction_");
      }

      /* Method load() reloads the object from the StorageManager */
      void ComposedNumericalMathFunction::load(const StorageManager::Advocate & adv)
      {
	NumericalMathFunctionImplementation::load(adv);
	Common::TypedInterfaceObject<NumericalMathFunctionImplementation> evaluationValue;

	StorageManager::List objList = adv.getList(StorageManager::ObjectEntity);
	if (objList.readValue(evaluationValue,  StorageManager::MemberNameAttribute, "leftFunction_"))
	  p_leftFunction_ = evaluationValue.getImplementation();
	if (objList.readValue(evaluationValue,  StorageManager::MemberNameAttribute, "rightFunction_"))
	  p_rightFunction_ = evaluationValue.getImplementation();
      }

    } /* namespace Func */
  } /* namespace Base */
} /* namespace OpenTURNS */
