// [[Rcpp::depends(Rcpp)]]
#define RCPP_NO_RTTI
#define RCPP_NO_SUGAR
#include <Rcpp.h>
using namespace Rcpp;

inline int minValue(int a, int b) { return (a < b) ? a : b; }

// [[Rcpp::export]]
List backwards(NumericMatrix transProb, NumericVector initDist, NumericMatrix obsProb,
               NumericMatrix durProb, NumericMatrix survProb, IntegerVector seqLen,
               int numStates, IntegerVector maxDur, NumericMatrix backwardProb,
               NumericVector normConst, NumericVector eta, NumericVector forwardVars,
               NumericVector stateProb, NumericVector xi,
               int numSeq, int totalLen, NumericVector backwardVars) {

  auto checkMemory = [](void* ptr) {
    if (ptr == NULL) Rcpp::stop("Memory allocation failed");
  };

  auto allocateMatrix = [&checkMemory](int rows, int cols, int size) -> void** {
    void **matrix = (void**)malloc(sizeof(void*) * rows);
    checkMemory(matrix);
    for(int i = 0; i < rows; i++) {
      matrix[i] = malloc(size * cols);
      checkMemory(matrix[i]);
    }
    return matrix;
  };

  auto freeMatrix = [](int rows, int cols, void **matrix) {
    for(int i = 0; i < rows; i++) free(matrix[i]);
    free(matrix);
  };

  double* trans = &transProb[0];
  double* obs = &obsProb[0];
  double* dur = &durProb[0];
  double* surv = &survProb[0];
  double* back = &backwardProb[0];

  double **alpha = (double**)allocateMatrix(numStates, totalLen, sizeof(double));
  double **beta = (double**)allocateMatrix(numStates, totalLen, sizeof(double));
  double **statein = (double**)allocateMatrix(numStates, totalLen, sizeof(double));
  double **zeta = (double**)allocateMatrix(numStates, totalLen, sizeof(double));
  double *sumProb = (double*)malloc(sizeof(double) * numStates);
  double **transMatrix = (double**)allocateMatrix(numStates, numStates, sizeof(double));

  double **probPtr = (double**)malloc(sizeof(double*) * numStates);
  double **alphaPtr = (double**)malloc(sizeof(double*) * numStates);
  double **betaPtr = (double**)malloc(sizeof(double*) * numStates);
  double **stateinPtr = (double**)malloc(sizeof(double*) * numStates);
  double **zetaPtr = (double**)malloc(sizeof(double*) * numStates);
  double **backPtr = (double**)malloc(sizeof(double*) * numStates);

  for(int j = 0; j < numStates; j++) {
    probPtr[j] = obs + j*totalLen;
    alphaPtr[j] = alpha[j];
    betaPtr[j] = beta[j];
  }

  double *norm = &normConst[0];

  for(int seq = 0; seq < numSeq; seq++) {
    int currLen = seqLen[seq];
    if(seq > 0) {
      int offset = seqLen[seq-1];
      for(int j = 0; j < numStates; j++) {
        probPtr[j] += offset;
        alphaPtr[j] += offset;
        betaPtr[j] += offset;
      }
      norm += offset;
    }

    for(int t = 0; t < currLen; t++) {
      norm[t] = 0;
      for(int j = 0; j < numStates; j++) {
        alphaPtr[j][t] = 0;
        double currProb = probPtr[j][t];

        if(t < currLen-1) {
          for(int u = 1; u <= minValue(t+1, maxDur[j]); u++) {
            if(u < t+1) {
              alphaPtr[j][t] += currProb * dur[j*maxDur[j]+u-1] * betaPtr[j][t-u+1];

              norm[t] += currProb * surv[j*maxDur[j]+u-1] * betaPtr[j][t-u+1];

              double denom = norm[t-u] > 1e-300 ? norm[t-u] : 1e-300;
              currProb *= probPtr[j][t-u]/denom;

            } else {
              double survivorProb = (u == 1) ? 1.0 : surv[j*maxDur[j]+u-2];
              alphaPtr[j][t] += currProb * survivorProb * initDist[j];
              norm[t] += currProb * survivorProb * initDist[j];
            }
          }
        } else {
          for(int u = 1; u <= minValue(t+1, maxDur[j]); u++) {
            if(u < currLen) {
              alphaPtr[j][currLen-1] += currProb * surv[j*maxDur[j]+u-1] * betaPtr[j][currLen-u];

              double denom = norm[currLen-1-u] > 1e-300 ? norm[currLen-1-u] : 1e-300;
              currProb *= probPtr[j][currLen-1-u]/denom;

            } else {
              double survivorProb = (currLen == 1) ? 1.0 : surv[j*maxDur[j]+currLen-2];
              alphaPtr[j][currLen-1] += currProb * survivorProb * initDist[j];
            }
          }
          norm[currLen-1] += alphaPtr[j][currLen-1];
        }
      }

      double denom_t = norm[t] > 1e-300 ? norm[t] : 1e-300;
      for(int j = 0; j < numStates; j++) {
        alphaPtr[j][t] /= denom_t;
        alphaPtr[j][t] += 1e-300;
      }

      if(t < currLen-1) {
        for(int j = 0; j < numStates; j++) {
          betaPtr[j][t+1] = 0;
          for(int i = 0; i < numStates; i++) {
            betaPtr[j][t+1] += trans[j*numStates+i] * alphaPtr[i][t];
          }
        }
      }
    }
  }

  for(int j = 0; j < numStates; j++) {
    probPtr[j] = obs + j*totalLen;
    alphaPtr[j] = alpha[j];
    betaPtr[j] = beta[j];
    stateinPtr[j] = statein[j];
    zetaPtr[j] = zeta[j];
    backPtr[j] = back + j*totalLen;
  }
  norm = &normConst[0];

  for(int seq = 0; seq < numSeq; seq++) {
    int currLen = seqLen[seq];
    if(seq > 0) {
      int offset = seqLen[seq-1];
      for(int j = 0; j < numStates; j++) {
        probPtr[j] += offset;
        alphaPtr[j] += offset;
        betaPtr[j] += offset;
        stateinPtr[j] += offset;
        zetaPtr[j] += offset;
        backPtr[j] += offset;
      }
      norm += offset;
    }

    for(int j = 0; j < numStates; j++) {
      zetaPtr[j][currLen-1] = alphaPtr[j][currLen-1];
    }

    if(seq == 0) {
      for(int j = 0; j < numStates; j++) {
        for(int u = 1; u < maxDur[j]; u++) {
          eta[j*maxDur[j]+u-1] = 0;
        }
      }
    }

    for(int t = currLen-2; t >= 0; t--) {
      for(int j = 0; j < numStates; j++) {
        stateinPtr[j][t+1] = 0;
        double currProb = 1;

        for(int u = 1; u <= minValue(currLen-1-t, maxDur[j]); u++) {
          double denom = norm[t+u] > 1e-300 ? norm[t+u] : 1e-300;
          currProb *= probPtr[j][t+u]/denom;

          if(u < currLen-1-t) {
            stateinPtr[j][t+1] += backPtr[j][t+u]/alphaPtr[j][t+u] * currProb * dur[j*maxDur[j]+u-1];
            eta[j*maxDur[j]+u-1] += backPtr[j][t+u]/alphaPtr[j][t+u] * currProb * dur[j*maxDur[j]+u-1] * betaPtr[j][t+1];
          } else {
            stateinPtr[j][t+1] += currProb * surv[j*maxDur[j]+currLen-1-t-1];
            eta[j*maxDur[j]+u-1] += currProb * dur[j*maxDur[j]+u-1] * betaPtr[j][t+1];
          }

          if(t == 0) {
            if(u > currLen-1) {
              eta[j*maxDur[j]+u-1] += backPtr[j][t+u]/alphaPtr[j][t+u] * currProb * dur[j*maxDur[j]+u-1] * initDist[j];
            } else {
              double survivorProb = (u == 1) ? 1.0 : surv[j*maxDur[j]+u-2];
              eta[j*maxDur[j]+u-1] += currProb * survivorProb * initDist[j];
            }
          }
        }
      }

      for(int j = 0; j < numStates; j++) {
        backPtr[j][t] = 0;
        for(int k = 0; k < numStates; k++) {
          backPtr[j][t] += stateinPtr[k][t+1] * trans[k*numStates+j];
        }
        backPtr[j][t] *= alphaPtr[j][t];
        zetaPtr[j][t] = backPtr[j][t] + zetaPtr[j][t+1] - stateinPtr[j][t+1] * betaPtr[j][t+1];
      }
    }
  }

  for(int i = 0; i < numStates; i++) {
    initDist[i] = 0;
    sumProb[i] = 0;
    for(int j = 0; j < numStates; j++) {
      transMatrix[i][j] = 0;
    }
  }

  for(int i = 0; i < numStates; i++) {
    for(int seq = 0; seq < numSeq; seq++) {
      int currLen = seqLen[seq];
      if(seq == 0) {
        for(int j = 0; j < numStates; j++) {
          alphaPtr[j] = alpha[j];
          stateinPtr[j] = statein[j];
          zetaPtr[j] = zeta[j];
          backPtr[j] = back + j*totalLen;
        }
      } else {
        int offset = seqLen[seq-1];
        for(int j = 0; j < numStates; j++) {
          alphaPtr[j] += offset;
          stateinPtr[j] += offset;
          zetaPtr[j] += offset;
          backPtr[j] += offset;
        }
      }

      initDist[i] += zetaPtr[i][0];
      for(int t = 0; t < currLen-2; t++) {
        sumProb[i] += backPtr[i][t];
      }
      for(int j = 0; j < numStates; j++) {
        for(int t = 0; t < currLen-2; t++) {
          transMatrix[i][j] += stateinPtr[j][t+1] * trans[j*numStates+i] * alphaPtr[i][t];
        }
      }
    }
  }

  for(int i = 0; i < numStates; i++) {
    initDist[i] /= numSeq;
    for(int j = 0; j < numStates; j++) {
      trans[j*numStates+i] = transMatrix[i][j]/sumProb[i];
    }
  }

  for(int j = 0; j < numStates; j++) {
    for(int t = 0; t < totalLen; t++) {
      xi[j * totalLen + t] = zeta[j][t];
      forwardVars[j * totalLen + t] = alpha[j][t];
      backwardVars[j * totalLen + t] = statein[j][t];
    }
  }

  freeMatrix(numStates, totalLen, (void**)alpha);
  freeMatrix(numStates, totalLen, (void**)beta);
  freeMatrix(numStates, totalLen, (void**)statein);
  freeMatrix(numStates, totalLen, (void**)zeta);
  freeMatrix(numStates, numStates, (void**)transMatrix);
  free(sumProb);
  free(probPtr);
  free(alphaPtr);
  free(betaPtr);
  free(stateinPtr);
  free(zetaPtr);
  free(backPtr);

  return List::create(
    Named("xi") = xi,
    Named("forwardVars") = forwardVars,
    Named("backwardVars") = backwardVars,
    Named("init") = initDist,
    Named("transProb") = transProb,
    Named("eta") = eta,
    Named("normConst") = normConst
  );
}




