// [[Rcpp::depends(RcppArmadillo)]]

#include <RcppArmadillo.h>
#include <math.h>
#include <unordered_set>
#include <string>
#include <algorithm>
#include <stack>

using namespace Rcpp;
using namespace std;
using namespace arma;

template <typename T>
T sortByDimNames(const T m);

typedef unsigned int uint;


// Returns whether a Markov chain is ergodic
// Declared in this same file
bool isIrreducible(S4 obj);

// Declared in utils.cpp
bool anyElement(const mat& matrix, bool (*condition)(const double&));

// Declared in utils.cpp
bool allElements(const mat& matrix, bool (*condition)(const double&));

// Declared in utils.cpp
bool approxEqual(const cx_double& a, const cx_double& b);

// [[Rcpp::export(.commClassesKernelRcpp)]]
List commClassesKernel(NumericMatrix P) {
  // The matrix must be stochastic by rows
  unsigned int numStates = P.ncol();
  CharacterVector stateNames = rownames(P);
  int numReachable;
  int classSize;
  
  // The entry (i,j) of this matrix is true iff we can reach j from i
  vector<vector<bool>> communicates(numStates, vector<bool>(numStates, false));
  vector<list<int>> adjacencies(numStates);
  
  // We fill the adjacencies matrix for the graph
  // A state j is in the adjacency of i iff P(i, j) > 0
  for (int i = 0; i < numStates; ++i)
    for (int j = 0; j < numStates; ++j)
      if (P(i, j) > 0)
        adjacencies[i].push_back(j);


  // Backtrack from all the states to find which
  // states communicate with a given on
  // O(n³) where n is the number of states
  for (int i = 0; i < numStates; ++i) {
    stack<int> notVisited;
    notVisited.push(i);
    
    while (!notVisited.empty()) {
      int j = notVisited.top();
      notVisited.pop();
      communicates[i][j] = true;
      
      for (int k: adjacencies[j])
        if (!communicates[i][k])
          notVisited.push(k);
    }
  }
  
  LogicalMatrix classes(numStates, numStates);
  classes.attr("dimnames") = List::create(stateNames, stateNames);
  // v populated with FALSEs
  LogicalVector closed(numStates);
  closed.names() = stateNames;
  
  for (int i = 0; i < numStates; ++i) {
    numReachable = 0;
    classSize = 0;
    
    /* We mark i and j as the same communicating class iff we can reach the
       state j from i and the state i from j
       We count the size of the communicating class of i (i is fixed here),
       and if it matches the number of states that can be reached from i,
       then the class is closed
    */
    for (int j = 0; j < numStates; ++j) {
      classes(i, j) = communicates[i][j] && communicates[j][i];
      
      if (classes(i,j))
        classSize += 1;

      // Number of states reachable from i
      if (communicates[i][j])
        numReachable += 1;
    }
    
    if (classSize == numReachable)
      closed(i) = true;
  }
  
  return List::create(_["classes"] = classes, _["closed"] = closed);
}

// Wrapper that computes the communicating states from the matrix generated by 
// commClassesKernel (a matrix where an entry i,j is TRUE iff i and j are in the
// same communicating class). It also needs the list of states names from the
// Markov Chain
List computeCommunicatingClasses(LogicalMatrix& commClasses, CharacterVector& states) {
  int numStates = states.size();
  vector<bool> computed(numStates, false);
  List classesList;
  
  for (int i = 0; i < numStates; ++i) {
    CharacterVector currentClass;
    
    if (!computed[i]) {
      for (int j = 0; j < numStates; ++j) {
        if (commClasses(i, j)) {
          currentClass.push_back(states[j]);
          computed[j] = true;
        }
      }
      
      classesList.push_back(currentClass);
    }
  }
  
  return classesList;
}

// [[Rcpp::export(.communicatingClassesRcpp)]]
List communicatingClasses(S4 object) {
  // Returns the underlying communicating classes
  // It is indifferent if the matrices are stochastic by rows or columns
  NumericMatrix transitionMatrix = object.slot("transitionMatrix");
  bool byrow = object.slot("byrow");
  CharacterVector states = object.slot("states");
  
  if (!byrow)
    transitionMatrix = transpose(transitionMatrix);
  
  List commClassesList = commClassesKernel(transitionMatrix);
  LogicalMatrix commClasses = commClassesList["classes"];
  
  return computeCommunicatingClasses(commClasses, states);
}

// Wrapper that computes the transient states from a list of the states and a
// vector indicating whether the communicating class for each state is closed
CharacterVector computeTransientStates(CharacterVector& states, LogicalVector& closedClass) {
  CharacterVector transientStates;
  
  for (int i = 0; i < states.size(); i++)
    if (!closedClass[i])
      transientStates.push_back(states[i]);
    
  return transientStates;
}

// Wrapper that computes the recurrent states from a list of states and a
// vector indicating whether the communicating class for each state is closed
CharacterVector computeRecurrentStates(CharacterVector& states, LogicalVector& closedClass) {
  CharacterVector recurrentStates;
  
  for (int i = 0; i < states.size(); i++)
    if (closedClass[i])
      recurrentStates.push_back(states[i]);
    
  return recurrentStates;
}

// [[Rcpp::export(.transientStatesRcpp)]]
CharacterVector transientStates(S4 object) {
  NumericMatrix transitionMatrix = object.slot("transitionMatrix");
  bool byrow = object.slot("byrow");
  
  if (!byrow)
    transitionMatrix = transpose(transitionMatrix);
  
  List commKernel = commClassesKernel(transitionMatrix);
  LogicalVector closed = commKernel["closed"];
  CharacterVector states = object.slot("states");

  return computeTransientStates(states, closed);
}

// [[Rcpp::export(.recurrentStatesRcpp)]]
CharacterVector recurrentStates(S4 object) {
  NumericMatrix transitionMatrix = object.slot("transitionMatrix");
  bool byrow = object.slot("byrow");
  CharacterVector states = object.slot("states");
  
  if (!byrow)
    transitionMatrix = transpose(transitionMatrix);
  
  List commKernel = commClassesKernel(transitionMatrix);
  LogicalVector closed = commKernel["closed"];
  
  return computeRecurrentStates(states, closed);
}

