Tidymodels hyperparameter tuning with targets dynamic branching

Dynamic branching with tune_grid()

Machine learning model parameters

At it’s core machine learning is a collection of statistical and mathematical techniques to analyze and draw inferences from patterns in data. The way machine learning models ‘learn’ is by adjusting sets of parameters in order to minimize error and maximize predictive performance. For example, in linear regression, the parameters are the regression coefficients \(\beta\) and the intercept. For neural networks, the parameters are the weights and biases of each neron. In tree based models, the parameters include the collection of tree topologies and their associated leaf weights.

Parameters vs hyper-parameters

While model parameters are learned from the data during training, hyper-parameters control how models learn from the data. For example for xgboost hyperparameters include \(\gamma\) a regularization parameter that controls the cost of tree complexity and \(\textbf{max\_depth}\) which limits how deep each tree can grow.

Hyper-parameters must be specified prior to model fitting, but how to chose the best ones? This is where cross-validation comes in. By splitting the data into training, validation, and testing datasets it’s possible to tune the hyper-parameters on a subset of the data and then fit and evaluate the final model performance against a hold-out data set.

Parallel hyper-parameter tuning

Hyper-parameter tuning is an important step in machine learning pipelines. It generally involves fitting a lot of models across a grid of different hyper-parameters in order to find the optimal set. Unfortunately fitting so many models takes a lot computational resources and time. Fortunately the problem is what is commonly known as embarrassingly parallel. That means each of the models can be fit independently and in parallel.

Tidymodels tune_grid()

In the tidymodels package in R, hyper-parameter tuning is handled by the tune_grid() function. The tune_grid() function can handle parallelism using the future package in R. However there are some limitations, for example there isn’t an easy way to monitor progress. Another problem is that if one of the models stochastically fails it can disrupt the entire tuning process. Lastly it can be useful to control whether parallelization occurs over the cross-validation folds, over the list of hyper-parameter sets, or a cross of both.

Targets dynamic branching

Fitting a lot of models involves orchestrating many different processes - something that pipeline management tools like targets are specifically designed to do. In targets, fitting a lot of models at once can be accomplished using dynamic branching. Dynamic branching sets up a subprocess for each model. Targets can track the progress of branches as they complete using tar_watch() and tar_poll() and any failed branches can be re-run after tuning completes by setting error = "null" within the dynamically branched target. The pattern argument in the target can also be used to specify how branching occurs - mapping over the cross validation folds, the hyper-parameters, or using cross() to fit a model to every combination.

targets and tidymodels

Below is an example targets based tidymodels pipeline. It uses two custom functions, tune_grid_branch() and select_best_params, within a dynamically branched target to tune the hyper-parameters of a Bayesian Additive Regression Trees (BART) model. The key target is bart_tuned which is set to branch over every combination of cross-validation fold and hyper-parameter set combination via pattern = cross(training_data_folds, bart_gridsearch)\.

  tar_target(analysis_recipe, recipe(paste(response_variable, "~ .") |>
as.formula(), data = analysis_data_train) |>
               step_naomit() |>
               step_string2factor(all_string()) |>
               step_novel(all_nominal(), -all_outcomes()) |>
               step_dummy(all_nominal(), -all_outcomes()) |>
               step_zv(all_predictors())),
               
# Set up the BART model
  tar_target(bart_model, 
             parsnip::bart(trees = tune(),
                           prior_terminal_node_coef = tune(),
                           prior_terminal_node_expo = tune()) |> 
               set_engine("dbarts") |>
               set_mode("classification")),
  
  # Set up the BART model workflow
  tar_target(bart_workflow, workflow() |> 
               add_recipe(analysis_recipe) |> 
               add_model(bart_model)),
  
  # Set up the hyper-parameter grid search.
  # Automatically extract the parameters to tune across.
  tar_target(bart_gridsearch, bart_workflow |> 
               extract_parameter_set_dials() |>
               dials::grid_latin_hypercube(size = 10)),
  
  # Tune the model
  tar_target(bart_tuned, tune_grid_branch(workflow = bart_workflow,
                                           gridsearch_params = bart_gridsearch,
                                           training_data_folds = training_data_folds),
             pattern = cross(training_data_folds, bart_gridsearch))
             
  # Extract the best set of hyper-parameters
  tar_target(bart_best_params, select_best_params(bart_tuned, metric = "roc_auc")),

tune_grid_branch()

Below is my implementation of a dynamic branch friendly tune_grid() function. In addition to the benefits described above it also tracks both fit time and the amount of memory required for each model fit. With targets it’s also possible to just fit the first few fold and hyper-parameter combinations in order to profile your model tuning workflow. This can be accomplished by setting pattern = head(cross(training_data_folds, bart_gridsearch), 5) which will fit the first 5 models only. You can then get a sense of how long the fitting will take and how much resources each branch will require. The best part is that when you’re ready to do the full run you can remove the head() part of the command and targets won’t need to re-run those first 5 models. I’m certain there’s room to improve this but at least to me it seems like a very powerful combination.

