Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Add annotation to specific cells in heatmap

I am plotting a seaborn heatmap and would like to annotate only the specific cells with custom text.

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from io import StringIO

data = StringIO(u'''75,83,41,47,19
                    51,24,100,0,58
                    12,94,63,91,7
                    34,13,86,41,77''')

labels = StringIO(u'''7,8,4,,1
                    5,2,,2,8
                    1,,6,,7
                    3,1,,4,7''')

data = pd.read_csv(data, header=None)
data = data.apply(pd.to_numeric)

labels = pd.read_csv(labels, header=None)
#labels = np.ma.masked_invalid(labels)

fig, ax = plt.subplots()
sns.heatmap(data, annot=labels, ax=ax, vmin=0, vmax=100)
plt.show()

The above code generates the following heatmap:

heatmap with nan values

and the commented line generates the following heatmap:

heatmap with 0 values

I would like to show only the non-nan (or non-zero) text on the cells. How can that be achieved?

like image 415
Saad Avatar asked Oct 13 '25 01:10

Saad


1 Answers

Use a string array for annot instead of a masked array:

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from io import StringIO

data = StringIO(u'''75,83,41,47,19
                    51,24,100,0,58
                    12,94,63,91,7
                    34,13,86,41,77''')

labels = StringIO(u'''7,8,4,,1
                    5,2,,2,8
                    1,,6,,7
                    3,1,,4,7''')

data = pd.read_csv(data, header=None)
data = data.apply(pd.to_numeric)

labels = pd.read_csv(labels, header=None)
#labels = np.ma.masked_invalid(labels)

# Convert everything to strings:
annotations = labels.astype(str)
annotations[np.isnan(labels)] = ""

fig, ax = plt.subplots()
sns.heatmap(data, annot=annotations, fmt="s", ax=ax, vmin=0, vmax=100)
plt.show()

output

like image 148
mrzo Avatar answered Oct 14 '25 13:10

mrzo