// Wrapper that computes the recurrent classes from the matrix given by 
// commClassesKernel (which entry i,j is TRUE iff i and j are in the same
// communicating class), a vector indicating wheter the class for state is
// closed and the states of the Markov Chain
List computeRecurrentClasses(LogicalMatrix& commClasses, 
                             LogicalVector& closedClass, 
                             CharacterVector& states) {
  int numStates = states.size();
  vector<bool> computed(numStates, false);
  List recurrentClassesList;
  bool isRecurrentClass;
  
  for (int i = 0; i < numStates; ++i) {
    CharacterVector currentClass;
    isRecurrentClass = closedClass(i) && !computed[i];
    
    if (isRecurrentClass) {
      for (int j = 0; j < numStates; ++j) {
        if (commClasses(i, j)) {
          currentClass.push_back(states[j]);
          computed[j] = true;
        }
      }
      
      recurrentClassesList.push_back(currentClass);
    }
  }
  
  return recurrentClassesList;
}

// returns the recurrent classes
// [[Rcpp::export(.recurrentClassesRcpp)]]
List recurrentClasses(S4 object) {
  NumericMatrix transitionMatrix = object.slot("transitionMatrix");
  bool byrow = object.slot("byrow");
  CharacterVector states = object.slot("states");
  
  if (!byrow)
    transitionMatrix = transpose(transitionMatrix);
  
  List commClassesList = commClassesKernel(transitionMatrix);
  LogicalMatrix commClasses = commClassesList["classes"];
  LogicalVector closed = commClassesList["closed"];
  
  return computeRecurrentClasses(commClasses, closed, states);
}

// Wrapper that computes the transient classes from the matrix given by 
// commClassesKernel (which entry i,j is TRUE iff i and j are in the same
// communicating class), a vector indicating wheter the class for state is
// closed and the states of the Markov Chain
List computeTransientClasses(LogicalMatrix& commClasses, 
                             LogicalVector& closedClass, 
                             CharacterVector& states) {
  int numStates = states.size();
  vector<bool> computed(numStates, false);
  List transientClassesList;
  bool isTransientClass;
  
  for (int i = 0; i < numStates; ++i) {
    CharacterVector currentClass;
    isTransientClass = !closedClass(i) && !computed[i];
    
    if (isTransientClass) {
      for (int j = 0; j < numStates; ++j) {
        if (commClasses(i, j)) {
          currentClass.push_back(states[j]);
          computed[j] = true;
        }
      }
      
      transientClassesList.push_back(currentClass);
    }
  }
  
  return transientClassesList;
}

// returns the transient classes
// [[Rcpp::export(.transientClassesRcpp)]]
List transientClasses(S4 object) {
  NumericMatrix transitionMatrix = object.slot("transitionMatrix");
  bool byrow = object.slot("byrow");
  CharacterVector states = object.slot("states");
  
  if (!byrow)
    transitionMatrix = transpose(transitionMatrix);
  
  List commClassesList = commClassesKernel(transitionMatrix);
  LogicalMatrix commClasses = commClassesList["classes"];
  LogicalVector closed = commClassesList["closed"];
  
  return computeTransientClasses(commClasses, closed, states);
}


// Defined in probabilistic.cpp
mat matrixPow(const mat& A, int n);


// [[Rcpp::export(.reachabilityMatrixRcpp)]]
LogicalMatrix reachabilityMatrix(S4 obj) {
  NumericMatrix matrix = obj.slot("transitionMatrix");
  
  // Reachability matrix
  int m = matrix.nrow();
  mat X(matrix.begin(), m, m, true);
  mat reachability = eye(m, m) + sign(X);
  reachability = matrixPow(reachability, m - 1);
  LogicalMatrix result = wrap(reachability > 0);
  result.attr("dimnames") = matrix.attr("dimnames");

  return result;
}

// [[Rcpp::export(.isAccessibleRcpp)]]
bool isAccessible(S4 obj, String from, String to) {
  NumericMatrix probs = obj.slot("transitionMatrix");
  CharacterVector states = obj.slot("states");
  int fromPos = -1, toPos = -1;
  bool byrow = obj.slot("byrow");
  int m = probs.ncol();
  
  // Compute indices for states from and pos
  for (int i = 0; i < m; ++i) {
    if (states[i] == from)
      fromPos = i;
    if (states[i] == to)
      toPos = i;
  }
  
  if (fromPos == -1 || toPos == -1)
    stop("Please give valid states method");
  
  stack<int> toExplore;
  toExplore.push(fromPos);
  vector<int> visited(m, false);
  visited[fromPos] = true;
  bool isReachable = false;
  
  // DFS until we hit 'to' state or we cannot traverse to more states
  while (!toExplore.empty() && !isReachable) {
    int i = toExplore.top();
    toExplore.pop();
    visited[i] = true;
    isReachable = i == toPos;

    for (int j = 0; j < m; ++j)
      if (((byrow && !approxEqual(probs(i, j), 0)) || (!byrow && !approxEqual(probs(j, i), 0))) 
          && !visited[j])
        toExplore.push(j);
  }
  
  return isReachable;
}


// summary of markovchain object
// [[Rcpp::export(.summaryKernelRcpp)]]
List summaryKernel(S4 object) {
  NumericMatrix transitionMatrix = object.slot("transitionMatrix");
  bool byrow = object.slot("byrow");
  CharacterVector states = object.slot("states");
  
  if (!byrow)
    transitionMatrix = transpose(transitionMatrix);
  
  List commClassesList = commClassesKernel(transitionMatrix);
  LogicalMatrix commClasses = commClassesList["classes"];
  LogicalVector closed = commClassesList["closed"];
  List recurrentClasses = computeRecurrentClasses(commClasses, closed, states);
  List transientClasses = computeTransientClasses(commClasses, closed, states);
  
  List summaryResult = List::create(_["closedClasses"]    = recurrentClasses,
                                    _["recurrentClasses"] = recurrentClasses,
                                    _["transientClasses"] = transientClasses);
  
  return(summaryResult);
}

