#include "miview_fmri.h"
#include "miviewview.h" // for Log

#include <odindata/correlation.h>
#include <odindata/filter.h>

struct FmriData {
  ivector pixelinslice;
  Data<float,4> fmri;
  Data<float,4> mask;
  Data<float,4> pmap;
  Data<float,4> zmap;
  Data<float,1> designdata;
  bool has_mask;
};

/////////////////////////////////////////////////////////////////

MiViewFmri::MiViewFmri() {

  valid=false;

  data=new FmriData();

  designfile.set_parmode(hidden).set_cmdline_option("design").set_description("Load fMRI design from this file (comma or space separated)");
  append_member(designfile,"designfile");

  fmrifile.set_parmode(hidden).set_cmdline_option("fmri").set_description("Load fMRI data from this file");
  append_member(fmrifile,"fmrifile");

  maskfile.set_parmode(hidden).set_cmdline_option("fmask").set_description("fMRI mask file");
  append_member(maskfile,"maskfile");

  bonferr=false;
  bonferr.set_cmdline_option("bonferr").set_description("Use Bonferroni correction");
  append_member(bonferr,"Bonferroni Correction");

  corr=0.001;
  corr.set_cmdline_option("corr").set_description("Error probability for correlation");
  append_member(corr,"Error Probability");

  neighb=1;
  neighb.set_cmdline_option("neighb").set_description("Minimum next neighbours with significant activation");
  append_member(neighb,"Min Neighbours");

  sigchange.set_parmode(noedit).set_description("fMRI signal change, averaged over all significant voxels");
  append_member(sigchange,"Signal Change");

  design.set_description("fMRI design vector");
  append_member(design,"fMRI Design");

  tcourse.set_description("fMRI signal");
  append_member(tcourse,"fMRI Timecourse");
}


MiViewFmri::~MiViewFmri() {
  delete data;
}


bool MiViewFmri::init(const FileReadOpts& ropts, const Protocol& prot, const FilterChain& filterchain) {
  Log<MiViewComp> odinlog("MiViewFmri","init");
  Range all=Range::all();

  Protocol protcopy(prot); // writable copy

  // read and parse design file
  if(designfile=="") return false;
  STD_string designstr;
  if(::load(designstr,designfile)<0) {
    ODINLOG(odinlog,errorLog) << "Unable to load designfile " << designfile << STD_endl;
    return false;
  }
  designstr=replaceStr(designstr,","," ");
  svector designvecstr=tokens(designstr);
  int nrep=designvecstr.size();
  design.resize(nrep);
  tcourse.resize(nrep);
  data->designdata.resize(nrep);
  for(int i=0; i<nrep; i++) {
    float designval=atof(designvecstr[i].c_str());
    data->designdata(i)=designval;
    design[i]=designval;
  }


  if(fmrifile=="") return false;
  ODINLOG(odinlog,infoLog) << "Loading fMRI data ..." << STD_endl;
  if(data->fmri.autoread(fmrifile,ropts,&protcopy)<=0) {
    ODINLOG(odinlog,errorLog) << "Unable to load fmrifile " << fmrifile << STD_endl;
    return false;
  }
  if(!filterchain.apply(protcopy, data->fmri)) return false;

  data->has_mask=false;
  if(maskfile!="") {
    ODINLOG(odinlog,infoLog) << "Loading fMRI mask ..." << STD_endl;
    if(data->mask.autoread(maskfile,ropts,&protcopy)<=0) {
      ODINLOG(odinlog,errorLog) << "Unable to load maskfile " << maskfile << STD_endl;
      return false;
    }

    if(!filterchain.apply(protcopy, data->mask)) return false;

    TinyVector<int,4> fmrishape=data->fmri.shape(); fmrishape(0)=1;
    TinyVector<int,4> maskshape=data->mask.shape(); maskshape(0)=1;

    if(fmrishape!=maskshape) {
      ODINLOG(odinlog,errorLog) << "Shape mismatch:  fmrishape/maskshape" << fmrishape << "/" << maskshape << STD_endl;
      return false;
    }

    data->has_mask=true;
  }
  

  int nrep_data=data->fmri.extent(0);
  if(nrep_data!=nrep) {
    ODINLOG(odinlog,errorLog) << "Repetition size mismatch: " << nrep_data << "!=" << nrep << STD_endl;
    return false;
  }

  if(!nrep) {
    return false;
  }


  int nz=data->fmri.extent(1);
  int ny=data->fmri.extent(2);
  int nx=data->fmri.extent(3);

  data->pmap.resize(1,nz,ny,nx);
  data->zmap.resize(1,nz,ny,nx);
  data->pixelinslice.resize(nz); data->pixelinslice=0;
  overlay_map.redim(nz,ny,nx);
  toberemoved.redim(nz,ny,nx);



  Data<float,1> onepixel(nrep);
  double sigthreshold=0.5*mean(data->fmri);
  for(int iz=0; iz<nz; iz++) {
    for(int iy=0; iy<ny; iy++) {
      for(int ix=0; ix<nx; ix++) {
        onepixel=data->fmri(all,iz,iy,ix);
        float p=0.0;
        float z=0.0;
        if(mean(onepixel)>sigthreshold) {
          correlationResult corr=correlation(data->designdata, onepixel);
          if(corr.r>0.0) { // take only positivie correlation
            p=corr.p;
            z=corr.z;
          }
          data->pixelinslice[iz]++;
        }
        data->pmap(0,iz,iy,ix)=p;
        data->zmap(0,iz,iy,ix)=z;
      }
    }
    ODINLOG(odinlog,infoLog) << "Correlation analysis with " << data->pixelinslice[iz] << " pixels" << STD_endl;
  }

  valid=true;

  return true;
}



