//                                               -*- C++ -*-
/**
 *  @file  KernelSmoothing.cxx
 *  @brief This class acts like a KernelMixture factory, implementing a
 *
 *  (C) Copyright 2005-2007 EDF-EADS-Phimeca
 *
 *  This library is free software; you can redistribute it and/or
 *  modify it under the terms of the GNU Lesser General Public
 *  License as published by the Free Software Foundation; either
 *  version 2.1 of the License.
 *
 *  This library is distributed in the hope that it will be useful
 *  but WITHOUT ANY WARRANTY; without even the implied warranty of
 *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
 *  Lesser General Public License for more details.
 *
 *  You should have received a copy of the GNU Lesser General Public
 *  License along with this library; if not, write to the Free Software
 *  Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307 USA
 *
 *  @author: $LastChangedBy: dutka $
 *  @date:   $LastChangedDate: 2008-06-26 13:50:17 +0200 (jeu 26 jun 2008) $
 *  Id:      $Id: KernelSmoothing.cxx 862 2008-06-26 11:50:17Z dutka $
 */
#include <cmath>
#include "KernelSmoothing.hxx"
#include "Normal.hxx"
#include "KernelMixture.hxx"
#include "TruncatedDistribution.hxx"
#include "PersistentObjectFactory.hxx"

namespace OpenTURNS {

  namespace Uncertainty {

    namespace Distribution {

      /**
       * @class KernelSmoothing
       *
       * The class describes the probabilistic concept of KernelSmoothing.
       */

      CLASSNAMEINIT(KernelSmoothing);

      static Base::Common::Factory<KernelSmoothing> RegisteredFactory("KernelSmoothing");

      /** Default constructor */
      KernelSmoothing::KernelSmoothing(const String & name):
	PersistentObject(name),
	bandwidth_(NumericalPoint(0)),
	kernel_(Normal())
      {
	// Nothing to do
      }

      /** Default constructor */
      KernelSmoothing::KernelSmoothing(const Distribution & kernel, const String & name) throw (InvalidArgumentException):
	PersistentObject(name),
	bandwidth_(NumericalPoint(0)),
	kernel_(kernel)
      {
	// Only 1D kernel allowed here
	if (kernel.getDimension() != 1) throw InvalidArgumentException(HERE) << "Error: only 1D kernel allowed for product kernel smoothing";
      }

      /* Virtual constructor */
      KernelSmoothing * KernelSmoothing::clone() const
      {
	return new KernelSmoothing(*this);
      }

      /** Compute the bandwidth according to Silverman's rule */
      KernelSmoothing::NumericalPoint KernelSmoothing::computeSilvermanBandwidth(const NumericalSample & sample)
      {
	UnsignedLong dimension(sample.getDimension());
	UnsignedLong size(sample.getSize());
	NumericalPoint standardDeviations(sample.computeStandardDeviationPerComponent());
	// Silverman's Normal rule
	NumericalScalar factor(pow(size, -1.0 / (4.0 + dimension)) / kernel_.getStandardDeviation()[0]);
	// Scott's Normal rule
	return factor * standardDeviations;
      }

      /** Build a Normal kernel mixture based on the given sample. If no bandwith has already been set, Silverman's rule is used */
      KernelSmoothing::Distribution KernelSmoothing::buildImplementation(const NumericalSample & sample, const Bool boundaryCorrection)
      {
	return buildImplementation(sample, computeSilvermanBandwidth(sample), boundaryCorrection);
      }

      /** Build a Normal kernel mixture based on the given sample and bandwidth */
      KernelSmoothing::Distribution KernelSmoothing::buildImplementation(const NumericalSample & sample,
							const NumericalPoint & bandwidth,
							const Bool boundaryCorrection)
	throw(InvalidDimensionException, InvalidArgumentException)
      {
	const UnsignedLong dimension(sample.getDimension());
	if (bandwidth.getDimension() != dimension) throw InvalidDimensionException(HERE) << "Error: the given bandwidth must have the same dimension as the given sample, here bandwidth dimension=" << bandwidth.getDimension() << " and sample dimension=" << dimension;
	setBandwidth(bandwidth);
	// Make cheap boundary correction by extending the sample. Only valid for 1D sample.
	if (boundaryCorrection && (dimension == 1))
	  {
	    NumericalScalar min(sample.getMin()[0]);
	    NumericalScalar max(sample.getMax()[0]);
	    NumericalScalar h(bandwidth[0]);
	    // Reflect and add points close to the boundaries to the sample
	    NumericalSample newSample(sample);
	    const UnsignedLong size(sample.getSize());
	    for (UnsignedLong i = 0; i < size; i++)
	      {
		NumericalScalar realization(sample[i][0]);
		if (realization <= min + h) newSample.add(NumericalPoint(1, 2.0 * min - realization));
		if (realization >= max - h) newSample.add(NumericalPoint(1, 2.0 * max - realization));
	      }
	    TruncatedDistribution kernelMixture(KernelMixture(kernel_, bandwidth, newSample), min, max);
	    kernelMixture.setName("Kernel smoothing from sample " + sample.getName());
	    return kernelMixture;
	  }
	KernelMixture kernelMixture(kernel_, bandwidth, sample);
	kernelMixture.setName("Kernel smoothing from sample " + sample.getName());
	return kernelMixture;
      }

      /** Bandwidth accessor */
      void KernelSmoothing::setBandwidth(const NumericalPoint & bandwidth)
	throw(InvalidArgumentException)
      {
	// Check the given bandwidth
	for (UnsignedLong i = 0; i < bandwidth.getDimension(); i++)
	  {
	    if (bandwidth[i] <= 0.0) throw InvalidArgumentException(HERE) << "Error: the bandwidth must be > 0, here bandwith=" << bandwidth;
	  }
	bandwidth_ = bandwidth;
      }

      KernelSmoothing::NumericalPoint KernelSmoothing::getBandwidth() const
      {
	return bandwidth_;
      }

      KernelSmoothing::Distribution KernelSmoothing::getKernel() const
      {
	return kernel_;
      }

      /* Method save() stores the object through the StorageManager */
      void KernelSmoothing::save(const StorageManager::Advocate & adv) const
      {
	PersistentObject::save(adv);
	adv.writeValue(bandwidth_, StorageManager::MemberNameAttribute, "bandwidth_");
      }

      /* Method load() reloads the object from the StorageManager */
      void KernelSmoothing::load(const StorageManager::Advocate & adv)
      {
	PersistentObject::load(adv);
	adv.readValue(bandwidth_, StorageManager::MemberNameAttribute, "bandwidth_");
      }
      
    } /* namespace Distribution */
  } /* namespace Uncertainty */
} /* namespace OpenTURNS */
