Train an elastic net model to predict sample-level phenomena using high-dimensional cytometry data.
Source:R/patient-level_modeling.R
tof_train_model.Rd
This function uses a training set/test set paradigm to tune and fit an elastic net model using a variety of user-specified details. Tuning can be performed using either a simple training vs. test set split, k-fold cross-validation, or bootstrapping, and multiple preprocessing options are available.
Usage
tof_train_model(
split_data,
unsplit_data,
predictor_cols,
response_col = NULL,
time_col = NULL,
event_col = NULL,
model_type = c("linear", "two-class", "multiclass", "survival"),
hyperparameter_grid = tof_create_grid(),
standardize_predictors = TRUE,
remove_zv_predictors = FALSE,
impute_missing_predictors = FALSE,
optimization_metric = "tidytof_default",
best_model_type = c("best", "best with sparsity"),
num_cores = 1
)
Arguments
- split_data
An `rsplit` or `rset` object from the
rsample
package containing the sample-level data to use for modeling. The easiest way to generate this is to usetof_split_data
.- unsplit_data
A tibble containing sample-level data to use for modeling without resampling. While using a resampling method is advised, this argument provides an interface to fit a model without using cross-validation or bootstrap resampling. Ignored if split_data is provided.
- predictor_cols
Unquoted column names indicating which columns in the data contained in `split_data` should be used as predictors in the elastic net model. Supports tidyselect helpers.
- response_col
Unquoted column name indicating which column in the data contained in `split_data` should be used as the outcome in a "two-class", "multiclass", or "linear" elastic net model. Must be a factor for "two-class" and "multiclass" models and must be a numeric for "linear" models. Ignored if `model_type` is "survival".
- time_col
Unquoted column name indicating which column in the data contained in `split_data` represents the time-to-event outcome in a "survival" elastic net model. Must be numeric. Ignored if `model_type` is "two-class", "multiclass", or "linear".
- event_col
Unquoted column name indicating which column in the data contained in `split_data` represents the time-to-event outcome in a "survival" elastic net model. Must be a binary column - all values should be either 0 or 1 (with 1 indicating the adverse event) or FALSE and TRUE (with TRUE indicating the adverse event). Ignored if `model_type` is "two-class", "multiclass", or "linear".
- model_type
A string indicating which kind of elastic net model to build. If a continuous response is being predicted, use "linear" for linear regression; if a categorical response with only 2 classes is being predicted, use "two-class" for logistic regression; if a categorical response with more than 2 levels is being predicted, use "multiclass" for multinomial regression; and if a time-to-event outcome is being predicted, use "survival" for Cox regression.
- hyperparameter_grid
A hyperparameter grid indicating which values of the elastic net penalty (lambda) and the elastic net mixture (alpha) hyperparamters should be used during model tuning. Generate this grid using
tof_create_grid
.- standardize_predictors
A logical value indicating if numeric predictor columns should be standardized (centered and scaled) before model fitting, as is standard practice during elastic net regularization. Defaults to TRUE.
- remove_zv_predictors
A logical value indicating if predictor columns with near-zero variance should be removed before model fitting using
step_nzv
. Defaults to FALSE.- impute_missing_predictors
A logical value indicating if predictor columns should have missing values imputed using k-nearest neighbors before model fitting (see
step_impute_knn
). Imputation is performed using an observation's 5 nearest-neighbors. Defaults to FALSE.- optimization_metric
A string indicating which optimization metric should be used for hyperparameter selection during model tuning. Valid values depend on the model_type.
For "linear" models, choices are "mse" (the mean squared error of the predictions; the default) and "mae" (the mean absolute error of the predictions).
For "two-class" models, choices are "roc_auc" (the area under the Receiver-Operating Curve for the classification; the default), "misclassification error" (the proportion of misclassified observations), "binomial_deviance" (see
deviance.glmnet
), "mse" (the mean squared error of the logit function), and "mae" (the mean absolute error of the logit function).For "multiclass" models, choices are "roc_auc" (the area under the Receiver-Operating Curve for the classification using the Hand-Till generalization of the ROC AUC for multiclass models in
roc_auc
; the default), "misclassification error" (the proportion of misclassified observations), "multinomial_deviance" (seedeviance.glmnet
), and "mse" and "mae" as above.For "survival" models, choices are "concordance_index" (Harrel's C index; see
deviance.glmnet
) and "partial_likelihood_deviance" (seedeviance.glmnet
).
- best_model_type
Currently unused.
- num_cores
Integer indicating how many cores should be used for parallel processing when fitting multiple models. Defaults to 1. Overhead to separate models across multiple cores can be high, so significant speedup is unlikely to be observed unless many large models are being fit.
Value
A `tof_model`, an S3 class that includes the elastic net model with the best performance (assessed via cross-validation, bootstrapping, or simple splitting depending on `split_data`) across all tested hyperparameter value combinations. `tof_models` store the following information:
- model
The final elastic net ("glmnet") model, which is chosen by selecting the elastic net hyperparameters with the best `optimization_metric` performance on the validation sets of each resample used to train the model (on average)
- recipe
The
recipe
used for data preprocessing- mixture
The optimal mixture hyperparameter (alpha) for the glmnet model
- penalty
The optimal penalty hyperparameter (lambda) for the glmnet model
- model_type
A string indicating which type of glmnet model was fit
- outcome_colnames
A character vector representing the names of the columns in the training data modeled as outcome variables
- training_data
A tibble containing the (not preprocessed) data used to train the model
- tuning_metrics
A tibble containing the validation set performance metrics (and model predictions) during for each resample fold during model tuning.
- log_rank_thresholds
For survival models only, a tibble containing information about the relative-risk thresholds that can be used to split the training data into 2 risk groups (low- and high-risk) based on the final model's predictions. For each relative-risk threshold, the log-rank test p-value and an indicator of which threshold gives the most significant separation is provided.
- best_log_rank_threshold
For survival models only, a numeric value representing the relative-risk threshold that yields the most significant log-rank test when separating the training data into low- and high-risk groups.
See also
Other modeling functions:
tof_assess_model()
,
tof_create_grid()
,
tof_predict()
,
tof_split_data()
Examples
feature_tibble <-
dplyr::tibble(
sample = as.character(1:100),
cd45 = runif(n = 100),
pstat5 = runif(n = 100),
cd34 = runif(n = 100),
outcome = (3 * cd45) + (4 * pstat5) + rnorm(100),
class =
as.factor(
dplyr::if_else(outcome > median(outcome), "class1", "class2")
),
multiclass =
as.factor(
c(rep("class1", 30), rep("class2", 30), rep("class3", 40))
),
event = c(rep(0, times = 30), rep(1, times = 70)),
time_to_event = rnorm(n = 100, mean = 10, sd = 2)
)
split_data <- tof_split_data(feature_tibble, split_method = "simple")
# train a regression model
tof_train_model(
split_data = split_data,
predictor_cols = c(cd45, pstat5, cd34),
response_col = outcome,
model_type = "linear"
)
#> A linear `tof_model` with a mixture parameter (alpha) of 1 and a penalty parameter (lambda) of 3.162e-03
#> # A tibble: 3 × 2
#> feature coefficient
#> <chr> <dbl>
#> 1 (Intercept) 3.36
#> 2 pstat5 1.15
#> 3 cd45 0.597
# train a logistic regression classifier
tof_train_model(
split_data = split_data,
predictor_cols = c(cd45, pstat5, cd34),
response_col = class,
model_type = "two-class"
)
#> A two-class `tof_model` with a mixture parameter (alpha) of 1 and a penalty parameter (lambda) of 3.162e-03
#> # A tibble: 4 × 2
#> feature coefficient
#> <chr> <dbl>
#> 1 pstat5 -2.73
#> 2 cd45 -1.37
#> 3 cd34 0.309
#> 4 (Intercept) -0.0126
# train a cox regression survival model
tof_train_model(
split_data = split_data,
predictor_cols = c(cd45, pstat5, cd34),
time_col = time_to_event,
event_col = event,
model_type = "survival"
)
#> A survival `tof_model` with a mixture parameter (alpha) of 1 and a penalty parameter (lambda) of 1e+00
#> # A tibble: 0 × 2
#> # ℹ 2 variables: feature <chr>, coefficient <dbl>