#include "eval/eval.h"
#include "eval/progress.h"
#include "eval/progressFeature.h"
#include "eval/openMidEnding.h"
#include "eval/progressEval.h"
#include "analyzer.h"
#include "osl/eval/pieceEval.h"
#include "osl/record/csaString.h"
#include "osl/record/csaRecord.h"
#include "osl/oslConfig.h"
#include "osl/eval/ml/openMidEndingEval.h"
#include "osl/eval/progressEval.h"
#include "osl/progress/ml/newProgress.h"
#include <ext/algorithm>
#include <fstream>
#include <iostream>
#include <valarray>

#include "gtest/gtest.h"

using namespace osl;
using namespace osl::eval;
extern bool isShortTest;

class EvalTest : public testing::Test
{
protected:
  void SetUp()
  {
    osl::eval::ml::OpenMidEndingEval::setUp();
    osl::progress::ml::NewProgress::setUp();
    osl::eval::ProgressEval::setUp();
  }
  int Gold() { return gpsshogi::PieceEval().value(GOLD); }
  int Rook() { return gpsshogi::PieceEval().value(ROOK); }
  int total(const gpsshogi::Eval& eval, osl::vector<std::pair<int, double> >& diff) 
  {
    int sum = 0;
    for (size_t i=0; i<diff.size(); ++i) {
      sum += (int)(eval.flatValue(diff[i].first)*diff[i].second);
    }
    return sum;
  }
  void testConsistentBase(boost::ptr_vector<gpsshogi::Eval> &evals,
			  double error)
  {
    for (size_t j=0; j<evals.size(); ++j) {
      evals[j].setRandom();
    }
  
    std::ifstream ifs(OslConfig::testCsaFile("FILES"));
    ASSERT_TRUE(ifs);
    std::string file_name;
    for (int i=0;i<900 && (ifs >> file_name) ; i++)
    {
      if ((i % 100) == 0)
	std::cerr << '.';
      if (file_name == "") 
	break;
      file_name = OslConfig::testCsaFile(file_name);

      const Record record=CsaFile(file_name).getRecord();
      const vector<osl::Move> moves=record.getMoves();

      NumEffectState state(record.getInitialState());
    
      for (unsigned int i=0; i<moves.size(); i++) {
	const Move m = moves[i];
	state.makeMove(m);

	for (size_t j=0; j<evals.size(); ++j) {
	  const int value = evals[j].eval(state);
	
	  double features_value = -1;
	  vector<std::pair<int, double> > diffs;
	  evals[j].features(state, features_value, diffs, 0);
#ifndef L1BALL_NO_SORT
          ASSERT_TRUE(__gnu_cxx::is_sorted(diffs.begin(), diffs.end()));
#endif
	  if (abs(value - total(evals[j], diffs)) > evals[j].roundUp()) {
	    std::cerr << "eval " << j << "\n" << state << m << " " << features_value << "\n";
	    for (size_t k=0; k<diffs.size(); ++k)
	      std::cerr << diffs[k].first << " "
			<< evals[j].findFeature(diffs[k].first).get<0>()
			<< " " << diffs[k].second
			<< "  " << evals[j].flatValue(diffs[k].first) << "\n";
	    std::cerr << "inequality " << value << " " << total(evals[j], diffs)
		      << ' ' << features_value << "\n";
	  }
	  ASSERT_NEAR(value, total(evals[j], diffs), evals[j].roundUp());
	  ASSERT_NEAR(value, features_value, error);
	}
      }
    }
  }
};

