#' Multi-step Predictive Simulation for the Bivariate Hurdle Model
#'
#' Generates forward simulations for \code{h} future periods from a fitted
#' bivariate hurdle negative binomial model (I/C), using posterior draws and
#' dynamically updating the lagged history as new simulated values are added.
#'
#' @param fit_obj A list returned by \code{fit_one()}, containing at least
#'   \code{$fit} (a CmdStanR fit object), \code{$spec} (model specification),
#'   and \code{$controls} (character vector of control variables).
#' @param DT A \code{data.frame} or \code{data.table} with the covariates and
#'   original time series, including columns \code{I}, \code{C},
#'   \code{Regime}, \code{trans_PS}, \code{trans_SF}, \code{trans_FC} and
#'   \code{log_exposure50}.
#' @param k Integer; lag order used in the fitted model.
#' @param Tcut Integer; last time index used as the starting point for
#'   prediction (historical window is \code{1:Tcut}).
#' @param h Integer; forecast horizon (number of steps ahead to simulate).
#' @param ndraws Integer; maximum number of posterior draws to use for
#'   simulation (default 800). If larger than available draws, it is truncated.
#' @param seed Optional integer; random seed passed to \code{set.seed()} for
#'   reproducibility of the simulation.
#'
#' @return A list with two components:
#' \item{pred_I}{Numeric matrix of dimension \code{S x h} with simulated paths
#'   for \code{I}, where \code{S} is the number of posterior draws used.}
#' \item{pred_C}{Numeric matrix of dimension \code{S x h} with simulated paths
#'   for \code{C}.}
#'
#' @details
#' For each selected posterior draw, the function iteratively simulates
#' \code{h} future values of \code{I} and \code{C}. At each step:
#' \itemize{
#'   \item The covariate vector is built from lagged outcomes
#'     (up to order \code{k}) and the corresponding row \code{t} of
#'     \code{DT} (trend terms, regime dummies, transition variables and
#'     controls).
#'   \item The hurdle probabilities and negative-binomial means are computed
#'     from the draw-specific parameters.
#'   \item New counts are sampled and appended to the local history so that
#'     subsequent steps use the updated lags.
#' }
#' Simulation stops early for a given path if \code{Tcut + step > nrow(DT)}.
#'
#' @examples
#' \donttest{
#' if (interactive() && requireNamespace("cmdstanr", quietly = TRUE)) {
#'   n <- 120
#'   DT <- data.table::data.table(
#'     I = rpois(n, 5), C = rpois(n, 3),
#'     Regime = factor(sample(c("A","B","C"), n, TRUE)),
#'     trans_PS = c(rep(1,5), rep(0,n-5)),
#'     trans_SF = c(rep(0,60), rep(1,5), rep(0,n-65)),
#'     trans_FC = rep(0, n),
#'     log_exposure50 = log(runif(n, 40, 60))
#'   )
#'   fit_obj <- fit_one(DT, k = 1, spec = "C")
#'   pred <- predict_multistep(fit_obj, DT, k = 1, Tcut = 100, h = 12,
#'                             ndraws = 500, seed = 123)
#'   str(pred$pred_I)
#' }
#' }
#' @keywords internal
#' @export
predict_multistep <- function(fit_obj, DT, k, Tcut, h, ndraws = 800, seed = NULL) {
  if (!is.null(seed)) set.seed(seed)
  DT <- as.data.frame(DT)
  draws <- posterior::as_draws_df(fit_obj$fit$draws())
  S <- min(ndraws, nrow(draws))
  idx_draws <- sample(nrow(draws), S)
  des <- fit_obj$des
  spec <- fit_obj$spec
  pred_I <- matrix(NA, S, h)
  pred_C <- matrix(NA, S, h)
  
  for (s in 1:S) {
    d <- idx_draws[s]
    I_loc <- DT$I[1:Tcut]
    C_loc <- DT$C[1:Tcut]
    zI_loc <- as.integer(I_loc > 0)
    zC_loc <- as.integer(C_loc > 0)
    for (step in 1:h) {
      t <- Tcut + step
      if (t > nrow(DT)) break
      
      xvec <- function(side = c("I", "C"), part = c("pi", "mu")) {
        side <- match.arg(side)
        part <- match.arg(part)
        xlag <- c()
        if (spec %in% c("A", "C") && side == "I" && k > 0) {
          if (part == "pi") xlag <- c(tail(zC_loc, k)) else xlag <- c(tail(C_loc, k))
        }
        if (spec %in% c("B", "C") && side == "C" && k > 0) {
          if (part == "pi") xlag <- c(tail(zI_loc, k)) else xlag <- c(tail(I_loc, k))
        }
        t_norm_t <- (t - 0.5) / nrow(DT)
        trend_vec <- c(t_norm_t, t_norm_t^2)
        reg_vec <- as.numeric(model.matrix(~ Regime, DT)[t, -1, drop = FALSE])
        tr_vec <- as.numeric(as.matrix(DT[, c("trans_PS", "trans_SF", "trans_FC")])[t, ])
        ctrl <- NULL
        if (length(fit_obj$controls) > 0)
          ctrl <- as.numeric(as.matrix(DT[t, fit_obj$controls, drop = FALSE]))
        c(xlag, trend_vec, reg_vec, tr_vec, ctrl)
      }
      
      a_pi_I <- draws$a_pi_I[d]
      a_mu_I <- draws$a_mu_I[d]
      phi_I <- exp(draws$log_phi_I[d])
      x_pi_I <- xvec("I", "pi")
      x_mu_I <- xvec("I", "mu")
      b_pi_I <- as.numeric(draws[d, grep("^b_pi_I\\[", names(draws))])
      b_mu_I <- as.numeric(draws[d, grep("^b_mu_I\\[", names(draws))])
      eta_pi_I <- a_pi_I + sum(b_pi_I * x_pi_I)
      eta_mu_I <- a_mu_I + sum(b_mu_I * x_mu_I) + DT$log_exposure50[t]
      pi_I <- plogis(eta_pi_I)
      mu_I <- exp(eta_mu_I)
      z_I_pred <- rbinom(1, 1, pi_I)
      y_I_pred <- if (z_I_pred == 0) 0 else max(1, rnbinom(1, size = phi_I, mu = mu_I))
      pred_I[s, step] <- y_I_pred
      
      a_pi_C <- draws$a_pi_C[d]
      a_mu_C <- draws$a_mu_C[d]
      phi_C <- exp(draws$log_phi_C[d])
      x_pi_C <- xvec("C", "pi")
      x_mu_C <- xvec("C", "mu")
      b_pi_C <- as.numeric(draws[d, grep("^b_pi_C\\[", names(draws))])
      b_mu_C <- as.numeric(draws[d, grep("^b_mu_C\\[", names(draws))])
      eta_pi_C <- a_pi_C + sum(b_pi_C * x_pi_C)
      eta_mu_C <- a_mu_C + sum(b_mu_C * x_mu_C) + DT$log_exposure50[t]
      pi_C <- plogis(eta_pi_C)
      mu_C <- exp(eta_mu_C)
      z_C_pred <- rbinom(1, 1, pi_C)
      y_C_pred <- if (z_C_pred == 0) 0 else max(1, rnbinom(1, size = phi_C, mu = mu_C))
      pred_C[s, step] <- y_C_pred
      
      I_loc <- c(I_loc, y_I_pred)
      C_loc <- c(C_loc, y_C_pred)
      zI_loc <- c(zI_loc, as.integer(y_I_pred > 0))
      zC_loc <- c(zC_loc, as.integer(y_C_pred > 0))
    }
  }
  list(pred_I = pred_I, pred_C = pred_C)
}

