litedown::reactor(
  print=NA,
  collapse = TRUE,
  comment = "#>",
  fig.width=10,
  fig.height=6)
data.table::setDTthreads(1)

(kfoldcv <- mlr3resampling::ResamplingSameOtherSizesCV$new())

task_list <- mlr3::tsks(c("spam", "german_credit"))
tasks_with_fold <- list()
for(task_i in seq_along(task_list)){
  task <- task_list[[task_i]]
  tcol <- task$col_roles$target
  task_dt <- task$data()
  task_dt[, Fold := rep(1:3, length.out=.N), by=c(tcol)]
  ftask <- mlr3::TaskClassif$new(
    task_dt, id=task$id, target=tcol)
  ftask$col_roles$feature <- task$col_roles$feature
  ftask$col_roles$fold <- "Fold"
  ftask$col_roles$stratum <- c("Fold", tcol)
  tasks_with_fold[[task$id]] <- ftask
}
tasks_with_fold

learner_list <- list(
  mlr3::LearnerClassifFeatureless$new())
if(requireNamespace("rpart")){
  learner_list$rpart <- mlr3::LearnerClassifRpart$new()
}
for(learner_i in seq_along(learner_list)){
  L <- learner_list[[learner_i]]
  L$predict_type <- "prob"
}

measure_list <- mlr3::msrs(c("classif.auc","classif.acc"))

(bgrid <- mlr3::benchmark_grid(tasks_with_fold, learner_list, kfoldcv))

extdata <- system.file(package="mlr3resampling", "extdata")
Sys.setenv(R_BATCHTOOLS_SEARCH_PATH=extdata) #comment to use ~/.batchtools.conf.R instead.

if(requireNamespace("mlr3batchmark")){
  reg_dir <- tempfile()
  reg <- batchtools::makeExperimentRegistry(reg_dir)
  slurm.available <- reg$cluster.functions$name=="Slurm"
  mlr3batchmark::batchmark(bgrid)
}

proj_dir <- if(interactive())"~/testproj" else tempfile()
unlink(proj_dir, recursive = TRUE)
mlr3resampling::proj_grid(
  proj_dir, tasks_with_fold, learner_list, kfoldcv,
  score_args = measure_list)

if(requireNamespace("mlr3batchmark")){
  batchtools::testJob(1)
}

mlr3resampling::proj_test(proj_dir, max_jobs=1)

if(requireNamespace("mlr3batchmark")){
  jt <- batchtools::getJobTable()
  jt1 <- jt[repl==1]
  testJob.repl1 <- sapply(jt1$job.id, batchtools::testJob)
}

submit_job_array <- function(jobs_dt, minutes=1, gigabytes=1){
  jobs_dt$chunk <- 1
  batchtools::submitJobs(jobs_dt, resources=list(
    walltime = minutes*60,#seconds
    memory = gigabytes*1000,#megabytes per cpu
    ncpus=1,  #>1 for multicore/parallel jobs.
    ntasks=1, #>1 for MPI jobs.
    chunks.as.arrayjobs=slurm.available))
}
if(requireNamespace("mlr3batchmark")){
  submit_job_array(jt1)
}

if(requireNamespace("mlr3batchmark")){
  batchtools::waitForJobs(jt1)
  test_res <- mlr3batchmark::reduceResultsBatchmark(jt1)
  test_res$score(measure_list)
}

mlr3resampling::proj_test(proj_dir)

if(require(future))plan("multisession")

bench_result <- mlr3::benchmark(bgrid)
bench_score <- bench_result$score(measure_list)
bench_score[, .(task_id, learner_id, iteration, classif.auc, classif.acc)]

proj_score <- mlr3resampling::proj_compute_all(proj_dir)
proj_score[, .(task_id, learner_id, iteration, classif.auc, classif.acc)]

if(requireNamespace("mlr3batchmark")){
  batchtools::getStatus()
}

