Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Use two colors to color different rows in seaborn heatmap split the rows into two

I have the following dataframe:

fruits={'fruit':['apple1','apple2','banana1','banan2','peach1','peach2'],'1':[0,0,0,1,0,1],'2':[1,1,0,1,1,1],'3':[1,1,1,1,0,0],'4':[0,1,1,1,1,1]}
df_fruits=pd.DataFrame(data=fruits)
df_fruits=df_fruits.set_index('fruit')


>>>     1   2   3   4
fruit               
apple1  0   1   1   0
apple2  0   1   1   1
banana1 0   0   1   1
banan2  1   1   1   1
peach1  0   1   0   1
peach2  1   1   0   1

I'm Trying to create some kind of heatmap so if value is 1 it will get color and if is zero will get color grey.In addition to that, and here is the problem, I want to give all the fruits with number one color blue and all the fruits with number two color green. I have tried to use the script as mentioned here but I get white lines on the cells in undesired locations that divide each row into two:

N_communities = df_fruits.index.size
N_cols = df_fruits.columns.size
cmaps = ['Blues','Greens','Blues','Greens','Blues','Greens','Blues','Greens','Blues','Greens','Blues','Greens','Blues','Greens','Blues','Greens','Blues','Greens','Blues','Greens','Blues','Greens','Blues','Greens','Blues','Greens','Blues','Greens','Blues','Greens']

fig, ax = plt.subplots(figsize=(10,8))

for i,((idx,row),cmap) in enumerate(zip(df_fruits.iterrows(), cmaps)):
    ax.imshow(np.vstack([row.values, row.values]), aspect='equal', extent=[-0.5,N_cols-0.5,i,i+1], cmap=cmap)
    for j,val in enumerate(row.values):
        vmin, vmax = row.agg(['min','max'])
        vmid = (vmax-vmin)/2
        #if not np.isnan(val):
            #ax.annotate(val, xy=(j,i+0.5), ha='center', va='center', color='black' if (val<=vmid or vmin==vmax) else 'white')
ax.set_ylim(0,N_communities)

ax.set_xticks(range(N_cols))
ax.set_xticklabels(df_fruits.columns, rotation=90, ha='center')

ax.set_yticks(0.5+np.arange(N_communities))
ax.set_yticklabels(df_fruits.index)
ax.set_ylabel('Index')
ax.hlines([2,4],color="black" ,*ax.get_xlim())
ax.invert_yaxis()

fig.tight_layout()

enter image description here

As you can see, it looks like apple 1 has two rows and apple 2 has two rows and etc., while I want to have one row per each. I have tried to play with the extent but could not ger rid of those lines.

My end goal - to have one rows in the heatmap for each row in the dataframe, when fruit finishes with 1 are blue, fruits finishes with 2 are green (only if value is 1) . if value is zero it will be grey.

Edit: I have used the the ax.grid(False) as suggested but is still not goood as the lines dissapear. also I found out that the plotting is wrong: enter image description here

as you can see, the row "banana2" suppose to get green color but is white.

like image 632
Reut Avatar asked Sep 06 '25 10:09

Reut


1 Answers

You can use the mask option of sns.heatmap:

mask: If passed, data will not be shown in cells where mask is True. Cells with missing values are automatically masked.

So, to plot the blue fruit1 squares, mask out the fruit2 values and vice versa.

The fruit1/fruit2 heatmaps can be plotted together by saving the axes handle ax and reusing it with ax=ax:

import pandas as pd
import seaborn as sns

fruits = {'fruit':['apple1','apple2','banana1','banana2','peach1','peach2'],'1':[0,0,0,1,0,1],'2':[1,1,0,1,1,1],'3':[1,1,1,1,0,0],'4':[0,1,1,1,1,1]}
df_fruits = pd.DataFrame(data=fruits)
df_fruits = df_fruits.set_index('fruit')

# *** this line is needed for seaborn 0.10.1 (not needed for 0.11.1) ***
df_fruits = df_fruits.astype('float')

# common settings: linewidths for grid lines, hide colorbar, set square aspect
kwargs = dict(linewidths=1, cbar=False, square=True)

# plot initial gray squares and save heatmap handle as ax
ax = sns.heatmap(df_fruits, cmap='Greys_r', alpha=0.2, **kwargs)

# iterate ending:cmap pairs
cmaps = {'1': 'Blues_r', '2': 'Greens_r'}
for ending, cmap in cmaps.items():
    
    # create mask for given fruit ending
    mask = df_fruits.apply(
        lambda x: x if x.name.endswith(ending) else 0,
        result_type='broadcast',
        axis=1,
    ).eq(0)
    
    # plot masked heatmap on reusable ax
    sns.heatmap(df_fruits, mask=mask, cmap=cmap, ax=ax, **kwargs)

fruit heatmap

like image 86
tdy Avatar answered Sep 10 '25 12:09

tdy