//LabPlot : FitListDialog.cc

#include <stdio.h>
#include <stdlib.h>
#include <iostream>
#include <math.h>
#include <qstring.h>
#include <qlabel.h>
#include <qfiledialog.h>
#include <qcolordialog.h>
#include <klocale.h>
#include <kmessagebox.h>
#ifdef HAVE_GSL
#include <gsl/gsl_rng.h>
#include <gsl/gsl_blas.h>
#endif
#include "FitListDialog.h"

using namespace std;

static const char *modelitems[] = {"linear (a*x+b)","exponentiell (a*exp(-b*x)+c)",
	"a*x^b","a+b*ln(x)","1/(a*x+b)","a*x*exp(-b*x)",
	"gaussian (1/(sqrt(2 PI) a)) exp(-(x-b)/(2 a))","maxwell (a v exp(-b v))",
	"planck (a v/(exp(b v)-1))","lorentz (a/((w-b)+c/4))","a exp(b x)+c exp(d x)",
	"a exp(b x)+c exp(d x)+e exp(f x)",0};

enum Model {MLINEAR,MEXP,MPOT,MLN,M1L,MEXP2,MGAUSSIAN,MMAXWELL,MPLANCK,MLORENTZ,MMULTIEXP2,MMULTIEXP3};

struct data {
        int n;
        double *x;
        double *y;
        double *sigma;
	int np;
	Model model;
	double base;
};

