/*
 *  Copyright (c) 2008 Cyrille Berger <cberger@cberger.net>
 *
 * This library is free software; you can redistribute it and/or
 * modify it under the terms of the GNU Lesser General Public
 * License as published by the Free Software Foundation;
 * version 2 of the License.
 *
 * This library is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
 * Lesser General Public License for more details.
 *
 * You should have received a copy of the GNU Lesser General Public License
 * along with this library; see the file COPYING.  If not, write to
 * the Free Software Foundation, Inc., 51 Franklin Street, Fifth Floor,
 * Boston, MA 02110-1301, USA.
 */

#include "Statement.h"

// LLVM
#include <llvm/BasicBlock.h>
#include <llvm/Constants.h>
#include <llvm/Function.h>
#include <llvm/GlobalVariable.h>
#include <llvm/Instructions.h>
#include <llvm/Module.h>

// GTLCore
#include <GTLCore/CodeGenerator_p.h>
#include <GTLCore/ExpressionResult_p.h>
#include <GTLCore/Function.h>
#include <GTLCore/Type.h>
#include <GTLCore/Utils_p.h>
#include <GTLCore/VariableNG_p.h>

// GTLCore
#include <GTLCore/Debug.h>

// AST
#include "Expression.h"

using namespace GTLCore::AST;

llvm::BasicBlock* Statement::createBlock( GenerationContext& context) const
{
  GTL_ASSERT( context.llvmFunction() );
  llvm::BasicBlock* bb = llvm::BasicBlock::Create();
  context.llvmFunction()->getBasicBlockList().push_back( bb );
  return bb;
}

llvm::BasicBlock* StatementsList::generateStatement( GenerationContext& _context, llvm::BasicBlock* _bb ) const
{
  llvm::BasicBlock* currentBlock = _bb;
  for( std::list<Statement*>::const_iterator it = m_list.begin();
       it != m_list.end(); ++it)
  {
    currentBlock = (*it)->generateStatement( _context, currentBlock);
  }
  return currentBlock;
}

StatementsList::~StatementsList()
{
  deleteAll( m_list );
}

//---------------------------------------------------//
//--------------- VariableDeclaration ---------------//
//---------------------------------------------------//

VariableDeclaration::VariableDeclaration( const GTLCore::Type* type, Expression* initialiser, bool constant, const std::list<Expression*>& _initialSizes) : m_variable( new GTLCore::VariableNG(type, constant) ), m_initialiser(initialiser), m_initialSizes(_initialSizes), m_functionInitialiser(0)
{
  GTL_ASSERT( (not _initialSizes.empty() and type->dataType() == GTLCore::Type::ARRAY ) or _initialSizes.empty() );
}

VariableDeclaration::~VariableDeclaration()
{
  delete m_variable;
  delete m_initialiser;
  deleteAll( m_initialSizes );
  delete m_functionInitialiser;
}

llvm::BasicBlock* VariableDeclaration::generateStatement( GenerationContext& _context, llvm::BasicBlock* _bb ) const
{
  ExpressionResult initialiserValue;
  if( m_initialiser )
  {
    initialiserValue = m_initialiser->generateValue( _context, _bb);
  }
  std::list<llvm::Value*> initialSizeValues;
  if( not m_initialSizes.empty() )
  {
    for( std::list<Expression*>::const_iterator it = m_initialSizes.begin();
         it != m_initialSizes.end(); ++it)
    {
      if( *it )
      {
        initialSizeValues.push_back( (*it)->generateValue( _context, _bb).value() );
      } else {
        initialSizeValues.push_back( _context.codeGenerator()->integerToConstant( 0 ) );
      }
    }
  }
  _bb = m_variable->initialise( _context, _bb, initialiserValue, initialSizeValues);
  if(m_functionInitialiser)
  {
    return m_functionInitialiser->generateStatement( _context, _bb );
  }
  return _bb;
}

//---------------------------------------------------//
//------------------- IfStatement -------------------//
//---------------------------------------------------//
IfStatement::~IfStatement()
{
  delete m_expression;
  delete m_ifStatement;
}

llvm::BasicBlock* IfStatement::generateStatement( GenerationContext& _context, llvm::BasicBlock* _bb ) const
{
  llvm::Value* test = m_expression->generateValue( _context, _bb ).value();
  llvm::BasicBlock* startAction = createBlock( _context );
  llvm::BasicBlock* endAction = m_ifStatement->generateStatement( _context, startAction );
  llvm::BasicBlock* after = createBlock( _context );
  _context.codeGenerator()->createIfStatement( _bb, test, m_expression->type(), startAction, endAction, after);
  return after;
}

IfElseStatement::~IfElseStatement()
{
  delete m_expression;
  delete m_ifStatement;
  delete m_elseStatement;
}

llvm::BasicBlock* IfElseStatement::generateStatement( GenerationContext& _context, llvm::BasicBlock* _bb) const
{
  llvm::Value* test = m_expression->generateValue( _context, _bb ).value();
  llvm::BasicBlock* startAction = createBlock( _context );
  llvm::BasicBlock* endAction = m_ifStatement->generateStatement( _context, startAction );
  llvm::BasicBlock* startElseAction = createBlock( _context );
  llvm::BasicBlock* endElseAction = m_elseStatement->generateStatement( _context, startElseAction );
  llvm::BasicBlock* after = createBlock( _context );
  _context.codeGenerator()->createIfElseStatement( _bb, test, m_expression->type(), startAction, endAction, startElseAction, endElseAction, after );
  return after;
}

ForStatement::~ForStatement()
{
  delete m_initStatement;
  delete m_testExpression;
  delete m_updateExpression;
  delete m_forStatement;
}

