Compile ONNX Models

Author: Joshua Z. Zhang

This article is an introductory tutorial to deploy ONNX models with Relay.

For us to begin with, ONNX package must be installed.

A quick solution is to install protobuf compiler, and

pip install onnx --user

or please refer to offical site.

import onnx
import numpy as np
import tvm
import tvm.relay as relay
from import download_testdata

Load pretrained ONNX model

The example super resolution model used here is exactly the same model in onnx tutorial we skip the pytorch model construction part, and download the saved onnx model

model_url = ''.join(['',
model_path = download_testdata(model_url, 'super_resolution.onnx', module='onnx')
# now you have super_resolution.onnx on disk
onnx_model = onnx.load(model_path)


File /workspace/.tvm_test_data/onnx/super_resolution.onnx exists, skip.

Load a test image

A single cat dominates the examples!

from PIL import Image
img_url = ''
img_path = download_testdata(img_url, 'cat.png', module='data')
img =, 224))
img_ycbcr = img.convert("YCbCr")  # convert to YCbCr
img_y, img_cb, img_cr = img_ycbcr.split()
x = np.array(img_y)[np.newaxis, np.newaxis, :, :]


File /workspace/.tvm_test_data/data/cat.png exists, skip.

Compile the model with relay

target = 'llvm'

input_name = '1'
shape_dict = {input_name: x.shape}
sym, params = relay.frontend.from_onnx(onnx_model, shape_dict)

with relay.build_config(opt_level=1):
    intrp = relay.build_module.create_executor('graph', sym, tvm.cpu(0), target)

Execute on TVM

dtype = 'float32'
tvm_output = intrp.evaluate(sym)(tvm.nd.array(x.astype(dtype)), **params).asnumpy()

Display results

We put input and output image neck to neck

from matplotlib import pyplot as plt
out_y = Image.fromarray(np.uint8((tvm_output[0, 0]).clip(0, 255)), mode='L')
out_cb = img_cb.resize(out_y.size, Image.BICUBIC)
out_cr = img_cr.resize(out_y.size, Image.BICUBIC)
result = Image.merge('YCbCr', [out_y, out_cb, out_cr]).convert('RGB')
canvas = np.full((672, 672*2, 3), 255)
canvas[0:224, 0:224, :] = np.asarray(img)
canvas[:, 672:, :] = np.asarray(result)

Total running time of the script: ( 0 minutes 3.208 seconds)

Gallery generated by Sphinx-Gallery