In mlr3, a data set is represented as a Task with meta-data called roles which define how columns should be used for learning and evaluation. The goal of this vignette is to explain how to use the new subset role with previous group and stratum roles.

Introduction: role definitions

Previous column roles in mlr3 are:

New column role proposed in mlr3resampling is:

Example using AZtrees data

The AZtrees data represent an image segmentation problem. Each row is a pixel, which we want to classify as tree or not (stratum role). Data were labeled by drawing polygons on a map (group role). Data are divided into either 3 or 4 regions (subset role).

data(AZtrees, package="mlr3resampling")
library(data.table)
options(datatable.print.keys=FALSE, datatable.print.class=FALSE)
AZtrees[, .(region3, polygon, y)]
#>      region3 polygon        y
#>   1:      NE       1 Not tree
#>   2:      NE       1 Not tree
#>   3:      NE       1 Not tree
#>   4:      NE       1 Not tree
#>   5:      NE       2 Not tree
#>  ---                         
#>5952:       S     187     Tree
#>5953:       S     188     Tree
#>5954:       S     189     Tree
#>5955:       S     189     Tree
#>5956:       S     190     Tree

Above we see the three columns that are relevant for these roles in these data. Below we count the number of rows for each polygon,

dcast(AZtrees, polygon ~ y, length)[order(`Not tree`, -Tree)]
#>     polygon Not tree Tree
#>  1:     102        0   45
#>  2:     112        0   33
#>  3:      85        0   30
#>  4:      95        0   26
#>  5:     104        0   24
#> ---                      
#>185:      28      223    0
#>186:      27      228    0
#>187:      55      297    0
#>188:     107      776    0
#>189:     108     1478    0

Above we see that there are several big polygons with no trees (hundreds of rows), whereas the polygons for presence of trees are smaller (max 45 rows). Below we define a helper function which creates a Task and assigns roles based on arguments.

trees <- function(..., cv=NULL, folds=3L){
  trees_task <- mlr3::as_task_classif(AZtrees, target="y")
  role_list <- list(...)
  for(role in names(role_list)){
    trees_task$col_roles[[role]] <- role_list[[role]]
  }
  if(is.null(cv))cv <- mlr3resampling::ResamplingSameOtherSizesCV$new()
  cv$param_set$values$folds <- folds
  cv$instantiate(trees_task)
  role_dt <- data.table(AZtrees)
  for(fold in 1:folds){
    set(role_dt, cv$test_set(fold), "fold", fold)
  }
  setkey(role_dt, polygon, fold)
  in_list <- list(
    rows=role_dt,
    groups=role_dt[, .(fold=paste(unique(fold), collapse=",")), by=.(y,polygon)])
  out_list <- list()
  for(unit in names(in_list)){
    in_dt <- in_list[[unit]]
    out_list[[unit]] <- rbind(in_dt[, table(y, fold)], TOTAL=in_dt[, table(fold)])
  }
  out_list
}
set.seed(1)

Above we set the seed to obtain reproducible results below (fold assignment is random).

Cross-validation with no column roles

Below we show counts of rows and groups in each fold, when no column roles are set:

trees()
#>$rows
#>            1    2    3
#>Not tree 1775 1748 1759
#>Tree      211  237  226
#>TOTAL    1986 1985 1985
#>
#>$groups
#>          1 1,2 1,2,3 1,3  2 2,3  3
#>Not tree  0   3    59   3  2   2  3
#>Tree     10  12    42  14 12  12 15
#>TOTAL    10  15   101  17 14  14 18
#>

We see in above that

Below we do it again to emphasize the random variation that can occur (especially in label counts) when roles are not set:

trees()
#>$rows
#>            1    2    3
#>Not tree 1782 1740 1760
#>Tree      204  245  225
#>TOTAL    1986 1985 1985
#>
#>$groups
#>          1 1,2 1,2,3 1,3  2 2,3  3
#>Not tree  1   4    60   3  1   2  1
#>Tree     11  12    49  11 11  10 13
#>TOTAL    12  16   109  14 12  12 14
#>

Below we compare with the previous mlr3 code:

trees(cv=mlr3::rsmp("cv"))
#>$rows
#>            1    2    3
#>Not tree 1749 1761 1772
#>Tree      237  224  213
#>TOTAL    1986 1985 1985
#>
#>$groups
#>          1 1,2 1,2,3 1,3  2 2,3  3
#>Not tree  1   4    60   4  0   0  3
#>Tree     16  13    42   8 13  15 10
#>TOTAL    17  17   102  12 13  15 13
#>

When no column roles are set, results are similar between previous mlr3 and proposed mlr3resampling.

Cross-validation using strata

Next, we illustrate the stratum column role.

trees(stratum="y")
#>$rows
#>            1    2    3
#>Not tree 1761 1761 1760
#>Tree      225  224  225
#>TOTAL    1986 1985 1985
#>
#>$groups
#>          1 1,2 1,2,3 1,3  2 2,3  3
#>Not tree  1   2    61   3  0   2  3
#>Tree     13  11    45  11 11  13 13
#>TOTAL    14  13   106  14 11  15 16
#>

