Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How do I include feature names in the plot_tree function from the XGBoost library?

Tags:

python

xgboost

I've been using the XGBoost library to develop a binary classification model. Having trained my model I am interested in visualizing the individual trees to better understand my models predictions.

To do this XGBoost provides a plot_tree function but it only shows the integer index of the feature. Here is an example of one of my trees:

How do I include the feature name in this image rather than feature index (f28)?

like image 850
FChm Avatar asked Oct 20 '25 16:10

FChm


1 Answers

The plot_tree function in xgboost has an argument fmap which is a path to a 'feature map' file; this contains a mapping of the feature index to feature name.

The documentation on the feature map file is sparse, but it is a tab-delimited file where the first column is the feature indices (starting from 0 and ending at the number of features), the second column the feature name and the final column an indicator showing the type of feature (q=quantitative feature, i=binary feature).

An example of a feature_map.txt file:

0    feature_name_0    q
1    feature_name_1    i
2    feature_name_2    q
…          …           … 

With this tab-delimited file you can then plot your tree from your trained model instance:

import xgboost
model = xgboost.XGBClassifier()

# train the model
model.fit(X, y)

# plot the decision tree, providing path to feature map file

xgboost.plot_tree(model,  num_trees=0, fmap='feature_map.txt')

Using this function displays the plot:

like image 74
FChm Avatar answered Oct 23 '25 04:10

FChm



Donate For Us

If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!