//here the kernel function to compute the first passage
// [[Rcpp::export(.firstpassageKernelRcpp)]]
NumericMatrix firstpassageKernel(NumericMatrix P, int i, int n) {
  arma::mat G = as<arma::mat>(P);
  arma::mat Pa = G;
  arma::mat H(n, P.ncol()); 
  
  //here Thoralf suggestion
  //initializing the first row
  for (unsigned int j = 0; j < G.n_cols; j++)
    H(0, j) = G(i-1, j);
  
  arma::mat E = 1 - arma::eye(P.ncol(), P.ncol());

  for (int m = 1; m < n; m++) {
    G = Pa * (G%E);
    
    for (unsigned int j = 0; j < G.n_cols; j ++) 
      H(m, j) = G(i-1, j);
  }
  
  NumericMatrix R = wrap(H);
  
  return R;
}



// [[Rcpp::export(.firstPassageMultipleRCpp)]]
NumericVector firstPassageMultipleRCpp(NumericMatrix P,int i, NumericVector setno, int n) {
  arma::mat G = as<arma::mat>(P);
  arma::mat Pa = G;
  arma::vec H = arma::zeros(n); //here Thoralf suggestion
  unsigned int size = setno.size();
  //initializing the first row
  for (unsigned int k = 0; k < size; k++) {
    H[0] += G(i-1, setno[k]-1);
  }
  
  arma::mat E = 1 - arma::eye(P.ncol(), P.ncol());
  
  for (int m = 1; m < n; m++) {
    G = Pa * (G%E);
    
    for (unsigned int k = 0; k < size; k++) {
      H[m] += G(i-1, setno[k]-1);
    }
  }
  
  NumericVector R = wrap(H);
  
  return R;
}

// [[Rcpp::export(.expectedRewardsRCpp)]]
NumericVector expectedRewardsRCpp(NumericMatrix matrix, int n, NumericVector rewards) {
  // initialises output vector
  NumericVector out;
  
  // gets no of states
  int no_of_states = matrix.ncol();
  
  // initialises armadillo matrices and vectors
  arma::vec temp = arma::zeros(no_of_states);
  arma::mat matr = as<arma::mat>(matrix);
  arma::vec v = arma::zeros(no_of_states);
  
  // initialses the vector for the base case of dynamic programming expression
  for (int i=0;i<no_of_states;i++) {
    temp[i] = rewards[i];
    v[i] = rewards[i];
  }
  
  // v(n, u) = r + [P]v(n−1, u);
  for (int i=0;i<n;i++) {
    temp = v + matr*temp;
  }
  
  // gets output in form of NumericVector
  out = wrap(temp);
  
  return out;
}

// [[Rcpp::export(.expectedRewardsBeforeHittingARCpp)]]
double expectedRewardsBeforeHittingARCpp(NumericMatrix matrix,int s0,
                               NumericVector rewards, int n ) {
  float result = 0.0;
  int size = rewards.size();
  arma::mat matr = as<arma::mat>(matrix);
  arma::mat temp = as<arma::mat>(matrix);
  arma::vec r = as<arma::vec>(rewards);
  arma::mat I = arma::zeros(1,size);
  
  I(0,s0-1) = 1;
  
  for (int j = 0; j < n; j++) {
    arma::mat res = I*(temp*r);
    result = result + res(0,0);
    temp = temp*matr;
  }
  
  return result;
}
  

// greatest common denominator
// [[Rcpp::export(.gcdRcpp)]]
int gcd (int a, int b) {
  int c;
  a = abs(a);
  b = abs(b);

  while ( a != 0 ) {
    c = a; a = b%a;  b = c;
  }
  
  return b;
}

// function to get the period of a DTMC

//' @rdname structuralAnalysis
//' 
//' @export
// [[Rcpp::export(period)]]
int period(S4 object) {
  bool irreducible = isIrreducible(object);
  
  if (!irreducible) {
    warning("The matrix is not irreducible");
    return 0;
  } else {
    NumericMatrix P = object.slot("transitionMatrix");
    int n = P.ncol();
    std::vector<double> r, T(1), w;
    int d = 0, m = T.size(), i = 0, j = 0;
    
    if (n > 0) {
      arma::vec v(n);
      v[0] = 1;
      
      while (m>0 && d!=1) {
        i = T[0];
        T.erase(T.begin());
        w.push_back(i);
        j = 0;
        
        while (j < n) {
          if (P(i,j) > 0) {
            r.insert(r.end(), w.begin(), w.end());
            r.insert(r.end(), T.begin(), T.end());
            double k = 0;
            
            for (std::vector<double>::iterator it = r.begin(); it != r.end(); it ++) 
              if (*it == j) k ++;
            
            if (k > 0) {
               int b = v[i] + 1 - v[j];
               d = gcd(d, b);
            } else {
              T.push_back(j);
              v[j] = v[i] + 1;
            }
          }
          j++;
        }
        m = T.size();
      }
    }
    
    // v = v - floor(v/d)*d;
    return d;
  }
}

