//! \file
/*
**  Copyright (C) - Triton
**
**  This program is under the terms of the Apache License 2.0.
*/

#include <triton/astSmtRepresentation.hpp>
#include <triton/exceptions.hpp>
#include <triton/symbolicExpression.hpp>
#include <triton/symbolicVariable.hpp>



namespace triton {
  namespace ast {
    namespace representations {

      AstSmtRepresentation::AstSmtRepresentation() {
      }


      /* Representation dispatcher from an abstract node */
      std::ostream& AstSmtRepresentation::print(std::ostream& stream, triton::ast::AbstractNode* node) {
        switch (node->getType()) {
          case ARRAY_NODE:                return this->print(stream, reinterpret_cast<triton::ast::ArrayNode*>(node)); break;
          case ASSERT_NODE:               return this->print(stream, reinterpret_cast<triton::ast::AssertNode*>(node)); break;
          case BSWAP_NODE:                return this->print(stream, reinterpret_cast<triton::ast::BswapNode*>(node)); break;
          case BVADD_NODE:                return this->print(stream, reinterpret_cast<triton::ast::BvaddNode*>(node)); break;
          case BVAND_NODE:                return this->print(stream, reinterpret_cast<triton::ast::BvandNode*>(node)); break;
          case BVASHR_NODE:               return this->print(stream, reinterpret_cast<triton::ast::BvashrNode*>(node)); break;
          case BVLSHR_NODE:               return this->print(stream, reinterpret_cast<triton::ast::BvlshrNode*>(node)); break;
          case BVMUL_NODE:                return this->print(stream, reinterpret_cast<triton::ast::BvmulNode*>(node)); break;
          case BVNAND_NODE:               return this->print(stream, reinterpret_cast<triton::ast::BvnandNode*>(node)); break;
          case BVNEG_NODE:                return this->print(stream, reinterpret_cast<triton::ast::BvnegNode*>(node)); break;
          case BVNOR_NODE:                return this->print(stream, reinterpret_cast<triton::ast::BvnorNode*>(node)); break;
          case BVNOT_NODE:                return this->print(stream, reinterpret_cast<triton::ast::BvnotNode*>(node)); break;
          case BVOR_NODE:                 return this->print(stream, reinterpret_cast<triton::ast::BvorNode*>(node)); break;
          case BVROL_NODE:                return this->print(stream, reinterpret_cast<triton::ast::BvrolNode*>(node)); break;
          case BVROR_NODE:                return this->print(stream, reinterpret_cast<triton::ast::BvrorNode*>(node)); break;
          case BVSDIV_NODE:               return this->print(stream, reinterpret_cast<triton::ast::BvsdivNode*>(node)); break;
          case BVSGE_NODE:                return this->print(stream, reinterpret_cast<triton::ast::BvsgeNode*>(node)); break;
          case BVSGT_NODE:                return this->print(stream, reinterpret_cast<triton::ast::BvsgtNode*>(node)); break;
          case BVSHL_NODE:                return this->print(stream, reinterpret_cast<triton::ast::BvshlNode*>(node)); break;
          case BVSLE_NODE:                return this->print(stream, reinterpret_cast<triton::ast::BvsleNode*>(node)); break;
          case BVSLT_NODE:                return this->print(stream, reinterpret_cast<triton::ast::BvsltNode*>(node)); break;
          case BVSMOD_NODE:               return this->print(stream, reinterpret_cast<triton::ast::BvsmodNode*>(node)); break;
          case BVSREM_NODE:               return this->print(stream, reinterpret_cast<triton::ast::BvsremNode*>(node)); break;
          case BVSUB_NODE:                return this->print(stream, reinterpret_cast<triton::ast::BvsubNode*>(node)); break;
          case BVUDIV_NODE:               return this->print(stream, reinterpret_cast<triton::ast::BvudivNode*>(node)); break;
          case BVUGE_NODE:                return this->print(stream, reinterpret_cast<triton::ast::BvugeNode*>(node)); break;
          case BVUGT_NODE:                return this->print(stream, reinterpret_cast<triton::ast::BvugtNode*>(node)); break;
          case BVULE_NODE:                return this->print(stream, reinterpret_cast<triton::ast::BvuleNode*>(node)); break;
          case BVULT_NODE:                return this->print(stream, reinterpret_cast<triton::ast::BvultNode*>(node)); break;
          case BVUREM_NODE:               return this->print(stream, reinterpret_cast<triton::ast::BvuremNode*>(node)); break;
          case BVXNOR_NODE:               return this->print(stream, reinterpret_cast<triton::ast::BvxnorNode*>(node)); break;
          case BVXOR_NODE:                return this->print(stream, reinterpret_cast<triton::ast::BvxorNode*>(node)); break;
          case BV_NODE:                   return this->print(stream, reinterpret_cast<triton::ast::BvNode*>(node)); break;
          case COMPOUND_NODE:             return this->print(stream, reinterpret_cast<triton::ast::CompoundNode*>(node)); break;
          case CONCAT_NODE:               return this->print(stream, reinterpret_cast<triton::ast::ConcatNode*>(node)); break;
          case DECLARE_NODE:              return this->print(stream, reinterpret_cast<triton::ast::DeclareNode*>(node)); break;
          case DISTINCT_NODE:             return this->print(stream, reinterpret_cast<triton::ast::DistinctNode*>(node)); break;
          case EQUAL_NODE:                return this->print(stream, reinterpret_cast<triton::ast::EqualNode*>(node)); break;
          case EXTRACT_NODE:              return this->print(stream, reinterpret_cast<triton::ast::ExtractNode*>(node)); break;
          case FORALL_NODE:               return this->print(stream, reinterpret_cast<triton::ast::ForallNode*>(node)); break;
          case IFF_NODE:                  return this->print(stream, reinterpret_cast<triton::ast::IffNode*>(node)); break;
          case INTEGER_NODE:              return this->print(stream, reinterpret_cast<triton::ast::IntegerNode*>(node)); break;
          case ITE_NODE:                  return this->print(stream, reinterpret_cast<triton::ast::IteNode*>(node)); break;
          case LAND_NODE:                 return this->print(stream, reinterpret_cast<triton::ast::LandNode*>(node)); break;
          case LET_NODE:                  return this->print(stream, reinterpret_cast<triton::ast::LetNode*>(node)); break;
          case LNOT_NODE:                 return this->print(stream, reinterpret_cast<triton::ast::LnotNode*>(node)); break;
          case LOR_NODE:                  return this->print(stream, reinterpret_cast<triton::ast::LorNode*>(node)); break;
          case LXOR_NODE:                 return this->print(stream, reinterpret_cast<triton::ast::LxorNode*>(node)); break;
          case REFERENCE_NODE:            return this->print(stream, reinterpret_cast<triton::ast::ReferenceNode*>(node)); break;
          case SELECT_NODE:               return this->print(stream, reinterpret_cast<triton::ast::SelectNode*>(node)); break;
          case STORE_NODE:                return this->print(stream, reinterpret_cast<triton::ast::StoreNode*>(node)); break;
          case STRING_NODE:               return this->print(stream, reinterpret_cast<triton::ast::StringNode*>(node)); break;
          case SX_NODE:                   return this->print(stream, reinterpret_cast<triton::ast::SxNode*>(node)); break;
          case VARIABLE_NODE:             return this->print(stream, reinterpret_cast<triton::ast::VariableNode*>(node)); break;
          case ZX_NODE:                   return this->print(stream, reinterpret_cast<triton::ast::ZxNode*>(node)); break;
          default:
            throw triton::exceptions::AstRepresentation("AstSmtRepresentation::print(AbstractNode): Invalid kind node.");
        }
        return stream;
      }


