///
/// This file is part of Rheolef.
///
/// Copyright (C) 2000-2009 Pierre Saramito <Pierre.Saramito@imag.fr>
///
/// Rheolef is free software; you can redistribute it and/or modify
/// it under the terms of the GNU General Public License as published by
/// the Free Software Foundation; either version 2 of the License, or
/// (at your option) any later version.
///
/// Rheolef is distributed in the hope that it will be useful,
/// but WITHOUT ANY WARRANTY; without even the implied warranty of
/// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
/// GNU General Public License for more details.
///
/// You should have received a copy of the GNU General Public License
/// along with Rheolef; if not, write to the Free Software
/// Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
/// 
/// =========================================================================
# include "rheolef/config.h"

#ifdef _RHEOLEF_HAVE_MPI
#include "rheolef/dis_macros.h"
#include "rheolef/csr.h"

#include "rheolef/asr_to_csr_dist_logical.h"
#include "rheolef/csr_to_asr.h"
#include "rheolef/asr_to_csr.h"
#include "rheolef/csr_amux.h"
#include "rheolef/csr_cumul_trans_mult.h"
#include "rheolef/mpi_scatter_init.h"
#include "rheolef/mpi_scatter_begin.h"
#include "rheolef/mpi_scatter_end.h"

using namespace std;
namespace rheolef {
// ----------------------------------------------------------------------------
// usefull class-functions
// ----------------------------------------------------------------------------
#include <algorithm> // for lower_bound
template <class Pair1, class Pair2, class RandomIterator>
struct op_ext2glob_t : unary_function<Pair1,Pair2> {
        Pair2 operator()(const Pair1& x) const { 
	    return Pair2(t[x.first], x.second); }
	op_ext2glob_t(RandomIterator t1) : t(t1) {}
	RandomIterator t;
};
template <class Pair1, class Pair2>
struct op_dia_t : unary_function<Pair1,Pair2> {
        Pair2 operator()(const Pair1& x) const {
	    return Pair2(shift + x.first, x.second); }
	typedef typename Pair1::first_type Size;
	op_dia_t(Size s) : shift(s) {}
	Size shift;
};
template <class Pair1, class Pair2, class RandomIterator>
struct op_dis_j2jext_t : unary_function<Pair1,Pair2> {
        Pair2 operator()(const Pair1& x) const { 
	    RandomIterator t = std::lower_bound(t1, t2, x.first);
	    assert_macro(*t == x.first, "problem in ditributed asr_to_csr");
	    return Pair2(distance(t1,t), x.second); }
	op_dis_j2jext_t(RandomIterator u1, RandomIterator u2) : t1(u1), t2(u2) {}
	RandomIterator t1, t2;
};
// ----------------------------------------------------------------------------
// member class functions
// ----------------------------------------------------------------------------
template<class T>
csr_mpi_rep<T>::csr_mpi_rep()
  : csr_seq_rep<T>(),
    _ext(),
    _jext2dis_j(0),
    _dis_nnz(0),
    _from(),
    _to(),
    _buffer()
{
}
template<class T>
csr_mpi_rep<T>::csr_mpi_rep(const csr_mpi_rep<T>& a)
  : csr_seq_rep<T>(a),
    _ext(a._ext),
    _jext2dis_j(a._jext2dis_j),
    _dis_nnz(a._dis_nnz),
    _from(a._from),
    _to(a._to),
    _buffer(a._buffer)
{
    // "physical copy of csr"
}
template<class T>
void
csr_mpi_rep<T>::resize (const distributor& row_ownership, const distributor& col_ownership, size_type nnz1)
{
  csr_seq_rep<T>::resize (row_ownership, col_ownership, nnz1);
  _ext.resize            (row_ownership, col_ownership, 0); // note: the _ext part will be resized elsewhere
}
template<class T>
csr_mpi_rep<T>::csr_mpi_rep(const asr_mpi_rep<T>& a)
  : csr_seq_rep<T>(),
    _ext (),
    _jext2dis_j(),
    _dis_nnz(),
    _from(),
    _to(),
    _buffer()
{
    build_from_asr (a);
}
template<class T>
void
csr_mpi_rep<T>::build_from_asr (const asr_mpi_rep<T>& a)
{
    csr_seq_rep<T>::resize (a.row_ownership(), a.col_ownership(), 0);
    _ext.resize            (a.row_ownership(), a.col_ownership(), 0);
    _dis_nnz = a.dis_nnz();

    distributor::tag_type tag = distributor::get_new_tag();
    typedef typename asr_mpi_rep<T>::row_type row_type;
    typedef typename row_type::value_type      const_pair_type; 
    typedef pair<size_type,T>                  pair_type; 
    is_dia_t<size_type, const_pair_type> 
	is_dia(col_ownership().first_index(), col_ownership().last_index());
    //
    // step 1: compute pointers
    //
    set<size_type> colext;

    size_type nnzext
    = asr_to_csr_dist_logical (a.begin(), a.end(), is_dia, colext);
    //
    // step 2: resize and copy jext2dis_j
    //
    size_type nnzdia = a.nnz() - nnzext;
    size_type ncoldia = col_ownership().size();
    size_type ncolext = colext.size();
    csr_seq_rep<T>::resize(a.row_ownership(), a.col_ownership(), nnzdia);
    _ext.resize           (a.row_ownership(), a.col_ownership(), nnzext);
    _jext2dis_j.resize (ncolext);
    copy (colext.begin(), colext.end(), _jext2dis_j.begin());
    _buffer.resize(ncolext);
    //
    // step 3: copy values
    //   column indexes of a   in 0..dis_ncol
    //                  of dia in 0..ncoldia
    //                  of dia in 0..ncolext
    op_dia_t<const_pair_type,pair_type> op_dia(- col_ownership().first_index());
    // step 3.a: copy dia part
    asr_to_csr (
        a.begin(), 
        a.end(), 
        is_dia, 
        op_dia,
	csr_seq_rep<T>::begin(), 
	csr_seq_rep<T>::_data.begin());

    // step 3.b: copy ext part
    unary_negate<is_dia_t<size_type, const_pair_type> > 
        is_ext(is_dia);
    op_dis_j2jext_t<const_pair_type, pair_type, typename vector<size_type>::const_iterator> 
	op_dis_j2jext(_jext2dis_j.begin(),  _jext2dis_j.end());
    asr_to_csr (
        a.begin(), a.end(), is_ext, op_dis_j2jext,
	_ext.begin(), _ext._data.begin());
    //
    // step 3. messages building for A*x
    //
    vector<size_type> id(_jext2dis_j.size());
    for (size_type i = 0; i < id.size(); i++) id[i] = i;

    mpi_scatter_init(
        _jext2dis_j.size(),
        _jext2dis_j.begin().operator->(),
	id.size(),
	id.begin().operator->(),
	dis_ncol(),
	col_ownership().begin().operator->(),
	tag,
	row_ownership().comm(),
        _from,
        _to);
}
template<class T>
void
csr_mpi_rep<T>::to_asr(asr_mpi_rep<T>& b) const
{
    typedef pair<size_type,T>                              pair_type; 
    typedef typename asr_mpi_rep<T>::row_type::value_type const_pair_type;
    op_dia_t<pair_type,const_pair_type> op_dia(col_ownership().first_index());
    typename vector_of_iterator<pair_type>::const_iterator b1 = csr_seq_rep<T>::begin();
    typename vector_of_iterator<pair_type>::const_iterator e1 = csr_seq_rep<T>::end();
    const pair_type*  d1 = csr_seq_rep<T>::_data.begin().operator->();
    csr_to_asr (
        b1,
        e1,
        d1,
        op_dia,
        b.begin().operator->());
    op_ext2glob_t <pair_type, const_pair_type, typename vector<size_type>::const_iterator> 
	op_ext2glob(_jext2dis_j.begin());
    csr_to_asr (
         _ext.begin(),
         _ext.begin() + _ext.size(),
         _ext._data.begin().operator->(),
         op_ext2glob,
         b.begin().operator->());
}
template<class T>
idiststream& 
csr_mpi_rep<T>::get (idiststream& ips)
{
    asr_mpi_rep<Float> a;
    a.get (ips);
    build_from_asr (a);
    return ips;
}
template<class T>
odiststream& 
csr_mpi_rep<T>::put (odiststream& ops) const
{
    // put all on io_proc 
    size_type io_proc = odiststream::io_proc();
    size_type my_proc = comm().rank();
    distributor a_row_ownership (dis_nrow(), comm(), (my_proc == io_proc ? dis_nrow() : 0));
    distributor a_col_ownership (dis_ncol(), comm(), (my_proc == io_proc ? dis_ncol() : 0));
    asr_mpi_rep<T> a (a_row_ownership, a_col_ownership);
    size_type i0 = row_ownership().first_index();
    size_type j0 = col_ownership().first_index();
    if (nnz() != 0) {
      const_iterator ia = begin(); 
      for (size_type i = 0; i < nrow(); i++) {
        for (const_data_iterator p = ia[i]; p < ia[i+1]; p++) {
	  const size_type& j   = (*p).first;
	  const T&         val = (*p).second;
          a.dis_entry (i+i0, j+j0) = val;
        }
      }
    }
    if (_ext.nnz() != 0) {
      const_iterator ext_ia = _ext.begin(); 
      for (size_type i = 0; i < nrow(); i++) {
        for (const_data_iterator p = ext_ia[i]; p < ext_ia[i+1]; p++) {
	  const size_type& j   = (*p).first;
	  const T&         val = (*p).second;
          a.dis_entry (i+i0, _jext2dis_j[j]) = val;
        }
      }
    }
    a.dis_entry_assembly_begin();
    a.dis_entry_assembly_end();
    if (my_proc == io_proc) {
      a.asr_seq_rep<T>::put (ops);
    }
    return ops;
}
template<class T>
void
csr_mpi_rep<T>::dump (const string& name) const
{
    odiststream ops;
    std::string filename = name + itos(comm().rank());
    ops.open (filename, "mtx", comm());
    check_macro(ops.good(), "\"" << filename << "[.mtx]\" cannot be created.");
    ops << "%%MatrixMarket matrix coordinate real general" << std::endl
        << dis_nrow() << " " << dis_ncol() << " " << dis_nnz() << std::endl;
    put(ops);
}
// ----------------------------------------------------------------------------
// basic linear algebra
// ----------------------------------------------------------------------------
template<class T>
void
csr_mpi_rep<T>::mult(
    const vec<T,distributed>& x,
    vec<T,distributed>&       y)
    const
{
    check_macro (x.size() == ncol(), "csr*vec: incompatible csr("<<nrow()<<","<<ncol()<<") and vec("<<x.size()<<")");
    y.resize (row_ownership());

    distributor::tag_type tag = distributor::get_new_tag();

    // send x to others
    mpi_scatter_begin (
	x.begin().operator->(),
        _buffer.begin().operator->(),
	_from,
        _to, 
	set_op<T,T>(), 
        tag, 
	row_ownership().comm());

    // y := dia*x
    csr_amux (
        csr_seq_rep<T>::begin(), 
        csr_seq_rep<T>::end(), 
        x.begin(), 
        set_op<T,T>(), 
        y.begin());

    // receive tmp from others
    mpi_scatter_end (
	x.begin(),
        _buffer.begin(), 
	_from,
        _to,
	set_op<T,T>(), 
	tag, 
	row_ownership().comm());

    // y += ext*tmp
    csr_amux (_ext.begin(), _ext.end(), _buffer.begin(), set_add_op<T,T>(), y.begin());
}
template<class T>
void
csr_mpi_rep<T>::trans_mult(
    const vec<T,distributed>& x,
    vec<T,distributed>&       y)
    const
{
    check_macro (x.size() == nrow(), "csr.trans_mult(vec): incompatible csr("<<nrow()<<","<<ncol()<<") and vec("<<x.size()<<")");
    y.resize (col_ownership());

    // y = dia*x
    std::fill (y.begin(), y.end(), T(0));
    csr_cumul_trans_mult (
        csr_seq_rep<T>::begin(), 
        csr_seq_rep<T>::end(), 
        x.begin(), 
        set_add_op<T,T>(),
        y.begin());

    // buffer = ext*x
    std::fill (_buffer.begin(), _buffer.end(), T(0));
    csr_cumul_trans_mult (
        _ext.begin(),
        _ext.end(),
        x.begin(), 
        set_add_op<T,T>(),
        _buffer.begin());

    // send buffer to others parts of y (+=)
    distributor::tag_type tag = distributor::get_new_tag();
    mpi_scatter_begin (
        _buffer.begin().operator->(),
	y.begin().operator->(),
        _to,  // reverse mode
	_from,
	set_add_op<T,T>(), // += 
        tag, 
	col_ownership().comm());

    // receive buffer from others
    mpi_scatter_end (
        _buffer.begin(), 
	y.begin(),
        _to, // reverse mode
	_from,
	set_add_op<T,T>(), // +=
	tag, 
	col_ownership().comm());
}
template<class T>
csr_mpi_rep<T>&
csr_mpi_rep<T>::operator*= (const T& lambda)
{
  csr_seq_rep<T>::operator*= (lambda);
  _ext.operator*= (lambda);
  return *this;
}
// ----------------------------------------------------------------------------
// expression c=a+b and c=a-b with a temporary c=*this
// ----------------------------------------------------------------------------
// NOTE: cet algo pourrait servir aussi au cas diag (dans csr_seq.cc)
// a condition de mettre des pseudo-renumerotations et d'enlever
// le set.insert dans la 1ere passe. Pas forcement plus lisible...
template<class T, class BinaryOp>
void
csr_ext_add (
    const csr_seq_rep<T>& a, const std::vector<typename csr<T>::size_type>& jext_a2dis_j,
    const csr_seq_rep<T>& b, const std::vector<typename csr<T>::size_type>& jext_b2dis_j,
          csr_seq_rep<T>& c,       std::vector<typename csr<T>::size_type>& jext_c2dis_j,
    BinaryOp binop)
{
    typedef typename csr_mpi_rep<T>::size_type size_type;
    typedef typename csr_mpi_rep<T>::iterator iterator;
    typedef typename csr_mpi_rep<T>::const_iterator const_iterator;
    typedef typename csr_mpi_rep<T>::data_iterator data_iterator;
    typedef typename csr_mpi_rep<T>::const_data_iterator const_data_iterator;
    typedef std::pair<size_type,T>             pair_type; 
    //
    // first pass: compute nnz_c and resize
    //
    size_type nnz_ext_c = 0;
    const size_type infty = std::numeric_limits<size_type>::max();
    const_iterator ia = a.begin();
    const_iterator ib = b.begin();
    std::set<size_type> jext_c_set;
    for (size_type i = 0, n = a.nrow(); i < n; i++) {
        for (const_data_iterator iter_jva = ia[i], last_jva = ia[i+1],
                                 iter_jvb = ib[i], last_jvb = ib[i+1];
	    iter_jva != last_jva || iter_jvb != last_jvb; ) {

            size_type dis_ja = iter_jva == last_jva ? infty : jext_a2dis_j [(*iter_jva).first];
            size_type dis_jb = iter_jvb == last_jvb ? infty : jext_b2dis_j [(*iter_jvb).first];
	    if (dis_ja == dis_jb) {
		jext_c_set.insert (dis_ja);
		iter_jva++;
		iter_jvb++;
	    } else if (dis_ja < dis_jb) {
		jext_c_set.insert (dis_ja);
		iter_jva++;
            } else {
		jext_c_set.insert (dis_jb);
		iter_jvb++;
            }
	    nnz_ext_c++;
  	}
    }
    c.resize (a.nrow(), b.ncol(), nnz_ext_c);
    jext_c2dis_j.resize (jext_c_set.size());
    std::copy (jext_c_set.begin(), jext_c_set.end(), jext_c2dis_j.begin());
    //
    // second pass: add and store in c
    //
    op_dis_j2jext_t<pair_type, pair_type, typename vector<size_type>::const_iterator> 
	op_dis_j2jext_c (jext_c2dis_j.begin(), jext_c2dis_j.end());
    data_iterator iter_jvc = c._data.begin().operator->();
    iterator ic = c.begin();
    *ic++ = iter_jvc;
    for (size_type i = 0, n = a.nrow(); i < n; i++) {
        for (const_data_iterator iter_jva = ia[i], last_jva = ia[i+1],
                                 iter_jvb = ib[i], last_jvb = ib[i+1];
	    iter_jva != last_jva || iter_jvb != last_jvb; ) {

            size_type dis_ja = iter_jva == last_jva ? infty : jext_a2dis_j [(*iter_jva).first];
            size_type dis_jb = iter_jvb == last_jvb ? infty : jext_b2dis_j [(*iter_jvb).first];
	    if (dis_ja == dis_jb) {
		*iter_jvc++ = op_dis_j2jext_c (pair_type(dis_ja, binop((*iter_jva).second, (*iter_jvb).second)));
		iter_jva++;
		iter_jvb++;
	    } else if (dis_ja < dis_jb) {
		*iter_jvc++ = op_dis_j2jext_c (pair_type(dis_ja, (*iter_jva).second));
		iter_jva++;
            } else {
		*iter_jvc++ = op_dis_j2jext_c (pair_type(dis_jb, (*iter_jvb).second));
		iter_jvb++;
            }
  	}
        *ic++ = iter_jvc;
    }
}
template<class T>
template<class BinaryOp>
void
csr_mpi_rep<T>::assign_add (
    const csr_mpi_rep<T>& a, 
    const csr_mpi_rep<T>& b,
    BinaryOp binop)
{
    check_macro (a.dis_nrow() == b.dis_nrow() && a.dis_ncol() == b.dis_ncol(),
	"a+b: invalid matrix a("<<a.dis_nrow()<<","<<a.dis_ncol()<<") and b("
	<<b.dis_nrow()<<","<<b.dis_ncol()<<")");
    check_macro (a.nrow() == b.nrow() && a.ncol() == b.ncol(),
	"a+b: matrix local distribution mismatch: a("<<a.nrow()<<","<<a.ncol()<<") and b("
	<<b.nrow()<<","<<b.ncol()<<")");

    // 1) the diagonal part:
    csr_seq_rep<T>::assign_add (a, b, binop);
    _dis_nnz = mpi::all_reduce (comm(), nnz(), std::plus<size_type>());

    // 2) the extra-diagonal part:
    csr_ext_add (
        a._ext, a._jext2dis_j,
        b._ext, b._jext2dis_j,
          _ext,   _jext2dis_j,
        binop);

    _dis_nnz += mpi::all_reduce (comm(), _ext.nnz(), std::plus<size_type>());

    // 3) scatter init for a*x :
    vector<size_type> id(_jext2dis_j.size());
    for (size_type i = 0; i < id.size(); i++) id[i] = i;
    distributor::tag_type tag = distributor::get_new_tag();

    _buffer.resize (_jext2dis_j.size());
    mpi_scatter_init(
        _jext2dis_j.size(),
        _jext2dis_j.begin().operator->(),
	id.size(),
	id.begin().operator->(),
	dis_ncol(),
	col_ownership().begin().operator->(),
	tag,
	row_ownership().comm(),
        _from,
        _to);
}
// ----------------------------------------------------------------------------
// trans(a)
// ----------------------------------------------------------------------------
template<class T>
void
csr_mpi_rep<T>::build_transpose (csr_mpi_rep<T>& b) const
{
  //
  // first: assembly all _ext parts of the a matrix in b=trans(a)
  //
  asr_mpi_rep<T> b_ext (col_ownership(), row_ownership());
  size_type first_i = row_ownership().first_index();
  const_iterator ext_ia = ext_begin();
  for (size_type i = 0, n = nrow(); i < n; i++) {
    size_type dis_i = first_i + i;
    for (const_data_iterator p = ext_ia[i], last_p = ext_ia[i+1]; p < last_p; p++) {
      size_type dis_j = jext2dis_j ((*p).first);
      const T& val    = (*p).second;
      b_ext.dis_entry (dis_j, dis_i) = val;
    }
  }
  b_ext.dis_entry_assembly_begin();
  b_ext.dis_entry_assembly_end();
  b.build_from_asr (b_ext);
  //
  // second: add all _diag parts
  //
  csr_seq_rep<T>::build_transpose (b);
  //
  // third: update dis_nnz by adding all diag nnz
  //
  b._dis_nnz += mpi::all_reduce (comm(), csr_seq_rep<T>::nnz(), std::plus<size_type>());
}
// ----------------------------------------------------------------------------
// instanciation in library
// ----------------------------------------------------------------------------
template class csr_mpi_rep<Float>;
template void csr_mpi_rep<Float>::assign_add (
	const csr_mpi_rep<Float>&, const csr_mpi_rep<Float>&, std::plus<Float>);
template void csr_mpi_rep<Float>::assign_add (
	const csr_mpi_rep<Float>&, const csr_mpi_rep<Float>&, std::minus<Float>);
} // namespace rheolef
# endif // _RHEOLEF_HAVE_MPI