Above we see that

Below we run the previous mlr3 method for comparison.

trees(stratum="y", cv=mlr3::rsmp("cv"))
#>$rows
#>            1    2    3
#>Not tree 1761 1761 1760
#>Tree      225  225  224
#>TOTAL    1986 1986 1984
#>
#>$groups
#>          1 1,2 1,2,3 1,3  2 2,3  3
#>Not tree  2   1    60   3  3   3  0
#>Tree     14  14    44   9 11  13 12
#>TOTAL    16  15   104  12 14  16 12
#>

Above we see the same patterns, except rows TOTAL is slightly less consistent (plus or minus 2 instead of 1).

Cross-validation on polygons

Below we compute counts when setting group role to polygon:

trees(group="polygon")
#>$rows
#>            1    2    3
#>Not tree 1763 1753 1766
#>Tree      223  232  219
#>TOTAL    1986 1985 1985
#>
#>$groups
#>          1  2  3
#>Not tree 15 28 29
#>Tree     43 36 38
#>TOTAL    58 64 67
#>

Above we see

Below we do it again to emphasize the random variation that can occur (especially in label counts) in this case:

trees(group="polygon")
#>$rows
#>            1    2    3
#>Not tree 1803 1756 1723
#>Tree      183  229  262
#>TOTAL    1986 1985 1985
#>
#>$groups
#>          1  2  3
#>Not tree 19 28 25
#>Tree     39 36 42
#>TOTAL    58 64 67
#>

Above we see TOTAL is the same, but label counts per fold have changed.

Next, we try regular CV from mlr3,

trees(group="polygon", cv=mlr3::rsmp("cv"))
#>$rows
#>            1    2    3
#>Not tree 1806 1066 2410
#>Tree      320  187  167
#>TOTAL    2126 1253 2577
#>
#>$groups
#>          1  2  3
#>Not tree 23 25 24
#>Tree     40 38 39
#>TOTAL    63 63 63
#>

We see a very different pattern: TOTAL=63 polygons per group, but there is large variation in the number of rows per fold.

Finally, we compare results for a larger number of folds.

more.folds <- 5
trees(group="polygon", folds=more.folds)
#>$rows
#>            1    2    3    4    5
#>Not tree 1478  959  926  953  966
#>Tree        0  161  194  166  153
#>TOTAL    1478 1120 1120 1119 1119
#>
#>$groups
#>         1  2  3  4  5
#>Not tree 1 16 17 20 18
#>Tree     0 27 31 28 31
#>TOTAL    1 43 48 48 49
#>
trees(group="polygon", folds=more.folds, cv=mlr3::rsmp("cv"))
#>$rows
#>           1   2    3   4   5
#>Not tree 303 629 2855 834 661
#>Tree     209 116  121  55 173
#>TOTAL    512 745 2976 889 834
#>
#>$groups
#>          1  2  3  4  5
#>Not tree 12 11 16 22 11
#>Tree     26 27 22 16 26
#>TOTAL    38 38 38 38 37
#>

Above we see that mlr3resampling results in a fold with one polygon and one class (unusable), whereas mlr3 yields both classes in every fold. This kind of failure only happens in data sets where the total number of rows is small, there is a very large group, a large number of folds, and stratum role not set (see fix below).

This section showed that the proposed mlr3resampling provides a new method for fold assignment using groups (total row counts equalized across folds), that is not the same as the previous method available in mlr3 (group counts equalized across folds).

Cross-validation on polygons with strata

Below we compute counts when setting stratum and group roles.

trees(stratum="y", group="polygon")
#>$rows
#>            1    2    3
#>Not tree 1761 1761 1760
#>Tree      225  224  225
#>TOTAL    1986 1985 1985
#>
#>$groups
#>          1  2  3
#>Not tree 19 26 27
#>Tree     39 39 39
#>TOTAL    58 65 66
#>

Above we see

Below we do it again to emphasize the random variation that does not occur in this case:

trees(stratum="y", group="polygon")
#>$rows
#>            1    2    3
#>Not tree 1761 1761 1760
#>Tree      225  224  225
#>TOTAL    1986 1985 1985
#>
#>$groups
#>          1  2  3
#>Not tree 19 26 27
#>Tree     39 39 39
#>TOTAL    58 65 66
#>

Above we see consistent counts (no variation in counts across random seeds in these data).

Next, we try regular CV from mlr3,

trees(stratum="y", group="polygon", cv=mlr3::rsmp("cv"))
#>Error: 
#>✖ Cannot combine stratification with grouping
#>→ Class: Mlr3ErrorInput

Above we see an error, because mlr3 does not implement the heuristic fold assignment algorithm for the case of groups and strata at the same time.

Finally, we compare results for a larger number of folds.

trees(stratum="y", group="polygon", folds=more.folds)
#>$rows
#>            1    2    3    4    5
#>Not tree  951 1478  951  951  951
#>Tree      135  134  135  135  135
#>TOTAL    1086 1612 1086 1086 1086
#>
#>$groups
#>          1  2  3  4  5
#>Not tree 13  1 18 20 20
#>Tree     22 23 24 24 24
#>TOTAL    35 24 42 44 44
#>