      /* array representation */
      std::ostream& AstSmtRepresentation::print(std::ostream& stream, triton::ast::ArrayNode* node) {
        stream << "Memory";
        return stream;
      }


      /* assert representation */
      std::ostream& AstSmtRepresentation::print(std::ostream& stream, triton::ast::AssertNode* node) {
        stream << "(assert " << node->getChildren()[0] << ")";
        return stream;
      }


      /* bswap representation */
      std::ostream& AstSmtRepresentation::print(std::ostream& stream, triton::ast::BswapNode* node) {
        stream << "(bswap" << node->getBitvectorSize() << " " << node->getChildren()[0] << ")";
        return stream;
      }


      /* bvadd representation */
      std::ostream& AstSmtRepresentation::print(std::ostream& stream, triton::ast::BvaddNode* node) {
        stream << "(bvadd " << node->getChildren()[0] << " " << node->getChildren()[1] << ")";
        return stream;
      }


      /* bvand representation */
      std::ostream& AstSmtRepresentation::print(std::ostream& stream, triton::ast::BvandNode* node) {
        stream << "(bvand " << node->getChildren()[0] << " " << node->getChildren()[1] << ")";
        return stream;
      }


      /* bvashr representation */
      std::ostream& AstSmtRepresentation::print(std::ostream& stream, triton::ast::BvashrNode* node) {
        stream << "(bvashr " << node->getChildren()[0] << " " << node->getChildren()[1] << ")";
        return stream;
      }


