/*
 *  Rcpp.cpp
 *  PCMBaseCpp
 *
 * Copyright 2017,2018 Venelin Mitov
 *
 * This file is part of PCMBaseCpp: A C++ backend for calculating the likelihood of phylogenetic comparative models.
 *
 * PCMBaseCpp is free software: you can redistribute it and/or modify
 * it under the terms of version 3 of the GNU General Public License as
 * published by the Free Software Foundation.
 *
 * PCMBaseCpp is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 *
 * You should have received a copy of the GNU General Public
 * License along with PCMBaseCpp.  If not, see
 * <http://www.gnu.org/licenses/>.
 *
 * @author Venelin Mitov
 */
#include <RcppArmadillo.h>

#include<vector>
#include<string>
#include<sstream>

#include "QuadraticPolyBM.h"
#include "QuadraticPolyMixedGaussian.h"

// [[Rcpp::plugins("cpp11")]]
// [[Rcpp::plugins(openmp)]]
// [[Rcpp::depends(RcppArmadillo)]]

void R_init_PCMBMkappa(DllInfo *info) {
   /* Register routines, allocate resources. */
   R_registerRoutines(info, NULL, NULL, NULL, NULL);
   R_useDynamicSymbols(info, TRUE);
}

void R_unload_PCMBMkappa(DllInfo *info) {
   /* Release resources. */
}

using namespace PCMBMkappa;
using namespace std;

SPLITT::Tree<uint, double>* CreatePCMBaseCppTree(Rcpp::List const& tree) {
  arma::umat branches = tree["edge"];
  SPLITT::uvec br_0 = arma::conv_to<SPLITT::uvec>::from(branches.col(0));
  SPLITT::uvec br_1 = arma::conv_to<SPLITT::uvec>::from(branches.col(1));
  SPLITT::vec t = Rcpp::as<SPLITT::vec>(tree["edge.length"]);
  return new SPLITT::Tree<uint, double>(br_0, br_1, t);
}

RCPP_MODULE(PCMBaseCpp__Tree) {
  Rcpp::class_<SPLITT::Tree<uint, double> > ( "PCMBaseCpp__Tree" )
  .factory<Rcpp::List const&>( &CreatePCMBaseCppTree )
  .property("num_nodes", &SPLITT::Tree<uint, double>::num_nodes )
  .property("num_tips", &SPLITT::Tree<uint, double>::num_tips )
  .method("LengthOfBranch", &SPLITT::Tree<uint, double>::LengthOfBranch )
  .method("FindNodeWithId", &SPLITT::Tree<uint, double>::FindNodeWithId )
  .method("FindIdOfNode", &SPLITT::Tree<uint, double>::FindIdOfNode )
  .method("FindIdOfParent", &SPLITT::Tree<uint, double>::FindIdOfParent )
  .method( "FindChildren", &SPLITT::Tree<uint, double>::FindChildren )
  .method("OrderNodes", &SPLITT::Tree<uint, double>::OrderNodes )
  ;
}

SPLITT::OrderedTree<uint, double>* CreatePCMBaseCppOrderedTree(Rcpp::List const& tree) {
  arma::umat branches = tree["edge"];
  SPLITT::uvec br_0 = arma::conv_to<SPLITT::uvec>::from(branches.col(0));
  SPLITT::uvec br_1 = arma::conv_to<SPLITT::uvec>::from(branches.col(1));
  SPLITT::vec t = Rcpp::as<SPLITT::vec>(tree["edge.length"]);
  return new SPLITT::OrderedTree<uint, double>(br_0, br_1, t);
}