#' Contrafactual Average Treatment Effects (ATE) for the Bivariate Hurdle Model
#'
#' Computes time-varying contrafactual Average Treatment Effects (ATE) for
#' both series (\code{I} and \code{C}) from a fitted bivariate hurdle
#' negative binomial model. For each time point and posterior draw, the
#' function compares the expected outcome under the observed design matrix
#' with a contrafactual scenario where cross-lag terms and transition
#' covariates are set to zero.
#'
#' @param fit_obj A list returned by \code{fit_one()} (or an equivalent
#'   fitting function), containing at least:
#'   \itemize{
#'     \item \code{$fit}: a CmdStanR fit object.
#'     \item \code{$des}: a list with design matrices
#'       \code{X_pi_I}, \code{X_mu_I}, \code{X_pi_C}, \code{X_mu_C},
#'       a vector \code{log_exposure50}, and an index vector \code{idx}.
#'   }
#' @param compute_intervals Logical; if \code{TRUE}, returns posterior
#'   means and 95\% credible intervals (2.5\% and 97.5\% quantiles). If
#'   \code{FALSE}, only posterior means are returned.
#' @param ndraws Integer; maximum number of posterior draws to use. If
#'   \code{ndraws} exceeds the number of available draws, it is truncated.
#' @param seed Integer; random seed used to subsample posterior draws.
#'
#' @details
#' The function identifies in the design matrices:
#' \itemize{
#'   \item Cross-lag terms via column names containing
#'     \code{"zC_L"} / \code{"C_L"} (for \code{I}) and
#'     \code{"zI_L"} / \code{"I_L"} (for \code{C}).
#'   \item Transition covariates via column names starting with
#'     \code{"trans_"}.
#' }
#' For each time point \code{t} and posterior draw \code{s}, the expected
#' value under the observed design (\eqn{E[Y \mid X]}) is contrasted with
#' a contrafactual design where these cross-lag and transition columns are
#' set to zero (\eqn{E[Y \mid X_{cf}]}). The ATE at time \code{t} is
#' defined as the posterior distribution of \eqn{E[Y \mid X] -
#' E[Y \mid X_{cf}]}, computed separately for \code{I} and \code{C}.
#'
#' @return A tibble with one row per effective time index (length
#'   \code{des$idx}). If \code{compute_intervals = TRUE}, the columns are:
#'   \itemize{
#'     \item \code{t}: time index (from \code{des$idx}).
#'     \item \code{ATE_I_mean}, \code{ATE_I_low}, \code{ATE_I_high}:
#'       posterior mean and 95\% credible interval for the ATE on \code{I}.
#'     \item \code{ATE_C_mean}, \code{ATE_C_low}, \code{ATE_C_high}:
#'       posterior mean and 95\% credible interval for the ATE on \code{C}.
#'   }
#'   If \code{compute_intervals = FALSE}, only \code{ATE_I_mean} and
#'   \code{ATE_C_mean} are returned (plus \code{t}).
#'
#' @examples
#' \donttest{
#' if (interactive() && requireNamespace("cmdstanr", quietly = TRUE)) {
#'   n <- 120
#'   DT <- data.table::data.table(
#'     I = rpois(n, 5), C = rpois(n, 3),
#'     Regime = factor(sample(c("A","B","C"), n, TRUE)),
#'     trans_PS = c(rep(1,5), rep(0,n-5)),
#'     trans_SF = c(rep(0,60), rep(1,5), rep(0,n-65)),
#'     trans_FC = rep(0, n),
#'     log_exposure50 = log(runif(n, 40, 60))
#'   )
#'   fit_obj <- fit_one(DT, k = 1, spec = "C")
#'   ate_tab <- contrafactual_ATE(fit_obj, compute_intervals = TRUE)
#'   head(ate_tab)
#' }
#' }
#' @export