//' @title predictiveDistribution
//'
//' @description The function computes the probability of observing a new data
//'   set, given a data set
//' @usage predictiveDistribution(stringchar, newData, hyperparam = matrix())
//'
//' @param stringchar This is the data using which the Bayesian inference is
//'   performed.
//' @param newData This is the data whose predictive probability is computed.
//' @param hyperparam This determines the shape of the prior distribution of the
//'   parameters. If none is provided, default value of 1 is assigned to each
//'   parameter. This must be of size kxk where k is the number of states in the
//'   chain and the values should typically be non-negative integers.
//' @return The log of the probability is returned.
//'
//' @details The underlying method is Bayesian inference. The probability is
//'   computed by averaging the likelihood of the new data with respect to the
//'   posterior. Since the method assumes conjugate priors, the result can be
//'   represented in a closed form (see the vignette for more details), which is
//'   what is returned.
//' @references 
//' Inferring Markov Chains: Bayesian Estimation, Model Comparison, Entropy Rate, 
//' and Out-of-Class Modeling, Christopher C. Strelioff, James P.
//' Crutchfield, Alfred Hubler, Santa Fe Institute
//' 
//' Yalamanchi SB, Spedicato GA (2015). Bayesian Inference of First Order Markov 
//' Chains. R package version 0.2.5
//' 
//' @author Sai Bhargav Yalamanchi
//' @seealso \code{\link{markovchainFit}}
//' @examples
//' sequence<- c("a", "b", "a", "a", "a", "a", "b", "a", "b", "a", "b", "a", "a", 
//'              "b", "b", "b", "a")
//' hyperMatrix<-matrix(c(1, 2, 1, 4), nrow = 2,dimnames=list(c("a","b"),c("a","b")))
//' predProb <- predictiveDistribution(sequence[1:10], sequence[11:17], hyperparam =hyperMatrix )
//' hyperMatrix2<-hyperMatrix[c(2,1),c(2,1)]
//' predProb2 <- predictiveDistribution(sequence[1:10], sequence[11:17], hyperparam =hyperMatrix2 )
//' predProb2==predProb
//' @export
//' 
// [[Rcpp::export]]
double predictiveDistribution(CharacterVector stringchar, CharacterVector newData, NumericMatrix hyperparam = NumericMatrix()) {
  // construct list of states
  CharacterVector elements = stringchar;
  
  for (int i = 0; i < newData.size(); i++)
    elements.push_back(newData[i]);
  
  elements = unique(elements).sort();
  int sizeMatr = elements.size();
  
  // if no hyperparam argument provided, use default value of 1 for all 
  if (hyperparam.nrow() == 1 && hyperparam.ncol() == 1) {
    NumericMatrix temp(sizeMatr, sizeMatr);
    temp.attr("dimnames") = List::create(elements, elements);
    
    for (int i = 0; i < sizeMatr; i++)
      for (int j = 0; j < sizeMatr; j++)
        temp(i, j) = 1;
    
    hyperparam = temp;
  }
  
  // validity check
  if (hyperparam.nrow() != hyperparam.ncol())
    stop("Dimensions of the hyperparameter matrix are inconsistent");
    
  if (hyperparam.nrow() < sizeMatr)
    stop("Hyperparameters for all state transitions must be provided");
    
  List dimNames = hyperparam.attr("dimnames");
  CharacterVector colNames = dimNames[1];
  CharacterVector rowNames = dimNames[0];
  int sizeHyperparam = hyperparam.ncol();
  CharacterVector sortedColNames(sizeHyperparam), sortedRowNames(sizeHyperparam);
  
  for (int i = 0; i < sizeHyperparam; i++)
    sortedColNames(i) = colNames(i), sortedRowNames(i) = rowNames(i);

  sortedColNames.sort();
  sortedRowNames.sort();
  
  for (int i = 0; i < sizeHyperparam; i++) {
    if (i > 0 && (sortedColNames(i) == sortedColNames(i-1) || sortedRowNames(i) == sortedRowNames(i-1)))  
      stop("The states must all be unique");
    else if (sortedColNames(i) != sortedRowNames(i))
      stop("The set of row names must be the same as the set of column names");
    
    bool found = false;
    
    for (int j = 0; j < sizeMatr; j++)
      if (elements(j) == sortedColNames(i))
        found = true;
    // hyperparam may contain states not in stringchar
    if (!found)  elements.push_back(sortedColNames(i));
  }
  
  // check for the case where hyperparam has missing data
  for (int i = 0; i < sizeMatr; i++) {
    bool found = false;
    
    for (int j = 0; j < sizeHyperparam; j++)
      if (sortedColNames(j) == elements(i))
        found = true;
    
    if (!found)
      stop("Hyperparameters for all state transitions must be provided");
  }   
      
  elements = elements.sort();
  sizeMatr = elements.size();
  
  for (int i = 0; i < sizeMatr; i++)
    for (int j = 0; j < sizeMatr; j++)
      if (hyperparam(i, j) < 1.)
        stop("The hyperparameter elements must all be greater than or equal to 1");
  
  // permute the elements of hyperparam such that the row, column names are sorted
  hyperparam = sortByDimNames(hyperparam);
  
  NumericMatrix freqMatr(sizeMatr), newFreqMatr(sizeMatr);

  double predictiveDist = 0.; // log of the predictive probability

  // populate frequeny matrix for old data; this is used for inference 
  int posFrom = 0, posTo = 0;
  
  for (int i = 0; i < stringchar.size() - 1; i ++) {
    for (int j = 0; j < sizeMatr; j ++) {
      if (stringchar[i] == elements[j]) posFrom = j;
      if (stringchar[i + 1] == elements[j]) posTo = j;
    }
    freqMatr(posFrom,posTo)++;
  }
  
  // frequency matrix for new data
  for (int i = 0; i < newData.size() - 1; i ++) {
    for (int j = 0; j < sizeMatr; j ++) {
      if (newData[i] == elements[j]) posFrom = j;
      if (newData[i + 1] == elements[j]) posTo = j;
    }
    newFreqMatr(posFrom,posTo)++;
  }
 
  for (int i = 0; i < sizeMatr; i++) {
    double rowSum = 0, newRowSum = 0, paramRowSum = 0;
    
    for (int j = 0; j < sizeMatr; j++) { 
      rowSum += freqMatr(i, j), newRowSum += newFreqMatr(i, j), paramRowSum += hyperparam(i, j);
      predictiveDist += lgamma(freqMatr(i, j) + newFreqMatr(i, j) + hyperparam(i, j)) -
                        lgamma(freqMatr(i, j) + hyperparam(i, j));
    }
    predictiveDist += lgamma(rowSum + paramRowSum) - lgamma(rowSum + newRowSum + paramRowSum);
  }

  return predictiveDist;
}


