#===============================================================================
# Helper Functions for Simulation
# Author: Sisi Shao
# This file contains functions for calculating feature selection and
# prediction metrics.
#===============================================================================

#' Null-coalescing operator
#'
#' @title Null-safe (coalescing) operator
#'
#' @description
#' Returns the left-hand side (`a`) if it is not `NULL`, otherwise returns
#' the right-hand side (`b`). This operator is useful for setting default
#' values in a concise manner.
#'
#' @details
#' This implementation mirrors the behavior of `%||%` used in several
#' R ecosystems (e.g., `rlang`, `purrr`) but is defined here for
#' convenience within the **TemporalForest** package.
#'
#' @section Conflicts:
#' The `%||%` operator name is also used in packages such as **rlang** and **purrr**.
#' If you load those packages after **TemporalForest**, they may mask this operator.
#' To avoid ambiguity, call it explicitly using
#' `TemporalForest::\`%||%\``.
#'
#' @param a The primary object or value.
#' @param b The default object or value to return if `a` is `NULL`.
#'
#' @return `a` if not `NULL`, otherwise `b`.
#'
#' @name null_or
#' @aliases %||%
#' @export
#'
#' @examples
#' a <- NULL
#' b <- 5
#' a %||% b   # Returns 5
#'
#' x <- 10
#' y <- 20
#' x %||% y   # Returns 10
#'
#' # Safe usage when multiple packages define %||%
#' TemporalForest::`%||%`(NULL, "default")
`%||%` <- function(a, b) {
    if (!is.null(a)) a else b
}

#' Calculate Feature Selection Metrics 
#' @description
#' Computes TP, FP, FN, TN, Sensitivity, Specificity, Precision, and F1-Score.
#'
#' @details
#' The function handles edge cases where the number of true positives, false
#' positives, or false negatives is zero to avoid division-by-zero errors in
#' precision, sensitivity, and F1-score calculations.
#'
#' @param selected_vars A character vector of selected variable names.
#' @param true_vars_global A character vector of the true variable names.
#' @param total_feature_count_p_val The total number of candidate features (p).
#'
#' @return A list containing the following numeric elements:
#'   \itemize{
#'     \item \code{TP}: True Positives
#'     \item \code{FP}: False Positives
#'     \item \code{FN}: False Negatives
#'     \item \code{TN}: True Negatives
#'     \item \code{Sens}: Sensitivity (Recall)
#'     \item \code{Spec}: Specificity
#'     \item \code{Prec}: Precision
#'     \item \code{F1}: F1-Score
#'     \item \code{N_Selected}: Total number of selected variables
#'   }
#' @note Inputs are character vectors; no NA imputation is performed. Division
#' by zero is guarded with 0-valued metrics as documented.
#' @export
#' @examples
#' # --- Example for an internal function ---
#' # Imagine our model selected 3 variables: V1, V2, and V10
#' selected <- c("V1", "V2", "V10")
#'
#' # And the "true" important variables were V1, V2, V3, and V4
#' true_set <- c("V1", "V2", "V3", "V4")
#'
#' # And the total pool of variables was 50
#' p <- 50
#'
#' # Calculate the performance metrics
#' metrics <- calculate_fs_metrics_cv(
#'   selected_vars = selected,
#'   true_vars_global = true_set,
#'   total_feature_count_p_val = p
#' )
#'
#' print(metrics)
#' #> Expected output:
#' #> $TP
#' #> [1] 2
#' #> $FP
#' #> [1] 1
#' #> $FN
#' #> [1] 2
#' #> $TN
#' #> [1] 45
#' #> $Sens
#' #> [1] 0.5
#' #> $Spec
#' #> [1] 0.9782609
#' #> $Prec
#' #> [1] 0.6666667
#' #> $F1
#' #> [1] 0.5714286
#' #> $N_Selected
#' #> [1] 3
calculate_fs_metrics_cv <- function(selected_vars, true_vars_global, total_feature_count_p_val) {
    TP <- sum(selected_vars %in% true_vars_global)
    FP <- sum(!selected_vars %in% true_vars_global)
    FN <- sum(!true_vars_global %in% selected_vars)
    TN <- total_feature_count_p_val - TP - FP - FN
    Sens <- ifelse((TP + FN) == 0, 0, TP / (TP + FN))
    Spec <- ifelse((TN + FP) == 0, 0, TN / (TN + FP))
    Prec <- ifelse((TP + FP) == 0, 0, TP / (TP + FP))
    F1   <- ifelse((Prec + Sens) == 0, 0, 2 * Prec * Sens / (Prec + Sens))
    return(list(TP = TP, FP = FP, FN = FN, TN = TN, Sens = Sens, Spec = Spec, Prec = Prec, F1 = F1, N_Selected = length(selected_vars)))
}


