Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How to create grid plot with inner subplots?

I have configured subplots of (5 x 1) format shown in Fig. 1 as given by Figure block A in the MWE. I am trying to repeat them n times such that they appear in a grid format with number of rows and columns given by the function fitPlots as mentioned here; to give output as shown in Fig. 2.

enter image description here Fig. 1 Initial plot

enter image description here Fig. 2 Repeated plot (desired output)

What would be the best way to repeat the code block to create a grid plot with inner subplots? The MWE creates multiple pages, I want all of them on a single page.

MWE

from matplotlib.backends.backend_pdf import PdfPages
import matplotlib.pyplot as plt
import numpy as np
import math

x = np.arange(1, 100, 0.2)
y_a = np.sqrt(x)
y_b = np.sin(x)
y_c = np.sin(x)
y_d = np.cos(x) * np.cos(x)
y_e = 1/x


########## Figure block A #####################
with PdfPages('./plot_grid.pdf') as plot_grid_loop:
    fig, (a, b, c, d, e) = plt.subplots(5, 1, sharex=True, gridspec_kw={'height_ratios': [5, 1, 1, 1, 1]})
    a.plot(x, y_a)
    b.plot(x, y_b)
    c.plot(x, y_c)
    d.plot(x, y_d)
    e.plot(x, y_e)
    plot_grid_loop.savefig()
    plt.close
########## Figure block A #####################

# from https://stackoverflow.com/a/43366784/4576447
def fitPlots(N, aspect=(16,9)):
    width = aspect[0]
    height = aspect[1]
    area = width*height*1.0
    factor = (N/area)**(1/2.0)
    cols = math.floor(width*factor)
    rows = math.floor(height*factor)
    rowFirst = width < height
    while rows*cols < N:
        if rowFirst:
            rows += 1
        else:
            cols += 1
        rowFirst = not(rowFirst)
    return rows, cols


n_plots = 15

n_rows, n_cols = fitPlots(n_plots)

with PdfPages('./plot_grid.pdf') as plot_grid_loop:
    for m in range(1, n_plots+1):
        fig, (a, b, c, d, e) = plt.subplots(5, 1, sharex=True, gridspec_kw={'height_ratios': [5, 1, 1, 1, 1]})
        a.plot(x, y_a)
        b.plot(x, y_b)
        c.plot(x, y_c)
        d.plot(x, y_d)
        e.plot(x, y_e)
        plot_grid_loop.savefig()
        plt.close
like image 275
Tom Kurushingal Avatar asked Dec 13 '25 21:12

Tom Kurushingal


1 Answers

This can be done by generating a GridSpec object with gs_fig = fig.add_gridspec() that contains enough rows and columns to fit the five figure blocks (note that when you use plt.subplots a GridSpec is also generated and can be accessed with ax.get_gridspec()). Each empty slot in the GridSpec can then be filled with a sub-GridSpec with gs_sub = gs_fig[i].subgridspec() to hold the five subplots. The trickier part is sharing the x-axis. This can be done by generating an empty first Axes with which the x-axis of all the subplots can be shared.

The following example illustrates this with only three figure blocks, based on the code sample you have shared but with some differences regarding the figure design: the number of rows is computed based on the chosen number of columns, and the figure dimensions are set based on a chosen figure width and aspect ratio. The code for saving the figure to a pdf file is not included.

import numpy as np               # v 1.19.2
import matplotlib.pyplot as plt  # v 3.3.4

# Create variables to plot
x = np.arange(1, 100, 0.2)
y_a = np.sqrt(x)
y_b = np.sin(x)
y_c = np.sin(x)
y_d = np.cos(x)*np.cos(x)
y_e = 1/x

# Set parameters for figure dimensions
nplots = 3  # random number of plots for this example
ncols = 2
nrows = int(np.ceil(nplots/ncols))
subp_w = 10/ncols  # 10 is the total figure width in inches
subp_h = 1*subp_w  # set subplot aspect ratio

# Create figure containing GridSpec object with appropriate dimensions
fig = plt.figure(figsize=(ncols*subp_w, nrows*subp_h))
gs_fig = fig.add_gridspec(nrows, ncols)

# Loop through GridSpec to add sub-GridSpec for each figure block
heights = [5, 1, 1, 1, 1]
for i in range(nplots):
    gs_sub = gs_fig[i].subgridspec(len(heights), 1, height_ratios=heights, hspace=0.2)
    ax = fig.add_subplot(gs_sub[0, 0])  # generate first empty Axes to enable sharex
    ax.axis('off')  # remove x and y axes because it is overwritten in the loop below
    # Loop through y variables to plot all the subplots with shared x-axis
    for j, y in enumerate([y_a, y_b, y_c, y_d, y_e]):
        ax = fig.add_subplot(gs_sub[j, 0], sharex=ax)
        ax.plot(x, y)
        if not ax.is_last_row():
            ax.tick_params(labelbottom=False)

subgridspec


Reference: matplotlib tutorial GridSpec using SubplotSpec

like image 70
Patrick FitzGerald Avatar answered Dec 16 '25 13:12

Patrick FitzGerald