Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How can I plot the label on the line of a lineplot?

I would like to plot labels on a line of a lineplot in matplotlib.

Minimal example

#!/usr/bin/env python
import numpy as np
import seaborn as sns
sns.set_style("whitegrid")
sns.set_palette(sns.color_palette("Greens", 8))
from scipy.ndimage.filters import gaussian_filter1d

for i in range(8):
    # Create data
    y = np.roll(np.cumsum(np.random.randn(1000, 1)),
                np.random.randint(0, 1000))
    y = gaussian_filter1d(y, 10)
    sns.plt.plot(y, label=str(i))
sns.plt.legend()
sns.plt.show()

generates

enter image description here

instead, I would prefer something like

enter image description here

like image 289
Martin Thoma Avatar asked Apr 23 '17 16:04

Martin Thoma


1 Answers

Maybe a bit hacky, but does this solve your problem?

#!/usr/bin/env python
import numpy as np
import seaborn as sns
sns.set_style("whitegrid")
sns.set_palette(sns.color_palette("Greens", 8))
from scipy.ndimage.filters import gaussian_filter1d

for i in range(8):
    # Create data
    y = np.roll(np.cumsum(np.random.randn(1000, 1)),
                np.random.randint(0, 1000))
    y = gaussian_filter1d(y, 10)
    p = sns.plt.plot(y, label=str(i))
    color = p[0].get_color()
    for x in [250, 500, 750]:
        y2 = y[x]
        sns.plt.plot(x, y2, 'o', color='white', markersize=9)
        sns.plt.plot(x, y2, 'k', marker="$%s$" % str(i), color=color,
                     markersize=7)
sns.plt.legend()
sns.plt.show()

Here's the result I get:

My Result

Edit: I gave it a little more thought and came up with a solution that automatically tries to find the best possible position for the labels in order to avoid the labels being positioned at x-values where two lines are very close to each other (which could e.g. lead to overlap between the labels):

#!/usr/bin/env python
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
sns.set_style("whitegrid")
sns.set_palette(sns.color_palette("Greens", 8))
from scipy.ndimage.filters import gaussian_filter1d

# -----------------------------------------------------------------------------

def inline_legend(lines, n_markers=1):
    """
    Take a list containing the lines of a plot (typically the result of 
    calling plt.gca().get_lines()), and add the labels for those lines on the
    lines themselves; more precisely, put each label n_marker times on the 
    line. 
    [Source of problem: https://stackoverflow.com/q/43573623/4100721]
    """

    import matplotlib.pyplot as plt
    from scipy.interpolate import interp1d
    from math import fabs

    def chunkify(a, n):
        """
        Split list a into n approximately equally sized chunks and return the 
        indices (start/end) of those chunks.
        [Idea: Props to http://stackoverflow.com/a/2135920/4100721 :)]
        """
        k, m = divmod(len(a), n)
        return list([(i * k + min(i, m), (i + 1) * k + min(i + 1, m)) 
                     for i in range(n)])

    # Calculate linear interpolations of every line. This is necessary to 
    # compare the values of the lines if they use different x-values
    interpolations = [interp1d(_.get_xdata(), _.get_ydata()) 
                      for _ in lines]


    # Loop over all lines
    for idx, line in enumerate(lines):

        # Get basic properties of the current line
        label = line.get_label()
        color = line.get_color()
        x_values = line.get_xdata()
        y_values = line.get_ydata()

        # Get all lines that are not the current line, as well as the
        # functions that are linear interpolations of them
        other_lines = lines[0:idx] + lines[idx+1:]
        other_functions = interpolations[0:idx] + interpolations[idx+1:]

        # Split the x-values in chunks to get regions in which to put 
        # labels. Creating 3 times as many chunks as requested and using only
        # every third ensures that no two labels for the same line are too
        # close to each other.
        chunks = list(chunkify(line.get_xdata(), 3*n_markers))[::3]

        # For each chunk, find the optimal position of the label
        for chunk_nr in range(n_markers):

            # Start and end index of the current chunk
            chunk_start = chunks[chunk_nr][0]
            chunk_end = chunks[chunk_nr][1]

            # For the given chunk, loop over all x-values of the current line,
            # evaluate the value of every other line at every such x-value,
            # and store the result.
            other_values = [[fabs(y_values[int(x)] - f(x)) for x in 
                             x_values[chunk_start:chunk_end]]
                            for f in other_functions]

            # Now loop over these values and find the minimum, i.e. for every
            # x-value in the current chunk, find the distance to the closest
            # other line ("closest" meaning abs_value(value(current line at x)
            # - value(other lines at x)) being at its minimum)
            distances = [min([_ for _ in [row[i] for row in other_values]]) 
                         for i in range(len(other_values[0]))]

            # Now find the value of x in the current chunk where the distance
            # is maximal, i.e. the best position for the label and add the
            # necessary offset to take into account that the index obtained
            # from "distances" is relative to the current chunk
            best_pos = distances.index(max(distances)) + chunks[chunk_nr][0]

            # Short notation for the position of the label
            x = best_pos
            y = y_values[x]

            # Actually plot the label onto the line at the calculated position
            plt.plot(x, y, 'o', color='white', markersize=9)
            plt.plot(x, y, 'k', marker="$%s$" % label, color=color,
                     markersize=7)

# -----------------------------------------------------------------------------

for i in range(8):
    # Create data
    y = np.roll(np.cumsum(np.random.randn(1000, 1)),
                np.random.randint(0, 1000))
    y = gaussian_filter1d(y, 10)
    sns.plt.plot(y, label=str(i))

inline_legend(plt.gca().get_lines(), n_markers=3)
sns.plt.show()

Example output of this solution (note how the x-positions of the labels are no longer all the same): Improved solution If one wants to avoid the use of scipy.interpolate.interp1d, one might consider a solution where for a given x-value of line A, one finds the x-value of line B that is closest to that. I think this might be problematic though if the lines use very different and/or sparse grids?

like image 142
der_herr_g Avatar answered Sep 28 '22 03:09

der_herr_g