#ifndef FILE_BILINEARFORM
#define FILE_BILINEARFORM

/*********************************************************************/
/* File:   bilinearform.hpp                                          */
/* Author: Joachim Schoeberl                                         */
/* Date:   25. Mar. 2000                                             */
/*********************************************************************/


class LinearForm;

/** 
    A bilinear-form.
    A bilinear-form provides the system matrix. 
    It knows about its definition in terms of integrators.
    In most cases, it is defined on two copies of the same space V,
    but it can also live on V x W.
*/
class BilinearForm : public NGS_Object
{
protected:
  /// Finite element space
  const FESpace & fespace;
  /// Test-space if different from trial-space, otherwise 0
  const FESpace * fespace2;

  ///
  bool nonassemble;
  /// 
  int diagonal;
  ///
  int multilevel;
  /// galerkin projection of coarse grid matrices
  int galerkin;
  /// complex forms are hermitean
  int hermitean;
  /// bilinear form is symmetric
  int symmetric;
  /// add epsilon for regularization
  double eps_regularization; 
  /// diagonal value for unused dofs
  double unuseddiag;

  /// low order bilinear-form, 0 if not used
  BilinearForm * low_order_bilinear_form;

  /// modify linear form due to static condensation
  LinearForm * linearform;

  /// matrices (sparse, application, diagonal, ...)
  ARRAY<BaseMatrix*> mats;
  ///
  ARRAY<BilinearFormIntegrator*> parts;
  ///
  bool timing;
  bool print;
  bool printelmat;
  bool elmat_ev;
  bool eliminate_internal;
  

public:
  BilinearForm (const FESpace & afespace,
		const string & aname, 
		const Flags & flags);

  ///
  BilinearForm (const FESpace & afespace, 
		const FESpace & afespace2, 
		const string & aname,
		const Flags & flags);
  ///
  virtual ~BilinearForm ();
  

  ///
  void AddIntegrator (BilinearFormIntegrator * bfi)
  {
    parts.Append (bfi);
    if (low_order_bilinear_form)
      low_order_bilinear_form -> AddIntegrator (parts.Last());
  }
  
  ///
  int NumIntegrators () const 
  {
    return parts.Size(); 
  }

  ///
  BilinearFormIntegrator * GetIntegrator (int i) 
  { return parts[i]; }

  ///
  const BilinearFormIntegrator * GetIntegrator (int i) const 
  { return parts[i]; }


  /// for static condensation of internal bubbles
  void SetLinearForm (LinearForm * alf)
  { linearform = alf; }

  ///
  void Assemble (LocalHeap & lh);
  ///
  void ReAssemble (LocalHeap & lh, bool reallocate = 0);
  ///
  virtual void AssembleLinearization (const BaseVector & lin,
				      LocalHeap & lh, 
				      bool reallocate = 0) = 0;

  ///
  void ApplyMatrix (const BaseVector & x,
		    BaseVector & y) const
  {
    y = 0;
    AddMatrix (1, x, y);
  }

  ///
  virtual void AddMatrix (double val, const BaseVector & x,
			  BaseVector & y) const = 0;
  
  virtual void AddMatrix (Complex val, const BaseVector & x,
			  BaseVector & y) const = 0;
  

  ///
  virtual void ApplyLinearizedMatrixAdd (double val,
					 const BaseVector & lin,
					 const BaseVector & x,
					 BaseVector & y) const = 0;
  ///
  virtual void ApplyLinearizedMatrixAdd (Complex val,
					 const BaseVector & lin,
					 const BaseVector & x,
					 BaseVector & y) const = 0;


  virtual double Energy (const BaseVector & x) const = 0;

  ///
  const BaseMatrix & GetMatrix () const
  { 
    return *mats.Last(); 
  }

  ///
  const BaseMatrix & GetMatrix (int level) const
  { 
    return *mats[level];
  }

  ///  
  BaseMatrix & GetMatrix () 
  { 
    return *mats.Last(); 
  }

  ///
  BaseMatrix & GetMatrix (int level) 
  { 
    return *mats[level];
  }



  const BilinearForm & GetLowOrderBilinearForm() const
  {
    return *low_order_bilinear_form;
  }

