## ----setup, include=FALSE-----------------------------------------------------
knitr::opts_chunk$set(
  collapse = TRUE,
  comment = "#>",
  fig.width = 8,
  fig.height = 5,
  warning = FALSE,
  message = FALSE
)

## ----eval=FALSE---------------------------------------------------------------
# # Install from github
# devtools::install_github('klintkanopka/mixedsubjects')

## ----load-package, echo=FALSE, message=FALSE----------------------------------
# For the vignette, we source the files directly

library(mixedsubjects)

pkg_dir <- system.file(package = "mixedsubjects")
if (pkg_dir == "") {
  # Running from source
  pkg_dir <- "../R"
  for (f in list.files(pkg_dir, pattern = "\\.R$", full.names = TRUE)) {
    source(f)
  }
}

## ----simulation-engine--------------------------------------------------------
#' Run a Monte Carlo comparison of all seven estimators
#'
#' @param dgp_fn A function(seed) that returns a list with components
#'   obs_df, unobs_df (data.frames), and true_tau (scalar).
#' @param n_sims Number of Monte Carlo replications.
#' @param n_folds Number of cross-fitting folds.
#' @return A data.frame with columns: estimator, mean_est, bias, variance, mse.
run_comparison <- function(dgp_fn, n_sims = 2000, n_folds = 2) {
  estimator_names <- c("dim", "greg", "ppi", "dt", "dip", "dip_pp", "dt_dip")
  results <- matrix(NA_real_, nrow = n_sims, ncol = length(estimator_names))
  colnames(results) <- estimator_names

  for (i in seq_len(n_sims)) {
    d <- dgp_fn(seed = i)
    msd <- msd_data(observed = d$obs_df, unobserved = d$unobs_df)

    results[i, "dim"] <- tryCatch(msd_dim(msd)$estimate, error = function(e) NA)
    results[i, "greg"] <- tryCatch(msd_greg(msd)$estimate, error = function(e) NA)
    results[i, "ppi"] <- tryCatch(msd_ppi(msd, n_folds = n_folds)$estimate, error = function(e) NA)
    results[i, "dt"] <- tryCatch(msd_dt(msd, n_folds = n_folds)$estimate, error = function(e) NA)
    results[i, "dip"] <- tryCatch(msd_dip(msd)$estimate, error = function(e) NA)
    results[i, "dip_pp"] <- tryCatch(msd_dip_pp(msd, n_folds = n_folds)$estimate, error = function(e) NA)
    results[i, "dt_dip"] <- tryCatch(msd_dt_dip(msd, n_folds = n_folds)$estimate, error = function(e) NA)
  }

  true_tau <- dgp_fn(seed = 1)$true_tau

  data.frame(
    estimator = estimator_names,
    mean_est  = colMeans(results, na.rm = TRUE),
    bias      = colMeans(results, na.rm = TRUE) - true_tau,
    variance  = apply(results, 2, var, na.rm = TRUE),
    mse       = apply(results, 2, function(x) mean((x - true_tau)^2, na.rm = TRUE)),
    stringsAsFactors = FALSE
  )
}

#' Pretty-print a comparison table, highlighting the minimum-variance estimator
print_comparison <- function(comp, title = "") {
  if (nchar(title) > 0) cat("###", title, "\n\n")
  comp$variance <- round(comp$variance, 6)
  comp$bias <- round(comp$bias, 4)
  comp$mse <- round(comp$mse, 6)
  best <- comp$estimator[which.min(comp$variance)]
  comp$best <- ifelse(comp$estimator == best, " <-- min", "")
  print(comp[, c("estimator", "bias", "variance", "mse", "best")], row.names = FALSE)
  cat("\nLowest variance:", best, "\n\n")
}

## ----dgp-poor-predictions-----------------------------------------------------
dgp_poor_predictions <- function(seed) {
  set.seed(seed)
  true_tau <- 0.5
  n <- 200; m <- 400

  D_obs <- rep(c(1, 0), each = n / 2)
  Y <- rnorm(n) + true_tau * D_obs

  # Predictions are pure noise — no correlation with Y
  S1_obs <- rnorm(n, 0.1, 1)
  S0_obs <- rnorm(n, 0, 1)

  D_unobs <- rep(c(1, 0), each = m / 2)
  S1_unobs <- rnorm(m, 0.1, 1)
  S0_unobs <- rnorm(m, 0, 1)

  list(
    obs_df = data.frame(Y = Y, D = D_obs, S0 = S0_obs, S1 = S1_obs),
    unobs_df = data.frame(D = D_unobs, S0 = S0_unobs, S1 = S1_unobs),
    true_tau = true_tau
  )
}

## ----run-scenario-1-----------------------------------------------------------
comp1 <- run_comparison(dgp_poor_predictions)
print_comparison(comp1, "Scenario 1: Poor Predictions")

