Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Plotting top n features using permutation importance

import matplotlib.pyplot as plt
import numpy as np
from sklearn.datasets import fetch_openml
from sklearn.ensemble import RandomForestClassifier
from sklearn.impute import SimpleImputer
from sklearn.inspection import permutation_importance
from sklearn.compose import ColumnTransformer
from sklearn.model_selection import train_test_split
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import OneHotEncoder


result = permutation_importance(rf,
                                X_test,
                                y_test,
                                n_repeats=10,
                                random_state=42,
                                n_jobs=2)
sorted_idx = result.importances_mean.argsort()
        

fig, ax = plt.subplots()
ax.boxplot(result.importances[sorted_idx].T,
           vert=False,
           labels=X_test.columns[sorted_idx])

ax.set_title("Permutation Importances (test set)")
fig.tight_layout()
plt.show()

In the code above, taken from this example in the documentation, is there a way to plot the top 3 features only instead of all the features?

like image 370
user308827 Avatar asked Oct 28 '25 04:10

user308827


1 Answers

argsort "returns the indices that would sort an array," so here sorted_idx contains the feature indices in order of least to most important. Since you just want the 3 most important features, take only the last 3 indices:

sorted_idx = result.importances_mean.argsort()[-3:]
# array([4, 0, 1])

Then the plotting code can remain as is, but now it will only plot the top 3 features:

# unchanged
fig, ax = plt.subplots(figsize=(6, 3))
ax.boxplot(result.importances[sorted_idx].T,
           vert=False, labels=X_test.columns[sorted_idx])
ax.set_title("Permutation Importances (test set)")
fig.tight_layout()
plt.show()


Note that if you prefer to leave sorted_idx untouched (e.g., to use the full indices elsewhere in the code),

  • either change sorted_idx to sorted_idx[-3:] inline:

    sorted_idx = result.importances_mean.argsort() # unchanged
    
    ax.boxplot(result.importances[sorted_idx[-3:]].T, # replace sorted_idx with sorted_idx[-3:]
               vert=False, labels=X_test.columns[sorted_idx[-3:]]) # replace sorted_idx with sorted_idx[-3:]
    
  • or store the filtered indices in a separate variable:

    sorted_idx = result.importances_mean.argsort() # unchanged
    top3_idx = sorted_idx[-3:] # store top 3 indices
    
    ax.boxplot(result.importances[top3_idx].T, # replace sorted_idx with top3_idx
               vert=False, labels=X_test.columns[top3_idx]) # replace sorted_idx with top3_idx
    
like image 137
tdy Avatar answered Oct 29 '25 18:10

tdy