RCPP_MODULE(PCMBaseCpp__OrderedTree) {
  Rcpp::class_<SPLITT::Tree<uint, double> > ( "PCMBaseCpp__Tree" )
  .factory<Rcpp::List const&>( &CreatePCMBaseCppTree )
  .property("num_nodes", &SPLITT::Tree<uint, double>::num_nodes )
  .property("num_tips", &SPLITT::Tree<uint, double>::num_tips )
  .method("LengthOfBranch", &SPLITT::Tree<uint, double>::LengthOfBranch )
  .method("FindNodeWithId", &SPLITT::Tree<uint, double>::FindNodeWithId )
  .method("FindIdOfNode", &SPLITT::Tree<uint, double>::FindIdOfNode )
  .method("FindIdOfParent", &SPLITT::Tree<uint, double>::FindIdOfParent )
  .method("FindChildren", &SPLITT::Tree<uint, double>::FindChildren )
  .method("OrderNodes", &SPLITT::Tree<uint, double>::OrderNodes )
  ;
  Rcpp::class_<SPLITT::OrderedTree<uint, double> >( "PCMBaseCpp__OrderedTree" )
    .derives<SPLITT::Tree<uint, double> > ( "PCMBaseCpp__Tree" )
    .factory<Rcpp::List const&>( &CreatePCMBaseCppOrderedTree )
    .property("num_levels", &SPLITT::OrderedTree<uint, double>::num_levels )
    .property("num_parallel_ranges_prune", &SPLITT::OrderedTree<uint, double>::num_parallel_ranges_prune )
    .property("ranges_id_visit", &SPLITT::OrderedTree<uint, double>::ranges_id_visit )
    .property("ranges_id_prune", &SPLITT::OrderedTree<uint, double>::ranges_id_prune )
  ;
}

struct ParsedRObjects {
  double threshold_SV;
  double threshold_EV;
  double threshold_skip_singular;
  double threshold_Lambda_ij;
  double NA_double_;
  
  bool skip_singular;
  bool transpose_Sigma_x;
  arma::mat const& X;
  arma::cube VE;
  Rcpp::List pcListInt;
  std::vector<arma::uvec> Pc;
  SPLITT::uvec br_0;
  SPLITT::uvec br_1; 
  SPLITT::vec t;
  
  uint RModel;
  std::vector<arma::uword> regimes;
  std::vector<arma::u8> jumps;
  
  SPLITT::uint num_tips;
  SPLITT::uint num_branches;
  SPLITT::uvec tip_names;
  
  ParsedRObjects(
    arma::mat const& X, 
    Rcpp::List const& tree, 
    Rcpp::List const& model,
    Rcpp::List const& metaInfo):
    
    threshold_SV(static_cast<double>(metaInfo["PCMBase.Threshold.SV"])),
    threshold_EV(static_cast<double>(metaInfo["PCMBase.Threshold.EV"])),
    threshold_skip_singular(static_cast<double>(metaInfo["PCMBase.Threshold.Skip.Singular"])), 
    threshold_Lambda_ij(static_cast<double>(metaInfo["PCMBase.Threshold.Lambda_ij"])),
    NA_double_(static_cast<double>(metaInfo["NA_double_"])),
    skip_singular(static_cast<int>(metaInfo["PCMBase.Skip.Singular"])),
    transpose_Sigma_x(static_cast<int>(metaInfo["PCMBase.Transpose.Sigma_x"])),
    X(X),
    VE(Rcpp::as<arma::cube>(metaInfo["VE"])),
    pcListInt(Rcpp::as<Rcpp::List>(metaInfo["pcListInt"])), 
    Pc(Rcpp::as<arma::uword>(metaInfo["M"])),
    RModel(Rcpp::as<uint>(metaInfo["RModel"])),
    regimes(Rcpp::as<vector<arma::uword> >(metaInfo["r"])),
    jumps(Rcpp::as<vector<arma::u8> >(metaInfo["xi"])), 
    num_tips(Rcpp::as<Rcpp::CharacterVector>(tree["tip.label"]).size()),
    tip_names(SPLITT::Seq(static_cast<SPLITT::uint>(1), num_tips)) {
    
    if(threshold_SV < 0) {
      ostringstream os;
      os<<"Rcpp.cpp:ParsedRObjects:: The argument threshold_SV should be non-negative real number.";
      throw invalid_argument(os.str());
    }
    if(threshold_EV < 0) {
      ostringstream os;
      os<<"Rcpp.cpp:ParsedRObjects:: The argument threshold_EV should be non-negative real number.";
      throw invalid_argument(os.str());
    }
    if(threshold_Lambda_ij < 0) {
      ostringstream os;
      os<<"ERR:03825:PCMBaseCpp:Rcpp.cpp:ParsedRObjects:: The argument threshold_Lambda_ij should be non-negative double.";
      throw invalid_argument(os.str());
    }
    
    for(arma::uword i = 0; i < Pc.size(); ++i) {
      Pc[i] = Rcpp::as<arma::uvec>(pcListInt[i]);
    }
    
    arma::umat branches = tree["edge"];
    br_0 = arma::conv_to<SPLITT::uvec>::from(branches.col(0));
    br_1 = arma::conv_to<SPLITT::uvec>::from(branches.col(1));
    t = Rcpp::as<SPLITT::vec>(tree["edge.length"]);
    num_branches = branches.n_rows;
    
    using namespace std;
    
    if(regimes.size() != branches.n_rows) {
      ostringstream os;
      os<<"ERR:03821:PCMBaseCpp:Rcpp.cpp:ParsedRObjects:: The slot r in metaInfo has different length ("<<regimes.size()<<
        ") than the number of edges ("<<branches.n_rows<<").";
      throw logic_error(os.str());
    }
    
    if(jumps.size() != branches.n_rows) {
      ostringstream os;
      os<<"ERR:03822:PCMBaseCpp:Rcpp.cpp:ParsedRObjects:: The slot jumps in trees has different length ("<<jumps.size()<<
        ") than the number of edges ("<<branches.n_rows<<").";
      throw logic_error(os.str());
    }
  }
};

