From the release highlight of pytorch 1.1.0. It appears that the latest JIT compiler now supports Dict type. (Source: https://jaxenter.com/pytorch-1-1-158332.html)
Dictionary and list support in TorchScript: Lists and dictionary types behave like Python lists and dictionaries.
Unfortunately I can't find a way to make this improvement to work properly. The following code is a simple example of exporting a Feature Pyramid Network (FPN) into tensorboard, which uses the JIT compiler:
from collections import OrderedDict
import torch
import torchvision
from torch.utils.tensorboard import SummaryWriter
torchWriter = SummaryWriter(log_dir=".tensorboard/example1")
m = torchvision.ops.FeaturePyramidNetwork([10, 20, 30], 5)
# get some dummy data
x = OrderedDict()
x['feat0'] = torch.rand(1, 10, 64, 64)
x['feat2'] = torch.rand(1, 20, 16, 16)
x['feat3'] = torch.rand(1, 30, 8, 8)
# compute the FPN on top of x
output = m.forward(x)
print([(k, v.shape) for k, v in output.items()])
torchWriter.add_graph(m, input_to_model=x)
When I run it I got the following error:
Traceback (most recent call last):
File "/home/shared/virtualenv/dl-torch/lib/python3.7/site-packages/torch/utils/tensorboard/_pytorch_graph.py", line 276, in graph
trace, _ = torch.jit.get_trace_graph(model, args)
File "/home/shared/virtualenv/dl-torch/lib/python3.7/site-packages/torch/jit/__init__.py", line 231, in get_trace_graph
return LegacyTracedModule(f, _force_outplace, return_inputs)(*args, **kwargs)
File "/home/shared/virtualenv/dl-torch/lib/python3.7/site-packages/torch/nn/modules/module.py", line 493, in __call__
result = self.forward(*input, **kwargs)
File "/home/shared/virtualenv/dl-torch/lib/python3.7/site-packages/torch/jit/__init__.py", line 284, in forward
in_vars, in_desc = _flatten(args)
RuntimeError: Only tuples, lists and Variables supported as JIT inputs, but got collections.OrderedDict
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "/home/peng/git-drone/gate_detection/python/gate_detection/errorcase/tb.py", line 36, in <module>
torchWriter.add_graph(m, input_to_model=x)
File "/home/shared/virtualenv/dl-torch/lib/python3.7/site-packages/torch/utils/tensorboard/writer.py", line 534, in add_graph
self._get_file_writer().add_graph(graph(model, input_to_model, verbose, **kwargs))
File "/home/shared/virtualenv/dl-torch/lib/python3.7/site-packages/torch/utils/tensorboard/_pytorch_graph.py", line 279, in graph
_ = model(*args) # don't catch, just print the error message
File "/home/shared/virtualenv/dl-torch/lib/python3.7/site-packages/torch/nn/modules/module.py", line 493, in __call__
result = self.forward(*input, **kwargs)
TypeError: forward() takes 2 positional arguments but 4 were given
From the error message it appears that the support is still pending. Can I trust the release highlight? Or I'm not using the API properly?
The release notes are accurate albeit a little vague. The dictionary/list/user defined classes support described in that link (and the official release notes) only apply to the TorchScript compiler (there are some code examples in the release notes), but SummaryWriter
by default will run the TorchScript tracer on whatever module you pass to it, and the tracer only supports Tensors and lists/tuples of Tensors.
So the fix would be to use the TorchScript compiler rather than the tracer, but that requires:
ScriptModule
) in TensorboardYou should file an issue for (2), and there is ongoing work to fix (1), but this won't work in the short term for that model afaik.
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With