#' Prior Distributions for Bayesian Changepoint Detection
#'
#' @description
#' Functions to create prior specifications for Bayesian changepoint detection
#' methods like BOCPD and Shiryaev-Roberts.
#'
#' @name priors
#' @noRd
NULL

#' Normal-Gamma Prior for Unknown Mean and Variance
#'
#' Creates a Normal-Gamma prior specification for data with unknown mean and
#' variance. This is conjugate for normal observations.
#'
#' @param mu0 Prior mean for the mean parameter
#' @param kappa0 Prior pseudo-observations for the mean (strength of prior)
#' @param alpha0 Shape parameter for the precision (inverse variance)
#' @param beta0 Rate parameter for the precision
#'
#' @return An object of class "regime_prior"
#'
#' @details
#' The Normal-Gamma prior places a joint distribution on (mu, tau) where tau = 1/sigma^2:
#'
#' tau ~ Gamma(alpha0, beta0)
#' mu | tau ~ Normal(mu0, 1/(kappa0 * tau))
#'
#' The prior mean of mu is mu0, and the prior mean of sigma^2 is beta0/(alpha0-1) for alpha0 > 1.
#'
#' @examples
#' prior <- normal_gamma()
#'
#' prior <- normal_gamma(mu0 = 0, kappa0 = 10, alpha0 = 3, beta0 = 2)
#'
#' @export
normal_gamma <- function(mu0 = 0, kappa0 = 1, alpha0 = 1, beta0 = 1) {
  stopifnot(
    is.numeric(mu0), length(mu0) == 1,
    is.numeric(kappa0), length(kappa0) == 1, kappa0 > 0,
    is.numeric(alpha0), length(alpha0) == 1, alpha0 > 0,
    is.numeric(beta0), length(beta0) == 1, beta0 > 0
  )
  
  structure(
    list(
      type = "normal_gamma",
      mu0 = mu0,
      kappa0 = kappa0,
      alpha0 = alpha0,
      beta0 = beta0,
      description = "Normal-Gamma prior for unknown mean and variance"
    ),
    class = c("regime_prior", "list")
  )
}

#' Normal Prior for Unknown Mean with Known Variance
#'
#' Creates a Normal prior specification for data with unknown mean but known
#' variance.
#'
#' @param mu0 Prior mean
#' @param sigma0 Prior standard deviation for the mean
#' @param known_var Known variance of the observations
#'
#' @return An object of class "regime_prior"
#'
#' @examples
#' prior <- normal_known_var(mu0 = 0, sigma0 = 1, known_var = 1)
#'
#' @export
normal_known_var <- function(mu0 = 0, sigma0 = 1, known_var = 1) {
  stopifnot(
    is.numeric(mu0), length(mu0) == 1,
    is.numeric(sigma0), length(sigma0) == 1, sigma0 > 0,
    is.numeric(known_var), length(known_var) == 1, known_var > 0
  )
  
  structure(
    list(
      type = "normal_known_var",
      mu0 = mu0,
      sigma0 = sigma0,
      known_var = known_var,
      description = "Normal prior for unknown mean with known variance"
    ),
    class = c("regime_prior", "list")
  )
}

