How can we get legends for seaborn FacetGrid heatmaps? The .add_legend() method isn't working for me.
Using code from this previous question:
import pandas as pd
import numpy as np
import itertools
import seaborn as sns
print("seaborn version {}".format(sns.__version__))
# R expand.grid() function in Python
# https://stackoverflow.com/a/12131385/1135316
def expandgrid(*itrs):
product = list(itertools.product(*itrs))
return {'Var{}'.format(i+1):[x[i] for x in product] for i in range(len(itrs))}
methods=['method 1', 'method2', 'method 3', 'method 4']
times = range(0,100,10)
data = pd.DataFrame(expandgrid(methods, times, times))
data.columns = ['method', 'dtsi','rtsi']
data['nw_score'] = np.random.sample(data.shape[0])
def facet(data,color):
data = data.pivot(index="dtsi", columns='rtsi', values='nw_score')
g = sns.heatmap(data, cmap='Blues', cbar=False)
with sns.plotting_context(font_scale=5.5):
g = sns.FacetGrid(data, col="method", col_wrap=2, size=3, aspect=1)
g = g.map_dataframe(facet)
g.add_legend()
g.set_titles(col_template="{col_name}", fontweight='bold', fontsize=18)

What you want (in matplotlib lingo) is a colorbar, not a legend. In matplotlib, the former is used for continuous data, while the latter is used for categorical data. Colorbar support isn't built into FacetGrid, but it is not hard to expand your example code to add a colorbar:
import pandas as pd
import numpy as np
import itertools
import seaborn as sns
methods=['method 1', 'method2', 'method 3', 'method 4']
times = range(0, 100, 10)
data = pd.DataFrame(list(itertools.product(methods, times, times)))
data.columns = ['method', 'dtsi','rtsi']
data['nw_score'] = np.random.sample(data.shape[0])
def facet_heatmap(data, color, **kws):
data = data.pivot(index="dtsi", columns='rtsi', values='nw_score')
sns.heatmap(data, cmap='Blues', **kws) # <-- Pass kwargs to heatmap
with sns.plotting_context(font_scale=5.5):
g = sns.FacetGrid(data, col="method", col_wrap=2, size=3, aspect=1)
cbar_ax = g.fig.add_axes([.92, .3, .02, .4]) # <-- Create a colorbar axes
g = g.map_dataframe(facet_heatmap,
cbar_ax=cbar_ax,
vmin=0, vmax=1) # <-- Specify the colorbar axes and limits
g.set_titles(col_template="{col_name}", fontweight='bold', fontsize=18)
g.fig.subplots_adjust(right=.9) # <-- Add space so the colorbar doesn't overlap the plot

I've indicated the changes I made and the rationale for them as inline comments.
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With