//' @title priorDistribution
//'
//' @description Function to evaluate the prior probability of a transition
//'   matrix. It is based on conjugate priors and therefore a Dirichlet
//'   distribution is used to model the transitions of each state.
//' @usage priorDistribution(transMatr, hyperparam = matrix())
//'
//' @param transMatr The transition matrix whose probability is the parameter of
//'   interest.
//' @param hyperparam The hyperparam matrix (optional). If not provided, a
//'   default value of 1 is assumed for each and therefore the resulting
//'   probability distribution is uniform.
//' @return The log of the probabilities for each state is returned in a numeric
//'   vector. Each number in the vector represents the probability (log) of
//'   having a probability transition vector as specified in corresponding the
//'   row of the transition matrix.
//'
//' @details The states (dimnames) of the transition matrix and the hyperparam
//'   may be in any order.
//' @references Yalamanchi SB, Spedicato GA (2015). Bayesian Inference of First
//' Order Markov Chains. R package version 0.2.5
//'
//' @author Sai Bhargav Yalamanchi, Giorgio Spedicato
//'
//' @note This function can be used in conjunction with inferHyperparam. For
//'   example, if the user has a prior data set and a prior transition matrix,
//'   he can infer the hyperparameters using inferHyperparam and then compute
//'   the probability of their prior matrix using the inferred hyperparameters
//'   with priorDistribution.
//' @seealso \code{\link{predictiveDistribution}}, \code{\link{inferHyperparam}}
//' 
//' @examples
//' priorDistribution(matrix(c(0.5, 0.5, 0.5, 0.5), 
//'                   nrow = 2, 
//'                   dimnames = list(c("a", "b"), c("a", "b"))), 
//'                   matrix(c(2, 2, 2, 2), 
//'                   nrow = 2, 
//'                   dimnames = list(c("a", "b"), c("a", "b"))))
//' @export
// [[Rcpp::export]]
NumericVector priorDistribution(NumericMatrix transMatr, NumericMatrix hyperparam = NumericMatrix()) {
  // begin validity checks for the transition matrix
  if (transMatr.nrow() != transMatr.ncol())
    stop("Transition matrix dimensions are inconsistent");
    
  int sizeMatr = transMatr.nrow();
  
  for (int i = 0; i < sizeMatr; i++) {
    double rowSum = 0., eps = 1e-10;
    
    for (int j = 0; j < sizeMatr; j++)
      if (transMatr(i, j) < 0. || transMatr(i, j) > 1.)
        stop("The entries in the transition matrix must each belong to the interval [0, 1]");
      else
        rowSum += transMatr(i, j);
    
    if (rowSum <= 1. - eps || rowSum >= 1. + eps)
      stop("The rows of the transition matrix must each sum to 1");
  }
  
  List dimNames = transMatr.attr("dimnames");
  
  if (dimNames.size() == 0)
    stop("Provide dimnames for the transition matrix");
  
  CharacterVector colNames = dimNames[1];
  CharacterVector rowNames = dimNames[0];
  CharacterVector sortedColNames(sizeMatr), sortedRowNames(sizeMatr);
  
  for (int i = 0; i < sizeMatr; i++)
    sortedColNames(i) = colNames(i), sortedRowNames(i) = rowNames(i);
  
  sortedColNames.sort();
  sortedRowNames.sort();
  
  for (int i = 0; i < sizeMatr; i++) 
    if (i > 0 && (sortedColNames(i) == sortedColNames(i-1) || sortedRowNames(i) == sortedRowNames(i-1)))  
      stop("The states must all be unique");
    else if (sortedColNames(i) != sortedRowNames(i))
      stop("The set of row names must be the same as the set of column names");
  
  // if no hyperparam argument provided, use default value of 1 for all 
  if (hyperparam.nrow() == 1 && hyperparam.ncol() == 1) {
    NumericMatrix temp(sizeMatr, sizeMatr);
    temp.attr("dimnames") = List::create(sortedColNames, sortedColNames);
  
    for (int i = 0; i < sizeMatr; i++)
      for (int j = 0; j < sizeMatr; j++)
        temp(i, j) = 1;
  
    hyperparam = temp;
  }
  
  // validity check for hyperparam
  if (hyperparam.nrow() != hyperparam.ncol())
    stop("Dimensions of the hyperparameter matrix are inconsistent");
    
  if (hyperparam.nrow() != sizeMatr)
    stop("Hyperparameter and the transition matrices differ in dimensions");
    
  List _dimNames = hyperparam.attr("dimnames");

  if (_dimNames.size() == 0)
    stop("Provide dimnames for the hyperparameter matrix");
  
  CharacterVector _colNames = _dimNames[1];
  CharacterVector _rowNames = _dimNames[0];
  int sizeHyperparam = hyperparam.ncol();
  CharacterVector _sortedColNames(sizeHyperparam), _sortedRowNames(sizeHyperparam);
  
  for (int i = 0; i < sizeHyperparam; i++)
    _sortedColNames(i) = colNames(i), _sortedRowNames(i) = rowNames(i);
  
  _sortedColNames.sort();
  _sortedRowNames.sort();
  
  for (int i = 0; i < sizeHyperparam; i++)
    if (sortedColNames(i) != _sortedColNames(i) || sortedRowNames(i) != _sortedRowNames(i))
      stop("Hyperparameter and the transition matrices states differ");
  
  for (int i = 0; i < sizeMatr; i++)
    for (int j = 0; j < sizeMatr; j++)
      if (hyperparam(i, j) < 1.)
        stop("The hyperparameter elements must all be greater than or equal to 1");
 
  transMatr = sortByDimNames(transMatr);
  hyperparam = sortByDimNames(hyperparam);
  NumericVector logProbVec;
  
  for (int i = 0; i < sizeMatr; i++) {
    double logProb_i = 0., hyperparamRowSum = 0;
  
    for (int j = 0; j < sizeMatr; j++) {
      hyperparamRowSum += hyperparam(i, j);
      logProb_i += (hyperparam(i, j) - 1.) * log(transMatr(i, j)) - lgamma(hyperparam(i, j));
    }
    
    logProb_i += lgamma(hyperparamRowSum);
    logProbVec.push_back(logProb_i);
  }
  
  logProbVec.attr("names") = sortedColNames;

  return logProbVec;
}

