I am trying to convert a onnx
model to tflite
, im facing an error executing line tf_rep.export_graph(tf_model_path)
. This question was asked in SO before but none provided a definitive solution.
Requirements installed: tensorflow: 2.12.0
, onnx 1.14.0
, onnx-tf 1.10.0
, Python 3.10.12
import torch
import onnx
import tensorflow as tf
import onnx_tf
from torchvision.models import resnet50
# Load the PyTorch ResNet50 model
pytorch_model = resnet50(pretrained=True)
pytorch_model.eval()
# Export the PyTorch model to ONNX format
input_shape = (1, 3, 224, 224)
dummy_input = torch.randn(input_shape)
onnx_model_path = 'resnet50.onnx'
torch.onnx.export(pytorch_model, dummy_input, onnx_model_path, opset_version=12, verbose=False)
# Load the ONNX model
onnx_model = onnx.load(onnx_model_path)
# Convert the ONNX model to TensorFlow format
tf_model_path = 'resnet50.pb
onnx_model = onnx.load(onnx_model_path)
from onnx_tf.backend import prepare
tf_rep = prepare(onnx_model)
tf_rep.export_graph(tf_model_path) #ERROR
Error:
WARNING:absl:`input.1` is not a valid tf.function parameter name. Sanitizing to `input_1`.
---------------------------------------------------------------------------
KeyError Traceback (most recent call last)
<ipython-input-4-f35b83c104b8> in <cell line: 8>()
6 tf_model_path = 'resnet50'
7 tf_rep = prepare(onnx_model)
----> 8 tf_rep.export_graph(tf_model_path)
35 frames
/usr/local/lib/python3.10/dist-packages/onnx_tf/handlers/backend/conv_mixin.py in tf__conv(cls, node, input_dict, transpose)
17 do_return = False
18 retval_ = ag__.UndefinedReturnValue()
---> 19 x = ag__.ld(input_dict)[ag__.ld(node).inputs[0]]
20 x_rank = ag__.converted_call(ag__.ld(len), (ag__.converted_call(ag__.ld(x).get_shape, (), None, fscope),), None, fscope)
21 x_shape = ag__.converted_call(ag__.ld(tf_shape), (ag__.ld(x), ag__.ld(tf).int32), None, fscope)
KeyError: in user code:
File "/usr/local/lib/python3.10/dist-packages/onnx_tf/backend_tf_module.py", line 99, in __call__ *
output_ops = self.backend._onnx_node_to_tensorflow_op(onnx_node,
File "/usr/local/lib/python3.10/dist-packages/onnx_tf/backend.py", line 347, in _onnx_node_to_tensorflow_op *
return handler.handle(node, tensor_dict=tensor_dict, strict=strict)
File "/usr/local/lib/python3.10/dist-packages/onnx_tf/handlers/handler.py", line 59, in handle *
return ver_handle(node, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/onnx_tf/handlers/backend/conv.py", line 15, in version_11 *
return cls.conv(node, kwargs["tensor_dict"])
File "/usr/local/lib/python3.10/dist-packages/onnx_tf/handlers/backend/conv_mixin.py", line 29, in conv *
x = input_dict[node.inputs[0]]
KeyError: 'input.1'
The problem was with a parameter name in onnx
model.
import onnx
onnx_model = onnx.load(onnx_model_path)
print("Model Inputs: ", [inp.name for inp in onnx_model.graph.input])
Model Inputs: ['input.1']
Here tflite
cannot parse the input.1
and has to be replaced by input_1
. The following code does that:
import onnx
from onnx import helper
onnx_model = onnx.load(onnx_model_path)
# Define a mapping from old names to new names
name_map = {"input.1": "input_1"}
# Initialize a list to hold the new inputs
new_inputs = []
# Iterate over the inputs and change their names if needed
for inp in onnx_model.graph.input:
if inp.name in name_map:
# Create a new ValueInfoProto with the new name
new_inp = helper.make_tensor_value_info(name_map[inp.name],
inp.type.tensor_type.elem_type,
[dim.dim_value for dim in inp.type.tensor_type.shape.dim])
new_inputs.append(new_inp)
else:
new_inputs.append(inp)
# Clear the old inputs and add the new ones
onnx_model.graph.ClearField("input")
onnx_model.graph.input.extend(new_inputs)
# Go through all nodes in the model and replace the old input name with the new one
for node in onnx_model.graph.node:
for i, input_name in enumerate(node.input):
if input_name in name_map:
node.input[i] = name_map[input_name]
# Save the renamed ONNX model
onnx.save(onnx_model, 'resnet50-new.onnx')
The new parameter looks like:
Model Inputs: ['input_1']
The output tflite
file generates without error.
import onnx
onnx_model_path = 'resnet50-new.onnx'
onnx_model = onnx.load(onnx_model_path)
from onnx_tf.backend import prepare
tf_model_path = 'resnet50'
tf_rep = prepare(onnx_model)
tf_rep.export_graph(tf_model_path)
It is indeed due to bad input tensor names which TF cannot support, but there is a much simpler fix if you are using pytorch. (And I imagine something similar exists for TF)
Simply export like:
torch.onnx.export(onnx_model_path,
input_names=['input'], output_names=['output'],
dynamic_axes={'input' : {0 : 'batch_size'},
'output' : {0 : 'batch_size'}}))
This example is for an MLP but I think it'll work for more complicated models. It does 2 things: it fixes the input/output names to be something simple (i.e. 'input'/'output') and it exports models with dynamic batch dimension (assuming that is the first dimension).
If you don't do this then your models' batch dimension will be fixed!!
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With