Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Tidymodels: How to extract importance from training data

I have the following code, where I do some grid search for different mtry and min_n. I know how to extract the parameters that give the highest accuracy (see second code box). How can I extract the importance of each feature in the training dataset? The guides I found online show how to do it only in the test dataset using "last_fit". E.g. of guide: https://www.tidymodels.org/start/case-study/#data-split

set.seed(seed_number)
    data_split <- initial_split(node_strength,prop = 0.8,strata = Group)
    
    train <- training(data_split)
    test <- testing(data_split)
    train_folds <- vfold_cv(train,v = 10)
    
    
    rfc <- rand_forest(mode = "classification", mtry = tune(),
                       min_n = tune(), trees = 1500) %>%
        set_engine("ranger", num.threads = 48, importance = "impurity")
    
    rfc_recipe <- recipe(data = train, Group~.)
    
    rfc_workflow <- workflow() %>% add_model(rfc) %>%
        add_recipe(rfc_recipe)
    
    rfc_result <- rfc_workflow %>%
        tune_grid(train_folds, grid = 40, control = control_grid(save_pred = TRUE),
                  metrics = metric_set(accuracy))

.

best <- 
        rfc_result %>% 
        select_best(metric = "accuracy")
like image 402
Orestis Avatar asked Jan 23 '26 21:01

Orestis


1 Answers

To do this, you will want to create a custom extract function, as outlined in this documentation.

For random forest variable importance, your function will look something like this:

get_rf_imp <- function(x) {
    x %>% 
        extract_fit_parsnip() %>% 
        vip::vi()
}

And then you can apply it to your resamples like so (notice that you get a new .extracts column):

library(tidymodels)
data(cells, package = "modeldata")

set.seed(123)
cell_split <- cells %>% select(-case) %>%
    initial_split(strata = class)
cell_train <- training(cell_split)
cell_test  <- testing(cell_split)
folds <- vfold_cv(cell_train)            

rf_spec <- rand_forest(mode = "classification") %>%
    set_engine("ranger", importance = "impurity")

ctrl_imp <- control_grid(extract = get_rf_imp)

cells_res <-
    workflow(class ~ ., rf_spec) %>%
    fit_resamples(folds, control = ctrl_imp)
cells_res
#> # Resampling results
#> # 10-fold cross-validation 
#> # A tibble: 10 × 5
#>    splits             id     .metrics         .notes           .extracts       
#>    <list>             <chr>  <list>           <list>           <list>          
#>  1 <split [1362/152]> Fold01 <tibble [2 × 4]> <tibble [0 × 3]> <tibble [1 × 2]>
#>  2 <split [1362/152]> Fold02 <tibble [2 × 4]> <tibble [0 × 3]> <tibble [1 × 2]>
#>  3 <split [1362/152]> Fold03 <tibble [2 × 4]> <tibble [0 × 3]> <tibble [1 × 2]>
#>  4 <split [1362/152]> Fold04 <tibble [2 × 4]> <tibble [0 × 3]> <tibble [1 × 2]>
#>  5 <split [1363/151]> Fold05 <tibble [2 × 4]> <tibble [0 × 3]> <tibble [1 × 2]>
#>  6 <split [1363/151]> Fold06 <tibble [2 × 4]> <tibble [0 × 3]> <tibble [1 × 2]>
#>  7 <split [1363/151]> Fold07 <tibble [2 × 4]> <tibble [0 × 3]> <tibble [1 × 2]>
#>  8 <split [1363/151]> Fold08 <tibble [2 × 4]> <tibble [0 × 3]> <tibble [1 × 2]>
#>  9 <split [1363/151]> Fold09 <tibble [2 × 4]> <tibble [0 × 3]> <tibble [1 × 2]>
#> 10 <split [1363/151]> Fold10 <tibble [2 × 4]> <tibble [0 × 3]> <tibble [1 × 2]>

Created on 2022-06-19 by the reprex package (v2.0.1)

Once you have those variable importance score extracts, you can unnest() them (right now, you have to do this twice because it is deeply nested) and then you can summarize and visualize as you prefer:

cells_res %>%
    select(id, .extracts) %>%
    unnest(.extracts) %>%
    unnest(.extracts) %>%
    group_by(Variable) %>%
    summarise(Mean = mean(Importance),
              Variance = sd(Importance)) %>%
    slice_max(Mean, n = 15) %>%
    ggplot(aes(Mean, reorder(Variable, Mean))) +
    geom_crossbar(aes(xmin = Mean - Variance, xmax = Mean + Variance)) +
    labs(x = "Variable importance", y = NULL)

Created on 2022-06-19 by the reprex package (v2.0.1)

like image 180
Julia Silge Avatar answered Jan 25 '26 15:01

Julia Silge



Donate For Us

If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!