FitListDialog::FitListDialog(Worksheet *p,const char *name)
	: ListDialog(p, name)
{
	setCaption(i18n("Fit Dialog"));

	Plot *plot = p->getPlot(p->getAPI());

	QTabWidget *tw = new QTabWidget(vbox);
	QVBox *tab1 = new QVBox(tw);

	QHBox *hb = new QHBox(tab1);
	QLabel *tmp = new QLabel(i18n("Model : "),hb);
	tmp->setMaximumWidth(50);
        modelcb = new KComboBox(hb);
        modelcb->insertStrList(modelitems);
        modelcb->setCurrentItem(0);
	QObject::connect(modelcb,SIGNAL(activated (int)),SLOT(updateModel(int)));
	// TODO : custom fit function
	//new QLabel(i18n("Fit Function :"),hb);
	//new QLabel(i18n(" y = "),hb);
	//funle = new KLineEdit("a*x+b",hb);

	hb = new QHBox(tab1);
	new QLabel(i18n("Nr. of Parameter : "),hb);
	parle = new KLineEdit(QString("2"),hb);
	// only needed when using custom function (not implemented yet)
	parle->setReadOnly(true);

	new QLabel(i18n("Initial Values : "),tab1);
	hb = new QHBox(tab1);
	new QLabel(i18n("a = "),hb);
	par1le = new KLineEdit(QString("1.0"),hb);
	par1le->setValidator(new QDoubleValidator(par1le));
	new QLabel(i18n(" b = "),hb);
	par2le = new KLineEdit(QString("1.0"),hb);
	par2le->setValidator(new QDoubleValidator(par2le));
	new QLabel(i18n(" c = "),hb);
	par3le = new KLineEdit(QString("1.0"),hb);
	par3le->setValidator(new QDoubleValidator(par3le));
	par3le->setEnabled(false);
	hb = new QHBox(tab1);
	new QLabel(i18n("d = "),hb);
	par4le = new KLineEdit(QString("1.0"),hb);
	par4le->setValidator(new QDoubleValidator(par4le));
	par4le->setEnabled(false);
	new QLabel(i18n(" e = "),hb);
	par5le = new KLineEdit(QString("1.0"),hb);
	par5le->setValidator(new QDoubleValidator(par5le));
	par5le->setEnabled(false);
	new QLabel(i18n(" f = "),hb);
	par6le = new KLineEdit(QString("1.0"),hb);
	par6le->setValidator(new QDoubleValidator(par6le));
	par6le->setEnabled(false);

	hb = new QHBox(tab1);
	new QLabel(i18n("Maximum Steps : "),hb);
	stepsle = new KLineEdit(QString("100"),hb);
	stepsle->setValidator(new QIntValidator(stepsle));
	new QLabel(i18n(" Tolerance : "),hb);
	tolle = new KLineEdit(QString("0.001"),hb);
	tolle->setValidator(new QDoubleValidator(0,1,20,tolle));

	hb = new QHBox(tab1);
	regioncb = new QCheckBox(i18n("use Region "),hb);
        if(plot->getRegionMin() != plot->getRegionMax() )
                regioncb->setChecked(true);
        else
                regioncb->setChecked(false);
        new QLabel(i18n("( From "),hb);
        regionminle = new KLineEdit(QString::number(plot->getRegionMin()),hb);
	regionminle->setValidator(new QDoubleValidator(regionminle));
        new QLabel(i18n(" To "),hb);
        regionmaxle = new KLineEdit(QString::number(plot->getRegionMax()),hb);
	regionmaxle->setValidator(new QDoubleValidator(regionmaxle));
        new QLabel(i18n(" )"),hb);

	hb = new QHBox(tab1);
        baselinecb = new QCheckBox(i18n("Use Baseline @ y = "),hb);
       	baselinecb->setChecked(false);
	baselinele = new KLineEdit(QString::number(plot->Baseline()),hb);
	baselinele->setValidator(new QDoubleValidator(baselinele));

	hb = new QHBox(tab1);
	new QLabel(i18n("Number of Points for fit function : "),hb);
	GraphList *gl = plot->getGraphList();
	GRAPHType s = gl->getStruct(0);
	int number=100;
	if (s == GRAPH2D) {
		Graph2D *g = gl->getGraph2D(0);
		number = g->Number();
	}
	numberle = new KLineEdit(QString::number(number),hb);
	numberle->setValidator(new QIntValidator(numberle));

	hb = new QHBox(tab1);
	new QLabel(i18n("Range of fit function : "),hb);
	LRange *range = plot->getRanges();
	minle = new KLineEdit(QString::number(range[0].Min()),hb);
	minle->setValidator(new QDoubleValidator(minle));
	new QLabel(i18n(" .. "),hb);
	maxle = new KLineEdit(QString::number(range[0].Max()),hb);
	maxle->setValidator(new QDoubleValidator(maxle));

	Style style;
	Symbol symbol;
	QVBox *styletab;
	if(p->getPlot(p->getAPI())->getType() == PSURFACE)
		styletab = surfaceStyle(tw);
	else
		styletab = simpleStyle(tw,0, &style, &symbol);

	tw->addTab(tab1,i18n("Parameter"));
	tw->addTab(styletab,i18n("Style"));

	infote = new QTextEdit(vbox);

#if QT_VERSION > 0x030005
	infote->setTextFormat( Qt::LogText );
#else
	infote->setTextFormat( Qt::PlainText );
#endif

	QObject::connect(ok,SIGNAL(clicked()),SLOT(ok_clicked()));
        QObject::connect(apply,SIGNAL(clicked()),SLOT(apply_clicked()));

	int sizex = vbox->minimumSizeHint().width()+50;
	int sizey = vbox->minimumSizeHint().height()+gbox->minimumSizeHint().height()+
		tw->minimumSizeHint().height()+100;
	setMinimumSize(sizex,sizey);
	resize(sizex,sizey);
}

//! called when a model is selected
void FitListDialog::updateModel(int model) {
	// TODO : set inital values using selected graph values

	if ((Model)model == MEXP || (Model)model == MLORENTZ) {
		parle->setText(QString("3"));
		par3le->setEnabled(true);
	}
	else if ((Model)model == MMULTIEXP2) {
		parle->setText(QString("4"));
		par3le->setEnabled(true);
		par4le->setEnabled(true);
	}
	else if ((Model)model == MMULTIEXP3) {
		parle->setText(QString("6"));
		par3le->setEnabled(true);
		par4le->setEnabled(true);
		par5le->setEnabled(true);
		par6le->setEnabled(true);
	}
	else  {
		parle->setText(QString("2"));
		par3le->setEnabled(false);
		par4le->setEnabled(false);
		par5le->setEnabled(false);
		par6le->setEnabled(false);
	}
}

