/* $Id: GCArrays.cpp 4488 2009-04-21 11:50:11Z potyra $ 
 *
 * Generate intermediate code, array specific parts.
 *
 * Copyright (C) 2008-2009 FAUmachine Team <info@faumachine.org>.
 * This program is free software. You can redistribute it and/or modify it
 * under the terms of the GNU General Public License, either version 2 of
 * the License, or (at your option) any later version. See COPYING.
 */


#include "frontend/visitor/GCArrays.hpp"
#include <cassert>
#include "frontend/visitor/GenCode.hpp"
#include "frontend/visitor/ResolveTypes.hpp"
#include "frontend/visitor/GCTypes.hpp"
#include "frontend/ast/DiscreteRange.hpp"
#include "frontend/ast/UnconstrainedArrayType.hpp"
#include "intermediate/operands/RegisterFactory.hpp"
#include "intermediate/operands/ImmediateOperand.hpp"
#include "intermediate/operands/IndirectOperand.hpp"
#include "intermediate/container/LabelFactory.hpp"
#include "intermediate/container/TypeFactory.hpp"
#include "intermediate/opcodes/Mov.hpp"
#include "intermediate/opcodes/Je.hpp"
#include "intermediate/opcodes/Jb.hpp"
#include "intermediate/opcodes/Jbe.hpp"
#include "intermediate/opcodes/Jmp.hpp"
#include "intermediate/opcodes/Sub.hpp"
#include "intermediate/opcodes/IMul.hpp"
#include "intermediate/opcodes/Add.hpp"
#include "intermediate/opcodes/AOffset.hpp"