#' Normal-Wishart Prior for Multivariate Data
#'
#' Creates a Normal-Wishart prior for multivariate normal data with unknown
#' mean vector and covariance matrix.
#'
#' @param mu0 Prior mean vector (d-dimensional)
#' @param kappa0 Prior pseudo-observations for the mean
#' @param nu0 Degrees of freedom for the Wishart (must be >= d)
#' @param Psi0 Scale matrix for the Wishart (d x d positive definite)
#'
#' @return An object of class "regime_prior"
#'
#' @examples
#' prior <- normal_wishart(
#'   mu0 = c(0, 0),
#'   kappa0 = 1,
#'   nu0 = 3,
#'   Psi0 = diag(2)
#' )
#'
#' @export
normal_wishart <- function(mu0, kappa0 = 1, nu0 = NULL, Psi0 = NULL) {
  d <- length(mu0)
  
  if (is.null(nu0)) nu0 <- d
  
  if (is.null(Psi0)) Psi0 <- diag(d)
  
  stopifnot(
    is.numeric(mu0),
    is.numeric(kappa0), length(kappa0) == 1, kappa0 > 0,
    is.numeric(nu0), length(nu0) == 1, nu0 >= d,
    is.matrix(Psi0), nrow(Psi0) == d, ncol(Psi0) == d
  )
  
  structure(
    list(
      type = "normal_wishart",
      mu0 = mu0,
      kappa0 = kappa0,
      nu0 = nu0,
      Psi0 = Psi0,
      d = d,
      description = "Normal-Wishart prior for multivariate data"
    ),
    class = c("regime_prior", "list")
  )
}

#' Gamma-Poisson Prior for Count Data
#'
#' Creates a Gamma prior specification for Poisson-distributed count data.
#'
#' @param alpha0 Shape parameter for the rate
#' @param beta0 Rate parameter
#'
#' @return An object of class "regime_prior"
#'
#' @examples
#' prior <- poisson_gamma(alpha0 = 1, beta0 = 1)
#'
#' @export
poisson_gamma <- function(alpha0 = 1, beta0 = 1) {
  stopifnot(
    is.numeric(alpha0), length(alpha0) == 1, alpha0 > 0,
    is.numeric(beta0), length(beta0) == 1, beta0 > 0
  )
  
  structure(
    list(
      type = "poisson_gamma",
      alpha0 = alpha0,
      beta0 = beta0,
      description = "Gamma prior for Poisson rate"
    ),
    class = c("regime_prior", "list")
  )
}

#' Inverse-Gamma Prior for Variance Only
#'
#' Creates an Inverse-Gamma prior for detecting changes in variance
#' with known mean.
#'
#' @param alpha0 Shape parameter
#' @param beta0 Scale parameter
#' @param known_mean Known mean of the observations
#'
#' @return An object of class "regime_prior"
#'
#' @export
inverse_gamma_var <- function(alpha0 = 1, beta0 = 1, known_mean = 0) {
  stopifnot(
    is.numeric(alpha0), length(alpha0) == 1, alpha0 > 0,
    is.numeric(beta0), length(beta0) == 1, beta0 > 0,
    is.numeric(known_mean), length(known_mean) == 1
  )
  
  structure(
    list(
      type = "inverse_gamma_var",
      alpha0 = alpha0,
      beta0 = beta0,
      known_mean = known_mean,
      description = "Inverse-Gamma prior for variance with known mean"
    ),
    class = c("regime_prior", "list")
  )
}

#' Geometric Hazard Prior
#'
#' Creates a geometric (constant hazard) prior for the changepoint process.
#' This implies that the probability of a changepoint at each time step is
#' constant and independent.
#'
#' @param lambda Hazard rate (probability of changepoint at each step).
#'   Should be between 0 and 1. Smaller values expect fewer changepoints.
#'
#' @return An object of class "hazard_prior"
#'
#' @details
#' Under a geometric hazard, the expected run length (time between changepoints)
#' is 1/lambda. For example, lambda = 0.01 expects a changepoint every 100 observations
#' on average.
#'
#' @examples
#' hazard <- geometric_hazard(lambda = 0.01)
#'
#' hazard <- geometric_hazard(lambda = 0.1)
#'
#' @export
geometric_hazard <- function(lambda = 0.01) {
  stopifnot(
    is.numeric(lambda), length(lambda) == 1,
    lambda > 0, lambda < 1
  )
  
  structure(
    list(
      type = "geometric",
      lambda = lambda,
      expected_run_length = 1 / lambda,
      description = sprintf("Geometric hazard (expected run length: %.1f)", 1/lambda)
    ),
    class = c("hazard_prior", "list")
  )
}

