// 
// Copyright (c) 2006-2010, Benjamin Kaufmann
// 
// This file is part of Clasp. See http://www.cs.uni-potsdam.de/clasp/ 
// 
// Clasp is free software; you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation; either version 2 of the License, or
// (at your option) any later version.
// 
// Clasp is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
// GNU General Public License for more details.
// 
// You should have received a copy of the GNU General Public License
// along with Clasp; if not, write to the Free Software
// Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA  02110-1301  USA
//
#include <clasp/solve_algorithms.h>
#include <clasp/solver.h>
#include <clasp/enumerator.h>
#include <clasp/lookahead.h>
#include <cmath>
namespace Clasp { 
/////////////////////////////////////////////////////////////////////////////////////////
// SolveParams
/////////////////////////////////////////////////////////////////////////////////////////
SolveParams::SolveParams() 
	: randFreq_(0.0)
	, shuffleFirst_(0), shuffleNext_(0) {
}
/////////////////////////////////////////////////////////////////////////////////////////
// ReduceParams
/////////////////////////////////////////////////////////////////////////////////////////
uint32 ReduceParams::initDbSize() const {
	double f   = std::max(0.0001f, baseFrac);
	uint32 ret = static_cast<uint32>(base/f);
	return std::min(std::max(ret, baseMin), baseMax);
}
uint32 ReduceParams::maxDbSize() const {
	double x   = std::max(base, baseMin) * dbMaxGrow;
	return (uint32)std::min(x, double(std::numeric_limits<uint32>::max()));
}
/////////////////////////////////////////////////////////////////////////////////////////
// Schedule
/////////////////////////////////////////////////////////////////////////////////////////
double growR(uint32 idx, double g)       { return pow(g, (double)idx); }
double addR(uint32 idx, double a)        { return a * idx; }
uint32 lubyR(uint32 idx)                 {
	uint32 i = idx + 1;
	while ((i & (i+1)) != 0) {
		i    -= ((1u << log2(i)) - 1);
	}
	return (i+1)>>1;
}
uint64 ScheduleStrategy::current() const {
	uint64 x;
	if      (base == 0) x = UINT64_MAX;
	else if (grow == 0) x = static_cast<uint64>(lubyR(idx)) * base;
	else if (type == 0) x = static_cast<uint64>(growR(idx, grow) * base);
	else                x = static_cast<uint64>(addR(idx, grow)  + base);
	return x + !x;
}
uint64 ScheduleStrategy::next() {
	++idx;
	uint64 x = current();
	if (outer && x > outer) {
		outer = x;
		idx   = 0;
		x     = current();
	}
	return x;
}
/////////////////////////////////////////////////////////////////////////////////////////
// solve
/////////////////////////////////////////////////////////////////////////////////////////
bool solve(SharedContext& ctx, const SolveParams& p) {
	return SimpleSolve().solve(ctx, p, LitVec());
}

bool solve(SharedContext& ctx, const SolveParams& p, const LitVec& assumptions) {
	return SimpleSolve().solve(ctx, p, assumptions);
}

/////////////////////////////////////////////////////////////////////////////////////////
// SolveAlgorithm
/////////////////////////////////////////////////////////////////////////////////////////
SolveAlgorithm::SolveAlgorithm()  {
}
SolveAlgorithm::~SolveAlgorithm() {}

bool SolveAlgorithm::backtrackFromModel(Solver& s) { 
	return s.sharedContext()->enumerator()->backtrackFromModel(s) == Enumerator::enumerate_continue;
}

void SolveAlgorithm::reportProgress(int t, Solver& s, uint64 maxCfl, uint32 maxL) {
	return s.sharedContext()->enumerator()->reportProgress(Enumerator::ProgressType(t), s, maxCfl, maxL);
}
bool SolveAlgorithm::solve(SharedContext& ctx, const SolveParams& p, LitVec assume) {
	assert(ctx.master() && "SharedContext not initialized!\n");
	if (!isSentinel(ctx.tagLiteral())) {
		assume.push_back(ctx.tagLiteral());
	}
	bool r = doSolve(*ctx.master(), p, assume);
	ctx.detach(*ctx.master());
	return r;
}
bool SolveAlgorithm::initPath(Solver& s, const LitVec& path, InitParams& params) {
	assert(!s.hasConflict() && s.decisionLevel() == 0);
	SingleOwnerPtr<Lookahead> look(0);
	if (params.initLook != 0 && params.lookType != Lookahead::no_lookahead) {
		look = new Lookahead(static_cast<Lookahead::Type>(params.lookType));
		look->init(s);
		s.addPost(look.release());
		--params.initLook;
	}
	bool ok = s.propagate() && s.simplify();
	if (look.get()) { 
		s.removePost(look.get());
		look = look.get(); // restore ownership
	}
	if (!ok) { return false; }
	// setup path
	for (LitVec::size_type i = 0, end = path.size(); i != end; ++i) {
		Literal p = path[i];
		if (s.value(p.var()) == value_free) {
			s.assume(p); --s.stats.choices;
			// increase root level - assumption can't be undone during search
			s.pushRootLevel();
			if (!s.propagate())  return false;
		}
		else if (s.isFalse(p)) return false;
	}
	// do random probings if any
	if (uint32 i = params.randRuns) {
		params.randRuns = 0;
		do {
			if (s.search(params.randConf, UINT32_MAX, false, 1.0) != value_free) { return !s.hasConflict(); }
			s.undoUntil(0);
		} while (--i);
	}
	// do initial lookahead choices if requested
	if (uint32 i = params.initLook) {
		params.initLook = 0;
		assert(look.get());
		RestrictedUnit::decorate(s, i, look.release());
	}
	return true;
}

ValueRep SolveAlgorithm::solvePath(Solver& s, const SolveParams& p) {
	typedef Enumerator::ProgressType ProgressType;
	typedef RestartParams::Type      RestartType;
	if (s.hasConflict()) return false;
	SearchLimits sLimit;
	WeightLitVec inDegree;
	ScheduleStrategy ds       = p.reduce.cflSched;
	ScheduleStrategy dg       = p.reduce.growSched;
	ScheduleStrategy rs       = p.restart.sched;
	RestartType      rt       = (RestartType)p.restart.type;
	Solver::DBInfo   db       = {0,0,0};
	SolveLimits gLimit        = p.limits;     // global limit
	uint64 rsLimit            = rs.current(); // current restart limit
	uint64 dsLimit            = ds.current(); // current deletion limit
	uint64 dgLimit            = dg.current(); // current grow limit
	uint64 minLimit           = 0;            // min of all limits
	gLimit.restarts           = std::max(gLimit.restarts, uint64(1));
	ValueRep result           = value_free;
	uint32 shuffle            = p.shuffleBase();
	ProgressType t            = Enumerator::progress_restart;
	double maxLearnts         = p.reduce.initDbSize();
	const double boundLearnts = p.reduce.maxDbSize();
	uint64 lastC              = s.stats.conflicts;
	uint64 lastR              = s.stats.restarts;
	bool growOnRestart        = p.reduce.growSched.ignore();
	if (maxLearnts < ds.base) {
		uint32 oldBase = ds.base;
		ds.base        = std::min(uint32(maxLearnts), std::max(uint32(5000), uint32(maxLearnts/2)));
		if (ds.type == ScheduleStrategy::arithmetic_schedule && ds.grow > 0.0) {
			double R     = ds.grow / oldBase;
			ds.grow      = ds.base * R;
		}
		dsLimit = ds.current();
	}
	if (p.reduce.growSched.disabled() && !growOnRestart) {
		maxLearnts = (double)UINT32_MAX; 
	}
	else if (maxLearnts < s.numLearntConstraints()) {
		maxLearnts = static_cast<double>(s.numLearntConstraints()) + p.reduce.baseMin;
		maxLearnts = std::min(maxLearnts, (double)UINT32_MAX);
	}
	if (rt == RestartParams::dynamic_restarts) {
		s.stats.enableQueue(rs.base);
		s.stats.queue->reset();
		sLimit.xLbd = (float)rs.grow;
		sLimit.xCfl = (float)(rs.outer / 1e8);
		rs          = ScheduleStrategy::arith(16000, 10000);
		rsLimit     = rs.current();
	}
	else if (rt == RestartParams::local_restarts) {
		sLimit.local= rsLimit;
		rsLimit     = UINT64_MAX;
	}
	while (result == value_free && !gLimit.reached()) {
		minLimit        = std::min(dgLimit, std::min(gLimit.conflicts, std::min(rsLimit, dsLimit)));
		sLimit.learnts  = (uint32)(maxLearnts + (p.reduce.strategy.noGlue*db.pinned));
		sLimit.conflicts= minLimit;
		if (t != Enumerator::num_progress_types) { reportProgress(t, s, std::min(minLimit, sLimit.local), sLimit.learnts); }
		if (sLimit.conflicts)                    { result = s.search(sLimit, p.randomProbability()); }
		minLimit   = (minLimit - sLimit.conflicts); // number of actual conflicts
		if (gLimit.conflicts != UINT64_MAX)      { gLimit.conflicts -= minLimit; }
		if (result == value_true && backtrackFromModel(s)) {
			result   = value_free; // continue enumeration
			t        = Enumerator::progress_model;
			if (p.restart.resetOnModel) {
				rs.reset();
			}
			// After the first solution was found, we allow further restarts only if this
			// is compatible with the enumerator used. 
			dsLimit  = ds.current();
			rsLimit  = std::max(rsLimit, rs.current());
			if (!p.restart.bounded && s.backtrackLevel() > s.rootLevel()) {
				sLimit = SearchLimits();
				rsLimit= UINT64_MAX;
			}
		}
		else if (result == value_free){  // limit reached
			rsLimit -= (rt!=RestartParams::local_restarts) * minLimit;
			dsLimit -= (ds.base != 0) * minLimit;
			dgLimit -= (dg.base != 0) * minLimit;
			minLimit = 0;
			t        = Enumerator::num_progress_types;
			if (s.numFreeVars() != 0 && (rsLimit == 0 || sLimit.local == 0 || sLimit.dynamicRestart(s.stats))) {
				// restart reached - do restart
				++s.stats.restarts;
				if (p.restart.cir && (s.stats.restarts % p.restart.cir) == 0 ) {
					inDegree.clear();
					s.strategies().heuristic->bump(s, inDegree, p.restart.cirBump / (double)s.inDegree(inDegree));
				}
				if (rt == RestartParams::dynamic_restarts) {
					uint64 num = s.stats.restarts  - lastR;
					uint64 cfl = s.stats.conflicts - lastC;
					if (cfl   >=  rs.current()) {
						double avg = cfl / double(num);
						lastR      = s.stats.restarts;
						lastC      = s.stats.conflicts;
						if      (avg >= 16000.0){ sLimit.xLbd += 0.1f;  rs.reset(); }
						else if (rsLimit == 0)  { sLimit.xLbd += 0.05f; rs.idx -= (rs.idx != 0); }
						else if (avg >= 4000.0) { sLimit.xLbd += 0.05f; }
						else if (avg <  1000.0) { sLimit.xLbd -= 0.05f; }
						else                    { ++rs.idx; }
					}
					rsLimit = rs.current();
					minLimit= s.stats.queue->samples;
					s.stats.queue->reset();
				}
				s.stats.lRestart = s.stats.analyzed;
				s.undoUntil(0);
				t = Enumerator::progress_restart;
				if (rsLimit == 0)                { rsLimit      = rs.next(); }
				if (sLimit.local == 0)           { sLimit.local = rs.next(); }
				if (p.reduce.reduceOnRestart)    { db           = s.reduceLearnts(.33f, p.reduce.strategy); }
				if (growOnRestart && !minLimit)  { minLimit     = rs.current(); }
				if (s.stats.restarts == shuffle) {
					shuffle += p.shuffleNext();
					s.shuffleOnNextSimplify();
				}
				--gLimit.restarts;
			}
			else if (dsLimit == 0 || s.numLearntConstraints() >= sLimit.learnts) {
				// reduction reached - remove learnt constraints
				db      = s.reduceLearnts(p.reduce.remFrac, p.reduce.strategy);
				dsLimit = dsLimit != 0 ? ds.current() : ds.next();
				t       = Enumerator::progress_reduce;
				if (db.size >= sLimit.learnts || db.pinned >= maxLearnts) { 
					ReduceStrategy t; t.algo = 2; t.score = 2; t.glue = 0;
					db.pinned /= 2;
					db.size    = s.reduceLearnts(0.5f, t).size;
					if (db.size >= sLimit.learnts){
						maxLearnts += std::max(100.0, s.numLearntConstraints()/10.0);
					}
				}
			}
			if (dgLimit == 0 || (growOnRestart && t == Enumerator::progress_restart)) {
				// grow sched reached - increase max db size
				if (!dgLimit)                                          { dgLimit     = dg.next(); minLimit = dgLimit; }
				if ((s.numLearntConstraints() + minLimit) > maxLearnts){ maxLearnts *= p.reduce.dbGrow; }
				if (maxLearnts > boundLearnts)                         { maxLearnts  = boundLearnts; dgLimit = UINT64_MAX; growOnRestart = false; }
			}
		}
	}
	p.limits = gLimit;
	return result;
}
/////////////////////////////////////////////////////////////////////////////////////////
// SimpleSolve
/////////////////////////////////////////////////////////////////////////////////////////
bool SimpleSolve::terminate() { return false; }
bool SimpleSolve::doSolve(Solver& s, const SolveParams& p, const LitVec& assume) {
	s.stats.reset();
	Enumerator*  enumerator = s.sharedContext()->enumerator();
	bool hasWork   = true, complete = true;
	InitParams init= p.init;
	// Remove any existing assumptions and restore solver to a usable state.
	// If this fails, the problem is unsat, even under no assumptions.
	while (s.clearAssumptions() && hasWork) {
		// Add assumptions - if this fails, the problem is unsat 
		// under the current assumptions but not necessarily unsat.
		if (initPath(s, assume, init)) {
			complete = (solvePath(s, p) != value_free && s.decisionLevel() == s.rootLevel());
		}
		// finished current work item
		hasWork    = complete && enumerator->optimizeNext();
	} 
	enumerator->reportResult(complete);
	return !complete;
}
}