#ifdef HAVE_GSL
int fun_f(const gsl_vector *v, void *params, gsl_vector *f) {
        int n = ((struct data *)params)->n;
        int np = ((struct data *)params)->np;
        double *x = ((struct data *)params)->x;
        double *y = ((struct data *)params)->y;
        double *sigma = ((struct data *) params)->sigma;
        Model model = ((struct data *) params)->model;
	double base = ((struct data *) params)->base;

	double* p = new double[np];
	for (int i=0;i<np;i++)
		p[i]=gsl_vector_get(v,i);

        for (int i = 0; i < n; i++) {
                double t = x[i];
		double Yi=0;
		if (model == MLINEAR)
			Yi = p[0] * t + p[1];
		else if (model == MEXP)
			Yi = p[0] * exp (-p[1] * t) + p[2];
		else if (model == MPOT)
			Yi = p[0] * pow(t,p[1]);
		else if (model == MLN) {
			double tlog;
			if (t<=0)
				tlog = 0;
			else
				tlog=log(t);
			Yi = p[0] + p[1]*tlog;
		}
		else if (model == M1L)
			Yi = 1/(p[0]+p[1]*t);
		else if (model == MEXP2)
			Yi = p[0]*t*exp(-p[1]*t);
		else if (model == MGAUSSIAN)
			Yi = 1/(sqrt(2*M_PI)*p[0])*exp(-(t-p[1])*(t-p[1])/(2*p[0]*p[0]));
		else if (model == MMAXWELL)
			Yi = p[0]*t*t*exp(-p[1]*t*t);
		else if (model == MPLANCK) {
			if (t==0)
				Yi = 0;
			else
				Yi = p[0]*t*t*t/(exp(p[1]*t)-1);
		}
		else if (model == MLORENTZ)
			Yi = p[0]/((t-p[1])*(t-p[1])+p[2]*p[2]/4);
		else if (model == MMULTIEXP2)
			Yi = p[0] * exp (p[1] * t) + p[2] * exp (p[3] * t);
		else if (model == MMULTIEXP3)
			Yi = p[0] * exp (p[1] * t) + p[2] * exp (p[3] * t) + p[4] * exp(p[5] * t);

		Yi += base;
		//cout<<"Yi = "<<Yi<<endl;

                gsl_vector_set (f, i, (Yi - y[i])/sigma[i]);
        }

        return GSL_SUCCESS;
}