TEST_F(EvalTest, testStableOpenMidEndingEval)
{  
  gpsshogi::StableOpenMidEnding eval;
  {
    const char *filename = "../stable-eval.txt";
    eval.load(filename);
  }
  using namespace osl;
  std::ifstream ifs(OslConfig::testCsaFile("FILES"));
  ASSERT_TRUE(ifs);
  std::string file_name;

  osl::eval::ml::OpenMidEndingEval::setUp();
  osl::progress::ml::NewProgress::setUp();
  for (int i=0;i<(isShortTest ? 10 : 200) && (ifs >> file_name) ; i++)
  {
    if (file_name == "")
      break;
    file_name = OslConfig::testCsaFile(file_name);
    const Record rec=CsaFile(file_name).getRecord();
    const vector<osl::Move> moves=rec.getMoves();
    NumEffectState state(rec.getInitialState());
    osl::eval::ml::OpenMidEndingEval test_eval(state, false);
    for (unsigned int i=0; i<moves.size(); i++){
      const Move m = moves[i];
      state.makeMove(m);
      test_eval.update(state, m);
      // test_eval.debug();
      if (test_eval.openingValue() != eval.openingValue(state)) {
	std::cerr << "osl\n";
	test_eval.debug();
	std::cerr << "gpsshogi\n";
	eval.debug(state);
      }
      ASSERT_EQ(test_eval.openingValue(), eval.openingValue(state)) << state << m << " progress " << eval.progress(state)
								    << " progress max " << eval.progressMax();
      ASSERT_EQ(test_eval.midgameValue(), eval.midgameValue(state)) << state << m;
#ifdef EVAL_QUAD
      ASSERT_EQ(test_eval.midgame2Value(), eval.midgame2Value(state)) << state << m;
#endif
      ASSERT_EQ(test_eval.endgameValue(), eval.endgameValue(state)) << state << m;
      ASSERT_EQ(test_eval.progressIndependentValue(), eval.progressIndependentValue(state)) << state << m;
      if (abs(test_eval.value() - eval.eval(state)) > 2.0) {
	const int progress = eval.progress(state);
	MultiInt stage_value;
	stage_value[0] = eval.openingValue(state);
	stage_value[1] = eval.midgameValue(state);
	stage_value[2] = eval.midgame2Value(state);
	stage_value[3] = eval.endgameValue(state);
	int flat = test_eval.progressIndependentValueAdjusted
	  (eval.progressIndependentValue(state),progress,eval.progressMax());
	ASSERT_NEAR(test_eval.value(),
		    eval.compose(flat, stage_value, progress),
		    2.0) << state << m;
      }
      else
	ASSERT_NEAR(test_eval.value(), eval.eval(state), 2.0) << state << m;
    }
  }
  // additional positions
  {
    const NumEffectState state(CsaString(
				 "P1-KY *  *  *  * -OU * -KE-KY\n"
				 "P2 * -HI *  *  *  * -KI *  * \n"
				 "P3 *  * -KE-FU * -KI-GI-FU * \n"
				 "P4 * -GI-FU-KA-FU-FU-FU * -FU\n"
				 "P5-FU-FU *  *  *  *  * +FU * \n"
				 "P6 *  * +FU+FU+FU * +FU * +FU\n"
				 "P7+FU+FU+GI+KI+KA+FU+GI * +KY\n"
				 "P8 * +OU+KI *  *  *  *  * +HI\n"
				 "P9+KY+KE *  *  *  *  * +KE * \n"
				 "+\n").getInitialState());
    osl::eval::ml::OpenMidEndingEval test_eval(state, false);
    ASSERT_NEAR(test_eval.value(), eval.eval(state), 2.0) << state;
  }
  {
    const NumEffectState state(CsaString(
				 "P1-KY-HI *  *  *  *  *  *  * \n"
				 "P2 *  *  *  * +GI *  *  *  * \n"
				 "P3-FU * -FU *  * +FU-KE * -KY\n"
				 "P4 *  * +FU *  * -KI-OU-FU-FU\n"
				 "P5 *  *  *  *  * -FU *  *  * \n"
				 "P6 * +GI * +KI+FU-GI-UM-KA * \n"
				 "P7+FU+FU *  * -FU * -TO * +KE\n"
				 "P8+OU+KI * +FU *  * -FU *  * \n"
				 "P9+KY+KE *  *  *  *  * +HI+KY\n"
				 "P+00KI00GI00KE00FU00FU\n"
				 "P-00FU00FU\n"
				 "+\n").getInitialState());
    osl::eval::ml::OpenMidEndingEval test_eval(state, false);
    ASSERT_NEAR(test_eval.value(), eval.eval(state), 2.0) << state;
  }
  {
    const NumEffectState state(CsaString(
				 "P1 *  *  * +UM *  *  * -OU-KY\n"
				 "P2+TO *  *  *  *  * +UM *  * \n"
				 "P3 *  *  *  *  *  *  * -FU * \n"
				 "P4 *  * -FU *  * -FU-FU * -FU\n"
				 "P5 *  *  * -GI * -KE *  *  * \n"
				 "P6-KY+FU+FU *  *  * +FU+FU+FU\n"
				 "P7 *  *  *  *  *  * -KE *  * \n"
				 "P8 *  *  *  * +GI *  * +OU * \n"
				 "P9 *  *  *  *  * +KI * +KE+KY\n"
				 "P+00HI00HI00KI00KI00GI\n"
				 "P-00KI00GI00KE00KY00FU00FU00FU00FU00FU00FU00FU\n"
				 "-\n").getInitialState());
    osl::eval::ml::OpenMidEndingEval test_eval(state, false);
    ASSERT_NEAR(test_eval.value(), eval.eval(state), 2.0) << state;
  }
}