// [[Rcpp::export(.hittingProbabilitiesRcpp)]]
NumericMatrix hittingProbabilities(S4 object) {
  NumericMatrix transitionMatrix = object.slot("transitionMatrix");
  CharacterVector states = object.slot("states");
  bool byrow = object.slot("byrow");
  
  if (!byrow)
    transitionMatrix = transpose(transitionMatrix);
  
  int numStates = transitionMatrix.nrow();
  arma::mat transitionProbs = as<arma::mat>(transitionMatrix);
  arma::mat hittingProbs(numStates, numStates);
  // Compute closed communicating classes
  List commClasses = commClassesKernel(transitionMatrix);
  List closedClass = commClasses["closed"];
  LogicalMatrix communicating = commClasses["classes"];

  
  for (int j = 0; j < numStates; ++j) {
    arma::mat coeffs = as<arma::mat>(transitionMatrix);
    arma::vec right_part = -transitionProbs.col(j);
    
    for (int i = 0; i < numStates; ++i) {
      coeffs(i, j) = 0;
      coeffs(i, i) -= 1;
    }

    for (int i = 0; i < numStates; ++i) {
      if (closedClass(i)) {
        for (int k = 0; k < numStates; ++k)
          if (k != i)
            coeffs(i, k) = 0;
          else
            coeffs(i, i) = 1;
          
        if (communicating(i, j))
          right_part(i) = 1;
        else
          right_part(i) = 0;
      }
    }
    
    hittingProbs.col(j) = arma::solve(coeffs, right_part);
  }
  
  NumericMatrix result = wrap(hittingProbs);
  colnames(result) = states;
  rownames(result) = states;
  
  if (!byrow)
    result = transpose(result);
  
  return result;
}



// method to convert into canonic form a markovchain object
// [[Rcpp::export(.canonicFormRcpp)]]
S4 canonicForm(S4 obj) {
  NumericMatrix transitions = obj.slot("transitionMatrix");
  bool byrow = obj.slot("byrow");
  int numRows = transitions.nrow();
  int numCols = transitions.ncol();
  NumericMatrix resultTransitions(numRows, numCols);
  CharacterVector states = obj.slot("states");
  unordered_map<string, int> stateToIndex;
  unordered_set<int> usedIndices;
  int currentIndex;
  List recClasses;
  S4 input("markovchain");
  S4 result("markovchain");
  vector<int> indexPermutation(numRows);
  
  if (!byrow) {
    input.slot("transitionMatrix") = transpose(transitions);
    input.slot("states") = states;
    input.slot("byrow") = true;
    transitions = transpose(transitions);
  } else {
    input = obj;
  }
  
  recClasses = recurrentClasses(input);
  
  // Map each state to the index it has
  for (int i = 0; i < states.size(); ++i) {
    string state = (string) states[i];
    stateToIndex[state] = i;
  }
  
  int toFill = 0;
  for (CharacterVector recClass : recClasses) {
    for (auto state : recClass) {
      currentIndex = stateToIndex[(string) state];
      indexPermutation[toFill] = currentIndex;
      ++toFill;
      usedIndices.insert(currentIndex);
    }
  }
  
  for (int i = 0; i < states.size(); ++i) {
    if (usedIndices.count(i) == 0) {
      indexPermutation[toFill] = i;
      ++toFill;
    }
  }
  
  CharacterVector newStates(numRows);
  
  for (int i = 0; i < numRows; ++i) {
    int r = indexPermutation[i];
    newStates(i) = states(r);
    
    for (int j = 0; j < numCols; ++j) {
      int c = indexPermutation[j];
      resultTransitions(i, j) = transitions(r, c);
    }
  }
  
  rownames(resultTransitions) = newStates;
  colnames(resultTransitions) = newStates;
  
  if (!byrow)
    resultTransitions = transpose(resultTransitions);
  
  result.slot("transitionMatrix") = resultTransitions;
  result.slot("byrow") = byrow;
  result.slot("states") = newStates;
  result.slot("name") = input.slot("name");
  return result;
}


// Function to sort a matrix of vectors lexicographically
NumericMatrix lexicographicalSort(NumericMatrix m) {
  int numCols = m.ncol();
  int numRows = m.nrow();
  
  if (numRows > 0 && numCols > 0) {
    vector<vector<double>> x(numRows, vector<double>(numCols));
    
    for (int i = 0; i < numRows; ++i)
      for (int j = 0; j < numCols; ++j)
        x[i][j] = m(i,j);
    
    sort(x.begin(), x.end());
    
    NumericMatrix result(numRows, numCols);
    
    for (int i = 0; i < numRows; ++i)
      for (int j = 0; j < numCols; ++j)
        result(i, j) = x[i][j];
    
    colnames(result) = colnames(m);
    return result;
  } else {
    return m;
  }
}


// This method computes the *unique* steady state that exists for an
// matrix has to be schocastic by rows
// ergodic (= irreducible) matrix
vec steadyStateErgodicMatrix(const mat& submatrix) {
  int nRows = submatrix.n_rows;
  int nCols = submatrix.n_cols;
  vec rightPart(nRows + 1, fill::zeros);
  vec result;
  mat coeffs(nRows + 1, nCols);
  
  // If P is Ergodic, the system (I - P)*w = 0 plus the equation 
  // w_1 + ... + w_m = 1 must have a soultion
  for (int i = 0; i < nRows; ++i) {
    for (int j = 0; j < nCols; ++j) {
      // transpose matrix in-place
      coeffs(i, j) = submatrix(j, i);
      
      if (i == j)
        coeffs(i, i) -= 1;
    }
  }
  
  for (int j = 0; j < nCols; ++j)
    coeffs(nRows, j) = 1;
  
  rightPart(nRows) = 1;
  
  if (!solve(result, coeffs, rightPart))
    stop("Failure computing eigen values / vectors for submatrix in steadyStateErgodicMatrix");
  
  return result;
}

// Precondition: the matrix should be stochastic by rows
NumericMatrix steadyStatesByRecurrentClasses(S4 object) {
  List recClasses = recurrentClasses(object);
  int numRecClasses = recClasses.size();
  NumericMatrix transitionMatrix = object.slot("transitionMatrix");
  
  CharacterVector states = object.slot("states");
  int numCols = transitionMatrix.ncol();
  NumericMatrix steady(numRecClasses, numCols);
  unordered_map<string, int> stateToIndex;
  int steadyStateIndex = 0;
  
  // Map each state to the index it has
  for (int i = 0; i < states.size(); ++i) {
    string state = (string) states[i];
    stateToIndex[state] = i;
  }
  
  // For each recurrent class, there must be an steady state
  for (CharacterVector recurrentClass : recClasses) {
    int recClassSize = recurrentClass.size();
    mat subMatrix(recClassSize, recClassSize);
    
    // Fill the submatrix corresponding to the current steady class
    // Note that for that we have to subset the matrix with the indices
    // the states in the recurrent class ocuppied in the transition matrix
    for (int i = 0; i < recClassSize; ++i) {
      int r = stateToIndex[(string) recurrentClass[i]];
      
      for (int j = 0; j < recClassSize; ++j) {
        int c = stateToIndex[(string) recurrentClass[j]];
        subMatrix(i, j) = transitionMatrix(r, c);
      }
    }

    // Compute the steady states for the given submatrix
    vec steadyState = steadyStateErgodicMatrix(subMatrix);

    for (int i = 0; i < recClassSize; ++i) {
      int c = stateToIndex[(string) recurrentClass[i]];
      steady(steadyStateIndex, c) = steadyState(i);;
    }
    
    ++steadyStateIndex;
  }
  
  colnames(steady) = states;
  
  return steady;
}