if(requireNamespace("mlr3batchmark")){
  not.done <- batchtools::getJobTable()[is.na(done)]
  submit_job_array(not.done)
}

if(requireNamespace("mlr3batchmark")){
  batchtools::waitForJobs()
  ignore.learner <- function(L){
    L$learner_state$model <- NULL
    L
  }
  bt_res <- mlr3batchmark::reduceResultsBatchmark(jt, fun=ignore.learner)
  bt_score <- bt_res$score(measure_list)
}

if(slurm.available){
  slurm_job_id <- mlr3resampling::proj_submit(
    proj_dir, tasks=2, hours=1, gigabytes=1)
}

(result_file_list <- mlr3resampling::proj_fread(proj_dir))

acc_in_list <- list(
  mlr3resampling=result_file_list$results.csv)
if(requireNamespace("mlr3batchmark"))
  acc_in_list$mlr3batchmark <- bt_score
acc_out_list <- list()
library(data.table)
for(package in names(acc_in_list)){
  acc_in <- melt(
    acc_in_list[[package]],
    id.vars=c("task_id", "learner_id", "iteration"),
    measure.vars=c("classif.auc", "classif.acc"))
  acc_out_list[[package]] <- data.table(package, acc_in)
}
acc_out <- rbindlist(acc_out_list)
(acc_compare <- dcast(
  acc_out,
  variable + task_id + learner_id + iteration ~ package))
if(requireNamespace("mlr3batchmark"))
  acc_compare[, all.equal(mlr3batchmark, mlr3resampling)]

if(require(ggplot2)){
  ggplot()+
    geom_point(aes(
      mlr3resampling, learner_id),
      data=acc_compare)+
    facet_wrap(c("task_id","variable"), labeller=label_both, scales="free", ncol=1)
}

time_compare <- rbind(
  if(requireNamespace("mlr3batchmark"))batchtools::getJobTable()[, .(
    package="mlr3batchmark", process=.I, start.time=started, end.time=done)],
  result_file_list$results.csv[, .(
    package="mlr3resampling", process, start.time, end.time)])
if(require(ggplot2)){
  ggplot(time_compare, aes(start.time, process))+
    geom_segment(aes(
      xend=end.time, yend=process))+
    geom_point()+
    facet_grid(
      package~.,
      labeller=label_both,
      scales="free")
}

proj_new <- if(interactive())"~/proj_new" else tempfile()
unlink(proj_new, recursive = TRUE)
learners_new <- list(
  mlr3::LearnerClassifFeatureless$new())
if(requireNamespace("torch") && torch::torch_is_installed()){
  gen_linear <- torch::nn_module(
    "my_linear",
    initialize = function(task) {
      self$weight = torch::nn_linear(task$n_features, 1)
    },
    forward = function(x) {
      self$weight(x)
    }
  )
  learners_new$torch <- mlr3resampling::AutoTunerTorch_epochs$new(
    "torch_linear",
    module_generator=gen_linear,
    max_epochs=1000,
    batch_size=10,
    measure_list=mlr3::msrs("classif.auc")
  )
}
if(requireNamespace("glmnet")){
  learners_new$glmnet <- mlr3resampling::LearnerClassifCVGlmnetSave$new()
}
for(learner_i in seq_along(learners_new)){
  L <- learners_new[[learner_i]]
  L$predict_type <- "prob"
}
mlr3resampling::proj_grid(
  proj_new, tasks_with_fold$spam, learners_new, kfoldcv,
  score_args = measure_list)
system.time({
  test_result_list <- mlr3resampling::proj_test(proj_new)
})

names(test_result_list)

test_result_list$learners_history.csv

test_result_list$learners_weights.csv

rds.vec <- Sys.glob(file.path(proj_new,"test","tasks","*rds"))
for(rds.i in seq_along(rds.vec)){
  mini_task <- readRDS(rds.vec[[rds.i]])
  print(table(mini_task$data()[[1]]))
}

if(require(future))plan("sequential")