We see above thet all folds have some data from both classes. One fold is much larger than the others, because there is one group that is very large.

This comparison shows that the proposed mlr3resampling provides a new method for fold assignment using groups and strata, that was not previously available in mlr3.

Cross-validation with subsets, strata, and groups

Below we define a task with all three roles:

trees_task <- mlr3::as_task_classif(AZtrees, target="y")
trees_task$col_roles$subset <- "region3"
trees_task$col_roles$stratum <- c("region3","y")
trees_task$col_roles$group <- "polygon"

Note above that there are two strata variables: the subset variable region3, and the target variable y. In general, both the target and subset variables should be used for stratum, so that folds are as balanced as possible.

Below we compute the row and group counts by label, fold, and subset.

cv <- mlr3resampling::ResamplingSameOtherSizesCV$new()
cv$instantiate(trees_task)
role_dt <- data.table(AZtrees, cv$instance$fold.dt)
in_list <- list(
  rows=role_dt,
  groups=role_dt[, .(rows=.N), by=.(y, polygon, test.subset, fold)])
out_list <- list()
for(unit in names(in_list)){
  in_dt <- in_list[[unit]][, fsub := paste0(test.subset, fold)]
  out_list[[unit]] <- rbind(in_dt[, table(y, fsub)], TOTAL=in_dt[, table(fsub)])
}
out_list
#>$rows
#>         NE1 NE2 NE3 NW1 NW2 NW3  S1   S2  S3
#>Not tree 378 378 378 504 502 502 776 1478 386
#>Tree     110 110 110  18  18  19  96   96  97
#>TOTAL    488 488 488 522 520 521 872 1574 483
#>
#>$groups
#>         NE1 NE2 NE3 NW1 NW2 NW3 S1 S2 S3
#>Not tree  16  17  16   6   5   4  1  1  6
#>Tree      18  20  18   6   7   4 14 15 15
#>TOTAL     34  37  34  12  12   8 15 16 21
#>

Above we see that the label row counts in NE and NW are very similar across folds, but subset S is more variable, because there are so few polygons (two folds with only one negative polygon).

Example using respiratory data

These are medical data,

First we define a helper function,

resp_counts <- function(Subset)if(requireNamespace("geepack")){
  data(respiratory, package="geepack")
  resp_dt <- data.table(respiratory)[, person := paste(center, id)]
  resp_task <- mlr3::as_task_classif(resp_dt, target="outcome")
  resp_task$col_roles$subset <- Subset
  resp_task$col_roles$stratum <- c(Subset, "outcome")
  resp_task$col_roles$group <- "person"
  cv <- mlr3resampling::ResamplingSameOtherSizesCV$new()
  cv$instantiate(resp_task)
  role_dt <- data.table(respiratory, cv$instance$fold.dt)[, subval := get(Subset)]
  in_list <- list(
    rows=role_dt,
    groups=role_dt[, .(rows=.N), by=.(outcome, id, fold, subval)])
  out_list <- list()
  for(unit in names(in_list)){
    in_dt <- in_list[[unit]][, fsub := paste0(subval, "_fold", fold)]
    out_list[[unit]] <- rbind(in_dt[, table(outcome, fsub)], TOTAL=in_dt[, table(fsub)])
  }
  out_list
}

Are models consistent across sex?

To see if we can learn a model which can combine data across sex, or predict well on one sex after training on the other, we can use sex as subset role.

resp_counts("sex")
#>Loading required namespace: geepack
#>$rows
#>      F_fold1 F_fold2 F_fold3 M_fold1 M_fold2 M_fold3
#>0          14      14      18      52      48      50
#>1          14      18      14      68      68      66
#>TOTAL      28      32      32     120     116     116
#>
#>$groups
#>      F_fold1 F_fold2 F_fold3 M_fold1 M_fold2 M_fold3
#>0           5       5       6      17      16      17
#>1           5       6       5      20      21      21
#>TOTAL      10      11      11      37      37      38
#>

The results above show row and group counts per label, sex, and fold. We see that the code works well for generating balanced folds.

Are models consistent across centers?

To see if we can learn a model which can combine data across centers, or predict well on one center after training on the other, we can use center as subset role.

resp_counts("center")
#>$rows
#>      1_fold1 1_fold2 1_fold3 2_fold1 2_fold2 2_fold3
#>0          41      41      40      26      23      25
#>1          35      35      32      50      49      47
#>TOTAL      76      76      72      76      72      72
#>
#>$groups
#>      1_fold1 1_fold2 1_fold3 2_fold1 2_fold2 2_fold3
#>0          15      15      14      10      10      10
#>1          13      13      12      16      16      15
#>TOTAL      28      28      26      26      26      25
#>

The results above show row and group counts per label, center, and fold. We see that the code works well for generating balanced folds.

Conclusion

Overall these results show that the proposed resampling methods can handle complex cross-validation experiments, including previous roles (group and stratum) and a new subset role.