TEST_F(EvalTest, testAllDifferential)
{
  gpsshogi::PieceEval eval;
  double value;
  {
    const NumEffectState state(CsaString(
				 "P1-KY-KE-GI-KI-OU-KI-GI-KE-KY\n"
				 "P2 * -HI *  *  *  *  * -KA * \n"
				 "P3-FU-FU-FU-FU-FU-FU-FU-FU-FU\n"
				 "P4 *  *  *  *  *  *  *  *  * \n"
				 "P5 *  *  * -KI *  *  *  *  * \n"
				 "P6 *  *  *  *  *  *  *  *  * \n"
				 "P7+FU+FU+FU+FU+FU+FU+FU+FU+FU\n"
				 "P8 * +KA *  *  *  *  * +HI * \n"
				 "P9+KY+KE+GI * +OU+KI+GI+KE+KY\n"
				 "+\n").getInitialState());
    gpsshogi::MoveData md;
    eval.features(state, value, md.diffs, 0);
  }
  {
    const NumEffectState state(CsaString(
				 "P1-KY-KE-GI-KI-OU-KI-GI-KE-KY\n"
				 "P2 * -HI *  *  *  *  * -KA * \n"
				 "P3-FU-FU-FU-FU-FU * -FU-FU-FU\n"
				 "P4 *  *  *  *  * +KA *  *  * \n"
				 "P5 *  *  *  *  *  *  *  *  * \n"
				 "P6 *  * +FU *  *  *  *  *  * \n"
				 "P7+FU+FU * +FU+FU+FU+FU+FU+FU\n"
				 "P8 *  *  *  *  *  *  * +HI * \n"
				 "P9+KY+KE+GI+KI+OU+KI+GI+KE+KY\n"
				 "P+00FU\n"
				 "-\n").getInitialState());
    gpsshogi::MoveData md;
    eval.features(state, value, md.diffs, 0);
  }
}

TEST_F(EvalTest, testConstruct)
{
  ASSERT_TRUE(osl::eval::ml::OpenMidEndingEval::setUp());
  ASSERT_TRUE(osl::progress::ml::NewProgress::setUp());
  gpsshogi::PieceEval eval;
  gpsshogi::RichEval reval(0);
  {
    const NumEffectState state((SimpleState(HIRATE)));
    EXPECT_EQ(0, eval.eval(state));
    EXPECT_EQ(0, reval.eval(state));
  }
  
  {
    const NumEffectState state(CsaString(
			      "P1-KY-KE-GI-KI-OU-KI-GI-KE-KY\n"
			      "P2 * -HI *  *  *  *  * -KA * \n"
			      "P3-FU-FU-FU-FU-FU-FU-FU-FU-FU\n"
			      "P4 *  *  *  *  *  *  *  *  * \n"
			      "P5 *  *  * -KI *  *  *  *  * \n"
			      "P6 *  *  *  *  *  *  *  *  * \n"
			      "P7+FU+FU+FU+FU+FU+FU+FU+FU+FU\n"
			      "P8 * +KA *  *  *  *  * +HI * \n"
			      "P9+KY+KE+GI * +OU+KI+GI+KE+KY\n"
			      "+\n").getInitialState());
    EXPECT_EQ(-Gold()*2, eval.eval(state));
    EXPECT_EQ(-Gold()*2, reval.eval(state));

    double value = -1;
    vector<std::pair<int, double> > diffs;
    reval.features(state, value, diffs, 0);
    EXPECT_EQ(-Gold()*2, static_cast<int>(value));
  }

  {
    const NumEffectState state(CsaString(
			      "P1-KY-KE-GI-KI-OU-KI-GI-KE-KY\n"
			      "P2 *  *  *  *  *  *  * -KA * \n"
			      "P3-FU-FU-FU-FU-FU-FU-FU-FU-FU\n"
			      "P4 *  *  *  *  *  *  *  *  * \n"
			      "P5 *  *  *  *  *  *  *  *  * \n"
			      "P6 *  *  *  *  *  *  *  *  * \n"
			      "P7+FU+FU+FU+FU+FU+FU+FU+FU+FU\n"
			      "P8 * +KA *  *  *  *  * +HI * \n"
			      "P9+KY+KE+GI+KI+OU+KI+GI+KE+KY\n"
			      "P+00HI\n"
			      "+\n").getInitialState());
    EXPECT_EQ(Rook()*2, eval.eval(state));
    EXPECT_EQ(Rook()*2, reval.eval(state));
  }
}