llvm::BasicBlock* ForStatement::generateStatement( GenerationContext& _context, llvm::BasicBlock* _bb) const
{
  // Generate the init block
  llvm::BasicBlock* initBlock = _bb;
  if( m_initStatement )
  {
    initBlock = m_initStatement->generateStatement( _context, _bb);
  }
  // Generate the test block
  llvm::BasicBlock* testBlock = createBlock( _context );
  llvm::Value* test = m_testExpression->generateValue( _context, testBlock ).value();
  // Generate the update block
  llvm::BasicBlock* updateBlock = createBlock( _context );
  if( m_updateExpression )
  {
    m_updateExpression->generateStatement( _context, updateBlock );
  }
  llvm::BasicBlock* startAction = createBlock( _context );
  llvm::BasicBlock* endAction = m_forStatement->generateStatement( _context, startAction );
  llvm::BasicBlock* after = createBlock( _context );
  _context.codeGenerator()->createForStatement( initBlock, testBlock, test,  m_testExpression->type(), updateBlock, startAction, endAction, after);
  return after;
}

WhileStatement::~WhileStatement()
{
  delete m_expression;
  delete m_whileStatement;
}

llvm::BasicBlock* WhileStatement::generateStatement( GenerationContext& _context, llvm::BasicBlock* _bb) const
{
  llvm::BasicBlock* testBlock = createBlock( _context );
  llvm::Value* test = m_expression->generateValue( _context, testBlock ).value();
  llvm::BasicBlock* startAction = createBlock( _context );
  llvm::BasicBlock* endAction = m_whileStatement->generateStatement( _context, startAction );
  llvm::BasicBlock* after = createBlock( _context );
  _context.codeGenerator()->createWhileStatement( _bb, testBlock, test, m_expression->type(), startAction, endAction, after );
  return after;
}

//------------------------- ReturnStatement -------------------------//

ReturnStatement::ReturnStatement( Expression* _returnExpr, std::list<VariableNG*> _variablesToClean ) : m_returnExpr( _returnExpr ), m_variablesToClean( _variablesToClean )
{
  if( m_returnExpr )
  {
    m_returnExpr->markAsReturnExpression();
  }
}

ReturnStatement::~ReturnStatement()
{
  delete m_returnExpr;
}

llvm::BasicBlock* ReturnStatement::generateStatement( GenerationContext& _context, llvm::BasicBlock* _bb) const
{
  if( m_returnExpr )
  {
    llvm::Value* result = m_returnExpr->generateValue( _context, _bb).value();
    for( std::list<VariableNG*>::const_iterator it = m_variablesToClean.begin();
         it != m_variablesToClean.end(); ++it)
    {
      _bb = (*it)->cleanUp( _context, _bb, result);
    }
    if( m_returnExpr->type()->dataType() != Type::ARRAY
        and m_returnExpr->type()->dataType() != Type::STRUCTURE )
    {
      result = _context.codeGenerator()->convertValueTo(_bb, result, m_returnExpr->type(), _context.function()->returnType() );
    }
    llvm::ReturnInst::Create( result, _bb);
  } else {
    for( std::list<VariableNG*>::const_iterator it = m_variablesToClean.begin();
         it != m_variablesToClean.end(); ++it)
    {
      _bb = (*it)->cleanUp(_context, _bb, 0);
    }
    llvm::ReturnInst::Create( _bb );
  }
  return _bb;
}

//------------------------- PrintStatement --------------------------//

PrintStatement::~PrintStatement()
{
  deleteAll( m_expressions );
}

llvm::BasicBlock* PrintStatement::generateStatement( GenerationContext& _context, llvm::BasicBlock* _bb ) const
{
  std::vector<const llvm::Type*> params;
  params.push_back( llvm::Type::Int32Ty);
  llvm::FunctionType* definitionType = llvm::FunctionType::get( llvm::Type::VoidTy, params, true );
  llvm::Function* func = dynamic_cast<llvm::Function*>( _context.llvmModule()->getOrInsertFunction("print", definitionType));
  
  std::vector<llvm::Value*> values;
  values.push_back( _context.codeGenerator()->integerToConstant( m_expressions.size() ));
  
  for( std::list<AST::Expression*>::const_iterator it = m_expressions.begin();
           it != m_expressions.end(); ++it)
  {
    GTLCore::ExpressionResult value = (*it)->generateValue( _context, _bb);
    const llvm::Type* type = value.value()->getType();
    if( (*it)->type() == 0 )
    { // It's a string
      values.push_back( _context.codeGenerator()->integerToConstant( 3) );
      values.push_back( new llvm::GlobalVariable( value.value()->getType(), true, llvm::GlobalValue::InternalLinkage, value.constant(), "", _context.llvmModule() ) );
    } else if( type == llvm::Type::Int32Ty )
    {
      values.push_back( _context.codeGenerator()->integerToConstant( 0) );
      values.push_back( value.value() );
    } else if( type == llvm::Type::FloatTy )
    {
      values.push_back( _context.codeGenerator()->integerToConstant( 1) );
      values.push_back( _context.codeGenerator()->convertValueTo( _bb, value.value(), (*it)->type(), GTLCore::Type::Double ));
    } else if( type == llvm::Type::Int1Ty )
    {
      values.push_back( _context.codeGenerator()->integerToConstant( 2) );
      values.push_back( _context.codeGenerator()->convertValueTo( _bb, value.value(), (*it)->type(), GTLCore::Type::Integer32 ));
    } else {
      GTL_DEBUG("Unknown type for print " << *type);
    }
  }
  llvm::CallInst::Create(func, values.begin(), values.end(), "", _bb);
  return _bb;
}
