/*
 *  Copyright (c) 2007-2008 Cyrille Berger <cberger@cberger.net>
 *
 * This library is free software; you can redistribute it and/or
 * modify it under the terms of the GNU Lesser General Public
 * License as published by the Free Software Foundation;
 * version 2 of the License.
 *
 * This library is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
 * Lesser General Public License for more details.
 *
 * You should have received a copy of the GNU Lesser General Public License
 * along with this library; see the file COPYING.  If not, write to
 * the Free Software Foundation, Inc., 51 Franklin Street, Fifth Floor,
 * Boston, MA 02110-1301, USA.
 */

#ifndef _GTLCORE_CODE_GENERATOR_H_
#define _GTLCORE_CODE_GENERATOR_H_

#include <list>

#include <GTLCore/String.h>

#include <GTLCore/Macros.h>

namespace llvm {
  class BasicBlock;
  class Constant;
  class Function;
  class FunctionType;
  class Module;
  class Value;
  class Type;
}

// TODO: all functions should be static and take a GenerationContext as argument

namespace GTLCore {
  class ExpressionResult;
  class ModuleData;
  class Type;
  class Value;
  class VariableNG;
  class Function;
  class GenerationContext;
  namespace AST {
    class Expression;
  }
  /**
   * @internal
   * This class is private to OpenGTL and shouldn't be use outside.
   * 
   * This class provides helper functions to create llvm code.
   * 
   * @ingroup GTLCore
   */
  class CodeGenerator {
    public:
      /**
       * Create a new \ref CodeGenerator for the given llvm module.
       */
      CodeGenerator(ModuleData* module);
      ~CodeGenerator();
      /**
       * @return an integer constant
       */
      static llvm::Constant* integerToConstant(int v);
      /**
       * @return a boolean constant
       */
      static llvm::Constant* boolToConstant(bool v);
      /**
       * @return a float constant
       */
      static llvm::Constant* floatToConstant(float v);
      /**
       * @return convert a \ref Value to a llvm constant
       */
      static llvm::Constant* valueToConstant( const GTLCore::Value& v);
      /**
       * Create a function and add it to the module.
       * @param name name of the function
       * @param type type of the function
       * @return a pointer to the llvm function
       */
      llvm::Function* createFunction( llvm::FunctionType* type, const GTLCore::String& name);
      static llvm::Value* convertPointerToCharP(llvm::BasicBlock* currentBlock, llvm::Value* value);
      static llvm::Value* convertPointerTo(llvm::BasicBlock* currentBlock, llvm::Value* value, const llvm::Type* type);
      static llvm::Value* convertValueTo(llvm::BasicBlock* currentBlock, llvm::Value* value, const Type* valueType, const Type* type);
      static llvm::Constant* convertConstantTo(llvm::Constant* constant, const Type* constantType, const Type* type);
      static llvm::Value* convertToHalf( GenerationContext& generationContext, llvm::BasicBlock* currentBlock, llvm::Value* value, const Type* _valueType);
      static llvm::Value* convertFromHalf( GenerationContext& generationContext, llvm::BasicBlock* currentBlock, llvm::Value* value);
      static llvm::Value* vectorValueAt( llvm::BasicBlock* _currentBlock, llvm::Value* _vector, llvm::Value* _index);
      static llvm::Value* createRound( llvm::BasicBlock* _currentBlock, llvm::Value* _val );
    public: // Boolean Expression
      llvm::Value* createOrExpression(llvm::BasicBlock* currentBlock, llvm::Value* lhs, const Type* lhsType, llvm::Value* rhs, const Type* rhsType);
      ExpressionResult createOrExpression( llvm::BasicBlock* currentBlock, ExpressionResult lhs, ExpressionResult rhs);
      llvm::Constant* createOrExpression( llvm::Constant* lhs, const Type* lhsType, llvm::Constant* rhs, const Type* rhsType);
      llvm::Value* createAndExpression(llvm::BasicBlock* currentBlock, llvm::Value* lhs, const Type* lhsType, llvm::Value* rhs, const Type* rhsType);
      ExpressionResult createAndExpression( llvm::BasicBlock* currentBlock, ExpressionResult lhs, ExpressionResult rhs);
      llvm::Constant* createAndExpression( llvm::Constant* lhs, const Type* lhsType, llvm::Constant* rhs, const Type* rhsType);
    public: // Bit expressions
      llvm::Value* createBitXorExpression(llvm::BasicBlock* currentBlock, llvm::Value* lhs, const Type* lhsType, llvm::Value* rhs, const Type* rhsType);
      ExpressionResult createBitXorExpression( llvm::BasicBlock* currentBlock, ExpressionResult lhs, const Type* lhsType, ExpressionResult rhs, const Type* rhsType);
      llvm::Constant* createBitXorExpression( llvm::Constant* lhs, const Type* lhsType, llvm::Constant* rhs, const Type* rhsType);
      llvm::Value* createBitOrExpression(llvm::BasicBlock* currentBlock, llvm::Value* lhs, const Type* lhsType, llvm::Value* rhs, const Type* rhsType);
      ExpressionResult createBitOrExpression( llvm::BasicBlock* currentBlock, ExpressionResult lhs, const Type* lhsType, ExpressionResult rhs, const Type* rhsType);
      llvm::Constant* createBitOrExpression( llvm::Constant* lhs, const Type* lhsType, llvm::Constant* rhs, const Type* rhsType);
      llvm::Value* createBitAndExpression(llvm::BasicBlock* currentBlock, llvm::Value* lhs, const Type* lhsType, llvm::Value* rhs, const Type* rhsType);
      ExpressionResult createBitAndExpression( llvm::BasicBlock* currentBlock, ExpressionResult lhs, const Type* lhsType, ExpressionResult rhs, const Type* rhsType);
      llvm::Constant* createBitAndExpression( llvm::Constant* lhs, const Type* lhsType, llvm::Constant* rhs, const Type* rhsType);
    public: // Arithmetic expressions
      // Addition
      static llvm::Value* createAdditionExpression(llvm::BasicBlock* currentBlock, llvm::Value* lhs, const Type* lhsType, llvm::Value* rhs, const Type* rhsType);
      static ExpressionResult createAdditionExpression( llvm::BasicBlock* currentBlock, ExpressionResult lhs, ExpressionResult rhs);
      static llvm::Constant* createAdditionExpression( llvm::Constant* lhs, const Type* lhsType, llvm::Constant* rhs, const Type* rhsType);
      // Substraction
      static llvm::Value* createSubstractionExpression(llvm::BasicBlock* currentBlock, llvm::Value* lhs, const Type* lhsType, llvm::Value* rhs, const Type* rhsType);
      static ExpressionResult createSubstractionExpression( llvm::BasicBlock* currentBlock, ExpressionResult lhs, ExpressionResult rhs);
      static llvm::Constant* createSubstractionExpression( llvm::Constant* lhs, const Type* lhsType, llvm::Constant* rhs, const Type* rhsType);
      // Multiplication
      static llvm::Value* createMultiplicationExpression(llvm::BasicBlock* currentBlock, llvm::Value* lhs, const Type* lhsType, llvm::Value* rhs, const Type* rhsType);
      static ExpressionResult createMultiplicationExpression( llvm::BasicBlock* currentBlock, ExpressionResult lhs, ExpressionResult rhs );
      static llvm::Constant* createMultiplicationExpression( llvm::Constant* lhs, const Type* lhsType, llvm::Constant* rhs, const Type* rhsType);
      // Division
      static llvm::Value* createDivisionExpression(llvm::BasicBlock* currentBlock, llvm::Value* lhs, const Type* lhsType, llvm::Value* rhs, const Type* rhsType);
      static llvm::Constant* createDivisionExpression( llvm::Constant* lhs, const Type* lhsType, llvm::Constant* rhs, const Type* rhsType);
      static ExpressionResult createDivisionExpression( llvm::BasicBlock* currentBlock, ExpressionResult lhs, ExpressionResult rhs);
      // Modulo
      llvm::Value* createModuloExpression(llvm::BasicBlock* currentBlock, llvm::Value* lhs, const Type* lhsType, llvm::Value* rhs, const Type* rhsType);
      llvm::Constant* createModuloExpression( llvm::Constant* lhs, const Type* lhsType, llvm::Constant* rhs, const Type* rhsType);
      ExpressionResult createModuloExpression( llvm::BasicBlock* currentBlock, ExpressionResult lhs, const Type* lhsType, ExpressionResult rhs, const Type* rhsType);
    public: // Shift expressions
      llvm::Value* createRightShiftExpression(llvm::BasicBlock* currentBlock, llvm::Value* lhs, const Type* lhsType, llvm::Value* rhs, const Type* rhsType);
      ExpressionResult createRightShiftExpression( llvm::BasicBlock* currentBlock, ExpressionResult lhs,  ExpressionResult rhs);
      llvm::Constant* createRightShiftExpression( llvm::Constant* lhs, const Type* lhsType, llvm::Constant* rhs, const Type* rhsType);
      llvm::Value* createLeftShiftExpression(llvm::BasicBlock* currentBlock, llvm::Value* lhs, const Type* lhsType, llvm::Value* rhs, const Type* rhsType);
      ExpressionResult createLeftShiftExpression( llvm::BasicBlock* currentBlock, ExpressionResult lhs, ExpressionResult rhs);
      llvm::Constant* createLeftShiftExpression( llvm::Constant* lhs, const Type* lhsType, llvm::Constant* rhs, const Type* rhsType);
    public: // Comparison Expressions
      llvm::Value* createEqualExpression(llvm::BasicBlock* currentBlock, llvm::Value* lhs, const Type* lhsType, llvm::Value* rhs, const Type* rhsType);
      ExpressionResult createEqualExpression( llvm::BasicBlock* currentBlock, ExpressionResult lhs, const Type* lhsType, ExpressionResult rhs, const Type* rhsType);
      llvm::Constant* createEqualExpression( llvm::Constant* lhs, const Type* lhsType, llvm::Constant* rhs, const Type* rhsType);
      //
      llvm::Value* createDifferentExpression(llvm::BasicBlock* currentBlock, llvm::Value* lhs, const Type* lhsType, llvm::Value* rhs, const Type* rhsType);
      ExpressionResult createDifferentExpression( llvm::BasicBlock* currentBlock, ExpressionResult lhs, const Type* lhsType, ExpressionResult rhs, const Type* rhsType);
      llvm::Constant* createDifferentExpression( llvm::Constant* lhs, const Type* lhsType, llvm::Constant* rhs, const Type* rhsType);
      //
      static llvm::Value* createStrictInferiorExpression(llvm::BasicBlock* currentBlock, llvm::Value* lhs, const Type* lhsType, llvm::Value* rhs, const Type* rhsType);
      ExpressionResult createStrictInferiorExpression( llvm::BasicBlock* currentBlock, ExpressionResult lhs, const Type* lhsType, ExpressionResult rhs, const Type* rhsType);
      llvm::Constant* createStrictInferiorExpression( llvm::Constant* lhs, const Type* lhsType, llvm::Constant* rhs, const Type* rhsType);
      // createInferiorExpression
      llvm::Value* createInferiorOrEqualExpression(llvm::BasicBlock* currentBlock, llvm::Value* lhs, const Type* lhsType, llvm::Value* rhs, const Type* rhsType);
      ExpressionResult createInferiorOrEqualExpression( llvm::BasicBlock* currentBlock, ExpressionResult lhs, const Type* lhsType, ExpressionResult rhs, const Type* rhsType);
      llvm::Constant* createInferiorOrEqualExpression( llvm::Constant* lhs, const Type* lhsType, llvm::Constant* rhs, const Type* rhsType);
      //
      llvm::Value* createStrictSupperiorExpression(llvm::BasicBlock* currentBlock, llvm::Value* lhs, const Type* lhsType, llvm::Value* rhs, const Type* rhsType);
      ExpressionResult createStrictSupperiorExpression( llvm::BasicBlock* currentBlock, ExpressionResult lhs, const Type* lhsType, ExpressionResult rhs, const Type* rhsType);
      llvm::Constant* createStrictSupperiorExpression( llvm::Constant* lhs, const Type* lhsType, llvm::Constant* rhs, const Type* rhsType);
      //
      llvm::Value* createSupperiorOrEqualExpression(llvm::BasicBlock* currentBlock, llvm::Value* lhs, const Type* lhsType, llvm::Value* rhs, const Type* rhsType);
      ExpressionResult createSupperiorOrEqualExpression( llvm::BasicBlock* currentBlock, ExpressionResult lhs, const Type* lhsType, ExpressionResult rhs, const Type* rhsType);
      llvm::Constant* createSupperiorOrEqualExpression( llvm::Constant* lhs, const Type* lhsType, llvm::Constant* rhs, const Type* rhsType);
    public: // Unary expression
      // Minus
      ExpressionResult createMinusExpression( llvm::BasicBlock* currentBlock, ExpressionResult rhs, const Type* rhsType);
      llvm::Value* createMinusExpression(llvm::BasicBlock* currentBlock, llvm::Value* rhs, const Type* rhsType);
      llvm::Constant* createMinusExpression( llvm::Constant* rhs, const Type* rhsType);
      // Not
      ExpressionResult createNotExpression( llvm::BasicBlock* currentBlock, ExpressionResult rhs, const Type* rhsType);
      llvm::Value* createNotExpression(llvm::BasicBlock* currentBlock, llvm::Value* rhs, const Type* rhsType);
      llvm::Constant* createNotExpression( llvm::Constant* rhs, const Type* rhsType);
      // Tilde
      ExpressionResult createTildeExpression( llvm::BasicBlock* currentBlock, ExpressionResult rhs, const Type* rhsType);
      llvm::Value* createTildeExpression(llvm::BasicBlock* currentBlock, llvm::Value* rhs, const Type* rhsType);
      llvm::Constant* createTildeExpression( llvm::Constant* rhs, const Type* rhsType);
      // Increment
      static void createIncrementExpression( GenerationContext& gc, llvm::BasicBlock* currentBlock, VariableNG* var);
      static void createIncrementExpression( llvm::BasicBlock* currentBlock, llvm::Value* pointer);
      void createDecrementExpression( llvm::BasicBlock* currentBlock, VariableNG* var);
      void createDecrementExpression( llvm::BasicBlock* currentBlock, llvm::Value* pointer);
      