int fun_df(const gsl_vector *v, void *params, gsl_matrix *J) {
        int n = ((struct data *)params)->n;
        int np = ((struct data *)params)->np;
        double *x = ((struct data *)params)->x;
        double *sigma = ((struct data *) params)->sigma;
        Model model = ((struct data *) params)->model;

	double* p = new double[np];
	for (int i=0;i<np;i++)
		p[i]=gsl_vector_get(v,i);

        for (int i = 0; i < n; i++) {
                /* Jacobian matrix J(i,j) = dfi / dxj, */
                /* where fi = (Yi - yi)/sigma[i],      */
                /*       Yi = model  */
                /* and the xj are the parameters */
		// TODO : 1/s in all Jacobians
                double t = x[i];
                double s = sigma[i];

		if (model == MLINEAR) {
			gsl_matrix_set (J, i, 0, t/s);
                	gsl_matrix_set (J, i, 1, 1/s);
		}
		else if (model == MEXP) {
                	double e = exp(-p[1] * t);
                	gsl_matrix_set (J, i, 0, e/s);
                	gsl_matrix_set (J, i, 1, - t * p[0] * e/s);
                	gsl_matrix_set (J, i, 2, 1/s);
		}
		else if (model == MPOT) {
			double tlog;
			if (t<=0)
				tlog = 0;
			else
				tlog=log(t);
			gsl_matrix_set (J, i, 0, pow(t,p[1])/s);
                	gsl_matrix_set (J, i, 1, p[0] * pow(t,p[1]) * tlog/s);
		}
		else if (model == MLN) {
			double plog;
			if (p[1]==0)
				plog = 0;
			else if (p[1]<0)
				plog = log(-p[1]);
			else
				plog=log(p[1]);
			gsl_matrix_set (J, i, 0, 1/s);
                	gsl_matrix_set (J, i, 1, plog/s);
		}
		else if (model == M1L) {
			double tmp = s*(p[0]+p[1]*t)*(p[0]+p[1]*t);
			gsl_matrix_set (J, i, 0, -1/tmp);
                	gsl_matrix_set (J, i, 1, - t/tmp);
		}
		else if (model == MEXP2) {
			gsl_matrix_set (J, i, 0, t*exp(-p[1]*t));
                	gsl_matrix_set (J, i, 1, -p[0]*exp(-p[1]*t)*t*t);
		}
		else if (model == MGAUSSIAN) {
			double e = exp(-(t-p[1])*(t-p[1])/(2*p[0]*p[0]));
			double p2 = p[0]*p[0];
			gsl_matrix_set (J, i, 0, e*(t*t-2*t*p[1]+p[1]*p[1]-p2)/(sqrt(2*M_PI*p2*p2)));
                	gsl_matrix_set (J, i, 1, e*(t-p[1])/(sqrt(2*M_PI)*p[0]*p2));
		}
		else if (model == MMAXWELL) {
			double e = exp(-p[1]*t*t);
			gsl_matrix_set (J, i, 0, t*t*e);
                	gsl_matrix_set (J, i, 1, -p[0]*e*t*t*t*t);
		}
		else if (model == MPLANCK) {
			if (t==0) {
				gsl_matrix_set (J, i, 0, 0);
               		 	gsl_matrix_set (J, i, 1, 0);
			}
			else {
				double e = exp(p[1]*t);
				gsl_matrix_set (J, i, 0, t*t*t/(e-1));
               		 	gsl_matrix_set (J, i, 1, -p[0]*e*t*t*t*t/((e-1)*(e-1)));
			}
		}
		else if (model == MLORENTZ) {
			double tmp = p[2]*p[2]/4+(t-p[1])*(t-p[1]);
			gsl_matrix_set (J, i, 0, 1/tmp);
                	gsl_matrix_set (J, i, 1, 2*p[0]*(t-p[1])/(tmp*tmp));
                	gsl_matrix_set (J, i, 2, -p[0]*p[2]/(2*tmp*tmp));
		}
		else if (model == MMULTIEXP2) {
                	double e1 = exp(p[1] * t);
                	double e2 = exp(p[3] * t);
                	gsl_matrix_set (J, i, 0, e1/s);
                	gsl_matrix_set (J, i, 1, t * p[0] * e1/s);
                	gsl_matrix_set (J, i, 2, e2/s);
                	gsl_matrix_set (J, i, 3, t * p[2] * e2/s);
		}
		else if (model == MMULTIEXP3) {
                	double e1 = exp(p[1] * t);
                	double e2 = exp(p[3] * t);
                	double e3 = exp(p[5] * t);
                	gsl_matrix_set (J, i, 0, e1/s);
                	gsl_matrix_set (J, i, 1, t * p[0] * e1/s);
                	gsl_matrix_set (J, i, 2, e2/s);
                	gsl_matrix_set (J, i, 3, t * p[2] * e2/s);
                	gsl_matrix_set (J, i, 4, e3/s);
                	gsl_matrix_set (J, i, 5, t * p[4] * e3/s);
		}
        }
        return GSL_SUCCESS;
}

int fun_fdf(const gsl_vector *x, void *params, gsl_vector *f,gsl_matrix *J) {
        fun_f (x, params, f);
        fun_df (x, params, J);

        return GSL_SUCCESS;
}

void FitListDialog::print_state(int iter, gsl_multifit_fdfsolver * s) {
	int np = parle->text().toInt();

	QString text;
	text+= "iter : "+QString::number(iter)+"| x = ";
	for (int i=0;i<np;i++)
		text+=QString::number(gsl_vector_get (s->x, i))+" ";
	text+="|f(x)| = "+QString::number(gsl_blas_dnrm2 (s->f));

	infote->append(text);
}
#endif

