Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Taking log transform for Z axis in sns.kdeplot

I am trying to take log transform on the z axis for the given plot so that i can see the data present in the graph where it is sparse. But i am not able to see how to do it.

sns.kdeplot( x , y , cmap='gist_gray_r', shade=True, shade_lowest=False)

Image from above command is attached below, (x, y are 2 lists of data). Can anyone help with how to take the log transform (i.e) apply log to the z axis.

Image for above command

like image 931
tushar shandhilya Avatar asked Dec 14 '25 14:12

tushar shandhilya


1 Answers

Seaborn's kdeplot won't be able to directly take a log into account. But you can manually call the kde, calculate the value over a grid and take the log. And then manually create a contour plot.

However, if the x and y arrays aren't extremely large, the kde won't have much meaningful information in the sparse areas. Everything will be smoothed out.

Another idea is a scatterplot. Using a comma marker (marker=',') and removing the dot edges (linewidth=0) very small dots can be drawn. Depending on the number of dots, a smaller alpha can be used. A scatter plot has the benefit of showing the data as they really are, instead of an arbitrary approximation via a kde.

The code below creates an example of the 3 approaches. X and Y axes are shared to illustrate the much larger area occupied by the logscale levels.

from matplotlib import pyplot as plt
import numpy as np
import seaborn as sns
from scipy.stats import gaussian_kde

x = np.random.normal(np.tile([0.02, 0.03, 0.04, 0.06], 10000), 0.01)
y = np.random.normal(np.tile([-0.5, 0.5, 1, 0.9], 10000), 0.4)

fig, axes = plt.subplots(ncols=3, figsize=(14, 4), sharex=True, sharey=True)

# create a regular kdeplot
sns.kdeplot(x, y, cmap='gist_gray_r', shade=True, shade_lowest=False, ax=axes[0])
axes[0].set_title('standard kdeplot')

# create a kdeplot with logscale levels
kde = gaussian_kde([x, y])
xmin, xmax = x.min() - 0.02, x.max() + 0.02
ymin, ymax = y.min() - 0.8, y.max() + 0.8
xs, ys = np.meshgrid(np.linspace(xmin, xmax, 50), np.linspace(ymin, ymax, 50))
z = kde([xs.ravel(), ys.ravel()]).reshape(xs.shape)
N = 10
levels = np.logspace(-6, np.log10(z.max()), N + 1)
cmap = plt.get_cmap('inferno_r', N)
axes[1].contourf(xs, ys, z, levels=levels, colors=[cmap((i + 1) / N) for i in range(N)], alpha=0.5)
axes[1].yaxis.set_tick_params(labelleft=True) # 'sharey' removes the ticks, here they are added again
axes[1].set_title('kdeplot with logscale levels')

# draw a scatter plot
sns.scatterplot(x, y, color='r', marker=',', linewidth=0, s=1, alpha=0.2, ax=axes[2])
axes[2].yaxis.set_tick_params(labelleft=True)
axes[2].set_title('scatterplot')
plt.show()

example plot

Another idea is to combine a regular kdeplot with a scatterplot for the faraway points:

ax = sns.kdeplot(x, y, cmap='gist_gray_r', shade=True, shade_lowest=False)
sns.scatterplot(x, y, color='grey', marker=',', linewidth=0, s=1, alpha=1, zorder=0, ax=ax)

combined plot

like image 184
JohanC Avatar answered Dec 16 '25 19:12

JohanC