    public: // Statements
      /**
       * Create an if statement.
       * @param before the basic block before the if statement
       * @param test the test
       * @param firstAction the block with the first action in case if is true
       * @param lastAction the block with the last action in case if is true
       * @param after the basic block after the if statement
       */
      void createIfStatement( llvm::BasicBlock* before, llvm::Value* test, const Type* testType, llvm::BasicBlock* firstAction, llvm::BasicBlock* lastAction, llvm::BasicBlock* after);
      void createIfElseStatement( llvm::BasicBlock* before, llvm::Value* test, const Type* testType, llvm::BasicBlock* firstAction, llvm::BasicBlock* lastAction, llvm::BasicBlock* firstElseAction, llvm::BasicBlock* lastElseAction, llvm::BasicBlock* after);
      /**
       * Create a for statement.
       * @param before the basic block before the for statement
       * @param test the basic block which hold the computation of the test
       * @param testResult the value with the result of the test
       * @param update the basic block which hold the computation of the update
       * @param firstAction the block with the first action of the loop
       * @param lastAction the block with the last action of the loop
       * @param after the basic block after the for statement
       */
      static void createForStatement(llvm::BasicBlock* before, llvm::BasicBlock* test, llvm::Value* testResult, const Type* testType, llvm::BasicBlock* update, llvm::BasicBlock* firstAction, llvm::BasicBlock* lastAction,  llvm::BasicBlock* after);
      