QuadraticPolyBM* CreateQuadraticPolyBM(
    arma::mat const& X, 
    Rcpp::List const& tree, 
    Rcpp::List const& model,
    Rcpp::List const& metaInfo) { 
    
  ParsedRObjects pObjs(X, tree, model, metaInfo);
  
  vector<typename QuadraticPolyBM::LengthType> lengths(pObjs.num_branches);
  
  for(arma::uword i = 0; i < pObjs.num_branches; ++i) {
    lengths[i].length_ = pObjs.t[i];
    lengths[i].regime_ = pObjs.regimes[i] - 1;
  }
  
  typename QuadraticPolyBM::DataType data(
      pObjs.tip_names, pObjs.X, pObjs.VE, pObjs.Pc, pObjs.RModel, 
      std::vector<std::string>(), 
      pObjs.threshold_SV, pObjs.threshold_EV, 
      pObjs.threshold_skip_singular, pObjs.skip_singular,
      pObjs.transpose_Sigma_x,
      pObjs.threshold_Lambda_ij,
      pObjs.NA_double_);
  
  return new QuadraticPolyBM(pObjs.br_0, pObjs.br_1, lengths, data);
}

//RCPP_EXPOSED_CLASS_NODECL(QuadraticPolyBM::TreeType)
RCPP_EXPOSED_CLASS_NODECL(QuadraticPolyBM::AlgorithmType)
  
