//
// BAGEL - Brilliantly Advanced General Electronic Structure Library
// Filename: zcasscf.h
// Copyright (C) 2013 Toru Shiozaki
//
// Author: Toru Shiozaki <shiozaki@northwestern.edu>
// Maintainer: Shiozaki group
//
// This file is part of the BAGEL package.
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with this program.  If not, see <http://www.gnu.org/licenses/>.
//

#ifndef __SRC_ZCASSCF_ZCASSCF_H
#define __SRC_ZCASSCF_ZCASSCF_H

#include <src/ci/zfci/zharrison.h>
#include <src/multi/casscf/rotfile.h>
#include <src/wfn/method.h>
#include <src/util/math/bfgs.h>
#include <src/util/math/step_restrict_bfgs.h>

namespace bagel {

class ZCASSCF : public Method, public std::enable_shared_from_this<ZCASSCF> {
  protected:
    int nneg_;
    int nocc_;
    int nclosed_;
    int nact_;
    int nvirt_;
    int nvirtnr_;
    int nbasis_;

    int charge_;

    bool gaunt_;
    bool breit_;
    bool natocc_;

    // enforce time-reversal symmetry
    bool tsymm_;

    double thresh_;
    double thresh_micro_;
    std::complex<double> rms_grad_;
    std::complex<double> level_shift_;

    int nstate_;

    int max_iter_;
    int max_micro_iter_;

    std::shared_ptr<const RelCoeff_Block> coeff_;
    std::shared_ptr<const Matrix>  nr_coeff_;
    std::shared_ptr<const ZMatrix> hcore_;
    std::shared_ptr<const ZMatrix> overlap_;
    VectorB occup_;

    void print_header() const;
    void print_iteration(int iter, int miter, int tcount, const std::vector<double> energy, const double error, const double time) const;

    void init();

    void mute_stdcout() const;
    void resume_stdcout() const;

    std::shared_ptr<ZHarrison> fci_;
    // compute F^{A} matrix ; see Eq. (18) in Roos IJQC 1980
    std::shared_ptr<const ZMatrix> active_fock(std::shared_ptr<const ZMatrix> transform = nullptr, const bool with_hcore = false, const bool bfgs = false) const;
    // transform RDM from bitset representation in ZFCI to CAS format
    std::shared_ptr<const ZMatrix> transform_rdm1() const;

    // energy
    std::vector<double> energy_;
    std::vector<double> prev_energy_;
    double micro_energy_;

    // internal functions
    // force time-reversal symmetry for a zmatrix with given number of virtual orbitals
    void kramers_adapt(std::shared_ptr<ZMatrix> o, const int nvirt) const;

    void zero_positronic_elements(std::shared_ptr<ZRotFile> rot);

    std::shared_ptr<RelCoeff_Kramers> nonrel_to_relcoeff(std::shared_ptr<const Matrix> nr_coeff) const;

  public:
    ZCASSCF(const std::shared_ptr<const PTree> idat, const std::shared_ptr<const Geometry> geom, const std::shared_ptr<const Reference> ref = nullptr);

    virtual void compute() override = 0;

    // TODO : add FCI quantities to reference
    std::shared_ptr<const Reference> conv_to_ref() const override;

    // diagonalize 1RDM to obtain natural orbital transformation matrix and natural orbital occupation numbers
    std::pair<std::shared_ptr<ZMatrix>, VectorB> make_natural_orbitals(std::shared_ptr<const ZMatrix> rdm1) const;
    // natural orbital transformations for the 1 and 2 RDMs, the coefficient, and qvec
    std::shared_ptr<const ZMatrix> natorb_rdm1_transform(const std::shared_ptr<ZMatrix> coeff, std::shared_ptr<const ZMatrix> rdm1) const;
    std::shared_ptr<const ZMatrix> natorb_rdm2_transform(const std::shared_ptr<ZMatrix> coeff, std::shared_ptr<const ZMatrix> rdm2) const;
    std::shared_ptr<const RelCoeff_Block> update_coeff(std::shared_ptr<const RelCoeff_Block> cold, std::shared_ptr<const ZMatrix> natorb) const;
    std::shared_ptr<const ZMatrix> update_qvec(std::shared_ptr<const ZMatrix> qold, std::shared_ptr<const ZMatrix> natorb) const;
    // kramers adapt for RotFile is a static function!
    static void kramers_adapt(std::shared_ptr<ZRotFile> o, const int nclosed, const int nact, const int nvirt);
    // print natural orbital occupation numbers
    void print_natocc() const;

    // functions to retrieve protected members
    int nocc() const { return nocc_; }
    int nclosed() const { return nclosed_; }
    int nact() const { return nact_; }
    int nvirt() const { return nvirt_; }
    int nvirtnr() const { return nvirtnr_; }
    int nbasis() const { return nbasis_; }
    int nstate() const { return nstate_; }
    int max_iter() const { return max_iter_; }
    int max_micro_iter() const { return max_micro_iter_; }
    double thresh() const { return thresh_; }
    double thresh_micro() const { return thresh_micro_; }
    double occup(const int i) const { return occup_[i]; }
    std::complex<double> rms_grad() const { return rms_grad_; }
    bool tsymm() const { return tsymm_; }
    // function to copy electronic rotations from a rotation file TODO: make lambda
    std::shared_ptr<ZRotFile> copy_electronic_rotations(std::shared_ptr<const ZRotFile> rot) const;
    // function to copy positronic rotations from a rotation file TODO: make lambda
    std::shared_ptr<ZRotFile> copy_positronic_rotations(std::shared_ptr<const ZRotFile> rot) const;

    std::shared_ptr<const ZHarrison> fci() const { return fci_; }
};

static const double zoccup_thresh = 1.0e-10;

}

#endif