      /* bvlshr representation */
      std::ostream& AstSmtRepresentation::print(std::ostream& stream, triton::ast::BvlshrNode* node) {
        stream << "(bvlshr " << node->getChildren()[0] << " " << node->getChildren()[1] << ")";
        return stream;
      }


      /* bvmul representation */
      std::ostream& AstSmtRepresentation::print(std::ostream& stream, triton::ast::BvmulNode* node) {
        stream << "(bvmul " << node->getChildren()[0] << " " << node->getChildren()[1] << ")";
        return stream;
      }


      /* bvnand representation */
      std::ostream& AstSmtRepresentation::print(std::ostream& stream, triton::ast::BvnandNode* node) {
        stream << "(bvnand " << node->getChildren()[0] << " " << node->getChildren()[1] << ")";
        return stream;
      }


      /* bvneg representation */
      std::ostream& AstSmtRepresentation::print(std::ostream& stream, triton::ast::BvnegNode* node) {
        stream << "(bvneg " << node->getChildren()[0] << ")";
        return stream;
      }


      /* bvnor representation */
      std::ostream& AstSmtRepresentation::print(std::ostream& stream, triton::ast::BvnorNode* node) {
        stream << "(bvnor " << node->getChildren()[0] << " " << node->getChildren()[1] << ")";
        return stream;
      }


      /* bvnot representation */
      std::ostream& AstSmtRepresentation::print(std::ostream& stream, triton::ast::BvnotNode* node) {
        stream << "(bvnot " << node->getChildren()[0] << ")";
        return stream;
      }


      /* bvor representation */
      std::ostream& AstSmtRepresentation::print(std::ostream& stream, triton::ast::BvorNode* node) {
        stream << "(bvor " << node->getChildren()[0] << " " << node->getChildren()[1] << ")";
        return stream;
      }


      /* bvrol representation */
      std::ostream& AstSmtRepresentation::print(std::ostream& stream, triton::ast::BvrolNode* node) {
        stream << "((_ rotate_left " << node->getChildren()[1] << ") " << node->getChildren()[0] << ")";
        return stream;
      }


      /* bvror representation */
      std::ostream& AstSmtRepresentation::print(std::ostream& stream, triton::ast::BvrorNode* node) {
        stream << "((_ rotate_right " << node->getChildren()[1] << ") " << node->getChildren()[0] << ")";
        return stream;
      }


      /* bvsdiv representation */
      std::ostream& AstSmtRepresentation::print(std::ostream& stream, triton::ast::BvsdivNode* node) {
        stream << "(bvsdiv " << node->getChildren()[0] << " " << node->getChildren()[1] << ")";
        return stream;
      }


      /* bvsge representation */
      std::ostream& AstSmtRepresentation::print(std::ostream& stream, triton::ast::BvsgeNode* node) {
        stream << "(bvsge " << node->getChildren()[0] << " " << node->getChildren()[1] << ")";
        return stream;
      }


