Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

tf_rep.export_graph(tf_model_path): KeyError: 'input.1

I am trying to convert a onnx model to tflite, im facing an error executing line tf_rep.export_graph(tf_model_path). This question was asked in SO before but none provided a definitive solution.

Requirements installed: tensorflow: 2.12.0, onnx 1.14.0, onnx-tf 1.10.0, Python 3.10.12

  import torch
  import onnx
  import tensorflow as tf
  import onnx_tf
  from torchvision.models import resnet50

  # Load the PyTorch ResNet50 model
  pytorch_model = resnet50(pretrained=True)
  pytorch_model.eval()

  # Export the PyTorch model to ONNX format
  input_shape = (1, 3, 224, 224)
  dummy_input = torch.randn(input_shape)
  onnx_model_path = 'resnet50.onnx'
  torch.onnx.export(pytorch_model, dummy_input, onnx_model_path, opset_version=12, verbose=False)

  # Load the ONNX model
  onnx_model = onnx.load(onnx_model_path)

  # Convert the ONNX model to TensorFlow format
  tf_model_path = 'resnet50.pb

  onnx_model = onnx.load(onnx_model_path)
  from onnx_tf.backend import prepare

  tf_rep = prepare(onnx_model)
  tf_rep.export_graph(tf_model_path)    #ERROR

Error:

WARNING:absl:`input.1` is not a valid tf.function parameter name. Sanitizing to `input_1`.
---------------------------------------------------------------------------
KeyError                                  Traceback (most recent call last)
<ipython-input-4-f35b83c104b8> in <cell line: 8>()
    6 tf_model_path = 'resnet50'
    7 tf_rep = prepare(onnx_model)
----> 8 tf_rep.export_graph(tf_model_path)

35 frames
/usr/local/lib/python3.10/dist-packages/onnx_tf/handlers/backend/conv_mixin.py in tf__conv(cls, node, input_dict, transpose)
    17                 do_return = False
    18                 retval_ = ag__.UndefinedReturnValue()
---> 19                 x = ag__.ld(input_dict)[ag__.ld(node).inputs[0]]
    20                 x_rank = ag__.converted_call(ag__.ld(len), (ag__.converted_call(ag__.ld(x).get_shape, (), None, fscope),), None, fscope)
    21                 x_shape = ag__.converted_call(ag__.ld(tf_shape), (ag__.ld(x), ag__.ld(tf).int32), None, fscope)

KeyError: in user code:

    File "/usr/local/lib/python3.10/dist-packages/onnx_tf/backend_tf_module.py", line 99, in __call__  *
        output_ops = self.backend._onnx_node_to_tensorflow_op(onnx_node,
    File "/usr/local/lib/python3.10/dist-packages/onnx_tf/backend.py", line 347, in _onnx_node_to_tensorflow_op  *
        return handler.handle(node, tensor_dict=tensor_dict, strict=strict)
    File "/usr/local/lib/python3.10/dist-packages/onnx_tf/handlers/handler.py", line 59, in handle  *
        return ver_handle(node, **kwargs)
    File "/usr/local/lib/python3.10/dist-packages/onnx_tf/handlers/backend/conv.py", line 15, in version_11  *
        return cls.conv(node, kwargs["tensor_dict"])
    File "/usr/local/lib/python3.10/dist-packages/onnx_tf/handlers/backend/conv_mixin.py", line 29, in conv  *
        x = input_dict[node.inputs[0]]

    KeyError: 'input.1'
like image 575
afsara_ben Avatar asked Sep 08 '25 18:09

afsara_ben


2 Answers

The problem was with a parameter name in onnx model.

import onnx

onnx_model = onnx.load(onnx_model_path)
print("Model Inputs: ", [inp.name for inp in onnx_model.graph.input])

Model Inputs: ['input.1']

Here tflite cannot parse the input.1 and has to be replaced by input_1. The following code does that:

import onnx
from onnx import helper

onnx_model = onnx.load(onnx_model_path)

# Define a mapping from old names to new names
name_map = {"input.1": "input_1"}

# Initialize a list to hold the new inputs
new_inputs = []

# Iterate over the inputs and change their names if needed
for inp in onnx_model.graph.input:
    if inp.name in name_map:
        # Create a new ValueInfoProto with the new name
        new_inp = helper.make_tensor_value_info(name_map[inp.name],
                                                inp.type.tensor_type.elem_type,
                                                [dim.dim_value for dim in inp.type.tensor_type.shape.dim])
        new_inputs.append(new_inp)
    else:
        new_inputs.append(inp)

# Clear the old inputs and add the new ones
onnx_model.graph.ClearField("input")
onnx_model.graph.input.extend(new_inputs)

# Go through all nodes in the model and replace the old input name with the new one
for node in onnx_model.graph.node:
    for i, input_name in enumerate(node.input):
        if input_name in name_map:
            node.input[i] = name_map[input_name]

# Save the renamed ONNX model
onnx.save(onnx_model, 'resnet50-new.onnx')

The new parameter looks like:

Model Inputs: ['input_1']

The output tflite file generates without error.

import onnx

onnx_model_path = 'resnet50-new.onnx'
onnx_model = onnx.load(onnx_model_path)
from onnx_tf.backend import prepare

tf_model_path = 'resnet50'
tf_rep = prepare(onnx_model)
tf_rep.export_graph(tf_model_path)
like image 184
afsara_ben Avatar answered Sep 10 '25 09:09

afsara_ben


It is indeed due to bad input tensor names which TF cannot support, but there is a much simpler fix if you are using pytorch. (And I imagine something similar exists for TF)

Simply export like:

torch.onnx.export(onnx_model_path, 
                  input_names=['input'], output_names=['output'],
                  dynamic_axes={'input' : {0 : 'batch_size'},
                                'output' : {0 : 'batch_size'}}))

This example is for an MLP but I think it'll work for more complicated models. It does 2 things: it fixes the input/output names to be something simple (i.e. 'input'/'output') and it exports models with dynamic batch dimension (assuming that is the first dimension).

If you don't do this then your models' batch dimension will be fixed!!

like image 23
profPlum Avatar answered Sep 10 '25 08:09

profPlum