Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Optimization of matplotlib stem plot

I'm trying to produce a Stem plot using the 'matplotlib.pyplot.stem' function. The code works but it is taking over 5 minutes to process.

I have a similar code within Matlab that produces the same plot with the same input data almost instantly.

Is there a way to optimize this code for speed or a better function I could be using?

The arguments for the stem plot 'H' and 'plotdata' are 16384 x 1 arrays.

def stemplot():

    import numpy as np
    from scipy.fftpack import fft
    import matplotlib.pyplot as plt

    ################################################
    # Code to set up the plot data

    N=2048
    dr = 100

    k = np.arange(0,N)

    cos = np.cos
    pi = np.pi

    w = 1-1.932617*cos(2*pi*k/(N-1))+1.286133*cos(4*pi*k/(N-1))-0.387695*cos(6*pi*k/(N-1))+0.0322227*cos(8*pi*k/(N-1))

    y = np.concatenate([w, np.zeros((7*N))])

    H = abs(fft(y, axis = 0))
    H = np.fft.fftshift(H)
    H = H/max(H)
    H = 20*np.log10(H)
    H = dr+H 
    H[H < 0] = 0        # Set all negative values in dr+H to 0

    plotdata = ((np.arange(1,(8*N)+1,1))-1-4*N)/8
    #################################################

    # Plotting Code

    plt.figure
    plt.stem(plotdata,H,markerfmt = " ")

    plt.axis([(-4*N)/8, (4*N)/8, 0, dr])    
    plt.grid()
    plt.ylabel('decibels')
    plt.xlabel('DFT bins')
    plt.title('Frequency response (Flat top)')
    plt.show()


    return

Here is also the Matlab code for reference:

N=2048;
dr = 100;
k=0:N-1

w = 1 - 1.932617*cos(2*pi*k/(N-1)) + 1.286133*cos(4*pi*k/(N-1)) -0.387695*cos(6*pi*k/(N-1)) +0.0322227*cos(8*pi*k/(N-1));

H = abs(fft([w zeros(1,7*N)]));
H = fftshift(H);
H = H/max(H);
H = 20*log10(H);
H = max(0,dr+H); % Sets negative numbers in dr+H to 0


figure
stem(([1:(8*N)]-1-4*N)/8,H,'-');
set(findobj('Type','line'),'Marker','none','Color',[.871 .49 0])
xlim([-4*N 4*N]/8)
ylim([0 dr])
set(gca,'YTickLabel','-100|-90|-80|-70|-60|-50|-40|-30|-20|-10|0')
grid on
ylabel('decibels')
xlabel('DFT bins')
title('Frequency response (Flat top)')
like image 368
Steven Goddard Avatar asked Apr 23 '26 04:04

Steven Goddard


2 Answers

You can simulate a stem plot in the format you desire using ax.vlines. Writing a small function,

def make_stem(ax, x, y, **kwargs):
    ax.axhline(x[0],x[-1],0, color='r')

    ax.vlines(x, 0, y, color='b')

    ax.set_ylim([1.05*y.min(), 1.05*y.max()])

And then altering the the relevant lines in your example as follows:

    # Plotting Code

##    plt.figure
##    plt.stem(plotdata,H,markerfmt = " ")

##    plt.axis([(-4*N)/8, (4*N)/8, 0, dr])

    fig, ax = plt.subplots()
    make_stem(ax, plotdata, H)

produces the plot more or less instantly. I don't, however, know whether this is faster or slower than the answer of @ImportanceOfBeingErnest.

like image 57
Thomas Kühn Avatar answered Apr 25 '26 17:04

Thomas Kühn


There seems to be no need for a stem plot here, since the markers are anyway made invsibible and would not make sense due to the large number of points.

Instead the use of a LineCollection may make sense. This is how matplotlib will do it in a future version anyways - see this PR. The code below runs within 0.25 seconds for me. (This is still slightly longer than using plot, due to the large number of lines.)

import numpy as np
from scipy.fftpack import fft
import matplotlib.pyplot as plt
import time
import matplotlib.collections as mcoll

N=2048
k = np.arange(0,N)
dr = 100

cos = np.cos
pi = np.pi

w = 1-1.932617*cos(2*pi*k/(N-1))+1.286133*cos(4*pi*k/(N-1))-0.387695*cos(6*pi*k/(N-1))+0.0322227*cos(8*pi*k/(N-1))

y = np.concatenate([w, np.zeros((7*N))])

H = abs(fft(y, axis = 0))
H = np.fft.fftshift(H)
H = H/max(H)
H = 20*np.log10(H)
H = dr+H 
H[H < 0] = 0        # Set all negative values in dr+H to 0

plotdata = ((np.arange(1,(8*N)+1,1))-1-4*N)/8


lines = []
for thisx, thisy in zip(plotdata,H):
    lines.append(((thisx, 0), (thisx, thisy)))
stemlines = mcoll.LineCollection(lines, linestyles="-",
                    colors="C0", label='_nolegend_')
plt.gca().add_collection(stemlines)


plt.axis([(-4*N)/8, (4*N)/8, 0, dr])    
plt.grid()
plt.ylabel('decibels')
plt.xlabel('DFT bins')
plt.title('Frequency response (Flat top)')

plt.show()
like image 23
ImportanceOfBeingErnest Avatar answered Apr 25 '26 17:04

ImportanceOfBeingErnest



Donate For Us

If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!