namespace ast {

using namespace intermediate;

/*
 * ===================== ARRAY HANDLING =======================
 */

ArrayHandling::ArrayHandling(
	TypeDeclaration *at,
	Operand *b,
	CodeContainer &container,
	std::list<Operand *> lbounds,
	std::list<Operand *> ubounds
) :		arrayType(at),
		base(b),
		cc(container),
		lowerBounds(lbounds),
		upperBounds(ubounds) 
{
	assert(at != NULL);
	assert(at->baseType == BASE_TYPE_ARRAY);

	GCTypes::GenTypeElements gte = 
		GCTypes::GenTypeElements(false, NULL, *this->arrayType, NULL);
	this->arrayType->accept(gte);

	this->indices = gte.getIndices();
	for (std::list<DiscreteRange*>::const_iterator i = 
		this->indices.begin();
		i != this->indices.end(); i++) {

		ImmediateOperand *lb = 
			new ImmediateOperand((*i)->getLowerBound());
		ImmediateOperand *ub = 
			new ImmediateOperand((*i)->getUpperBound());
		this->lowerBounds.push_back(lb);
		this->upperBounds.push_back(ub);
	}

	assert(gte.composite.size() == 1);
	assert(gte.referredTypes.size() == 1);

	TypeElement *elem = gte.composite.front();
	std::string name = elem->name;

	this->itype = TypeFactory::lookupType(name);
	assert(this->itype != NULL);

	this->elementType = gte.referredTypes.front();
}

ArrayHandling::~ArrayHandling()
{
}

void
ArrayHandling::factorize(void)
{
	// reverse dimension sizes
	// (..., d6, d5, d4, d3, d2)
	std::list<Operand *> rsizes;

	assert(this->lowerBounds.size() == this->upperBounds.size());
	assert(this->lowerBounds.size() > 0);

	std::list<Operand *>::const_reverse_iterator lbit = 
		this->lowerBounds.rbegin();
	std::list<Operand *>::const_reverse_iterator ubit = 
		this->upperBounds.rbegin();

	std::list<Operand *>::const_reverse_iterator llast = 
		this->lowerBounds.rend();
	llast--;
	// iterate over all but *first* entry (first is the actual dimension
	// that gets subscribed to with the first subscription element,
	// so the factor is always constant 1)
	for (; lbit != llast; lbit++, ubit++) {
		Register *r1 = this->cc.createRegister(OP_TYPE_INTEGER);
		Sub *diff = new Sub(*ubit, *lbit, r1);
		this->cc.addCode(diff);

		Register *r2 = this->cc.createRegister(OP_TYPE_INTEGER);
		Add *inc = new Add(r1, ImmediateOperand::getOne(), r2);
		this->cc.addCode(inc);
		rsizes.push_back(r2);
	}

	// factors should be
	// (d2 * d3 * d4 * d5 * d6) 
	// (d3 * d4 * d5 * d6)
	// (d4 * d5 * d6)
	// (d5 * d6)
	// (d6)
	intermediate::Operand *factorial = ImmediateOperand::getOne();
	for (std::list<Operand *>::const_iterator i = rsizes.begin();
		i != rsizes.end(); i++) {

		Register *r1 = this->cc.createRegister(OP_TYPE_INTEGER);
		IMul *m = new IMul(factorial, *i, r1);
		this->cc.addCode(m);
		factorial = r1;
		this->dimensionFactors.push_front(factorial);
	}
	
	// also store 1 as factor for first dimension
	this->dimensionFactors.push_front(ImmediateOperand::getOne());
}

Register *
ArrayHandling::subscribe(std::list<Operand *> relativeIndices)
{
	// array(x, y, z) 
	// corresponds to (d1, d2, d3, d4, d5, d6)
	// --> (x - lower(d1)) * factor(d2, d3, d4, d5, d6) 
	//     + (y - lower(d2)) * factor(d3, d4, d5, d6)
	//     + (z - lower(d3)) * factor(d4, d5, d6)
	//
	// the corresponding factors should be in dimensionFactors
	// or can be created with factorize().
	//
	if (this->dimensionFactors.empty()) {
		this->factorize();
	}

	assert(! this->dimensionFactors.empty());
	assert(this->dimensionFactors.size() >= relativeIndices.size());
	assert(this->itype != NULL);

	Operand *offset = ImmediateOperand::getZero();
	std::list<Operand *>::const_iterator f = 
		this->dimensionFactors.begin();
	std::list<Operand *>::const_iterator lb = 
		this->lowerBounds.begin();
	std::list<Operand *>::const_iterator ri = relativeIndices.begin();

	while (ri != relativeIndices.end()) {
		Register *r1 = this->cc.createRegister(OP_TYPE_INTEGER);
		Register *r2 = this->cc.createRegister(OP_TYPE_INTEGER);
		Register *r3 = this->cc.createRegister(OP_TYPE_INTEGER);

		// FIXME this is only valid for up arrays, not for downto
		//       ones!
		Sub *s1 = new Sub(*ri, *lb, r1);
		IMul *m1 = new IMul(r1, *f, r2);
		Add *a1 = new Add(r2, offset, r3);
		offset = r3;

		this->cc.addCode(s1);
		this->cc.addCode(m1);
		this->cc.addCode(a1);
		f++; lb++; ri++;
	}

	while (f != this->dimensionFactors.end()) {	
		Register *r1 = this->cc.createRegister(OP_TYPE_INTEGER);
		IMul *m1 = new IMul(offset, *f, r1);
		offset = r1;

		this->cc.addCode(m1);
		f++;
	}

	Register *result = this->cc.createRegister(OP_TYPE_POINTER);
	AOffset *ao = new AOffset(this->base, offset, this->itype, result);
	this->cc.addCode(ao);

	return result;
}

/*
 * =================== STATIC ARRAY ITERATE ===================
 */
void
StaticArrayIterate::iterate(void)
{
	std::list<universal_integer> lbounds = std::list<universal_integer>();
	std::list<universal_integer> ubounds = std::list<universal_integer>();
	std::list<universal_integer> i = std::list<universal_integer>();
	std::list<Operand*> offsetL = std::list<Operand*>();

	for (std::list<DiscreteRange*>::const_iterator d = 
		this->indices.begin();
		d != this->indices.end(); d++) {

		universal_integer lb = (*d)->getLowerBound();
		universal_integer ub = (*d)->getUpperBound();

		lbounds.push_back(lb);
		i.push_back(lb);
		ubounds.push_back(ub);
	}

	while (StaticArrayIterate::checkLoop(i, ubounds)) {

		for (std::list<universal_integer>::const_iterator i1 = 
			i.begin(); i1 != i.end(); i1++) {

			offsetL.push_back(new ImmediateOperand(*i1));
		}

		Register *element = this->subscribe(offsetL);
		this->iterateBody(element, i);

		StaticArrayIterate::incCounters(i, lbounds, ubounds);
		offsetL.clear();
	}
}

bool
StaticArrayIterate::checkLoop(
	const std::list<universal_integer> &counters,
	const std::list<universal_integer> &ubounds
)
{
	std::list<universal_integer>::const_iterator i = counters.begin();
	std::list<universal_integer>::const_iterator j = ubounds.begin();

	while (i != counters.end()) {
		bool ret = ((*i) <= (*j));

		if (! ret) {
			return false;
		}

		i++;
		j++;
	}

	return true;
}

void
StaticArrayIterate::incCounters(
	std::list<universal_integer> &counters,
	const std::list<universal_integer> &lbounds,
	const std::list<universal_integer> &ubounds
)
{
	bool carry = false;
	std::list<universal_integer>::iterator i = counters.begin();
	std::list<universal_integer>::const_iterator l = lbounds.begin();
	std::list<universal_integer>::const_iterator u = ubounds.begin();

	(*i)++;

	while (i != counters.end()) {
		if (carry) {
			(*i)++;
			carry = false;
		}

		if ((*u) < (*i)) {
			carry = true;
			(*i) = (*l);
		}

		i++;
		u++;
		l++;
	}

	// make sure, that at least one index overflows if all 
	// values have been handled, otherwise checkLoop would
	// never yield false.
	if (carry) {
		i = counters.begin();
		u = ubounds.begin();
		(*i) = (*u) + 1;
	}
}

/*
 * =================== ARRAY ITERATE ===================
 */

void
ArrayIterate::initCounters(void)
{
	for (std::list<Operand *>::const_iterator i = 
		this->lowerBounds.begin();
		i != this->lowerBounds.end();
		i++) {

		Register *b = this->cc.createRegister(OP_TYPE_INTEGER);
		Mov *m = new Mov(*i, b);
		this->cc.addCode(m);
		this->counters.push_back(b);
	}
}

void
ArrayIterate::incCounters(void)
{
	std::list<Operand*>::const_reverse_iterator ubi = 
		this->upperBounds.rbegin();
	std::list<Operand*>::const_reverse_iterator lbi = 
		this->lowerBounds.rbegin();

	for (std::list<Register *>::reverse_iterator i = 
		this->counters.rbegin(); i != this->counters.rend();
		i++) {

		// c = c + 1
		Add *a = new Add(ImmediateOperand::getOne(), *i, *i);
		this->cc.addCode(a);

		// if c <= upperBound goto loop
		Jbe *jbe = new Jbe(*i, *ubi, this->loop);
		this->cc.addCode(jbe);

		// c = lowerBound.
		Mov *m = new Mov(*lbi, *i);
		this->cc.addCode(m);

		ubi++;
		lbi++;
	}

	// we're done with the iteration.
}

void
ArrayIterate::iterate(void)
{
	this->initCounters();
	this->cc.addCode(this->loop);
	std::list<Operand *> idx;

	for (std::list<Register *>::iterator i = this->counters.begin();
		i != this->counters.end(); i++) {

		idx.push_back(*i);
	}

	Register *elem = this->subscribe(idx);
	this->iterateBody(elem, this->counters);

	this->incCounters();
}

}; /* namespace ast */