## ----dgp-high-quality-balanced------------------------------------------------
dgp_neg_corr_predictions <- function(seed) {
  set.seed(seed)
  true_tau <- 0.5
  n <- 200; m <- 500

  D_obs <- rep(c(1, 0), each = n / 2)
  # X has opposite effects on Y(1) vs Y(0), creating negative Cov(Y1, Y0)
  X_obs <- rnorm(n)
  Y0_obs <- rnorm(n, 0, 0.3) + 1.0 * X_obs
  Y1_obs <- rnorm(n, 0, 0.3) - 1.0 * X_obs + true_tau
  Y <- D_obs * Y1_obs + (1 - D_obs) * Y0_obs

  # Good predictions of each potential outcome
  S1_obs <- 0.85 * Y1_obs + rnorm(n, 0, 0.2)
  S0_obs <- 0.85 * Y0_obs + rnorm(n, 0, 0.2)

  D_unobs <- rep(c(1, 0), each = m / 2)
  X_unobs <- rnorm(m)
  Y0_unobs <- rnorm(m, 0, 0.3) + 1.0 * X_unobs
  Y1_unobs <- rnorm(m, 0, 0.3) - 1.0 * X_unobs + true_tau
  S1_unobs <- 0.85 * Y1_unobs + rnorm(m, 0, 0.2)
  S0_unobs <- 0.85 * Y0_unobs + rnorm(m, 0, 0.2)

  list(
    obs_df = data.frame(Y = Y, D = D_obs, S0 = S0_obs, S1 = S1_obs),
    unobs_df = data.frame(D = D_unobs, S0 = S0_unobs, S1 = S1_unobs),
    true_tau = true_tau
  )
}

## ----run-scenario-2-----------------------------------------------------------
comp2 <- run_comparison(dgp_neg_corr_predictions)
print_comparison(comp2, "Scenario 2: Negatively Correlated Predictions")

## ----dgp-heterogeneous-quality------------------------------------------------
dgp_heterogeneous_quality <- function(seed) {
  set.seed(seed)
  true_tau <- 0.5
  n <- 200; m <- 500

  D_obs <- rep(c(1, 0), each = n / 2)
  Y0_obs <- rnorm(n)
  Y1_obs <- Y0_obs + true_tau
  Y <- D_obs * Y1_obs + (1 - D_obs) * Y0_obs

  # Treatment arm: excellent predictions (rho ~ 0.85)
  # Control arm: mediocre predictions (rho ~ 0.25)
  S1_obs <- 0.9 * Y1_obs + rnorm(n, 0, 0.3)
  S0_obs <- 0.2 * Y0_obs + rnorm(n, 0, 0.9)

  D_unobs <- rep(c(1, 0), each = m / 2)
  Y0_unobs <- rnorm(m)
  Y1_unobs <- Y0_unobs + true_tau
  S1_unobs <- 0.9 * Y1_unobs + rnorm(m, 0, 0.3)
  S0_unobs <- 0.2 * Y0_unobs + rnorm(m, 0, 0.9)

  list(
    obs_df = data.frame(Y = Y, D = D_obs, S0 = S0_obs, S1 = S1_obs),
    unobs_df = data.frame(D = D_unobs, S0 = S0_unobs, S1 = S1_unobs),
    true_tau = true_tau
  )
}

## ----run-scenario-3-----------------------------------------------------------
comp3 <- run_comparison(dgp_heterogeneous_quality)
print_comparison(comp3, "Scenario 3: Heterogeneous Quality Across Arms")

## ----dgp-high-common-mode-----------------------------------------------------
dgp_high_common_mode <- function(seed) {
  set.seed(seed)
  true_tau <- 0.5
  n <- 200; m <- 500

  D_obs <- rep(c(1, 0), each = n / 2)
  Y0_obs <- rnorm(n)
  Y1_obs <- Y0_obs + true_tau
  Y <- D_obs * Y1_obs + (1 - D_obs) * Y0_obs

  # Common-mode prediction error (shared LLM bias per unit, not in Y)
  common_obs <- rnorm(n, 0, 0.8)
  S1_obs <- 0.9 * Y1_obs + common_obs + rnorm(n, 0, 0.1)
  S0_obs <- 0.9 * Y0_obs + common_obs + rnorm(n, 0, 0.1)

  D_unobs <- rep(c(1, 0), each = m / 2)
  Y0_unobs <- rnorm(m)
  Y1_unobs <- Y0_unobs + true_tau
  common_unobs <- rnorm(m, 0, 0.8)
  S1_unobs <- 0.9 * Y1_unobs + common_unobs + rnorm(m, 0, 0.1)
  S0_unobs <- 0.9 * Y0_unobs + common_unobs + rnorm(m, 0, 0.1)

  list(
    obs_df = data.frame(Y = Y, D = D_obs, S0 = S0_obs, S1 = S1_obs),
    unobs_df = data.frame(D = D_unobs, S0 = S0_unobs, S1 = S1_unobs),
    true_tau = true_tau
  )
}

