Compile TFLite Models

Author: Zhao Wu

This article is an introductory tutorial to deploy TFLite models with Relay.

To get started, Flatbuffers and TFLite package needs to be installed as prerequisites.

A quick solution is to install Flatbuffers via pip

pip install flatbuffers --user

To install TFlite packages, you could use our prebuilt wheel:

# For python3:
pip install tflite-0.0.1-py3-none-any.whl --user

# For python2:
pip install tflite-0.0.1-py2-none-any.whl --user

or you could generate TFLite package by yourself. The steps are as following:

# Get the flatc compiler.
# Please refer to for details
# and make sure it is properly installed.
flatc --version

# Get the TFLite schema.

# Generate TFLite package.
flatc --python schema.fbs

# Add it to PYTHONPATH.
export PYTHONPATH=/path/to/tflite

Now please check if TFLite package is installed successfully, python -c "import tflite"

Below you can find an example for how to compile TFLite model using TVM.

Utils for downloading and extracting zip files

def download(url, path, overwrite=False):
    import os
    if os.path.isfile(path) and not overwrite:
        print('File {} existed, skip.'.format(path))
    print('Downloading from url {} to {}'.format(url, path))
        import urllib.request
        urllib.request.urlretrieve(url, path)
        import urllib
        urllib.urlretrieve(url, path)

def extract(path):
    import tarfile
    if path.endswith("tgz") or path.endswith("gz"):
        tar =
        raise RuntimeError('Could not decompress the file: ' + path)

Load pretrained TFLite model

we load mobilenet V1 TFLite model provided by Google

model_url = ""

# we download model tar file and extract, finally get mobilenet_v1_1.0_224.tflite
download(model_url, "mobilenet_v1_1.0_224.tgz", False)

# now we have mobilenet_v1_1.0_224.tflite on disk and open it
tflite_model_file = "mobilenet_v1_1.0_224.tflite"
tflite_model_buf = open(tflite_model_file, "rb").read()

# get TFLite model from buffer
import tflite.Model
tflite_model = tflite.Model.Model.GetRootAsModel(tflite_model_buf, 0)


File mobilenet_v1_1.0_224.tgz existed, skip.

Load a test image

A single cat dominates the examples!

from PIL import Image
from matplotlib import pyplot as plt
import numpy as np

image_url = ''
download(image_url, 'cat.png')
resized_image ='cat.png').resize((224, 224))
image_data = np.asarray(resized_image).astype("float32")

# convert HWC to CHW
image_data = image_data.transpose((2, 0, 1))

# after expand_dims, we have format NCHW
image_data = np.expand_dims(image_data, axis=0)

# preprocess image as described here:
image_data[:, 0, :, :] = 2.0 / 255.0 * image_data[:, 0, :, :] - 1
image_data[:, 1, :, :] = 2.0 / 255.0 * image_data[:, 1, :, :] - 1
image_data[:, 2, :, :] = 2.0 / 255.0 * image_data[:, 2, :, :] - 1
print('input', image_data.shape)


File cat.png existed, skip.
input (1, 3, 224, 224)


Input layout:

Currently, TVM TFLite frontend accepts NCHW as input layout.

Compile the model with relay

# TFLite input tensor name, shape and type
input_tensor = "input"
input_shape = (1, 3, 224, 224)
input_dtype = "float32"

# parse TFLite model and convert into Relay computation graph
from tvm import relay
func, params = relay.frontend.from_tflite(tflite_model,
                                          shape_dict={input_tensor: input_shape},
                                          dtype_dict={input_tensor: input_dtype})

# targt x86 cpu
target = "llvm"
with relay.build_module.build_config(opt_level=3):
    graph, lib, params =, target, params=params)

Execute on TVM

import tvm
from tvm.contrib import graph_runtime as runtime

# create a runtime executor module
module = runtime.create(graph, lib, tvm.cpu())

# feed input data
module.set_input(input_tensor, tvm.nd.array(image_data))

# feed related params

# run

# get output
tvm_output = module.get_output(0).asnumpy()

Display results

# load label file
label_file_url = ''.join(['',
label_file = "labels_mobilenet_quant_v1_224.txt"
download(label_file_url, label_file)

# map id to 1001 classes
labels = dict()
with open(label_file) as f:
    for id, line in enumerate(f):
        labels[id] = line

# convert result to 1D data
predictions = np.squeeze(tvm_output)

# get top 1 prediction
prediction = np.argmax(predictions)

# convert id to class name and show the result
print("The image prediction result is: id " + str(prediction) + " name: " + labels[prediction])


File labels_mobilenet_quant_v1_224.txt existed, skip.
The image prediction result is: id 283 name: tiger cat

Total running time of the script: ( 0 minutes 7.124 seconds)

Gallery generated by Sphinx-Gallery