TEST_F(EvalTest, testConsistent)
{
  boost::ptr_vector<gpsshogi::Eval> evals;
  evals.push_back(new gpsshogi::PieceEval);
  evals.push_back(new gpsshogi::RichEval(2));
  testConsistentBase(evals, 0.1);
}

TEST_F(EvalTest, testConsistentProgress)
{
  boost::ptr_vector<gpsshogi::Eval> evals;
#ifdef LEARN_TEST_PROGRESS
  evals.push_back(new gpsshogi::HandProgressFeatureEval());
  evals.push_back(new gpsshogi::EffectProgressFeatureEval());
#endif
  evals.push_back(new gpsshogi::OpenMidEndingForTest(1));
  evals.push_back(new gpsshogi::KOpenMidEnding());
  evals.push_back(new gpsshogi::StableOpenMidEnding());
  evals.push_back(new gpsshogi::OslOpenMidEnding());
  //evals.push_back(new gpsshogi::KProgressEval());
  testConsistentBase(evals, 0.1 * 16);
}

TEST_F(EvalTest, testConsistentUpdate)
{
  boost::ptr_vector<gpsshogi::Eval> evals;
  evals.push_back(new gpsshogi::PieceEval);
  evals.push_back(new gpsshogi::RichEval(2));
#ifdef LEARN_TEST_PROGRESS
  evals.push_back(new gpsshogi::HandProgressFeatureEval());
  evals.push_back(new gpsshogi::EffectProgressFeatureEval());
#endif
  evals.push_back(new gpsshogi::OpenMidEndingForTest(1));
  evals.push_back(new gpsshogi::KOpenMidEnding());
  evals.push_back(new gpsshogi::KProgressEval());

  osl::vector<int> values(evals.size());
  for (size_t j=0; j<evals.size(); ++j) {
    evals[j].setRandom();
  }
  
  std::ifstream ifs(OslConfig::testCsaFile("FILES"));
  ASSERT_TRUE(ifs);
  std::string file_name;
  for (int i=0;i<900 && (ifs >> file_name) ; i++)
  {
    if ((i % 100) == 0)
      std::cerr << '.';
    if (file_name == "") 
      break;
    file_name = OslConfig::testCsaFile(file_name);

    const Record record=CsaFile(file_name).getRecord();
    const vector<osl::Move> moves=record.getMoves();

    NumEffectState state(record.getInitialState());
    
    boost::ptr_vector<gpsshogi::EvalValueStack> eval_values;
    for (size_t j=0; j<evals.size(); ++j) {
      values[j] = evals[j].eval(state);
      eval_values.push_back(evals[j].newStack(state));
    }
    for (unsigned int i=0; i<moves.size(); i++) {
      const Move m = moves[i];
      state.makeMove(m);

      for (size_t j=0; j<evals.size(); ++j) {
	const int value = evals[j].eval(state);

	eval_values[j].push(state, m);
	EXPECT_EQ(value, eval_values[j].value())
	  << "eval " << j << "\n" << state << m << "\n";
      }
      if (! state.inCheck()) {
	state.changeTurn();
	for (size_t j=0; j<evals.size(); ++j) {
	  eval_values[j].push(state, Move::PASS(alt(state.turn())));
	  const int value = evals[j].eval(state);
	  EXPECT_EQ(value, eval_values[j].value())
	    << " eval(pass) " << j << "\n" << state << m << "\n";
	  eval_values[j].pop();
	}
	state.changeTurn();
      }
    }
  }
}