  ///
  const FESpace & GetFESpace() const
    { return fespace; }
  ///
  int MixedSpaces () const
    { return fespace2 != NULL; }
  ///
  const FESpace & GetFESpace2() const
    { return *fespace2; }

  ///
  int GetNLevels() const
    { return mats.Size(); }

  void SetNonAssemble (bool na = 1)
  { nonassemble = na; }

  ///
  void SetGalerkin (int agalerkin = 1)
    { galerkin = agalerkin; }
  ///
  void SetDiagonal (int adiagonal = 1)
    { diagonal = adiagonal; }
  ///
  void SetSymmetric (int asymmetric = 1)
  { symmetric = asymmetric; }
  ///
  void SetHermitean (int ahermitean = 1)
  { hermitean = ahermitean; }
  ///
  void SetMultiLevel (int amultilevel = 1)
    { multilevel = amultilevel; }

  void SetTiming (bool at) 
  { timing = at; }

  void SetEliminateInternal (bool eliminate) 
  { eliminate_internal = eliminate; }

  void SetPrint (bool ap);
  void SetPrintElmat (bool ap);
  void SetElmatEigenValues (bool ee);

  ///
  void GalerkinProjection ();


  virtual void ComputeInternal (BaseVector & u, LocalHeap & lh) const = 0;
  
  ///
  void SetEpsRegularization(double val)
    { eps_regularization = val; }
  ///
  void SetUnusedDiag (double val)
    { unuseddiag = val; }

  ///
  int UseGalerkin () const
    { return galerkin; }

  ///
  virtual string GetClassName () const
  {
    return "BilinearForm";
  }

  ///
  virtual void PrintReport (ostream & ost);

  ///
  virtual void MemoryUsage (ARRAY<MemoryUsageStruct*> & mu) const;

  ///
  void WriteMatrix (ostream & ost) const;
  ///
  virtual BaseVector * CreateVector() const = 0;
private:
  ///
  virtual void DoAssemble (LocalHeap & lh) = 0;

  ///
  virtual void AllocateMatrix () = 0;
};








template <class SCAL>
class S_BilinearForm : public BilinearForm
{
protected:
public:
  S_BilinearForm (const FESpace & afespace, const string & aname,
		  const Flags & flags)
    : BilinearForm (afespace, aname, flags) { ; }

  ///
  S_BilinearForm (const FESpace & afespace, 
		  const FESpace & afespace2,
		  const string & aname, const Flags & flags)
    : BilinearForm (afespace, afespace2, aname, flags) { ; }


  ///
  void AddMatrix1 (SCAL val, const BaseVector & x,
		   BaseVector & y) const;

  virtual void AddMatrix (double val, const BaseVector & x,
			  BaseVector & y) const
  {
    AddMatrix1 (val, x, y);
  }


  virtual void AddMatrix (Complex val, const BaseVector & x,
			  BaseVector & y) const
  {
    AddMatrix1 (ngbla::ReduceComplex<SCAL> (val), x, y);
  }


  void ApplyLinearizedMatrixAdd1 (SCAL val,
				  const BaseVector & lin,
				  const BaseVector & x,
				  BaseVector & y) const;
  
  virtual void ApplyLinearizedMatrixAdd (double val,
					 const BaseVector & lin,
					 const BaseVector & x,
					 BaseVector & y) const
  {
    ApplyLinearizedMatrixAdd1 (val, lin, x, y);
  }
  
  virtual void ApplyLinearizedMatrixAdd (Complex val,
					 const BaseVector & lin,
					 const BaseVector & x,
					 BaseVector & y) const
  {
    ApplyLinearizedMatrixAdd1 (ReduceComplex<SCAL> (val), lin, x, y);
  }
  

  virtual double Energy (const BaseVector & x) const;

  virtual void ComputeInternal (BaseVector & u, LocalHeap & lh) const;

  ///
  virtual void DoAssemble (LocalHeap & lh);
  ///
  virtual void AssembleLinearization (const BaseVector & lin,
				      LocalHeap & lh, 
				      bool reallocate = 0);
  ///
  virtual void AddElementMatrix (const ARRAY<int> & dnums1,
				 const ARRAY<int> & dnums2,
				 const FlatMatrix<SCAL> & elmat,
				 LocalHeap & lh) = 0;
};



