Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Is there a way in matplotlib to check which artists are in the currently displayed area of the axes?

I have a program with an interactive figure where occasionally many artists are drawn. In this figure, you can also zoom and pan using the mouse. However, the performace during zooming an panning is not very good because every artist is always redrawn. Is there a way to check which artists are in the currently displayed area and only redraw those? (In the example below the perfomace is still relatively good, but it can be made arbitrarily worse by using more or more complex artists)

I had a similar performace problem with the hover method that whenever it was called it ran canvas.draw() at the end. But as you can see I found a neat workaround for that by making use of caching and restoring the background of the axes (based on this). This significantly improved the performace and now even with many artists it runs very smooth. Maybe there is a similar way of doing this but for the pan and zoom method?

Sorry for the long code sample, most of it is not directly relevant for the question but necessary for a working example to highlight the issue.

EDIT

I updated the MWE to something that is more representative of my actual code.

import numpy as np
import sys
import matplotlib.pyplot as plt
from matplotlib.backends.backend_qt5agg import \
    FigureCanvasQTAgg
import matplotlib.patheffects as PathEffects
from matplotlib.text import Annotation
from matplotlib.collections import LineCollection

from PyQt5.QtWidgets import QApplication, QVBoxLayout, QDialog


def check_limits(base_xlim, base_ylim, new_xlim, new_ylim):
    if new_xlim[0] < base_xlim[0]:
        overlap = base_xlim[0] - new_xlim[0]
        new_xlim[0] = base_xlim[0]
        if new_xlim[1] + overlap > base_xlim[1]:
            new_xlim[1] = base_xlim[1]
        else:
            new_xlim[1] += overlap
    if new_xlim[1] > base_xlim[1]:
        overlap = new_xlim[1] - base_xlim[1]
        new_xlim[1] = base_xlim[1]
        if new_xlim[0] - overlap < base_xlim[0]:
            new_xlim[0] = base_xlim[0]
        else:
            new_xlim[0] -= overlap
    if new_ylim[1] < base_ylim[1]:
        overlap = base_ylim[1] - new_ylim[1]
        new_ylim[1] = base_ylim[1]
        if new_ylim[0] + overlap > base_ylim[0]:
            new_ylim[0] = base_ylim[0]
        else:
            new_ylim[0] += overlap
    if new_ylim[0] > base_ylim[0]:
        overlap = new_ylim[0] - base_ylim[0]
        new_ylim[0] = base_ylim[0]
        if new_ylim[1] - overlap < base_ylim[1]:
            new_ylim[1] = base_ylim[1]
        else:
            new_ylim[1] -= overlap

    return new_xlim, new_ylim


class FigureCanvas(FigureCanvasQTAgg):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.bg_cache = None

    def draw(self):
        ax = self.figure.axes[0]
        hid_annotation = False
        if ax.annot.get_visible():
            ax.annot.set_visible(False)
            hid_annotation = True
        hid_highlight = False
        if ax.last_artist:
            ax.last_artist.set_path_effects([PathEffects.Normal()])
            hid_highlight = True
        super().draw()
        self.bg_cache = self.copy_from_bbox(self.figure.bbox)
        if hid_highlight:
            ax.last_artist.set_path_effects(
                [PathEffects.withStroke(
                    linewidth=7, foreground="c", alpha=0.4
                )]
            )
            ax.draw_artist(ax.last_artist)
        if hid_annotation:
            ax.annot.set_visible(True)
            ax.draw_artist(ax.annot)

        if hid_highlight:
            self.update()


def position(t_, coeff, var=0.1):
    x_ = np.random.normal(np.polyval(coeff[:, 0], t_), var)
    y_ = np.random.normal(np.polyval(coeff[:, 1], t_), var)

    return x_, y_


class Data:
    def __init__(self, times):
        self.length = np.random.randint(1, 20)
        self.t = np.sort(
            np.random.choice(times, size=self.length, replace=False)
        )
        self.vel = [np.random.uniform(-2, 2), np.random.uniform(-2, 2)]
        self.accel = [np.random.uniform(-0.01, 0.01), np.random.uniform(-0.01,
                                                                      0.01)]
        x0, y0 = np.random.uniform(0, 1000, 2)
        self.x, self.y = position(
            self.t, np.array([self.accel, self.vel, [x0, y0]])
        )


