Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Subplot for Go.Figure objects with multiple plots within them

How would I create a subplot using a bunch of go.Figure objects that have multiple lines and data points themselves? To explain:

# Data Visualization
from plotly.subplots import make_subplots
import plotly.graph_objects as go

epoch_list = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
loss_list = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
val_loss_list = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
error_rate = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
val_error_rate = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]

loss_plots = [go.Scatter(x=epoch_list,
                         y=loss_list,
                         mode='lines',
                         name='Loss',
                         line=dict(width=4)),
              go.Scatter(x=epoch_list,
                         y=val_loss_list,
                         mode='lines',
                         name='Validation Loss',
                         line=dict(width=4))]

loss_figure = go.Figure(data=loss_plots)

error_plots = [go.Scatter(x=epoch_list,
                          y=loss_list,
                          mode='lines',
                          name='Error Rate',
                          line=dict(width=4)),
               go.Scatter(x=epoch_list,
                          y=val_loss_list,
                          mode='lines',
                          name='Validation Error Rate',
                          line=dict(width=4))]

error_figure = go.Figure(data=error_plots)

metric_figure = make_subplots(
    rows=3, cols=2,
    specs=[[{}, {}],
           [{}, {}],
           [{'colspan': 2}, {}]])

metric_figure.append_trace(loss_figure, row=1, col=1)
metric_figure.append_trace(error_figure, row=1, col=2)
metric_figure.show()

The error I get when trying to create the subplot is “invalid element(s) received for the ‘data’ property of Invalid elements include: [Figure”. I think I know why the error occurs, but is there a way around it? I still want to change the layout of each graph and have multiple lines on a single graph.

like image 849
Luleo_Primoc Avatar asked Oct 27 '25 01:10

Luleo_Primoc


1 Answers

It's a simple case of loop over traces in each of the figures and add them to required sub-plot.

# Data Visualization
from plotly.subplots import make_subplots
import plotly.graph_objects as go

epoch_list = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
loss_list = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
val_loss_list = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
error_rate = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
val_error_rate = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]

loss_plots = [go.Scatter(x=epoch_list,
                         y=loss_list,
                         mode='lines',
                         name='Loss',
                         line=dict(width=4)),
              go.Scatter(x=epoch_list,
                         y=val_loss_list,
                         mode='lines',
                         name='Validation Loss',
                         line=dict(width=4))]

loss_figure = go.Figure(data=loss_plots)

error_plots = [go.Scatter(x=epoch_list,
                          y=loss_list,
                          mode='lines',
                          name='Error Rate',
                          line=dict(width=4)),
               go.Scatter(x=epoch_list,
                          y=val_loss_list,
                          mode='lines',
                          name='Validation Error Rate',
                          line=dict(width=4))]

error_figure = go.Figure(data=error_plots)

metric_figure = make_subplots(
    rows=3, cols=2,
    specs=[[{}, {}],
           [{}, {}],
           [{'colspan': 2}, {}]])

for t in loss_figure.data:
    metric_figure.append_trace(t, row=1, col=1)
for t in error_figure.data:
    metric_figure.append_trace(t, row=1, col=2)
metric_figure.show()
like image 158
Rob Raymond Avatar answered Oct 29 '25 17:10

Rob Raymond



Donate For Us

If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!