      /* bvsgt representation */
      std::ostream& AstSmtRepresentation::print(std::ostream& stream, triton::ast::BvsgtNode* node) {
        stream << "(bvsgt " << node->getChildren()[0] << " " << node->getChildren()[1] << ")";
        return stream;
      }


      /* bvshl representation */
      std::ostream& AstSmtRepresentation::print(std::ostream& stream, triton::ast::BvshlNode* node) {
        stream << "(bvshl " << node->getChildren()[0] << " " << node->getChildren()[1] << ")";
        return stream;
      }


      /* bvsle representation */
      std::ostream& AstSmtRepresentation::print(std::ostream& stream, triton::ast::BvsleNode* node) {
        stream << "(bvsle " << node->getChildren()[0] << " " << node->getChildren()[1] << ")";
        return stream;
      }


      /* bvslt representation */
      std::ostream& AstSmtRepresentation::print(std::ostream& stream, triton::ast::BvsltNode* node) {
        stream << "(bvslt " << node->getChildren()[0] << " " << node->getChildren()[1] << ")";
        return stream;
      }


      /* bvsmod representation */
      std::ostream& AstSmtRepresentation::print(std::ostream& stream, triton::ast::BvsmodNode* node) {
        stream << "(bvsmod " << node->getChildren()[0] << " " << node->getChildren()[1] << ")";
        return stream;
      }


      /* bvsrem representation */
      std::ostream& AstSmtRepresentation::print(std::ostream& stream, triton::ast::BvsremNode* node) {
         stream << "(bvsrem " << node->getChildren()[0] << " " << node->getChildren()[1] << ")";
        return stream;
      }


      /* bvsub representation */
      std::ostream& AstSmtRepresentation::print(std::ostream& stream, triton::ast::BvsubNode* node) {
        stream << "(bvsub " << node->getChildren()[0] << " " << node->getChildren()[1] << ")";
        return stream;
      }


      /* bvudiv representation */
      std::ostream& AstSmtRepresentation::print(std::ostream& stream, triton::ast::BvudivNode* node) {
        stream << "(bvudiv " << node->getChildren()[0] << " " << node->getChildren()[1] << ")";
        return stream;
      }


      /* bvuge representation */
      std::ostream& AstSmtRepresentation::print(std::ostream& stream, triton::ast::BvugeNode* node) {
        stream << "(bvuge " << node->getChildren()[0] << " " << node->getChildren()[1] << ")";
        return stream;
      }


      /* bvugt representation */
      std::ostream& AstSmtRepresentation::print(std::ostream& stream, triton::ast::BvugtNode* node) {
        stream << "(bvugt " << node->getChildren()[0] << " " << node->getChildren()[1] << ")";
        return stream;
      }


      /* bvule representation */
      std::ostream& AstSmtRepresentation::print(std::ostream& stream, triton::ast::BvuleNode* node) {
        stream << "(bvule " << node->getChildren()[0] << " " << node->getChildren()[1] << ")";
        return stream;
      }


      /* bvult representation */
      std::ostream& AstSmtRepresentation::print(std::ostream& stream, triton::ast::BvultNode* node) {
        stream << "(bvult " << node->getChildren()[0] << " " << node->getChildren()[1] << ")";
        return stream;
      }


      /* bvurem representation */
      std::ostream& AstSmtRepresentation::print(std::ostream& stream, triton::ast::BvuremNode* node) {
        stream << "(bvurem " << node->getChildren()[0] << " " << node->getChildren()[1] << ")";
        return stream;
      }


      /* bvxnor representation */
      std::ostream& AstSmtRepresentation::print(std::ostream& stream, triton::ast::BvxnorNode* node) {
        stream << "(bvxnor " << node->getChildren()[0] << " " << node->getChildren()[1] << ")";
        return stream;
      }


      /* bvxor representation */
      std::ostream& AstSmtRepresentation::print(std::ostream& stream, triton::ast::BvxorNode* node) {
        stream << "(bvxor " << node->getChildren()[0] << " " << node->getChildren()[1] << ")";
        return stream;
      }