class Test(QDialog):
    def __init__(self):
        super().__init__()
        self.fig, self.ax = plt.subplots()
        self.canvas = FigureCanvas(self.fig)
        self.artists = []
        self.zoom_factor = 1.5
        self.x_press = None
        self.y_press = None
        self.annot = Annotation(
            "", xy=(0, 0), xytext=(-20, 20), textcoords="offset points",
            bbox=dict(boxstyle="round", fc="w", alpha=0.7), color='black',
            arrowprops=dict(arrowstyle="->"), zorder=6, visible=False,
            annotation_clip=False, in_layout=False,
        )
        self.annot.set_clip_on(False)
        setattr(self.ax, 'annot', self.annot)
        self.ax.add_artist(self.annot)
        self.last_artist = None
        setattr(self.ax, 'last_artist', self.last_artist)

        self.image = np.random.uniform(0, 100, 1000000).reshape((1000, 1000))
        self.ax.imshow(self.image, cmap='gray', interpolation='nearest')
        self.times = np.linspace(0, 20)
        for i in range(1000):
            data = Data(self.times)
            points = np.array([data.x, data.y]).T.reshape(-1, 1, 2)
            segments = np.concatenate([points[:-1], points[1:]], axis=1)
            z = np.linspace(0, 1, data.length)
            norm = plt.Normalize(z.min(), z.max())
            lc = LineCollection(
                segments, cmap='autumn', norm=norm, alpha=1,
                linewidths=2, picker=8, capstyle='round',
                joinstyle='round'
            )
            setattr(lc, 'data_id', i)
            lc.set_array(z)
            self.ax.add_artist(lc)
            self.artists.append(lc)
        self.default_xlim = self.ax.get_xlim()
        self.default_ylim = self.ax.get_ylim()

        self.canvas.draw()

        self.cid_motion = self.fig.canvas.mpl_connect(
            'motion_notify_event', self.motion_event
        )
        self.cid_button = self.fig.canvas.mpl_connect(
            'button_press_event', self.pan_press
        )
        self.cid_zoom = self.fig.canvas.mpl_connect(
            'scroll_event', self.zoom
        )

        layout = QVBoxLayout()
        layout.addWidget(self.canvas)
        self.setLayout(layout)

    def zoom(self, event):
        if event.inaxes == self.ax:
            scale_factor = np.power(self.zoom_factor, -event.step)
            xdata = event.xdata
            ydata = event.ydata
            cur_xlim = self.ax.get_xlim()
            cur_ylim = self.ax.get_ylim()
            x_left = xdata - cur_xlim[0]
            x_right = cur_xlim[1] - xdata
            y_top = ydata - cur_ylim[0]
            y_bottom = cur_ylim[1] - ydata

            new_xlim = [
                xdata - x_left * scale_factor, xdata + x_right * scale_factor
            ]
            new_ylim = [
                ydata - y_top * scale_factor, ydata + y_bottom * scale_factor
            ]
            # intercept new plot parameters if they are out of bounds
            new_xlim, new_ylim = check_limits(
                self.default_xlim, self.default_ylim, new_xlim, new_ylim
            )

            if cur_xlim != tuple(new_xlim) or cur_ylim != tuple(new_ylim):
                self.ax.set_xlim(new_xlim)
                self.ax.set_ylim(new_ylim)

                self.canvas.draw_idle()

    def motion_event(self, event):
        if event.button == 1:
            self.pan_move(event)
        else:
            self.hover(event)

    def pan_press(self, event):
        if event.inaxes == self.ax:
            self.x_press = event.xdata
            self.y_press = event.ydata

    def pan_move(self, event):
        if event.inaxes == self.ax:
            xdata = event.xdata
            ydata = event.ydata
            cur_xlim = self.ax.get_xlim()
            cur_ylim = self.ax.get_ylim()
            dx = xdata - self.x_press
            dy = ydata - self.y_press
            new_xlim = [cur_xlim[0] - dx, cur_xlim[1] - dx]
            new_ylim = [cur_ylim[0] - dy, cur_ylim[1] - dy]

            # intercept new plot parameters that are out of bound
            new_xlim, new_ylim = check_limits(
                self.default_xlim, self.default_ylim, new_xlim, new_ylim
            )

            if cur_xlim != tuple(new_xlim) or cur_ylim != tuple(new_ylim):
                self.ax.set_xlim(new_xlim)
                self.ax.set_ylim(new_ylim)

                self.canvas.draw_idle()

    def update_annot(self, event, artist):
        self.ax.annot.xy = (event.xdata, event.ydata)
        text = f'Data #{artist.data_id}'
        self.ax.annot.set_text(text)
        self.ax.annot.set_visible(True)
        self.ax.draw_artist(self.ax.annot)

    def hover(self, event):
        vis = self.ax.annot.get_visible()
        if event.inaxes == self.ax:
            ind = 0
            cont = None
            while (
                ind in range(len(self.artists))
                and not cont
            ):
                artist = self.artists[ind]
                cont, _ = artist.contains(event)
                if cont and artist is not self.ax.last_artist:
                    if self.ax.last_artist is not None:
                        self.canvas.restore_region(self.canvas.bg_cache)
                        self.ax.last_artist.set_path_effects(
                            [PathEffects.Normal()]
                        )
                        self.ax.last_artist = None
                    artist.set_path_effects(
                        [PathEffects.withStroke(
                            linewidth=7, foreground="c", alpha=0.4
                        )]
                    )
                    self.ax.last_artist = artist
                    self.ax.draw_artist(self.ax.last_artist)
                    self.update_annot(event, self.ax.last_artist)
                ind += 1

            if vis and not cont and self.ax.last_artist:
                self.canvas.restore_region(self.canvas.bg_cache)
                self.ax.last_artist.set_path_effects([PathEffects.Normal()])
                self.ax.last_artist = None
                self.ax.annot.set_visible(False)
        elif vis:
            self.canvas.restore_region(self.canvas.bg_cache)
            self.ax.last_artist.set_path_effects([PathEffects.Normal()])
            self.ax.last_artist = None
            self.ax.annot.set_visible(False)
        self.canvas.update()
        self.canvas.flush_events()


if __name__ == '__main__':
    app = QApplication(sys.argv)
    test = Test()
    test.show()
    sys.exit(app.exec_())
like image 799
mapf Avatar asked Jan 19 '26 02:01

mapf


1 Answers

You can find which artists are in the current area of the axes if you focus on the data the artists are plotting.

For example if you put your points data (a and b arrays) in a numpy array like this:

self.points = np.random.randint(0, 100, (1000, 2))

you can get the list of points inside the current x and y limits:

xmin, xmax = self.ax.get_xlim()
ymin, ymax = self.ax.get_ylim()

p = self.points

indices_of_visible_points = (np.argwhere((p[:, 0] > xmin) & (p[:, 0] < xmax) & (p[:, 1] > ymin) &  (p[:, 1] < ymax))).flatten()

you can use indices_of_visible_points to index your related self.artists list

like image 184
Guglie Avatar answered Jan 21 '26 17:01

Guglie



Donate For Us

If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!