Cross-validation with multiple ML algorithms

We can estimate ITR with various machine learning algorithms and then compare the performance of each model. The package includes all ML algorithms in the caret package and 2 additional algorithms (causal forest and bartCause).

The package also allows estimate heterogeneous treatment effects on the individual and group-level. On the individual-level, the summary statistics and the AUPEC plot show whether assigning individualized treatment rules may outperform complete random experiment. On the group-level, we specify the number of groups through ngates and estimating heterogeneous treatment effects across groups.

library(evalITR)
#> Loading required package: MASS
#> 
#> Attaching package: 'MASS'
#> The following object is masked from 'package:dplyr':
#> 
#>     select
#> Loading required package: Matrix
#> Loading required package: quadprog

# specify the trainControl method
fitControl <- caret::trainControl(
                           method = "repeatedcv",
                           number = 2,
                           repeats = 2)
# estimate ITR
set.seed(2021)
fit_cv <- estimate_itr(
               treatment = "treatment",
               form = user_formula,
               data = star_data,
               trControl = fitControl,
               algorithms = c(
                  "causal_forest", 
                  # "bartc",
                  # "rlasso", # from rlearner 
                  # "ulasso", # from rlearner 
                  "lasso" # from caret package
                  # "rf" # from caret package
                  ), # from caret package
               budget = 0.2,
               n_folds = 2)
#> Evaluate ITR with cross-validation ...
#> Loading required package: ggplot2
#> Loading required package: lattice
#> Warning: model fit failed for Fold1.Rep1: fraction=0.9 Error in elasticnet::enet(as.matrix(x), y, lambda = 0, ...) : 
#>   Some of the columns of x have zero variance
#> Warning: model fit failed for Fold1.Rep2: fraction=0.9 Error in elasticnet::enet(as.matrix(x), y, lambda = 0, ...) : 
#>   Some of the columns of x have zero variance
#> Warning in nominalTrainWorkflow(x = x, y = y, wts = weights, info = trainInfo,
#> : There were missing values in resampled performance measures.

# evaluate ITR
est_cv <- evaluate_itr(fit_cv)
#> 
#> Attaching package: 'purrr'
#> The following object is masked from 'package:caret':
#> 
#>     lift

# summarize estimates
summary(est_cv)
#> ── PAPE ────────────────────────────────────────────────────────────────────────
#>   estimate std.deviation     algorithm statistic p.value
#> 1      1.5           1.1 causal_forest       1.4    0.16
#> 2      1.2           1.0         lasso       1.1    0.26
#> 
#> ── PAPEp ───────────────────────────────────────────────────────────────────────
#>   estimate std.deviation     algorithm statistic p.value
#> 1     0.88          0.93 causal_forest      0.95    0.34
#> 2    -0.22          0.80         lasso     -0.27    0.79
#> 
#> ── PAPDp ───────────────────────────────────────────────────────────────────────
#>   estimate std.deviation             algorithm statistic p.value
#> 1      1.1             1 causal_forest x lasso       1.1    0.27
#> 
#> ── AUPEC ───────────────────────────────────────────────────────────────────────
#>   estimate std.deviation     algorithm statistic p.value
#> 1     1.40          0.77 causal_forest      1.82   0.069
#> 2     0.52          1.20         lasso      0.43   0.666
#> 
#> ── GATE ────────────────────────────────────────────────────────────────────────
#>    estimate std.deviation     algorithm group statistic p.value upper lower
#> 1       -80            68 causal_forest     1    -1.172    0.24    54  -214
#> 2       -38            74 causal_forest     2    -0.517    0.61   106  -182
#> 3        96            59 causal_forest     3     1.627    0.10   212   -20
#> 4       -22            82 causal_forest     4    -0.274    0.78   138  -183
#> 5        63            72 causal_forest     5     0.867    0.39   205   -79
#> 6        27            83         lasso     1     0.321    0.75   188  -135
#> 7       -60            80         lasso     2    -0.750    0.45    96  -215
#> 8        80            76         lasso     3     1.043    0.30   230   -70
#> 9        -4            82         lasso     4    -0.048    0.96   156  -164
#> 10      -24            83         lasso     5    -0.293    0.77   138  -187

We plot the estimated Area Under the Prescriptive Effect Curve for the writing score across different ML algorithms.

# plot the AUPEC with different ML algorithms
plot(est_cv)