      /* bv representation */
      std::ostream& AstSmtRepresentation::print(std::ostream& stream, triton::ast::BvNode* node) {
        stream << "(_ bv" << node->getChildren()[0] << " " << node->getChildren()[1] << ")";
        return stream;
      }


      /* compound representation */
      std::ostream& AstSmtRepresentation::print(std::ostream& stream, triton::ast::CompoundNode* node) {
        std::vector<triton::ast::SharedAbstractNode> children = node->getChildren();
        triton::usize size = children.size();

        for (triton::usize index = 0; index < size-1; index++)
          stream << children[index] << std::endl;
        stream << children[size-1];

        return stream;
      }


      /* concat representation */
      std::ostream& AstSmtRepresentation::print(std::ostream& stream, triton::ast::ConcatNode* node) {
        std::vector<triton::ast::SharedAbstractNode> children = node->getChildren();
        triton::usize size = children.size();

        if (size < 2)
          throw triton::exceptions::AstRepresentation("AstSmtRepresentation::print(ConcatNode): Exprs must contain at least two expressions.");

        stream << "(concat";
        for (triton::usize index = 0; index < size; index++)
          stream << " " << children[index];
        stream << ")";

        return stream;
      }


      /* declare representation */
      std::ostream& AstSmtRepresentation::print(std::ostream& stream, triton::ast::DeclareNode* node) {
        if (node->getChildren()[0]->getType() == VARIABLE_NODE) {
          const triton::engines::symbolic::SharedSymbolicVariable& var = reinterpret_cast<triton::ast::VariableNode*>(node->getChildren()[0].get())->getSymbolicVariable();
          if (var->getAlias().empty())
            stream << "(declare-fun " << var->getName() << " () (_ BitVec " << var->getSize() << "))";
          else
            stream << "(declare-fun " << var->getAlias() << " () (_ BitVec " << var->getSize() << "))";
        }

        else if (node->getChildren()[0]->getType() == ARRAY_NODE) {
          const auto& array = node->getChildren()[0];
          const auto& size  = array->getChildren()[0];
          stream << "(define-fun " << node->getChildren()[0] << " () (Array (_ BitVec " << size << ") (_ BitVec 8)) ";
          stream << "((as const (Array (_ BitVec " << size << ") (_ BitVec 8))) (_ bv0 8)))";
        }

        else
          throw triton::exceptions::AstRepresentation("AstSmtRepresentation::print(DeclareNode): Invalid sort.");

        return stream;
      }


      /* distinct representation */
      std::ostream& AstSmtRepresentation::print(std::ostream& stream, triton::ast::DistinctNode* node) {
        stream << "(distinct " << node->getChildren()[0] << " " << node->getChildren()[1] << ")";
        return stream;
      }


      /* equal representation */
      std::ostream& AstSmtRepresentation::print(std::ostream& stream, triton::ast::EqualNode* node) {
        stream << "(= " << node->getChildren()[0] << " " << node->getChildren()[1] << ")";
        return stream;
      }


      /* extract representation */
      std::ostream& AstSmtRepresentation::print(std::ostream& stream, triton::ast::ExtractNode* node) {
        stream << "((_ extract " << node->getChildren()[0] << " " << node->getChildren()[1] << ") " << node->getChildren()[2] << ")";
        return stream;
      }


      /* forall representation */
      std::ostream& AstSmtRepresentation::print(std::ostream& stream, triton::ast::ForallNode* node) {
        triton::usize size = node->getChildren().size() - 1;

        stream << "(forall (";
        for (triton::uint32 i = 0; i != size; i++) {
          const auto& var = reinterpret_cast<triton::ast::VariableNode*>(node->getChildren()[i].get())->getSymbolicVariable();
          if (var->getAlias().empty())  stream << "(" << var->getName()  << " (_ BitVec " << var->getSize() << "))";
          else                          stream << "(" << var->getAlias() << " (_ BitVec " << var->getSize() << "))";
          if (i + 1 != size)            stream << " ";
        }
        stream << ") " << node->getChildren()[size] << ")";

        return stream;
      }