template <class TM>
class T_BilinearForm : public S_BilinearForm<typename mat_traits<TM>::TSCAL>
{
public:
  typedef typename mat_traits<TM>::TSCAL TSCAL;
  typedef typename mat_traits<TM>::TV_COL TV_COL;
  typedef SparseMatrix<TM> TMATRIX;

protected:

public:
  ///
  T_BilinearForm (const FESpace & afespace, const string & aname, const Flags & flags);
  ///
  T_BilinearForm (const FESpace & afespace, 
		  const FESpace & afespace2,
		  const string & aname,
		  const Flags & flags);
  ///
  virtual ~T_BilinearForm ();

  ///
  virtual void AllocateMatrix ();
  /*
  {
    if (graphs.Size() == ma.GetNLevels())
      return;

    MatrixGraph * graph = 
      const_cast<FESpace&>(fespace).GetGraph (ma.GetNLevels()-1, false);
    graphs.Append (graph);
    mats.Append (new SparseMatrix<TM> (*graph));
    
    if (!multilevel || low_order_bilinear_form)
      for (int i = 0; i < mats.Size()-1; i++)
	{
	  delete mats[i];
	  mats[i] = 0;
	}
  }
  */
  virtual BaseVector * CreateVector() const;
  /*
  {
    return new VVector<TV_COL> (fespace.GetNDof());
  }
  */
  
  ///
  virtual void AddElementMatrix (const ARRAY<int> & dnums1,
				 const ARRAY<int> & dnums2,
				 const FlatMatrix<TSCAL> & elmat,
				 LocalHeap & lh);
};








template <class TM>
class T_BilinearFormSymmetric : public S_BilinearForm<typename mat_traits<TM>::TSCAL>
{

public:
  typedef typename mat_traits<TM>::TSCAL TSCAL;
  typedef typename mat_traits<TM>::TV_COL TV_COL;
  typedef SparseMatrixSymmetric<TM> TMATRIX;

protected:

public:
  T_BilinearFormSymmetric (const FESpace & afespace, const string & aname,
			   const Flags & flags);
  virtual ~T_BilinearFormSymmetric ();

  virtual void AllocateMatrix ();
  virtual BaseVector * CreateVector() const;

  virtual void AddElementMatrix (const ARRAY<int> & dnums1,
				 const ARRAY<int> & dnums2,
				 const FlatMatrix<TSCAL> & elmat,
				 LocalHeap & lh);
};



extern BilinearForm * CreateBilinearForm (const FESpace * space,
					  const string & name,
					  const Flags & flags);



///
class BilinearFormApplication : public BaseMatrix
{
protected:
  ///
  const BilinearForm * bf;
public:
  ///
  BilinearFormApplication (const BilinearForm * abf);
  ///
  virtual void Mult (const BaseVector & v, BaseVector & prod) const;
  ///
  virtual void MultAdd (double val, const BaseVector & v, BaseVector & prod) const;
  ///
  virtual void MultAdd (Complex val, const BaseVector & v, BaseVector & prod) const;
  ///
  // virtual void MultTransAdd (double val, const BaseVector & v, BaseVector & prod) const;
  ///
  virtual BaseVector * CreateVector () const;

  virtual int VHeight() const
  {
    return bf->GetFESpace().GetNDof(); 
  }
  virtual int VWidth() const
  {
    return bf->GetFESpace().GetNDof(); 
  }
};


class LinearizedBilinearFormApplication : public BilinearFormApplication 
{
protected:
  const BaseVector * veclin;
public:
  LinearizedBilinearFormApplication (const BilinearForm * abf,
				    const BaseVector * aveclin);

  ///
  virtual void Mult (const BaseVector & v, BaseVector & prod) const;
  ///
  virtual void MultAdd (double val, const BaseVector & v, BaseVector & prod) const;
  ///
  virtual void MultAdd (Complex val, const BaseVector & v, BaseVector & prod) const;
  ///
  // virtual void MultTransAdd (double val, const BaseVector & v, BaseVector & prod) const;

};

#endif