// [[Rcpp::export(.steadyStatesRcpp)]]
NumericMatrix steadyStates(S4 obj) {
  NumericMatrix transitions = obj.slot("transitionMatrix");
  CharacterVector states = obj.slot("states");
  bool byrow = obj.slot("byrow");
  S4 object("markovchain");
  
  if (!byrow) {
    object.slot("transitionMatrix") = transpose(transitions);
    object.slot("states") = states;
    object.slot("byrow") = true;
  } else {
    object = obj;
  }
  
  // Compute steady states using recurrent classes (there is 
  // exactly one steady state associated with each recurrent class)
  NumericMatrix result = lexicographicalSort(steadyStatesByRecurrentClasses(object));
  
  if (!byrow)
    result = transpose(result);
  
  return result;
}


// This method is agnostic on whether the matrix is stochastic 
// by rows or by columns, we just need the diagonal
// [[Rcpp::export(.absorbingStatesRcpp)]]
CharacterVector absorbingStates(S4 obj) {
  NumericMatrix transitionMatrix = obj.slot("transitionMatrix");
  CharacterVector states = obj.slot("states");
  CharacterVector absorbing;
  int numStates = states.size();
  
  for (int i = 0; i < numStates; ++i)
    if (approxEqual(transitionMatrix(i, i), 1))
      absorbing.push_back(states(i));
    
  return absorbing;
}


// [[Rcpp::export(.isIrreducibleRcpp)]]
bool isIrreducible(S4 obj) {
  List commClasses = communicatingClasses(obj);
  // The markov chain is irreducible iff has only a single communicating class
  return commClasses.size() == 1;
}


// [[Rcpp::export(.isRegularRcpp)]]
bool isRegular(S4 obj) {
  NumericMatrix transitions = obj.slot("transitionMatrix");
  int m = transitions.ncol();
  mat probs(transitions.begin(), m, m, true);
  mat reachable;
  // Let alias this as d
  int positiveDiagonal = 0;
  auto arePositive = [](const double& x){ return x > 0; };
  
  // Count positive elements in the diagonal
  for (int i = 0; i < m; ++i)
    if (probs(i, i) > 0)
      ++positiveDiagonal;
  
  // Taken from the book: 
  // Matrix Analysis. Roger A.Horn, Charles R.Johnson. 2nd edition. 
  // Corollary 8.5.8 and Theorem 8.5.9
  //
  // If A is irreducible and has 0 < d positive diagonal elements
  //   A is regular and $A^{2m - d - 1} > 0
  //
  // A is regular iff A^{m²- 2m + 2} > 0
  if (positiveDiagonal > 0)
    reachable = matrixPow(probs, 2*m - positiveDiagonal - 1);
  else
    reachable = matrixPow(probs, m*m - 2*m + 2);
  
  return allElements(reachable, arePositive);
}


NumericMatrix computeMeanAbsorptionTimes(mat& probs, CharacterVector& absorbing, 
                                         CharacterVector& states) {
  unordered_set<string> toErase;
  vector<uint> indicesToKeep;
  CharacterVector newNames;
  string current;
  
  for (auto state : absorbing)
    toErase.insert((string) state);
  
  // Compute the states which are not in absorbing
  // and subset the sub-probability matrix of those
  // states which are not considered absorbing, Q
  for (uint i = 0; i < states.size(); ++i) {
    current = (string) states(i);
    
    if (toErase.count(current) == 0) {
      indicesToKeep.push_back(i);
      newNames.push_back(current);
    }
  }
  
  int n = indicesToKeep.size();
  uvec indices(indicesToKeep);
  // Comppute N = 1 - Q
  mat coeffs = eye(n, n) - probs(indices, indices);
  vec rightPart = vec(n, fill::ones);
  mat meanTimes;
  
  // Mean absorbing times A are computed as N * A = 1,
  // where 1 is a column vector of 1s
  if (!solve(meanTimes, coeffs, rightPart))
    stop("Error solving system in meanAbsorptionTime");
  
  NumericMatrix result = wrap(meanTimes);
  rownames(result) = newNames;
  
  return result;
}


// [[Rcpp::export(.meanAbsorptionTimeRcpp)]]
NumericVector meanAbsorptionTime(S4 obj) {
  NumericMatrix transitions = obj.slot("transitionMatrix");
  CharacterVector states = obj.slot("states");
  bool byrow = obj.slot("byrow");
  unordered_set<string> allStates;
  
  if (!byrow)
    transitions = transpose(transitions);
  
  // Compute recurrent and transient states
  List commKernel = commClassesKernel(transitions);
  LogicalVector closed = commKernel["closed"];
  CharacterVector transient = computeTransientStates(states, closed);
  CharacterVector recurrent = computeRecurrentStates(states, closed);
  
  // Compute the mean absorption time for the transient states
  mat probs(transitions.begin(), transitions.nrow(), transitions.ncol(), true);
  NumericMatrix meanTimes = computeMeanAbsorptionTimes(probs, recurrent, states);
  NumericVector result;
  
  if (meanTimes.ncol() > 0) {
    result = meanTimes(_, 0);
    result.attr("names") = transient;
  }
  
  return result;
}