      /* iff representation */
      std::ostream& AstSmtRepresentation::print(std::ostream& stream, triton::ast::IffNode* node) {
        stream << "(iff " << node->getChildren()[0] << " " << node->getChildren()[1] << ")";
        return stream;
      }


      /* integer representation */
      std::ostream& AstSmtRepresentation::print(std::ostream& stream, triton::ast::IntegerNode* node) {
        stream << node->getInteger();
        return stream;
      }


      /* ite representation */
      std::ostream& AstSmtRepresentation::print(std::ostream& stream, triton::ast::IteNode* node) {
        stream << "(ite " << node->getChildren()[0] << " " << node->getChildren()[1] << " " << node->getChildren()[2] << ")";
        return stream;
      }


      /* land representation */
      std::ostream& AstSmtRepresentation::print(std::ostream& stream, triton::ast::LandNode* node) {
        triton::usize size = node->getChildren().size();

        stream << "(and";
        for (triton::usize index = 0; index < size; index++)
          stream << " " << node->getChildren()[index];
        stream << ")";

        return stream;
      }


      /* let representation */
      std::ostream& AstSmtRepresentation::print(std::ostream& stream, triton::ast::LetNode* node) {
        stream << "(let ((" << node->getChildren()[0] << " " << node->getChildren()[1] << ")) " << node->getChildren()[2] << ")";
        return stream;
      }


      /* lnot representation */
      std::ostream& AstSmtRepresentation::print(std::ostream& stream, triton::ast::LnotNode* node) {
        stream << "(not " << node->getChildren()[0] << ")";
        return stream;
      }


      /* lor representation */
      std::ostream& AstSmtRepresentation::print(std::ostream& stream, triton::ast::LorNode* node) {
        triton::usize size = node->getChildren().size();

        stream << "(or";
        for (triton::usize index = 0; index < size; index++)
          stream << " " << node->getChildren()[index];
        stream << ")";

        return stream;
      }


      /* lxor representation */
      std::ostream& AstSmtRepresentation::print(std::ostream& stream, triton::ast::LxorNode* node) {
        triton::usize size = node->getChildren().size();

        stream << "(xor";
        for (triton::usize index = 0; index < size; index++)
          stream << " " << node->getChildren()[index];
        stream << ")";

        return stream;
      }


      /* reference representation */
      std::ostream& AstSmtRepresentation::print(std::ostream& stream, triton::ast::ReferenceNode* node) {
        stream << node->getSymbolicExpression()->getFormattedId();
        return stream;
      }


      /* select representation */
      std::ostream& AstSmtRepresentation::print(std::ostream& stream, triton::ast::SelectNode* node) {
        stream << "(select " << node->getChildren()[0] << " " << node->getChildren()[1] << ")";
        return stream;
      }


      /* store representation */
      std::ostream& AstSmtRepresentation::print(std::ostream& stream, triton::ast::StoreNode* node) {
        stream << "(store " << node->getChildren()[0] << " " << node->getChildren()[1] << " " << node->getChildren()[2] << ")";
        return stream;
      }


      /* string representation */
      std::ostream& AstSmtRepresentation::print(std::ostream& stream, triton::ast::StringNode* node) {
        stream << node->getString();
        return stream;
      }


      /* sx representation */
      std::ostream& AstSmtRepresentation::print(std::ostream& stream, triton::ast::SxNode* node) {
        stream << "((_ sign_extend " << node->getChildren()[0] << ") " << node->getChildren()[1] << ")";
        return stream;
      }


      /* variable representation */
      std::ostream& AstSmtRepresentation::print(std::ostream& stream, triton::ast::VariableNode* node) {
        if (node->getSymbolicVariable()->getAlias().empty())
          stream << node->getSymbolicVariable()->getName();
        else
          stream << node->getSymbolicVariable()->getAlias();
        return stream;
      }


      /* zx representation */
      std::ostream& AstSmtRepresentation::print(std::ostream& stream, triton::ast::ZxNode* node) {
        stream << "((_ zero_extend " << node->getChildren()[0] << ") " << node->getChildren()[1] << ")";
        return stream;
      }

    };
  };
};