const farray& MiViewFmri::get_overlay_map() const {
  int nz=data->pmap.extent(1);
  int ny=data->pmap.extent(2);
  int nx=data->pmap.extent(3);

  overlay_map=0.0; // reset

  for(int iz=0; iz<nz; iz++) {

    // Bonferroni correction
    float corrthresh=corr;
    unsigned int total_pixels=data->pixelinslice[iz];
    if(bonferr && total_pixels>0) {
      corrthresh=1.0-pow(1.0-corr,1.0/float(total_pixels));
    }

    for(int iy=0; iy<ny; iy++) {
      for(int ix=0; ix<nx; ix++) {
        float p=data->pmap(0,iz,iy,ix);
        bool include=(p<corrthresh);
        if(data->has_mask && data->mask(0,iz,iy,ix)<=0.0) include=false;
        if(include) overlay_map(iz,iy,ix)=data->zmap(0,iz,iy,ix);
      }
    }
  }


  // next-neighbour analysis
  if(neighb>0) {
    for(int iz=0; iz<nz; iz++) {
      for(int iym=0; iym<ny; iym++) {
        for(int ixm=0; ixm<nx; ixm++) {

          int neighbours=0;
          for(int iyc=iym-1; iyc<=iym+1; iyc++) {
            for(int ixc=ixm-1; ixc<=ixm+1; ixc++) {

              if( iyc!=iym || ixc!=ixm ) {
                if( iyc>=0 && iyc<ny && ixc>=0 && ixc<nx ) {
                  if(overlay_map(iz,iyc,ixc)>0.0) neighbours++;
                }
              }
            }
          }

          if(neighbours<neighb) toberemoved(iz,iym,ixm)=0.0;
          else toberemoved(iz,iym,ixm)=1.0;
        }
      }
    }

    overlay_map*=toberemoved;
  }

  return overlay_map;
}


void MiViewFmri::update() {
  Range all=Range::all();
  int nz=data->pmap.extent(1);
  int ny=data->pmap.extent(2);
  int nx=data->pmap.extent(3);


  // update tcourse
  int nrep=tcourse.length();
  Data<float,1> tcourse_all(nrep); tcourse_all=0.0;
  int nvoxels=0;
  tcourse_all=0.0;
  for(int iz=0; iz<nz; iz++) {
    for(int iy=0; iy<ny; iy++) {
      for(int ix=0; ix<nx; ix++) {
        if(overlay_map(iz,iy,ix)>0.0) {
          tcourse_all(all)+=data->fmri(all,iz,iy,ix);
          nvoxels++;
        }
      }
    }
  }
  if(nvoxels) tcourse_all/=float(nvoxels);
  for(int i=0; i<nrep; i++) {
//    tcourse[i]=STD_complex(tcourse_all(i))*expc(float2imag(data->designdata(i))); // Encode as amplitude and phase
    tcourse[i]=tcourse_all(i);
  }

  // update sigchange
  fmriResult fr=fmri_eval(tcourse_all, data->designdata);
  sigchange=ftos(100.0*fr.rel_diff,2)+"+/-"+ftos(100.0*fr.rel_err,2)+"%";

}