RCPP_MODULE(PCMBaseCpp__QuadraticPolyBM) {
  Rcpp::class_<QuadraticPolyBM::TreeType::Tree> ( "PCMBaseCpp__QuadraticPolyBM_Tree" )
  .property("num_nodes", &QuadraticPolyBM::TreeType::Tree::num_nodes )
  .property("num_tips", &QuadraticPolyBM::TreeType::Tree::num_tips )
  .method("FindNodeWithId", &QuadraticPolyBM::TreeType::Tree::FindNodeWithId )
  .method("FindIdOfNode", &QuadraticPolyBM::TreeType::Tree::FindIdOfNode )
  .method("FindIdOfParent", &QuadraticPolyBM::TreeType::Tree::FindIdOfParent )
  .method("OrderNodes", &QuadraticPolyBM::TreeType::Tree::OrderNodes )
  ;
  Rcpp::class_<QuadraticPolyBM::TreeType>( "PCMBaseCpp__QuadraticPolyBM_OrderedTree" )
    .derives<QuadraticPolyBM::TreeType::Tree> ( "PCMBaseCpp__QuadraticPolyBM_Tree" )
    .method("RangeIdPruneNode", &QuadraticPolyBM::TreeType::RangeIdPruneNode )
    .method("RangeIdVisitNode", &QuadraticPolyBM::TreeType::RangeIdVisitNode )
    .property("num_levels", &QuadraticPolyBM::TreeType::num_levels )
    .property("ranges_id_visit", &QuadraticPolyBM::TreeType::ranges_id_visit )
    .property("ranges_id_prune", &QuadraticPolyBM::TreeType::ranges_id_prune )
  ;
  Rcpp::class_<QuadraticPolyBM::AlgorithmType::ParentType>( "PCMBaseCpp__QuadraticPolyBM_TraversalAlgorithm" )
    .property( "VersionOPENMP", &QuadraticPolyBM::AlgorithmType::ParentType::VersionOPENMP )
    .property( "NumOmpThreads", &QuadraticPolyBM::AlgorithmType::NumOmpThreads )
  ;
  Rcpp::class_<QuadraticPolyBM::AlgorithmType> ( "PCMBaseCpp__QuadraticPolyBM_ParallelPruning" )
    .derives<QuadraticPolyBM::AlgorithmType::ParentType>( "PCMBaseCpp__QuadraticPolyBM_TraversalAlgorithm" )
    .method( "ModeAutoStep", &QuadraticPolyBM::AlgorithmType::ModeAutoStep )
    .property( "ModeAutoCurrent", &QuadraticPolyBM::AlgorithmType::ModeAutoCurrent )
    .property( "IsTuning", &QuadraticPolyBM::AlgorithmType::IsTuning )
    .property( "min_size_chunk_visit", &QuadraticPolyBM::AlgorithmType::min_size_chunk_visit )
    .property( "min_size_chunk_prune", &QuadraticPolyBM::AlgorithmType::min_size_chunk_prune )
    .property( "durations_tuning", &QuadraticPolyBM::AlgorithmType::durations_tuning )
    .property( "fastest_step_tuning", &QuadraticPolyBM::AlgorithmType::fastest_step_tuning )
  ;
  Rcpp::class_<QuadraticPolyBM>( "PCMBaseCpp__QuadraticPolyBM" )
    .factory<arma::mat const&, Rcpp::List const&, Rcpp::List const&>(&CreateQuadraticPolyBM)
    .method( "TraverseTree", &QuadraticPolyBM::TraverseTree )
    .method( "StateAtNode", &QuadraticPolyBM::StateAtNode )
    .property( "tree", &QuadraticPolyBM::tree )
    .property( "algorithm", &QuadraticPolyBM::algorithm )
  ;
}




