How would I create a subplot using a bunch of go.Figure objects that have multiple lines and data points themselves? To explain:
# Data Visualization
from plotly.subplots import make_subplots
import plotly.graph_objects as go
epoch_list = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
loss_list = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
val_loss_list = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
error_rate = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
val_error_rate = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
loss_plots = [go.Scatter(x=epoch_list,
y=loss_list,
mode='lines',
name='Loss',
line=dict(width=4)),
go.Scatter(x=epoch_list,
y=val_loss_list,
mode='lines',
name='Validation Loss',
line=dict(width=4))]
loss_figure = go.Figure(data=loss_plots)
error_plots = [go.Scatter(x=epoch_list,
y=loss_list,
mode='lines',
name='Error Rate',
line=dict(width=4)),
go.Scatter(x=epoch_list,
y=val_loss_list,
mode='lines',
name='Validation Error Rate',
line=dict(width=4))]
error_figure = go.Figure(data=error_plots)
metric_figure = make_subplots(
rows=3, cols=2,
specs=[[{}, {}],
[{}, {}],
[{'colspan': 2}, {}]])
metric_figure.append_trace(loss_figure, row=1, col=1)
metric_figure.append_trace(error_figure, row=1, col=2)
metric_figure.show()
The error I get when trying to create the subplot is “invalid element(s) received for the ‘data’ property of Invalid elements include: [Figure”. I think I know why the error occurs, but is there a way around it? I still want to change the layout of each graph and have multiple lines on a single graph.
It's a simple case of loop over traces in each of the figures and add them to required sub-plot.
# Data Visualization
from plotly.subplots import make_subplots
import plotly.graph_objects as go
epoch_list = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
loss_list = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
val_loss_list = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
error_rate = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
val_error_rate = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
loss_plots = [go.Scatter(x=epoch_list,
y=loss_list,
mode='lines',
name='Loss',
line=dict(width=4)),
go.Scatter(x=epoch_list,
y=val_loss_list,
mode='lines',
name='Validation Loss',
line=dict(width=4))]
loss_figure = go.Figure(data=loss_plots)
error_plots = [go.Scatter(x=epoch_list,
y=loss_list,
mode='lines',
name='Error Rate',
line=dict(width=4)),
go.Scatter(x=epoch_list,
y=val_loss_list,
mode='lines',
name='Validation Error Rate',
line=dict(width=4))]
error_figure = go.Figure(data=error_plots)
metric_figure = make_subplots(
rows=3, cols=2,
specs=[[{}, {}],
[{}, {}],
[{'colspan': 2}, {}]])
for t in loss_figure.data:
metric_figure.append_trace(t, row=1, col=1)
for t in error_figure.data:
metric_figure.append_trace(t, row=1, col=2)
metric_figure.show()
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With