I have this problem that I want to plot a data distribution where some values occur frequently while others are quite rare. The number of points in total is around 30.000. Rendering such a plot as png or (god forbid) pdf takes forever and the pdf is much too large to display.
So I want to subsample the data just for the plots. What I would like to achieve is to remove a lot of points where they overlap (where the density is high), but keep the ones where the density is low with almost probability 1.
Now, numpy.random.choice
allows one to specify a vector of probabilities, which I've computed according to the data histogram with a few tweaks. But I can't seem to get my choice so that the rare points are really kept.
I've attached an image of the data; the right tail of the distribution has orders of magnitude fewer points, so I'd like to keep those. The data is 3d, but the density comes from only one dimension, so I can use that as a measure for how many points are in a given location
Consider the following function. It will bin the data in equal bins along the axis and
This allows to keep the original data in regions of low density, but significantly reduce the amount of data to plot in regions of high density. At the same time all the features are preserved with a sufficiently dense binning.
import numpy as np; np.random.seed(42)
def filt(x,y, bins):
d = np.digitize(x, bins)
xfilt = []
yfilt = []
for i in np.unique(d):
xi = x[d == i]
yi = y[d == i]
if len(xi) <= 2:
xfilt.extend(list(xi))
yfilt.extend(list(yi))
else:
xfilt.extend([xi[np.argmax(yi)], xi[np.argmin(yi)]])
yfilt.extend([yi.max(), yi.min()])
# prepend/append first/last point if necessary
if x[0] != xfilt[0]:
xfilt = [x[0]] + xfilt
yfilt = [y[0]] + yfilt
if x[-1] != xfilt[-1]:
xfilt.append(x[-1])
yfilt.append(y[-1])
sort = np.argsort(xfilt)
return np.array(xfilt)[sort], np.array(yfilt)[sort]
To illustrate the concept let's use some toy data
x = np.array([1,2,3,4, 6,7,8,9, 11,14, 17, 26,28,29])
y = np.array([4,2,5,3, 7,3,5,5, 2, 4, 5, 2,5,3])
bins = np.linspace(0,30,7)
Then calling xf, yf = filt(x,y,bins)
and plotting both the original data and the filtered data gives:
The usecase of the question with some 30000 datapoints would be shown in the following. Using the presented technique would allow to reduce the number of plotted points from 30000 to some 500. This number will of course depend on the binning in use - here 300 bins. In this case the function takes ~10 ms to compute. This is not super-fast, but still a large improvement compared to plotting all the points.
import matplotlib.pyplot as plt
# Generate some data
x = np.sort(np.random.rayleigh(3, size=30000))
y = np.cumsum(np.random.randn(len(x)))+250
# Decide for a number of bins
bins = np.linspace(x.min(),x.max(),301)
# Filter data
xf, yf = filt(x,y,bins)
# Plot results
fig, (ax1, ax2, ax3) = plt.subplots(nrows=3, figsize=(7,8),
gridspec_kw=dict(height_ratios=[1,2,2]))
ax1.hist(x, bins=bins)
ax1.set_yscale("log")
ax1.set_yticks([1,10,100,1000])
ax2.plot(x,y, linewidth=1, label="original data, {} points".format(len(x)))
ax3.plot(xf, yf, linewidth=1, label="binned min/max, {} points".format(len(xf)))
for ax in [ax2, ax3]:
ax.legend()
plt.show()
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With