Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Shap plot crops/truncates the feature names

import csv
import pandas as pd
import numpy as np
from matplotlib import pyplot 
import shap
from sklearn import preprocessing
from sklearn.preprocessing import StandardScaler
df1=pd.read_csv("./wine.data",sep=",",encoding='utf_8_sig')
X_train = df1
le = preprocessing.LabelEncoder()
X_train['alc_class'] = le.fit_transform(X_train.alc_class.values)
print(X_train.columns)

print(X_train.describe())


y = X_train['alc_class']
X = X_train.drop(columns='alc_class')
import xgboost as xgb


# split X and y into training and testing sets

from sklearn.model_selection import train_test_split
from sklearn.model_selection import GridSearchCV, RandomizedSearchCV


X_train, X_test, y_train, y_test = train_test_split(X, y, test_size = 0.30, random_state = 2100, stratify = y)

# import XGBClassifier
import xgboost as xgb
from sklearn.metrics import mean_squared_error
DM_train = xgb.DMatrix(data = X_train, 
                       label = y_train)
                       
                       
DM_test =  xgb.DMatrix(data = X_test,
                       label = y_test)


xgb_param_grid = {
     'colsample_bytree': np.linspace(0.5, 0.9, 2),
     'n_estimators':[30],
     'max_depth': [5],
     'learning_rate':[0.01],
     'alpha':[10],
     'objective':['binary:logistic'],
     'tree_method':['hist'],
     'min_child_weight': [1],
     'gamma': [0.5],
     'subsample': [0.6],

}

          
# instantiate the classifier 
xgb_clf = xgb.XGBClassifier(use_label_encoder=False, eval_metric="auc")


# perform 5 fold cross-validation using mean square error as a scoring method
grid_mse = GridSearchCV(estimator = xgb_clf, param_grid = xgb_param_grid, scoring = 'neg_mean_squared_error', cv = 5, verbose = 1)

# Fit grid_mse to the data, get best parameters and best score (lowest RMSE)

grid_mse.fit(X_train, y_train)



print("Best parameters found: ",grid_mse.best_params_)
print("Lowest RMSE found: ", np.sqrt(np.abs(grid_mse.best_score_)))


#Predict using the test data


y_pred = grid_mse.predict(X_test)
y_pred_prob = grid_mse.predict_proba(X_test)


print("Root mean square error for test dataset: {}".format(np.round(np.sqrt(mean_squared_error(y_test, y_pred)), 2)))



from sklearn.metrics import accuracy_score, roc_curve, auc,recall_score,precision_score, precision_recall_curve,f1_score, classification_report, confusion_matrix,roc_auc_score


print('XGBoost model accuracy score: {0:0.4f}'. format(accuracy_score(y_test, y_pred)))
print('XGBoost model F1 score: {0:0.4f}'. format(f1_score(y_test, y_pred, average='weighted')))

precision, recall, thresholds = precision_recall_curve(y_test, y_pred)
area = auc(recall, precision)
print("----------------")
print("\n\n Evaluation Metrics \n\n")


aucroc_score = roc_auc_score(y_test, y_pred_prob[:,1])
print("Area Under ROC Curve: ",aucroc_score)
# roc curve for models
fpr, tpr, thresh = roc_curve(y_test, y_pred_prob[:,1], pos_label=1)

# roc curve for tpr = fpr 
random_probs = [0 for i in range(len(y_test))]
p_fpr, p_tpr, _ = roc_curve(y_test, random_probs, pos_label=1)


print("confusion_matrix ", confusion_matrix(y_test,y_pred))
print("classification_report ", classification_report(y_test,y_pred))


explainer = shap.TreeExplainer(grid_mse.best_estimator_)
shap_values = explainer(X_train)
shap.plots.beeswarm(shap_values, plot_size = 1.8, max_display = 13)


print(grid_mse.best_estimator_.feature_importances_)
for col,score in zip(X_train.columns,grid_mse.best_estimator_.feature_importances_):
    print('%s, %0.3f ' %(col,score))

enter image description here

  1. I have long feature names and I plot the beeswarm shapley plots and feature names get truncated. I would like the full feature name to be displayed on y-axis. Any help would be greatly appreciated.
  2. I have tried changing the plot size but it did not work.
like image 959
serenaarez Avatar asked Oct 14 '25 15:10

serenaarez


1 Answers

Add a flag to hide the plot. Then save to output with tight bbox layout:

path = 'save_path_here.png'
shap.plots.beeswarm(shap_values, plot_size = 1.8, max_display = 13, show=False)
plt.savefig(path, bbox_inches='tight', dpi=300)
like image 91
peterhunter Avatar answered Oct 17 '25 05:10

peterhunter



Donate For Us

If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!