TEST_F(EvalTest, testSaveWeight)
{
  boost::ptr_vector<gpsshogi::Eval> evals;
  evals.push_back(new gpsshogi::PieceEval);
  evals.push_back(new gpsshogi::RichEval(2));
#ifdef LEARN_TEST_PROGRESS
  evals.push_back(new gpsshogi::HandProgressFeatureEval());
  evals.push_back(new gpsshogi::EffectProgressFeatureEval());
#endif
  evals.push_back(new gpsshogi::OpenMidEndingForTest(1));
  evals.push_back(new gpsshogi::KOpenMidEnding());
  evals.push_back(new gpsshogi::KProgressEval());

  osl::vector<int> values(evals.size());
  for (size_t j=0; j<evals.size(); ++j) {
    evals[j].setRandom();
    std::valarray<double> weight(evals[j].dimension());
    evals[j].saveWeight(&weight[0]);
    for (size_t i=0; i<evals[j].dimension(); ++i)
      ASSERT_EQ(evals[j].flatValue(i), weight[i]);
  }  
}

TEST_F(EvalTest, testSymmetry)
{  
  boost::ptr_vector<gpsshogi::Eval> evals;
  evals.push_back(new gpsshogi::PieceEval);
  evals.push_back(new gpsshogi::RichEval(2));
#ifdef LEARN_TEST_PROGRESS
  evals.push_back(new gpsshogi::HandProgressFeatureEval());
  evals.push_back(new gpsshogi::EffectProgressFeatureEval());
#endif
  evals.push_back(new gpsshogi::OpenMidEndingForTest(1));
  evals.push_back(new gpsshogi::KOpenMidEnding());
  // evals.push_back(new gpsshogi::KProgressEval()); not symmetric yet

  for (size_t j=0; j<evals.size(); ++j) {
    evals[j].setRandom();
  }

  using namespace osl;
  std::ifstream ifs(OslConfig::testCsaFile("FILES"));
  ASSERT_TRUE(ifs);
  std::string file_name;
  for (int i=0;i<(isShortTest ? 10 : 200) && (ifs >> file_name) ; i++)
  {
    if (file_name == "")
      break;
    file_name = OslConfig::testCsaFile(file_name);
    const Record rec=CsaFile(file_name).getRecord();
    const vector<osl::Move> moves=rec.getMoves();
    NumEffectState state(rec.getInitialState());
    for (unsigned int i=0; i<moves.size(); i++){
      const Move m = moves[i];
      state.makeMove(m);
      NumEffectState state_r(state.rotate180());
      for (size_t j=0; j<evals.size(); ++j) {
	if (evals[j].eval(state) != -evals[j].eval(state_r)) {
	  std::cerr << "asymmetry found " << j << "\n" << state
		    << m
		    << "\n";
	  evals[j].showEvalSummary(state);
	  std::cerr << "\n";
	  evals[j].showEvalSummary(state_r);
	}
	ASSERT_NEAR(evals[j].eval(state), -evals[j].eval(state_r),
		    evals[j].roundUp()*2);
      }
    }
  } 
}

#ifdef LEARN_TEST_PROGRESS
TEST_F(EvalTest, testStableProgress)
{  
  gpsshogi::StableEffectProgressFeatureEval progress;
  {
    std::string filename = OslConfig::home();
    filename += "/data/progress.txt";
    progress.load(filename.c_str());
  }
  using namespace osl;
  std::ifstream ifs(OslConfig::testCsaFile("FILES"));
  ASSERT_TRUE(ifs);
  std::string file_name;

  const int osl_progress_scale = osl::progress::ml::NewProgress::ProgressScale;

  osl::progress::ml::NewProgress::setUp();
  ASSERT_EQ(osl::progress::ml::NewProgress::maxProgress(),
	    progress.maxProgress() / osl_progress_scale) << progress.maxProgress();
  for (int i=0;i<(isShortTest ? 10 : 200) && (ifs >> file_name) ; i++)
  {
    if (file_name == "")
      break;
    file_name = OslConfig::testCsaFile(file_name);
    const Record rec=CsaFile(file_name).getRecord();
    const vector<osl::Move> moves=rec.getMoves();
    NumEffectState state(rec.getInitialState());
    osl::progress::ml::NewProgress test_progress(state);
    for (unsigned int i=0; i<moves.size(); i++){
      const Move m = moves[i];
      state.makeMove(m);
      test_progress.update(state, m);
      ASSERT_EQ(test_progress.progress(),
		std::max(std::min(progress.progress(state),
				  progress.maxProgress() - osl_progress_scale), 0) / osl_progress_scale)
	<< state << m;
    }
  } 

  gpsshogi::StableOpenMidEnding eval;
  {
    const char *filename = "../stable-eval.txt";
    eval.load(filename);
  }
  ASSERT_NEAR(progress.pawnValue() / osl_progress_scale, eval.pawnValue(),
    eval.pawnValue() / 10);
}
#endif

