Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How to plot the graph in python like varImpPlot() method plots in R ,for plotting the important variables in Random forest?

I have data containing around 370 features ,and I have built a random forest model to get the important features ,but when I plot I am not able to figure out the features to be considered since 370 features looks very clumsy in the x-axis.

Can anyone help me to plot the graph in python, like the graph plotted by varImpPlot() plots in R .

like image 251
ashwin g Avatar asked Sep 01 '25 06:09

ashwin g


1 Answers

In the randomForest package in R, varImpPlot() plots the top 30 variables with highest importance, you can do likewise in python, using an example from sklearn help page:

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.datasets import make_classification
from sklearn.ensemble import RandomForestClassifier

X, y = make_classification(n_samples=1000,
                           n_features=370,
                           n_informative=16,
                           n_classes=2,
                           random_state=0)

forest = RandomForestClassifier(random_state=0)
forest.fit(X, y)

To plot it, we can put the importance scores into a pd.Series and plot top 30:

importances = pd.Series(forest.feature_importances_,index=X.columns)
importances = importances.sort_values()
importances[-30:].plot.barh()

enter image description here

like image 140
StupidWolf Avatar answered Sep 02 '25 19:09

StupidWolf