      /**
       * Create an iteration for statement.
       * This is a stament that looks like :
       * @code
       * for(int variable = 0; variable \< maxValue ; ++variable )
       * {
       *   // firstAction
       *   // lastAction
       * }
       * @endcode
       * 
       * @param before the basic block before the for statement
       * @param variable the variable to increment
       * @param maxValue the (maximum value + 1) reached by the variable
       * @param firstAction the block with the first action of the loop
       * @param lastAction the block with the last action of the loop
       * @return the exit block of the loop
       */
      static llvm::BasicBlock* createIterationForStatement( GenerationContext&, llvm::BasicBlock* before, GTLCore::VariableNG* variable, llvm::Value* maxValue, const Type* maxValueType, llvm::BasicBlock* firstAction, llvm::BasicBlock* lastAction);
      
      void createWhileStatement( llvm::BasicBlock* before, llvm::BasicBlock* test, llvm::Value* testResult, const Type* testType, llvm::BasicBlock* firstAction, llvm::BasicBlock* lastAction,  llvm::BasicBlock* after);
    public:
      /**
       * @param _currentBlock the current basic block
       * @param _pointer a pointer to the array structure
       * @param _index the index of the value in the array
       * @return a pointer on the value of an array
       */
      llvm::Value* accessArrayValue( llvm::BasicBlock* _currentBlock, llvm::Value* _pointer, llvm::Value* _index );
    public:
      GTLCore::ExpressionResult callFunction( GenerationContext& _gc, llvm::BasicBlock* _bb, const GTLCore::Function* _function, const std::list<AST::Expression*>& m_arguments );
    private:
      static llvm::Value* createComparisonExpression(llvm::BasicBlock* currentBlock, llvm::Value* lhs, const Type* lhsType, llvm::Value* rhs, const Type* rhsType, unsigned int unsignedIntegerPred, unsigned int signedIntegerPred, unsigned int floatPred);
      static ExpressionResult createComparisonExpression( llvm::BasicBlock* currentBlock, ExpressionResult lhs, const Type* lhsType, ExpressionResult rhs, const Type* rhsType, unsigned int unsignedIntegerPred, unsigned int signedIntegerPred, unsigned int floatPred);
      static llvm::Constant* createComparisonExpression( llvm::Constant* lhs, const Type* lhsType, llvm::Constant* rhs, const Type* rhsType, unsigned int unsignedIntegerPred, unsigned int signedIntegerPred, unsigned int floatPred);
    private:
      struct Private;
      Private* const d;
  };
}

#endif