TEST_F(EvalTest, testStableProgressEval)
{  
  osl::eval::ProgressEval::setUp();
  gpsshogi::StableProgressEval eval;
  {
    const char *filename = "../stable-progresseval.txt";
    eval.load(filename);
  }
  using namespace osl;
  std::ifstream ifs(OslConfig::testCsaFile("FILES"));
  ASSERT_TRUE(ifs);
  std::string file_name;

  for (int i=0;i<(isShortTest ? 10 : 200) && (ifs >> file_name) ; i++)
  {
    if (file_name == "")
      break;
    file_name = OslConfig::testCsaFile(file_name);
    const Record rec=CsaFile(file_name).getRecord();
    const vector<osl::Move> moves=rec.getMoves();
    NumEffectState state(rec.getInitialState());
    osl::eval::ProgressEval osl_eval(state);
    for (unsigned int i=0; i<moves.size(); i++){
      const Move m = moves[i];
      state.makeMove(m);
      osl_eval.update(state, m);
      ASSERT_EQ(osl_eval.openingValue(), eval.openingValue(state)) << state << m << " progress " << eval.progress(state)
								    << " progress max " << eval.progressMax();
      ASSERT_EQ(osl_eval.endgameValue(), eval.endgameValue(state)) << state << m;
      ASSERT_NEAR(osl_eval.value(), eval.eval(state), 2.0) << state << m;
    }
  }
  // additional positions
  {
    const NumEffectState state(CsaString(
				 "P1-KY *  *  *  * -OU * -KE-KY\n"
				 "P2 * -HI *  *  *  * -KI *  * \n"
				 "P3 *  * -KE-FU * -KI-GI-FU * \n"
				 "P4 * -GI-FU-KA-FU-FU-FU * -FU\n"
				 "P5-FU-FU *  *  *  *  * +FU * \n"
				 "P6 *  * +FU+FU+FU * +FU * +FU\n"
				 "P7+FU+FU+GI+KI+KA+FU+GI * +KY\n"
				 "P8 * +OU+KI *  *  *  *  * +HI\n"
				 "P9+KY+KE *  *  *  *  * +KE * \n"
				 "+\n").getInitialState());
    osl::eval::ProgressEval osl_eval(state);
    ASSERT_NEAR(osl_eval.value(), eval.eval(state), 2.0) << state;
  }
  {
    const NumEffectState state(CsaString(
				 "P1-KY-HI *  *  *  *  *  *  * \n"
				 "P2 *  *  *  * +GI *  *  *  * \n"
				 "P3-FU * -FU *  * +FU-KE * -KY\n"
				 "P4 *  * +FU *  * -KI-OU-FU-FU\n"
				 "P5 *  *  *  *  * -FU *  *  * \n"
				 "P6 * +GI * +KI+FU-GI-UM-KA * \n"
				 "P7+FU+FU *  * -FU * -TO * +KE\n"
				 "P8+OU+KI * +FU *  * -FU *  * \n"
				 "P9+KY+KE *  *  *  *  * +HI+KY\n"
				 "P+00KI00GI00KE00FU00FU\n"
				 "P-00FU00FU\n"
				 "+\n").getInitialState());
    osl::eval::ProgressEval osl_eval(state);
    ASSERT_NEAR(osl_eval.value(), eval.eval(state), 2.0) << state;
  }
}

/* ------------------------------------------------------------------------- */
// ;;; Local Variables:
// ;;; mode:c++
// ;;; c-basic-offset:2
// ;;; End:
