Skip to main content

Guangzhou, China

Open Neural Network Exchange (ONNX)

Github Repository

ONNX is an open format built to represent machine learning models. ONNX defines a common set of operators - the building blocks of machine learning and deep learning models - and a common file format to enable AI developers to use models with a variety of frameworks, tools, runtimes, and compilers.

Netron

Inspect your ONNX model using Netron. Netron is a viewer for neural networks, deep learning and machine learning models. Netron supports ONNX, TensorFlow Lite, Core ML, Keras, Caffe, Darknet, MXNet, PaddlePaddle, ncnn, MNN and TensorFlow.js. Netron has experimental support for PyTorch, TorchScript, TensorFlow, OpenVINO, RKNN, MediaPipe, ML.NET and scikit-learn.

Installation on Linux using the latest AppImage:

wget https://github.com/lutzroeder/netron/releases/download/v7.0.8/Netron-7.0.8.AppImage
chmod 777 Netron-7.0.8.AppImage
./Netron-7.0.8.AppImage

Open Neural Network Exchange (ONNX)

Loading an ONNX Model

import tensorflow as tf
import tf2onnx

import onnx
from onnx import shape_inference, TensorProto
from onnx import version_converter
from onnx import numpy_helper
import onnx.helper as helper

from onnxsim import simplify

from onnx_tf.backend import prepare

import json
import numpy as np
import sys

Tensorflow to ONNX

tf_model = tf.keras.saving.load_model('mobilenet_model')
tf_model.summary()
input_signature = [tf.TensorSpec([1, 224, 224, 3], tf.float32, name='x')]
# Use from_function for tf functions
onnx_model, _ = tf2onnx.convert.from_keras(tf_model, input_signature, opset=12)
onnx.save(onnx_model, "mobilenet_model.onnx")

AttributeError: 'FuncGraph' object has no attribute '_captures'

BUG Tensorflow 2.13 / Python 3.11: GITHUB ISSUE1, GITHUB ISSUE1 AttributeError: 'FuncGraph' object has no attribute '_captures'. Did you mean: 'captures'?

Replace the following in:

  • ~/.local/lib/python3.11/site-packages/tf2onnx/tf_loader.py
  • ~/.local/lib/python3.11/site-packages/tf2onnx/convert.py
    #graph_captures = concrete_func.graph._captures  # pylint: disable=protected-access
#captured_inputs = [t_name.name for t_val, t_name in graph_captures.values()]

if hasattr(concrete_func.graph, "captures"):
graph_captures = concrete_func.graph.captures
captured_inputs = [t_name.name for t_val, t_name in graph_captures]
else:
graph_captures = concrete_func.graph._captures
captured_inputs = [t_name.name for t_val, t_name in graph_captures.values()]

Tensorflow Lite to ONNX

tf2onnx.convert.from_tflite(
tflite_path='mobilenet_model.tflite',
output_path='mobilenet_model_tflite.onnx',
opset=12
)

ONNX to Tensorflow

# from tf model
onnx_model = onnx.load("mobilenet_model.onnx")
onnx.checker.check_model(onnx_model)
tf_prep = prepare(onnx_model)
tf_prep.export_graph('deploy_model')
# from tf lite model
onnx_lite_model = onnx.load("mobilenet_model_tflite.onnx")
onnx.checker.check_model(onnx_lite_model)
tf_prep = prepare(onnx_lite_model)
tf_prep.export_graph('deploy_lite_model')

# ValueError: Tried to convert 'x' to a tensor and failed. Error: None values not supported.

Generate ONNX Prototext

def dump_normal(elem, indent, file) :
for s in str(elem).splitlines() :
print(indent + s, file=file)
def dump_initializer(elem, indent, file) : 
# calculate size.
size = 1
for d in elem.dims :
size *= d

# in the case of enough small size, output all data.
if (size <= 32) :
dump_normal(elem, indent, file)
return

# output metadata only, in all other cases.
for d in elem.dims :
print(indent + " dims: " + json.dumps(d), file=file)
print(indent + " data_type: " + json.dumps(elem.data_type), file=file)
print(indent + " name: " + json.dumps(elem.name), file=file)
def onnx2prototxt(onnx_path) :

# show information
out_path = onnx_path + ".prototxt"
print("+ creating " + out_path)
print(" from " + onnx_path + " ...")

# load model
model = onnx.load(onnx_path)

# print prototxt
with open(out_path, "w") as f :
print("ir_version: " + json.dumps(model.ir_version), file=f)
print("producer_name: " + json.dumps(model.producer_name), file=f)
print("producer_version: " + json.dumps(model.producer_version), file=f)
# print("domain: " + json.dumps(model.domain), file=f)
print("model_version: " + json.dumps(model.model_version), file=f)
# print("doc_string: " + json.dumps(model.doc_string), file=f)
print("graph {", file=f)
print(" name: " + json.dumps(model.graph.name), file=f)

for e in model.graph.node :
print(" node {", file=f)
dump_normal(e, " ", f)
print(" }", file=f)

for e in model.graph.initializer :
print(" initializer {", file=f)
dump_initializer(e, " ", f)
print(" }", file=f)

for e in model.graph.input :
print(" input {", file=f)
dump_normal(e, " ", f)
print(" }", file=f)

for e in model.graph.output :
print(" output {", file=f)
dump_normal(e, " ", f)
print(" }", file=f)

print("}", file=f)

for e in model.opset_import :
print("opset_import {", file=f)
print(" version: " + json.dumps(e.version), file=f)
print("}", file=f)
def show_usage(script) :
print("usage: python " + script + " input.onnx [more.onnx ..]")
onnx_path ="mobilenet_model.onnx"

onnx2prototxt(onnx_path)

ONNX to NovaONNX

INPUT = 'mobilenet_model.onnx'
OUTPUT = 'deploy.onnx'
SKIP_FUSE_BN = True
SKIP_ONNX_SIM = False
SKIP_MODIFY_IDX = False
onnx_model = onnx.load(INPUT)
# Each dimension of input shape must greater than zero
onnx_model.graph.input
for input in onnx_model.graph.input:
print(input.type.tensor_type.shape.dim)
# Opset version 8 ~ 12 supported
onnx_model.opset_import[0].version
### Supported Layer Types
SUPPORTED_OP_TYPE_LIST = [
'Abs',
'Add',
'AveragePool',
'BatchNormalization',
'Clip',
'Conv',
'ConvTranspose',
'Concat',
'Flatten',
'Gemm',
'GlobalAveragePool',
'GlobalMaxPool',
'LeakyRelu',
'LSTM',
'MatMul',
'Max',
'MaxPool',
'MaxRoiPool',
'Mul',
'Pad',
'PRelu',
'ReduceMean',
'Relu',
'Resize',
'Sigmoid',
'Softmax',
'Sub',
'Tanh',
'Transpose',
'Upsample',
'Reshape',
'Slice',
'Split',
'Neg',
'Sub',
'Tanh',
'Sqrt',
'Exp',
'Div',
'Log',
'Pow',
'Sin',
'Floor',
'Round',
'Squeeze',
'UnSqueeze'
]

Support Functions

def onnx_attribute_to_dict(onnx_attr):
#print(onnx_attr)
if onnx_attr.HasField('name'):
name = getattr(onnx_attr, 'name')
#print(name)

if onnx_attr.HasField('t'):
return name, numpy_helper.to_array(getattr(onnx_attr, 't'))

for attr_type in ['f', 'i', 's']:
if onnx_attr.HasField(attr_type):
return name, getattr(onnx_attr, attr_type)

for attr_type in ['floats', 'ints', 'strings']:
if getattr(onnx_attr, attr_type):
return name, list(getattr(onnx_attr, attr_type))
def add_input_from_initializer(model : onnx.ModelProto):
"""
Currently onnx.shape_inference doesn't use the shape of initializers, so add
that info explicitly as ValueInfoProtos.
Mutates the model.
Args:
model: The ModelProto to update.
"""
# All (top-level) constants will have ValueInfos before IRv4 as they are all inputs
if model.ir_version < 4:
return

def add_const_value_infos_to_graph(graph : onnx.GraphProto):
inputs = {i.name for i in graph.input}
existing_info = {vi.name: vi for vi in graph.input}
for init in graph.initializer:
# Check it really is a constant, not an input
if init.name in inputs:
continue

# The details we want to add
elem_type = init.data_type
shape = init.dims

# Get existing or create new value info for this constant
vi = existing_info.get(init.name)
if vi is None:
vi = graph.input.add()
vi.name = init.name

# Even though it would be weird, we will not overwrite info even if it doesn't match
tt = vi.type.tensor_type
if tt.elem_type == onnx.TensorProto.UNDEFINED:
tt.elem_type = elem_type
if not tt.HasField("shape"):
# Ensure we set an empty list if the const is scalar (zero dims)
tt.shape.dim.extend([])
for dim in shape:
tt.shape.dim.add().dim_value = dim

# Handle subgraphs
for node in graph.node:
for attr in node.attribute:
# Ref attrs refer to other attrs, so we don't need to do anything
if attr.ref_attr_name != "":
continue

if attr.type == onnx.AttributeProto.GRAPH:
add_const_value_infos_to_graph(attr.g)
if attr.type == onnx.AttributeProto.GRAPHS:
for g in attr.graphs:
add_const_value_infos_to_graph(g)

return add_const_value_infos_to_graph(model.graph)
def ReplaceUpsampleWithResize(onnx_model):

graph = onnx_model.graph

for i in range(len(graph.node)):
if graph.node[i].op_type == 'Upsample':
old_node = graph.node[i]
roi = numpy_helper.from_array(np.empty([0], dtype=np.float32), old_node.name + "_roi")
onnx_model.graph.initializer.append(roi)
roi_value_info = helper.make_tensor_value_info(old_node.name + "_roi", onnx.TensorProto.FLOAT, [0])
onnx_model.graph.value_info.append(roi_value_info)
inputs = [old_node.input[0], old_node.name + "_roi", old_node.input[1]]
mode_string = ''
for attr in graph.node[i].attribute:
if attr.name == 'mode':
mode_string = attr.s
new_node = onnx.helper.make_node(
"Resize",
coordinate_transformation_mode="asymmetric",
cubic_coeff_a=-0.75,
mode=mode_string,
nearest_mode="floor",
inputs=inputs,
outputs=old_node.output
)
graph.node.remove(old_node)
graph.node.insert(i, new_node)
def check_shapes(onnx_model):
names = []
for input_tensor in onnx_model.graph.input:
names.append(input_tensor.name)
for output_tensor in onnx_model.graph.output:
names.append(output_tensor.name)
for init_tensor in onnx_model.graph.initializer:
names.append(init_tensor.name)
for value in onnx_model.graph.value_info:
names.append(value.name)

for node in onnx_model.graph.node:
outputs = node.output
for output in outputs:
if output not in names:
assert False, "Shape checking error. Node: %s Type: %s, cannot get output shape, please check the attribute." % (node.name, node.op_type)
def Constant_to_initializer(onnxmodel):
graph = onnxmodel.graph
delete = []
for i in range(len(graph.node)):
if graph.node[i].op_type=="Constant":
# data = np.frombuffer(graph.node[i].attribute[0].t.raw_data, dtype=np.float32)
p_t = helper.make_tensor(graph.node[i].output[0], onnx.TensorProto.FLOAT, dims = 0, vals=graph.node[i].attribute[0].t.raw_data, raw=True)
delete.append(graph.node[i])
graph.initializer.insert(0, p_t)
for oldnode in delete:
graph.node.remove(oldnode)
def modify_layer_dix(graph):
outputs = graph.output
outputs_dict = {}
for i, output in enumerate(outputs):
for j, node in enumerate(graph.node):
if output.name in node.output:
# output_idx : node_idx, layer_idx
outputs_dict[i] = [j, j]

for i in range(len(outputs_dict)):
min_index = i
# find min_index
for j in range(i+1, len(outputs_dict)):
if outputs_dict[j][1] < outputs_dict[min_index][1]:
min_index = j

if min_index != i:
# exchange layer idx
for k, attr in enumerate(graph.node[outputs_dict[i][0]].attribute):
if attr.name == 'layer_idx':
new_layer_idx = onnx.helper.make_attribute("layer_idx", outputs_dict[min_index][1])
del graph.node[outputs_dict[i][0]].attribute[k]
graph.node[outputs_dict[i][0]].attribute.extend([new_layer_idx])
break

for k, attr in enumerate(graph.node[outputs_dict[min_index][0]].attribute):
if attr.name == 'layer_idx':
new_layer_idx = onnx.helper.make_attribute("layer_idx", outputs_dict[i][1])
del graph.node[outputs_dict[min_index][0]].attribute[k]
graph.node[outputs_dict[min_index][0]].attribute.extend([new_layer_idx])
break

# if graph.node[1].attribute
outputs_dict[i][1], outputs_dict[min_index][1] = outputs_dict[min_index][1], outputs_dict[i][1]

return graph
def to_nova_onnx(in_model_path, out_model_path, skip_fuse_bn, skip_onnx_sim, skip_modify_idx):
# load model
onnx_model = onnx.load(in_model_path)

if onnx_model.producer_name == 'Novatek NovaOnnx Converter' or onnx_model.producer_name == 'Novatek Caffe2Onnx Converter':
print("INFO :: This model is already a nova onnx model, skip the conversion process...")
return

# check input shape
for input in onnx_model.graph.input:
input_shape = input.type.tensor_type.shape.dim
for d in input_shape:
if d.dim_value <= 0:
assert (False), "ERROR :: Each dimension of input shape must greater than zero, illegal input name = %s"% input.name
Constant_to_initializer(onnx_model)
# convert model
add_input_from_initializer(onnx_model)

has_custom_op = 0
for node in onnx_model.graph.node:
if node.domain != '' and node.domain != 'ai.onnx':
has_custom_op = 1
if has_custom_op == 1:

#get all value_info and output name
tensor_names = []
for vi in onnx_model.graph.value_info:
tensor_names.append(vi.name)
for output in onnx_model.graph.output:
tensor_names.append(output.name)

# Add missing tensor_value_info (fake shape)
for i in range(len(onnx_model.graph.node)):
for output in onnx_model.graph.node[i].output:
if output not in tensor_names:
if onnx_model.graph.node[i].op_type == "Gemm" or onnx_model.graph.node[i].op_type == "Flatten":
fake_value_info = helper.make_tensor_value_info(output, TensorProto.FLOAT, [-1,-1])
else:
fake_value_info = helper.make_tensor_value_info(output, TensorProto.FLOAT, [-1,-1,-1,-1])
tensor_names.append(output)
onnx_model.graph.value_info.append(fake_value_info)

else:
# convert model to opset 12
if onnx_model.opset_import[0].version != 12:
if onnx_model.opset_import[0].version > 12 or onnx_model.opset_import[0].version < 8:
assert (False), ": Opset version of the input model is %d, novaonnx only supports Opset version 8 ~ 12."% onnx_model.opset_import[0].version
print("WARNING :: Opset version of the input model is {}, novaonnx support Opset version 12.".format(onnx_model.opset_import[0].version))
print("INFO :: Conversion from Opset version {} to Opset version 12.".format(onnx_model.opset_import[0].version))
onnx_model = version_converter.convert_version(onnx_model, 12)

#version_converter can not convert upsample(deprecated in opset 12), convert it to resize
ReplaceUpsampleWithResize(onnx_model)

if skip_onnx_sim:
onnx_model = shape_inference.infer_shapes(onnx_model)
check_shapes(onnx_model)
else:
# apply onnx simplify
onnx_model, check = simplify(onnx_model, skip_fuse_bn = skip_fuse_bn)

assert check, "WARNING :: Simplified ONNX model could not be validated"

for i in range(len(onnx_model.graph.node)):
if onnx_model.graph.node[i].op_type not in SUPPORTED_OP_TYPE_LIST:
print("WARNING :: Unsupported Layer Type ", onnx_model.graph.node[i].op_type)

graph = onnx_model.graph


init_name_list = []
for initializer in graph.initializer:
init_name_list.append(initializer.name)

name_dict = {}

#modify Conv weight name
for i in range(len(graph.node)):
if graph.node[i].op_type == 'Conv':
if graph.node[i].input[1] in init_name_list:
name_dict.setdefault(graph.node[i].input[1], graph.node[i].op_type + "_" + graph.node[i].input[1] + "_W")
graph.node[i].input[1] = graph.node[i].op_type + "_" + graph.node[i].input[1] + "_W"
if len(graph.node[i].input) > 2:
if graph.node[i].input[2] in init_name_list:
name_dict.setdefault(graph.node[i].input[2], graph.node[i].op_type + "_" + graph.node[i].input[2] + "_B")
graph.node[i].input[2] = graph.node[i].op_type + "_" + graph.node[i].input[2] + "_B"


#modify output tensor_name to (node_name)_Y
for k in range(len(graph.node[i].input)):
if graph.node[i].input[k] in name_dict:
graph.node[i].input[k] = name_dict[graph.node[i].input[k]]
for l in range(len(graph.node[i].output)):
name_dict.setdefault(graph.node[i].output[l], graph.node[i].op_type + "_" + graph.node[i].output[l] + "_Y")
graph.node[i].output[l] = graph.node[i].op_type + "_" + graph.node[i].output[l] + "_Y"

# Add layer_id attribute for each node
new_attr = helper.make_attribute("layer_idx", i)
graph.node[i].attribute.append(new_attr)

#modify Conv weight name
if graph.node[i].op_type == 'AveragePool' or graph.node[i].op_type == 'MaxPool':
new_attr = helper.make_attribute("pool_at_pad", 1)
graph.node[i].attribute.append(new_attr)

#print(graph.value_info)
#modify graph output tensor_name to (node_name)_Y
for m in range(len(graph.output)):
if graph.output[m].name in name_dict:
graph.output[m].name = name_dict[graph.output[m].name]

#modify value info name
for n in range(len(graph.value_info)):
if graph.value_info[n].name in name_dict:
graph.value_info[n].name = name_dict[graph.value_info[n].name]

#modify input name
for o in range(len(graph.input)):
if graph.input[o].name in name_dict:
graph.input[o].name = name_dict[graph.input[o].name]

#modify initializer name
for p in range(len(graph.initializer)):
if graph.initializer[p].name in name_dict:
graph.initializer[p].name = name_dict[graph.initializer[p].name]

if not skip_modify_idx:
graph = modify_layer_dix(graph)

onnx_model.producer_name = 'Novatek NovaOnnx Converter'
onnx_model.producer_version = '1.0'
onnx.save(onnx_model, out_model_path)
print("INFO :: Converted to NOVA ONNX!")
to_nova_onnx(INPUT, OUTPUT, SKIP_FUSE_BN, SKIP_ONNX_SIM, SKIP_MODIFY_IDX)