// [[Rcpp::export(.absorptionProbabilitiesRcpp)]]
NumericMatrix absorptionProbabilities(S4 obj) {
  NumericMatrix transitions = obj.slot("transitionMatrix");
  CharacterVector states = obj.slot("states");
  string current;
  bool byrow = obj.slot("byrow");
  if (!byrow)
    transitions = transpose(transitions);
  
  unordered_map<string, uint> stateToIndex;
  
  // Map each state to the index it has
  for (int i = 0; i < states.size(); ++i) {
    current = (string) states[i];
    stateToIndex[current] = i;
  }
  
  List commKernel = commClassesKernel(transitions);
  LogicalVector closed = commKernel["closed"];
  CharacterVector transient = computeTransientStates(states, closed);
  CharacterVector recurrent = computeRecurrentStates(states, closed);
  
  vector<uint> transientIndxs, recurrentIndxs;
  
  // Compute the indexes of the matrix which correspond to transient and recurrent states
  for (auto state : transient) {
    current = (string) state;
    transientIndxs.push_back(stateToIndex[current]);
  }
  
  for (auto state : recurrent) {
    current = (string) state;
    recurrentIndxs.push_back(stateToIndex[current]);
  }
  
  int m = transitions.ncol();
  int n = transientIndxs.size();
  
  if (n == 0)
    stop("Markov chain does not have transient states, method not applicable");
  
  // Get the indices in arma::uvec s
  uvec transientIndices(transientIndxs);
  uvec recurrentIndices(recurrentIndxs);
  
  // Compute N = (1 - Q)^{-1}
  mat probs(transitions.begin(), m, m, true);
  mat toInvert = eye(n, n) - probs(transientIndices, transientIndices);
  mat fundamentalMatrix;
  
  if (!inv(fundamentalMatrix, toInvert))
    stop("Could not compute fundamental matrix");
  
  // Compute the mean absorption probabilities as F* = N*P[transient, recurrent]
  mat meanProbs = fundamentalMatrix * probs(transientIndices, recurrentIndices);
  NumericMatrix result = wrap(meanProbs);
  rownames(result) = transient;
  colnames(result) = recurrent;
  
  if (!byrow)
    result = transpose(result);
  
  return result;
}

// [[Rcpp::export(.meanFirstPassageTimeRcpp)]]
NumericMatrix meanFirstPassageTime(S4 obj, CharacterVector destination) {
  bool isErgodic = isIrreducible(obj);
  
  if (!isErgodic)
    stop("Markov chain needs to be ergodic (= irreducile) for this method to work");
  else {
    NumericMatrix transitions = obj.slot("transitionMatrix");
    mat probs(transitions.begin(), transitions.nrow(), transitions.ncol(), true);
    CharacterVector states = obj.slot("states");
    bool byrow = obj.slot("byrow");
    int numStates = states.size();
    NumericMatrix result;
    
    if (!byrow)
      probs = probs.t();
    
    if (destination.size() > 0) {
      result = computeMeanAbsorptionTimes(probs, destination, states);
      // This transpose is intentional to return a row always instead of a column
      result = transpose(result);
      return result;
    } else {
      result = NumericMatrix(numStates, numStates);
      vec steadyState = steadyStateErgodicMatrix(probs);
      mat toInvert(numStates, numStates);
      mat Z;
      
      // Compute inverse for (I - P + W), where P = probs,
      // and W = steadyState pasted row-wise
      for (int i = 0; i < numStates; ++i) {
        for (int j = 0; j < numStates; ++j) {
          toInvert(i, j) = -probs(i, j) + steadyState(j);
          
          if (i == j)
            toInvert(i, i) += 1;
        }
      }
      
      if (!inv(Z, toInvert))
        stop("Problem computing inverse of matrix inside meanFirstPassageTime");
      
      // Set the result matrix
      for (int j = 0; j < numStates; ++j) {
        double r_j = 1.0 / steadyState(j);
        
        for (int i = 0; i < numStates; ++i) {
          result(i, j) = (Z(j,j) - Z(i,j)) * r_j;
        }
      }
    
      colnames(result) = states;
      rownames(result) = states;
      
      if (!byrow)
        result = transpose(result);
      
      return result;
    }
  }
}

// [[Rcpp::export(.meanRecurrenceTimeRcpp)]]
NumericVector meanRecurrenceTime(S4 obj) {
  NumericMatrix steady = steadyStates(obj);
  bool byrow = obj.slot("byrow");
  
  if (!byrow)
    steady = transpose(steady);
    
  CharacterVector states = obj.slot("states");
  NumericVector result;
  CharacterVector recurrentStates;
  
  for (int i = 0; i < steady.nrow(); ++i) {
    for (int j = 0; j < steady.ncol(); ++j) {
      // This depends on our imlementation of the steady
      // states, but we have the guarantee that the entry
      // corresponding to a state in a recurrent class is
      // only going to be positive in one vector and the 
      // entries corresponding to transient states are
      // going to be zero
      if (!approxEqual(steady(i, j), 0)) {
        result.push_back(1.0 / steady(i, j));
        recurrentStates.push_back(states(j));
      }
    }
  }
  
  result.attr("names") = recurrentStates;
  
  return result;
}

// [[Rcpp::export(.minNumVisitsRcpp)]]
NumericMatrix meanNumVisits(S4 obj) {
  NumericMatrix hitting = hittingProbabilities(obj);
  CharacterVector states = obj.slot("states");
  bool byrow = obj.slot("byrow");
  
  if (!byrow)
    hitting = transpose(hitting);
  
  int n = hitting.ncol();
  bool closeToOne;
  double inverse;
  NumericMatrix result(n, n);
  rownames(result) = states;
  colnames(result) = states;
  
  // Lets call the matrix of hitting probabilities as f
  // Then mean number of visits from i to j are given by 
  //            f_{ij} / (1 - f_{jj})
  // having care when f_{ij} -> mean num of visits is zero
  // and when f_{ij} > 0 and f_{jj} = 1 -> infinity mean
  //                                       num of visits
  for (int j = 0; j < n; ++j) {
    closeToOne = approxEqual(hitting(j, j), 1);
    
    if (!closeToOne)
      inverse = 1 / (1 - hitting(j, j));
        
    for (int i = 0; i < n; ++i) {
      if (hitting(i, j) == 0)
        result(i, j) = 0;
      else {
        if (closeToOne)
          result(i, j) = R_PosInf;
        else
          result(i, j) = hitting(i, j) * inverse; 
      }
    }
  }
  
  if (!byrow)
    result = transpose(result);
  
  return result;
}
