litedown::reactor(
  print=NA,
  collapse = TRUE,
  comment = "#>",
  fig.width=10,
  fig.height=6)
data.table::setDTthreads(1)

data(AZtrees, package="mlr3resampling")
library(data.table)
options(datatable.print.keys=FALSE, datatable.print.class=FALSE)
AZtrees[, .(region3, polygon, y)]

dcast(AZtrees, polygon ~ y, length)[order(`Not tree`, -Tree)]

trees <- function(..., cv=NULL, folds=3L){
  trees_task <- mlr3::as_task_classif(AZtrees, target="y")
  role_list <- list(...)
  for(role in names(role_list)){
    trees_task$col_roles[[role]] <- role_list[[role]]
  }
  if(is.null(cv))cv <- mlr3resampling::ResamplingSameOtherSizesCV$new()
  cv$param_set$values$folds <- folds
  cv$instantiate(trees_task)
  role_dt <- data.table(AZtrees)
  for(fold in 1:folds){
    set(role_dt, cv$test_set(fold), "fold", fold)
  }
  setkey(role_dt, polygon, fold)
  in_list <- list(
    rows=role_dt,
    groups=role_dt[, .(fold=paste(unique(fold), collapse=",")), by=.(y,polygon)])
  out_list <- list()
  for(unit in names(in_list)){
    in_dt <- in_list[[unit]]
    out_list[[unit]] <- rbind(in_dt[, table(y, fold)], TOTAL=in_dt[, table(fold)])
  }
  out_list
}
set.seed(1)

trees()

trees()

trees(cv=mlr3::rsmp("cv"))

trees(stratum="y")

trees(stratum="y", cv=mlr3::rsmp("cv"))

trees(group="polygon")

trees(group="polygon")

trees(group="polygon", cv=mlr3::rsmp("cv"))

more.folds <- 5
trees(group="polygon", folds=more.folds)
trees(group="polygon", folds=more.folds, cv=mlr3::rsmp("cv"))

trees(stratum="y", group="polygon")

trees(stratum="y", group="polygon")

trees(stratum="y", group="polygon", folds=more.folds)

trees_task <- mlr3::as_task_classif(AZtrees, target="y")
trees_task$col_roles$subset <- "region3"
trees_task$col_roles$stratum <- c("region3","y")
trees_task$col_roles$group <- "polygon"

cv <- mlr3resampling::ResamplingSameOtherSizesCV$new()
cv$instantiate(trees_task)
role_dt <- data.table(AZtrees, cv$instance$fold.dt)
in_list <- list(
  rows=role_dt,
  groups=role_dt[, .(rows=.N), by=.(y, polygon, test.subset, fold)])
out_list <- list()
for(unit in names(in_list)){
  in_dt <- in_list[[unit]][, fsub := paste0(test.subset, fold)]
  out_list[[unit]] <- rbind(in_dt[, table(y, fsub)], TOTAL=in_dt[, table(fsub)])
}
out_list

resp_counts <- function(Subset)if(requireNamespace("geepack")){
  data(respiratory, package="geepack")
  resp_dt <- data.table(respiratory)[, person := paste(center, id)]
  resp_task <- mlr3::as_task_classif(resp_dt, target="outcome")
  resp_task$col_roles$subset <- Subset
  resp_task$col_roles$stratum <- c(Subset, "outcome")
  resp_task$col_roles$group <- "person"
  cv <- mlr3resampling::ResamplingSameOtherSizesCV$new()
  cv$instantiate(resp_task)
  role_dt <- data.table(respiratory, cv$instance$fold.dt)[, subval := get(Subset)]
  in_list <- list(
    rows=role_dt,
    groups=role_dt[, .(rows=.N), by=.(outcome, id, fold, subval)])
  out_list <- list()
  for(unit in names(in_list)){
    in_dt <- in_list[[unit]][, fsub := paste0(subval, "_fold", fold)]
    out_list[[unit]] <- rbind(in_dt[, table(outcome, fsub)], TOTAL=in_dt[, table(fsub)])
  }
  out_list
}

resp_counts("sex")

resp_counts("center")