#' Calculate Prediction Metrics 
#'
#' @description
#' Computes Root Mean Squared Error (RMSE) and R-squared.
#'
#' @details
#' The function is robust to missing values (`NA`) within the input vectors, as they
#' are removed prior to calculation (`na.rm = TRUE`). It also handles multiple
#' edge cases, returning `NA` for both metrics if inputs are `NULL`, empty, contain only
#' `NA`s, or are of unequal length. If the variance of `actual` values is
#' near zero, `R_squared` is handled safely.
#'
#' @param predictions A numeric vector of model predictions.
#' @param actual A numeric vector of the true outcome values.
#'
#' @return A list containing the following numeric elements:
#'   \itemize{
#'     \item \code{RMSE}: Root Mean Squared Error.
#'     \item \code{R_squared}: R-squared (Coefficient of Determination).
#'   }
#' @note Inputs are numeric vectors of equal length. NAs are removed via
#' `na.rm = TRUE`; if inputs are NULL/empty/all-NA/length-mismatch, both
#' metrics are returned as `NA_real_`.
#' @export
#' @examples
#' # --- Example for an internal function ---
#' # Example predicted values from a model
#' predicted_values <- c(2.5, 3.8, 6.1, 7.9)
#'
#' # The corresponding actual, true values
#' actual_values <- c(2.2, 4.1, 5.9, 8.3)
#'
#' # Calculate the prediction metrics
#' metrics <- calculate_pred_metrics_cv(
#'   predictions = predicted_values,
#'   actual = actual_values
#' )
#'
#' print(metrics)
#' #> Expected output:
#' #> $RMSE
#' #> [1] 0.3082207
#' #> $R_squared
#' #> [1] 0.981269
calculate_pred_metrics_cv <- function(predictions, actual) {
    if (is.null(predictions) || length(predictions) == 0 || all(is.na(predictions)) ||
        is.null(actual) || length(actual) == 0 || all(is.na(actual)) ||
        length(predictions) != length(actual)) {
        return(list(RMSE = NA_real_, R_squared = NA_real_))
    }
    RMSE <- sqrt(mean((actual - predictions)^2, na.rm = TRUE))
    SST <- sum((actual - mean(actual, na.rm = TRUE))^2, na.rm = TRUE)
    SSE <- sum((actual - predictions)^2, na.rm = TRUE)
    R_squared <- if (SST < 1e-9) {
        if (SSE < 1e-9) 1.0 else 0.0
    } else {
        1 - (SSE / SST)
    }
    return(list(RMSE = RMSE, R_squared = R_squared))
}


#' Check Consistency of Temporal Predictor Data
#'
#' @description
#' Verifies that the input list 'X' is properly formatted for the main
#' `temporal_forest` function.
#'
#' @details
#' This helper function is called internally by `temporal_forest()` to perform
#' critical input validation before any heavy computation begins. It checks for
#' two main conditions:
#' \enumerate{
#'   \item That the input `X` is a list containing data for at least two time points.
#'   \item That all data frames or matrices in the list have identical column names in the exact same order.
#' }
#' This prevents downstream errors during network construction and modeling, and
#' provides a clear, informative error message to the user if their data format
#' is incorrect.
#'
#' @param X A list of numeric matrices or data frames, where each element is
#'   expected to represent a time point, with subjects as rows and predictors
#'   as columns.
#'
#' @return Returns `TRUE` invisibly if all consistency checks pass. If a check
#'   fails, it throws a specific error and stops execution.
#'
#' @export
#' @examples
#' # --- 1. A valid input that will pass ---
#' mat1 <- matrix(1:4, nrow = 2, dimnames = list(NULL, c("V1", "V2")))
#' mat2 <- matrix(5:8, nrow = 2, dimnames = list(NULL, c("V1", "V2")))
#' good_X <- list(mat1, mat2)
#'
#' # This will run silently and return TRUE
#' check_temporal_consistency(good_X)
#'
#' # --- 2. An invalid input that will fail ---
#' mat3 <- matrix(9:12, nrow = 2, dimnames = list(NULL, c("V1", "V3"))) # Mismatched colnames
#' bad_X <- list(mat1, mat3)
#'
#' # This will throw an informative error
#' # We wrap it in try() to prevent the example from stopping
#' try(check_temporal_consistency(bad_X))
#'
check_temporal_consistency <- function(X) {
    # Check if X is a list with at least two elements
    if (!is.list(X) || length(X) < 2) {
        stop("Input 'X' must be a list containing data for at least two time points.", call. = FALSE)
    }
    
    # Get column names from the first matrix
    first_colnames <- colnames(X[[1]])
    if (is.null(first_colnames)) {
        stop("All matrices in the list 'X' must have column names.", call. = FALSE)
    }
    
    # Loop through the rest of the matrices and compare column names
    for (i in 2:length(X)) {
        current_colnames <- colnames(X[[i]])
        if (!identical(first_colnames, current_colnames)) {
            stop(paste("Inconsistent data format: The column names of the matrix for time point", i,
                       "do not match the column names of the first time point."), call. = FALSE)
        }
    }
    
    invisible(TRUE)
}