QuadraticPolyMixedGaussian* CreateQuadraticPolyMixedGaussian(
    arma::mat const& X,
    Rcpp::List const& tree,
    Rcpp::List const& model,
    Rcpp::List const& metaInfo,
    std::vector<std::string> const& regimeModels) {
  
  ParsedRObjects pObjs(X, tree, model, metaInfo);
  
  vector<typename QuadraticPolyMixedGaussian::LengthType> lengths(pObjs.num_branches);
  
  for(arma::uword i = 0; i < pObjs.num_branches; ++i) {
    lengths[i].length_ = pObjs.t[i];
    lengths[i].regime_ = pObjs.regimes[i] - 1;
    lengths[i].jump_ = pObjs.jumps[i];
  }
  
  typename QuadraticPolyMixedGaussian::DataType data(
      pObjs.tip_names, pObjs.X, pObjs.VE, pObjs.Pc, 
      pObjs.RModel, 
      regimeModels,
      pObjs.threshold_SV, pObjs.threshold_EV, 
      pObjs.threshold_skip_singular, pObjs.skip_singular,
      pObjs.transpose_Sigma_x,
      pObjs.threshold_Lambda_ij,
      pObjs.NA_double_);
  
  return new QuadraticPolyMixedGaussian(pObjs.br_0, pObjs.br_1, lengths, data);
}

  RCPP_EXPOSED_CLASS_NODECL(QuadraticPolyMixedGaussian::AlgorithmType)
  
  RCPP_MODULE(PCMBaseCpp__QuadraticPolyMixedGaussian) {
    Rcpp::class_<QuadraticPolyMixedGaussian::TreeType::Tree> ( "PCMBaseCpp__QuadraticPolyMixedGaussian_Tree" )
    .property("num_nodes", &QuadraticPolyMixedGaussian::TreeType::Tree::num_nodes )
    .property("num_tips", &QuadraticPolyMixedGaussian::TreeType::Tree::num_tips )
    .method("FindNodeWithId", &QuadraticPolyMixedGaussian::TreeType::Tree::FindNodeWithId )
    .method("FindIdOfNode", &QuadraticPolyMixedGaussian::TreeType::Tree::FindIdOfNode )
    .method("FindIdOfParent", &QuadraticPolyMixedGaussian::TreeType::Tree::FindIdOfParent )
    .method("OrderNodes", &QuadraticPolyMixedGaussian::TreeType::Tree::OrderNodes )
    ;
    Rcpp::class_<QuadraticPolyMixedGaussian::TreeType>( "PCMBaseCpp__QuadraticPolyMixedGaussian_OrderedTree" )
      .derives<QuadraticPolyMixedGaussian::TreeType::Tree> ( "PCMBaseCpp__QuadraticPolyMixedGaussian_Tree" )
      .method("RangeIdPruneNode", &QuadraticPolyMixedGaussian::TreeType::RangeIdPruneNode )
      .method("RangeIdVisitNode", &QuadraticPolyMixedGaussian::TreeType::RangeIdVisitNode )
      .property("num_levels", &QuadraticPolyMixedGaussian::TreeType::num_levels )
      .property("ranges_id_visit", &QuadraticPolyMixedGaussian::TreeType::ranges_id_visit )
      .property("ranges_id_prune", &QuadraticPolyMixedGaussian::TreeType::ranges_id_prune )
    ;
    Rcpp::class_<QuadraticPolyMixedGaussian::AlgorithmType::ParentType>( "PCMBaseCpp__QuadraticPolyMixedGaussian_TraversalAlgorithm" )
      .property( "VersionOPENMP", &QuadraticPolyMixedGaussian::AlgorithmType::ParentType::VersionOPENMP )
      .property( "NumOmpThreads", &QuadraticPolyMixedGaussian::AlgorithmType::NumOmpThreads )
    ;
    Rcpp::class_<QuadraticPolyMixedGaussian::AlgorithmType> ( "PCMBaseCpp__QuadraticPolyMixedGaussian_ParallelPruning" )
      .derives<QuadraticPolyMixedGaussian::AlgorithmType::ParentType>( "PCMBaseCpp__QuadraticPolyMixedGaussian_TraversalAlgorithm" )
      .method( "ModeAutoStep", &QuadraticPolyMixedGaussian::AlgorithmType::ModeAutoStep )
      .property( "ModeAutoCurrent", &QuadraticPolyMixedGaussian::AlgorithmType::ModeAutoCurrent )
      .property( "IsTuning", &QuadraticPolyMixedGaussian::AlgorithmType::IsTuning )
      .property( "min_size_chunk_visit", &QuadraticPolyMixedGaussian::AlgorithmType::min_size_chunk_visit )
      .property( "min_size_chunk_prune", &QuadraticPolyMixedGaussian::AlgorithmType::min_size_chunk_prune )
      .property( "durations_tuning", &QuadraticPolyMixedGaussian::AlgorithmType::durations_tuning )
      .property( "fastest_step_tuning", &QuadraticPolyMixedGaussian::AlgorithmType::fastest_step_tuning )
    ;
    Rcpp::class_<QuadraticPolyMixedGaussian>( "PCMBaseCpp__QuadraticPolyMixedGaussian" )
      .factory<arma::mat const&, Rcpp::List const&, Rcpp::List const&>(&CreateQuadraticPolyMixedGaussian)
      .method( "TraverseTree", &QuadraticPolyMixedGaussian::TraverseTree )
      .method( "StateAtNode", &QuadraticPolyMixedGaussian::StateAtNode )
      .property( "tree", &QuadraticPolyMixedGaussian::tree )
      .property( "algorithm", &QuadraticPolyMixedGaussian::algorithm )
    ;
  }