## ----run-scenario-4-----------------------------------------------------------
comp4 <- run_comparison(dgp_high_common_mode)
print_comparison(comp4, "Scenario 4: High Common-Mode Prediction Error")

## ----dgp-common-mode-heterogeneous--------------------------------------------
dgp_common_mode_heterogeneous <- function(seed) {
  set.seed(seed)
  true_tau <- 0.5
  n <- 200; m <- 500

  D_obs <- rep(c(1, 0), each = n / 2)
  Y0_obs <- rnorm(n)
  Y1_obs <- Y0_obs + true_tau
  Y <- D_obs * Y1_obs + (1 - D_obs) * Y0_obs

  common_obs <- rnorm(n, 0, 1.2)
  # Treatment arm: good signal; control arm: weak signal
  S1_obs <- 0.9 * Y1_obs + common_obs + rnorm(n, 0, 0.3)
  S0_obs <- 0.2 * Y0_obs + common_obs + rnorm(n, 0, 0.5)

  D_unobs <- rep(c(1, 0), each = m / 2)
  Y0_unobs <- rnorm(m)
  Y1_unobs <- Y0_unobs + true_tau
  common_unobs <- rnorm(m, 0, 1.2)
  S1_unobs <- 0.9 * Y1_unobs + common_unobs + rnorm(m, 0, 0.3)
  S0_unobs <- 0.2 * Y0_unobs + common_unobs + rnorm(m, 0, 0.5)

  list(
    obs_df = data.frame(Y = Y, D = D_obs, S0 = S0_obs, S1 = S1_obs),
    unobs_df = data.frame(D = D_unobs, S0 = S0_unobs, S1 = S1_unobs),
    true_tau = true_tau
  )
}

## ----run-scenario-5-----------------------------------------------------------
comp5 <- run_comparison(dgp_common_mode_heterogeneous)
print_comparison(comp5, "Scenario 5: Common-Mode Error + Heterogeneous Quality")

## ----dgp-near-perfect---------------------------------------------------------
dgp_near_perfect <- function(seed) {
  set.seed(seed)
  true_tau <- 0.5
  n <- 100; m <- 500

  D_obs <- rep(c(1, 0), each = n / 2)
  Y0_obs <- rnorm(n)
  Y1_obs <- Y0_obs + true_tau
  Y <- D_obs * Y1_obs + (1 - D_obs) * Y0_obs

  # Near-perfect predictions: rho = 1/sqrt(1 + 0.01) ~ 0.995
  S1_obs <- Y1_obs + rnorm(n, 0, 0.1)
  S0_obs <- Y0_obs + rnorm(n, 0, 0.1)

  D_unobs <- rep(c(1, 0), each = m / 2)
  Y0_unobs <- rnorm(m)
  Y1_unobs <- Y0_unobs + true_tau
  S1_unobs <- Y1_unobs + rnorm(m, 0, 0.1)
  S0_unobs <- Y0_unobs + rnorm(m, 0, 0.1)

  list(
    obs_df = data.frame(Y = Y, D = D_obs, S0 = S0_obs, S1 = S1_obs),
    unobs_df = data.frame(D = D_unobs, S0 = S0_unobs, S1 = S1_unobs),
    true_tau = true_tau
  )
}

## ----run-scenario-6-----------------------------------------------------------
comp6 <- run_comparison(dgp_near_perfect)
print_comparison(comp6, "Scenario 6: Near-Perfect Predictions")

## ----summary-table------------------------------------------------------------
scenarios <- list(
  "1: Poor predictions"        = comp1,
  "2: Neg corr predictions"    = comp2,
  "3: Heterogeneous quality"   = comp3,
  "4: High common-mode error"  = comp4,
  "5: Common-mode + hetero"    = comp5,
  "6: Near-perfect"            = comp6
)

summary_df <- do.call(rbind, lapply(names(scenarios), function(nm) {
  comp <- scenarios[[nm]]
  best_idx <- which.min(comp$mse)
  data.frame(
    scenario = nm,
    winner = comp$estimator[best_idx],
    winner_mse = comp$mse[best_idx],
    winner_var = comp$variance[best_idx],
    dim_var = comp$variance[comp$estimator == "dim"],
    reduction_pct = round(
      (1 - comp$variance[best_idx] / comp$variance[comp$estimator == "dim"]) * 100, 1
    ),
    stringsAsFactors = FALSE
  )
}))

knitr::kable(summary_df,
             col.names = c("Scenario", "Best Estimator", "Best MSE",
                           "Best Var", "DiM Var", "Var Reduction (%)"),
             digits = 5,
             caption = "Which estimator achieves the lowest Monte Carlo MSE under each DGP?")