contrafactual_ATE <- function(fit_obj, compute_intervals=TRUE, ndraws=1200, seed=42) {
  set.seed(seed)
  draws <- posterior::as_draws_df(fit_obj$fit$draws())
  des <- fit_obj$des; T_eff <- length(des$idx)
  S <- min(ndraws, nrow(draws))
  
  getB <- function(base, X) {
    p <- ncol(X); if (p==0) return(matrix(numeric(0), nrow=nrow(draws), ncol=0))
    v <- as.matrix(draws[, grep(paste0("^",base,"\\["), names(draws)) ])
    if (is.null(dim(v))) v <- matrix(v, ncol=1)
    colnames(v) <- colnames(X); v
  }
  BpiI <- getB("b_pi_I", des$X_pi_I); BmuI <- getB("b_mu_I", des$X_mu_I)
  BpiC <- getB("b_pi_C", des$X_pi_C); BmuC <- getB("b_mu_C", des$X_mu_C)
  a_pi_I <- draws$a_pi_I; a_pi_C <- draws$a_pi_C; a_mu_I <- draws$a_mu_I; a_mu_C <- draws$a_mu_C
  phiI <- exp(draws$log_phi_I); phiC <- exp(draws$log_phi_C)
  
  idx_cols <- function(X, patt) { if (ncol(X)==0) integer(0) else grep(patt, colnames(X), fixed = TRUE) }
  
  zcI_cols <- idx_cols(des$X_pi_I, "zC_L")
  cI_cols  <- idx_cols(des$X_mu_I, "C_L")
  ziC_cols <- idx_cols(des$X_pi_C, "zI_L")
  iC_cols  <- idx_cols(des$X_mu_C, "I_L")
  
  trans_cols_X <- function(X) { if (ncol(X)==0) integer(0) else grep("^trans_", colnames(X)) }
  tpiI <- trans_cols_X(des$X_pi_I); tmuI <- trans_cols_X(des$X_mu_I)
  tpiC <- trans_cols_X(des$X_pi_C); tmuC <- trans_cols_X(des$X_mu_C)
  
  ATE_I <- matrix(NA, S, T_eff); ATE_C <- matrix(NA, S, T_eff)
  
  for (tt in 1:T_eff) {
    xpiI <- des$X_pi_I[tt,]; xmuI <- des$X_mu_I[tt,]
    xpiC <- des$X_pi_C[tt,]; xmuC <- des$X_mu_C[tt,]
    
    xpiI0 <- xpiI; if (length(zcI_cols)) xpiI0[zcI_cols] <- 0
    xmuI0 <- xmuI; if (length(cI_cols))  xmuI0[cI_cols]  <- 0
    xpiC0 <- xpiC; if (length(ziC_cols)) xpiC0[ziC_cols] <- 0
    xmuC0 <- xmuC; if (length(iC_cols))  xmuC0[iC_cols]  <- 0
    
    if (length(tpiI)) xpiI0[tpiI] <- 0
    if (length(tmuI)) xmuI0[tmuI] <- 0
    if (length(tpiC)) xpiC0[tpiC] <- 0
    if (length(tmuC)) xmuC0[tmuC] <- 0
    
    for (s in 1:S) {
      eta_pi_I <- a_pi_I[s] + sum(BpiI[s,]*xpiI)
      eta_mu_I <- a_mu_I[s] + sum(BmuI[s,]*xmuI) + des$log_exposure50[tt]
      p0I <- suppressWarnings(dnbinom(0, size=phiI[s], mu=exp(eta_mu_I)))
      EY_I <- plogis(eta_pi_I) * exp(eta_mu_I) / (1 - p0I + 1e-12)
      
      eta_pi_I0 <- a_pi_I[s] + sum(BpiI[s,]*xpiI0)
      eta_mu_I0 <- a_mu_I[s] + sum(BmuI[s,]*xmuI0) + des$log_exposure50[tt]
      p0I0 <- suppressWarnings(dnbinom(0, size=phiI[s], mu=exp(eta_mu_I0)))
      EY_I0 <- plogis(eta_pi_I0) * exp(eta_mu_I0) / (1 - p0I0 + 1e-12)
      
      eta_pi_C <- a_pi_C[s] + sum(BpiC[s,]*xpiC)
      eta_mu_C <- a_mu_C[s] + sum(BmuC[s,]*xmuC) + des$log_exposure50[tt]
      p0C <- suppressWarnings(dnbinom(0, size=phiC[s], mu=exp(eta_mu_C)))
      EY_C <- plogis(eta_pi_C) * exp(eta_mu_C) / (1 - p0C + 1e-12)
      
      eta_pi_C0 <- a_pi_C[s] + sum(BpiC[s,]*xpiC0)
      eta_mu_C0 <- a_mu_C[s] + sum(BmuC[s,]*xmuC0) + des$log_exposure50[tt]
      p0C0 <- suppressWarnings(dnbinom(0, size=phiC[s], mu=exp(eta_mu_C0)))
      EY_C0 <- plogis(eta_pi_C0) * exp(eta_mu_C0) / (1 - p0C0 + 1e-12)
      
      ATE_I[s, tt] <- EY_I - EY_I0
      ATE_C[s, tt] <- EY_C - EY_C0
    }
  }
  if (compute_intervals) {
    tibble::tibble(
      t = des$idx,
      ATE_I_mean = colMeans(ATE_I),
      ATE_I_low = apply(ATE_I, 2, quantile, 0.025),
      ATE_I_high = apply(ATE_I, 2, quantile, 0.975),
      ATE_C_mean = colMeans(ATE_C),
      ATE_C_low = apply(ATE_C, 2, quantile, 0.025),
      ATE_C_high = apply(ATE_C, 2, quantile, 0.975)
    )
  } else {
    tibble::tibble(
      t = des$idx,
      ATE_I_mean = colMeans(ATE_I),
      ATE_C_mean = colMeans(ATE_C)
    )
  }
}