#' Constant Hazard Prior (Alias for Geometric)
#'
#' @inheritParams geometric_hazard
#' @return An object of class "hazard_prior"
#' @export
constant_hazard <- geometric_hazard

#' Negative Binomial Hazard Prior
#'
#' Creates a negative binomial hazard prior, which allows for more flexibility
#' in the distribution of run lengths.
#'
#' @param r Number of successes (shape parameter)
#' @param p Probability of success
#'
#' @return An object of class "hazard_prior"
#'
#' @export
negbin_hazard <- function(r = 1, p = 0.01) {
  stopifnot(
    is.numeric(r), length(r) == 1, r > 0,
    is.numeric(p), length(p) == 1, p > 0, p < 1
  )
  
  structure(
    list(
      type = "negative_binomial",
      r = r,
      p = p,
      description = "Negative binomial hazard"
    ),
    class = c("hazard_prior", "list")
  )
}

#' @export
print.regime_prior <- function(x, ...) {
  cat("Regime Prior Specification\n")
  cat("==========================\n")
  cat("Type:", x$type, "\n")
  cat("Description:", x$description, "\n\n")
  cat("Parameters:\n")
  
  params <- x[!names(x) %in% c("type", "description", "class")]
  for (nm in names(params)) {
    val <- params[[nm]]
    if (is.matrix(val)) {
      cat(sprintf("  %s: %dx%d matrix\n", nm, nrow(val), ncol(val)))
    } else if (length(val) > 5) {
      cat(sprintf("  %s: vector of length %d\n", nm, length(val)))
    } else {
      cat(sprintf("  %s: %s\n", nm, paste(format(val, digits = 4), collapse = ", ")))
    }
  }
  
  invisible(x)
}

#' @export
print.hazard_prior <- function(x, ...) {
  cat("Hazard Prior Specification\n")
  cat("==========================\n")
  cat("Type:", x$type, "\n")
  cat("Description:", x$description, "\n\n")
  
  if (x$type == "geometric") {
    cat("Parameters:\n")
    cat(sprintf("  lambda: %g\n", x$lambda))
    cat(sprintf("  Expected run length: %.1f\n", x$expected_run_length))
  } else {
    params <- x[!names(x) %in% c("type", "description", "class")]
    cat("Parameters:\n")
    for (nm in names(params)) {
      cat(sprintf("  %s: %s\n", nm, paste(format(params[[nm]], digits = 4), collapse = ", ")))
    }
  }
  
  invisible(x)
}

#' @noRd
update_normal_gamma <- function(prior, x) {
  n <- length(x)
  x_bar <- mean(x)
  
  kappa_n <- prior$kappa0 + n
  mu_n <- (prior$kappa0 * prior$mu0 + n * x_bar) / kappa_n
  alpha_n <- prior$alpha0 + n / 2
  
  ss <- if (n > 1) sum((x - x_bar)^2) else 0
  beta_n <- prior$beta0 + 0.5 * ss +
    (prior$kappa0 * n * (x_bar - prior$mu0)^2) / (2 * kappa_n)
  
  list(
    mu0 = mu_n,
    kappa0 = kappa_n,
    alpha0 = alpha_n,
    beta0 = beta_n
  )
}

#' @noRd
pred_density_normal_gamma <- function(prior, x) {
  df <- 2 * prior$alpha0
  scale <- sqrt(prior$beta0 * (prior$kappa0 + 1) / (prior$alpha0 * prior$kappa0))
  
  dt((x - prior$mu0) / scale, df = df) / scale
}

#' @noRd
update_poisson_gamma <- function(prior, x) {
  n <- length(x)
  sum_x <- sum(x)
  
  list(
    alpha0 = prior$alpha0 + sum_x,
    beta0 = prior$beta0 + n
  )
}

#' @noRd
pred_density_poisson_gamma <- function(prior, x) {
  r <- prior$alpha0
  p <- prior$beta0 / (prior$beta0 + 1)
  
  dnbinom(x, size = r, prob = p)
}