void FitListDialog::apply_clicked() {
#ifdef HAVE_GSL
	// TODO : all selected graphs
	int item = (int) (lv->itemPos(lv->currentItem())/lv->currentItem()->height());
	GraphList *gl = p->getPlot(p->getAPI())->getGraphList();
	GRAPHType s = gl->getStruct(item);

	double base=0;
	if (baselinecb->isChecked()) {
		base = baselinele->text().toDouble();
		p->getPlot(p->getAPI())->setBaseline(base);
	}

	if (s == GRAPH2D) {
		Graph2D *g = gl->getGraph2D(item);
		int nx = g->Number();
		Point *a = g->getData();

		int np = parle->text().toInt();	 //  number of parameter
		Model model = (Model)  modelcb->currentItem();
		double* x = new double[nx];
		double* y = new double[nx];
		double* sigma = new double[nx];
		kdDebug()<<" MODEL = "<<model<<endl;
		kdDebug()<<" NP = "<<np<<endl;

		// start values
		double* x_init = new double[np];
		x_init[0]=par1le->text().toDouble();
		x_init[1]=par2le->text().toDouble();
		if ((Model)model == MEXP || (Model)model == MLORENTZ)
			x_init[2]=par3le->text().toDouble();
		if ((Model)model == MMULTIEXP2) {
			x_init[2]=par3le->text().toDouble();
			x_init[3]=par4le->text().toDouble();
		}
		if ((Model)model == MMULTIEXP3) {
			x_init[2]=par3le->text().toDouble();
			x_init[3]=par4le->text().toDouble();
			x_init[4]=par4le->text().toDouble();
			x_init[5]=par5le->text().toDouble();
		}

		gsl_vector_view v = gsl_vector_view_array (x_init, np);
		gsl_rng_env_setup();

		int N=0;
		for (int i = 0; i < nx; i++) {
			double xx=a[i].X();
			double yy=a[i].Y();
			if(!regioncb->isChecked() || xx > regionminle->text().toDouble() && xx < regionmaxle->text().toDouble()) {
				x[N] = xx;
				y[N] = yy;
				// TODO : use correct errors if available
				sigma[N] = 0.1;
				N++;
			}
        	};

	        struct data d = { N, x, y, sigma, np, model, base};
		gsl_multifit_function_fdf f;
		f.f = &fun_f;
		f.df = &fun_df;
		f.fdf = &fun_fdf;
		f.n = N;
		f.p = np;
		f.params = &d;

		const gsl_multifit_fdfsolver_type *T = gsl_multifit_fdfsolver_lmsder;
        	gsl_multifit_fdfsolver *s = gsl_multifit_fdfsolver_alloc (T, N, np);
        	gsl_multifit_fdfsolver_set (s, &f, &v.vector);
		int status,iter = 0;
		int maxsteps = stepsle->text().toInt();

		print_state (iter, s);

        	do {
                	iter++;
                	status = gsl_multifit_fdfsolver_iterate (s);

                	print_state (iter, s);

                	if (status)
                        	break;

			double tolerance = tolle->text().toDouble();
                	status = gsl_multifit_test_delta (s->dx, s->x,tolerance, tolerance);
        	} while (status == GSL_CONTINUE && iter < maxsteps);

		gsl_matrix *covar = gsl_matrix_alloc (np, np);
        	gsl_multifit_covar (s->J, 0.0, covar);

#define FIT(i) gsl_vector_get(s->x, i)
#define ERR(i) sqrt(gsl_matrix_get(covar,i,i))

		// info
		QString text;

		text += "a = "+QString::number(FIT(0))+" +/- "+QString::number(ERR(0));
		text += "\nb = "+QString::number(FIT(1))+" +/- "+QString::number(ERR(1));
		if ((Model)model == MEXP || (Model)model == MLORENTZ)
			text += "\nc = "+QString::number(FIT(2))+" +/- "+QString::number(ERR(2));
		if ((Model)model == MMULTIEXP2 ) {
			text += "\nc = "+QString::number(FIT(2))+" +/- "+QString::number(ERR(2));
			text += "\nd = "+QString::number(FIT(3))+" +/- "+QString::number(ERR(3));
		}
		if ((Model)model == MMULTIEXP3) {
			text += "\nc = "+QString::number(FIT(2))+" +/- "+QString::number(ERR(2));
			text += "\nd = "+QString::number(FIT(3))+" +/- "+QString::number(ERR(3));
			text += "\ne = "+QString::number(FIT(4))+" +/- "+QString::number(ERR(4));
			text += "\nf = "+QString::number(FIT(5))+" +/- "+QString::number(ERR(5));
		}
		text += "\nstatus = "+QString(gsl_strerror(status));
		infote->append(text);
		infote->scrollToBottom();

		// update parameter start values
		par1le->setText(QString::number(FIT(0)));
		par2le->setText(QString::number(FIT(1)));
		if ((Model)model == MEXP || (Model)model == MLORENTZ)
			par3le->setText(QString::number(FIT(2)));
		if ((Model)model == MMULTIEXP2 ) {
			par3le->setText(QString::number(FIT(2)));
			par4le->setText(QString::number(FIT(3)));
		}
		if ((Model)model == MMULTIEXP3 ) {
			par3le->setText(QString::number(FIT(2)));
			par4le->setText(QString::number(FIT(3)));
			par5le->setText(QString::number(FIT(4)));
			par6le->setText(QString::number(FIT(5)));
		}

		// create fit function
		int numberx = numberle->text().toInt();
		Point *ptr = new Point[numberx];
		double rangemin=minle->text().toDouble();
		double rangemax=maxle->text().toDouble();
		double xmin=0,xmax=0,ymin=0, ymax=1;
		for (int i = 0;i<numberx;i++) {
			double y=0,x=rangemin+i*(rangemax-rangemin)/(double)(numberx-1);
			if(model == MLINEAR)
				y = FIT(0)*x+FIT(1);
			else if(model == MEXP)
				y = FIT(0)*exp(-FIT(1)*x)+FIT(2);
			else if(model == MPOT)
				y = FIT(0)*pow(x,FIT(1));
			else if(model == MLN) {
				if(x<0)
					y = FIT(0);
				else if (x==0)
					y = 0;	// -inf
				else
					y = FIT(0) + FIT(1)*log(x);
			}
			else if(model == M1L)
				y = 1/(FIT(0) + FIT(1)*x);
			else if (model == MEXP2)
				y = FIT(0)*x*exp(-FIT(1)*x);
			else if (model == MGAUSSIAN)
				y = 1/(sqrt(2*M_PI)*FIT(0))*exp(-(x-FIT(1))*(x-FIT(1))/(2*FIT(0)*FIT(0)));
			else if (model == MMAXWELL)
				y = FIT(0)*x*x*exp(-FIT(1)*x*x);
			else if (model == MPLANCK) {
				if(x==0)
					y=0;
				else
					y = FIT(0)*x*x*x/(exp(FIT(1)*x)-1);
			}
			else if (model == MLORENTZ)
				y = FIT(0)/((x-FIT(1))*(x-FIT(1))+FIT(2)*FIT(2)/4);
			else if (model == MMULTIEXP2)
				y = FIT(0)*exp(FIT(1)*x)+FIT(2)*exp(FIT(3)*x);
			else if (model == MMULTIEXP3)
				y = FIT(0)*exp(FIT(1)*x)+FIT(2)*exp(FIT(3)*x)+FIT(4)*exp(FIT(5)*x);


			y += base;

			// new ranges
			if (i==0) {
				xmin=xmax=x;
				ymin=ymax=y;
			}
			else {
				x<xmin?xmin=x:0;
				x>xmax?xmax=x:0;
				y<ymin?ymin=y:0;
				y>ymax?ymax=y:0;
			}

			ptr[i].setPoint(x,y);
		}

        	gsl_multifit_fdfsolver_free (s);

		LRange range[2];
		range[0] = LRange(xmin,xmax);
		range[1] = LRange(ymin,ymax);

		// TODO : use fit function? (might be long)
		QString fun = QString("fit of "+g->Label());

		Style style(cb2->currentItem(),color->color(),filled->isChecked(),fcolor->color());
		Symbol symbol((SType)symbolcb->currentItem(),scolor->color(),
			ssize->text().toInt(),(FType)symbolfillcb->currentItem(),sfcolor->color());
		Graph2D *ng = new Graph2D(fun,fun,range,P2D,style,symbol,ptr,numberx);
		p->addGraph2D(ng);
	}
	else if (s == GRAPH3D) {
		// TODO
	}
	else if (s == GRAPHM) {
		// TODO
	}

	updateList();
#else
	KMessageBox::error(this, i18n("Sorry. Your installation doesn't support the GSL!"));
#endif
}