#' Grid search for a tidymodels workflow using targets dynamic branching
#'
#' This function performs grid search tuning for a machine learning workflow 
#' using cross-validation. It iterates over provided folds and grid search 
#' parameters and computes specified evaluation metrics (e.g., AUC, F1 score) 
#' and profiles memory usage and timing for each model fit.
#' 
#' @author Nathan Layman
#'
#' @param workflow # A tidymodels workflow with recipe and model already attached
#' @param gridsearch_params # A tibble where each row is a set of hyperparameters
#' @param training_data_folds # A tibble where each row is a training data fold
#' @param metrics # A set of metrics produced by `yardstick::metric_set`
#'
#' @return Returns a tibble with fit performance metrics, fit time, and the ram used while fitting
#'
#' @examples
#' \dontrun{
#' performance <- tune_grid_branch(workflow, gridsearch_params, training_data_folds, verbose = TRUE)
#' }
#'
#' @export
tune_grid_branch <- function(workflow, 
                             gridsearch_params, 
                             training_data_folds, 
                             metrics = metric_set(pr_auc,           # Precision-Recall AUC
                                                  roc_auc,          # ROC AUC
                                                  accuracy,         # Accuracy
                                                  f_meas,           # F1 Score
                                                  recall,           # Recall
                                                  precision),
                             verbose = F) {
  
  # Get the performance and profiling metrics of every combination of 
  # data fold and hyper-parameter combination passed in to the function
  performance <- map_dfr(1:nrow(training_data_folds), function(i) {
    map_dfr(1:nrow(gridsearch_params), function(j) {
      
      rsamp <- rsample::manual_rset(training_data_folds[i,]$splits, training_data_folds[i,]$id)
      params <- gridsearch_params[j,]
      if(verbose) print(rsamp |> bind_cols(params) |> select(-splits))
      
      # Grab start time
      start_time <- Sys.time()
      
      # Fit the model against the training data and profile memory usage
      # Using tune_grid here but we could make this simpler and just
      # fit and evaluate the model manually
      mem_usage_bytes <- profmem::profmem({
        fold_param <- tune::tune_grid(workflow,
                                      resamples = rsamp,
                                      grid = params,
                                      metrics = metrics)
      })
      
      # Report fit performance metrics
      fold_param |> select(-splits) |>
        mutate(id = rsamp$id,
               branch = targets::tar_name(),
               mem_usage_bytes = sum(mem_usage_bytes$bytes, na.rm=T),
               fit_time = start_time - Sys.time())
    })
  })
  
  # Clean up environment in case targets tries to store extra stuff
  rm(list=setdiff(ls(), "performance"))
  
  # Return performance
  performance
}

select_best_params()

And here is how you select the best set of parameters from the tibble returned by tune_grid_branch. This should seamlessly fit back into a tidymodels pipeline for fitting the final model, extracting performance metrics against the hold-out test dataset, and performing variable importance (e.g. DALEX).

#' Select Best Parameters from Tuned Model Results
#'
#' This function extracts the best parameters from a tuned model's metrics based on a specified evaluation metric. 
#' It calculates the average of the specified metric across tuning folds and selects the parameters with the 
#' minimum value of the specified metric (e.g., "roc_auc"). Unnecessary columns such as splits, IDs, and memory usage are removed.
#'
#' @author Nathan Layman
#'
#' @param tuned A tibble containing the results of the tuning process, including the model metrics.
#' @param metric A character string specifying the evaluation metric to be used for selecting the best parameters. 
#' The default is `"roc_auc"`.
#'
#' @return A tibble containing the best parameters, excluding unnecessary columns such as `.estimate`, `mem_usage`, and any matching branches.
#' 
#' @details The function first unnests the `.metrics` column of the `tuned` tibble, filters by the selected metric, 
#' and calculates the mean of the evaluation metric for each set of parameters. It then selects the parameters 
#' that minimize the metric, without ties.
#'
#' @examples
#' \dontrun{
#' # Example usage:
#' best_params <- select_best_params(tuned_model_results, metric = "roc_auc")
#' }
#'
#' @export
select_best_params <- function(tuned, metric = "roc_auc") {
  
  best_params <- tuned |> 
    unnest(.metrics) |> 
    filter(.metric == metric) |> 
    select(-splits, -id, -starts_with("."), .estimate) |>
    group_by(across(-.estimate)) |>
    summarize(.estimate = mean(.estimate), .groups = "drop") |>
    slice_min(.estimate, with_ties = F) |>
    select(-.estimate, -starts_with("mem_usage"), -matches("branch"))
  
  return(best_params)
}
Tags: data R tricks
Share: LinkedIn