tvm.relay.transform

The Relay IR namespace containing transformations.

Functions

AlterOpLayout()

Alternate the layouts of operators or replace primitive operators with other expressions.

AnnotateTarget(target)

Annotate ops in an experession with a provied compiler/target and then use it for codegen.

BackwardFoldScaleAxis()

Backward fold axis scaling into weights of conv2d/dense.

CanonicalizeCast()

Canonicalize cast expressions to make operator fusion more efficient.

CanonicalizeOps()

Canonicalize special operators to basic operators.

CombineParallelConv2D([min_num_branches])

Combine multiple conv2d operators into one.

CombineParallelDense([min_num_branches])

Combine multiple dense operators into one.

ConvertLayout(desired_layout)

Given a dest layout, this pass transforms the expr such that most of the ops input data layout is changed to the dest layout.

DeadCodeElimination([inline_once])

Remove expressions that do not have any users (dead code).

EliminateCommonSubexpr([fskip])

Eliminate common subexpressions.

EtaExpand([expand_constructor, …])

Add abstraction over a constructor or global variable bound to a function

FastMath()

Converts the expensive non linear functions to their fast but approximate counterparts.

FoldConstant()

Fold the constant expressions in a Relay program.

FoldScaleAxis()

Fold the scaling of axis into weights of conv2d/dense.

ForwardFoldScaleAxis()

Fold the scaling of axis into weights of conv2d/dense.

FuseOps([fuse_opt_level])

Fuse operators in an expr to a larger operator according to some rules.

InferType()

Infer the type of an expr.

Inline()

Perform inlining on the given Relay IR module.

LambdaLift()

Lift the closure to global function.

LazyGradientInit()

Reduces memory usage of gradient tensors

Legalize([legalize_map_attr_name])

Legalizes an expression with another expression.

MergeCompilerRegions()

Merge together compiler regions.

MergeComposite(pattern_table)

Merge multiple operators into a single composite relay function.

PartialEvaluate()

Evaluate the static fragment of the code.

PartitionGraph()

Partition a Relay program into regions that can be executed on different backends.

PrintIR([show_meta_data])

Print the IR for a module to help debugging.

RemoveUnusedFunctions([entry_functions])

Remove unused global relay functions in a relay module.

RewriteAnnotatedOps(fallback_device)

Rewrite the annotated program where annotation operators, e.g.

SimplifyInference()

Simplify the data-flow graph for inference phase.

ToANormalForm()

Turn Graph Normal Form expression into A Normal Form Expression.

ToCPS(expr[, mod])

Turn expression into continuation passing style(CPS).

ToGraphNormalForm()

Turn a Relay program in A Normal Form into Graph Normal Form

build_config([opt_level, fallback_device, …])

Configure the build behavior by setting config variables.

function_pass([pass_func, opt_level, name, …])

Decorate a function pass.

gradient(expr[, mod, mode])

Transform the input function, returning a function that calculate the original result, paired with gradient of the input.

module_pass([pass_func, opt_level, name, …])

Decorate a module pass.

to_cps(func[, mod])

Turn expression into CPS expression.

un_cps(func)

Turn an cps function into a Function without the continuation argument.

Classes

ChangeBatch(data[, batch_size])

Change the batch size.

FunctionPass

A pass that works on each tvm.relay.Function in a module.

ModulePass

A pass that works on tvm.IRModule.

Pass

The base class of all passes.

PassContext([opt_level, fallback_device, …])

The basis where a Relay optimization/analysis runs on.

PassInfo(opt_level, name[, required])

The class contains the meta data required by a pass.

Sequential([passes, opt_level, name, required])

A pass that works on a sequence of pass objects.

tvm.relay.transform.AlterOpLayout()

Alternate the layouts of operators or replace primitive operators with other expressions. This pass can be used for computing convolution in custom layouts or other general weight pre-transformation.

Returns

ret – The registered pass that alters the layout of operators.

Return type

tvm.relay.Pass

tvm.relay.transform.AnnotateTarget(target)

Annotate ops in an experession with a provied compiler/target and then use it for codegen.

Parameters

target (String) – The target compiler used for codegen.

Returns

ret – The annotated pass that wrapps ops with subgraph_start and subgraph_end.

Return type

tvm.relay.Pass

tvm.relay.transform.BackwardFoldScaleAxis()

Backward fold axis scaling into weights of conv2d/dense.

Returns

ret – The registered pass to backward fold expressions.

Return type

tvm.relay.Pass

Note

It is recommended to call backward_fold_scale_axis before using forward_fold_scale_axis as backward folding targets the common conv->bn pattern.

tvm.relay.transform.CanonicalizeCast()

Canonicalize cast expressions to make operator fusion more efficient.

Returns

ret – The registered pass that canonicalizes cast expression.

Return type

tvm.relay.Pass

tvm.relay.transform.CanonicalizeOps()

Canonicalize special operators to basic operators. This can simplify followed analysis, e.g. expanding bias_add to expand_dims and broadcast_add.

Returns

ret – The registered pass performing the canonicalization.

Return type

tvm.relay.Pass

class tvm.relay.transform.ChangeBatch(data, batch_size=16)

Change the batch size.

Parameters
  • data (Dict[relay.Var, int]) – A dictionary of all the params to change. The keys are all params, and the values are which dimension hold the batch.

  • batch_size (int) – The batch size to change to.

Returns

pass – The pass.

Return type

FunctionPass

tvm.relay.transform.CombineParallelConv2D(min_num_branches=3)

Combine multiple conv2d operators into one.

Parameters

min_num_branches (int) – The minimum number of required parallel branches for performing this optimization.

Returns

ret – The registered pass that combines parallel conv2d operators.

Return type

tvm.relay.Pass

tvm.relay.transform.CombineParallelDense(min_num_branches=3)

Combine multiple dense operators into one. For example:

Would become:

Parameters

min_num_branches (int) – The minimum number of required parallel branches for performing this optimization.

Returns

ret – The registered pass that combines parallel dense operators.

Return type

tvm.relay.Pass

tvm.relay.transform.ConvertLayout(desired_layout)

Given a dest layout, this pass transforms the expr such that most of the ops input data layout is changed to the dest layout. In ideal situation, there are only 2 layout transforms, one at the start and one at the end.

This pass is not a part of relay.build and is expected to be called between framework-relay parser and relay.build call. This is very helpful for hardware backends that support/prefer only type of data layout.

RFC - https://discuss.tvm.ai/t/layout-conversion-pass/4009

This pass uses most of the AlterOpLayout and InferCorrectLayout infrastructure. We can define new layouts for conv2d ops for now. Most of the other operators try to adapt to their input layout using the InferCorrectLayout infrastructure.

Parameters

desired_layout (str) – The desired layout for the transformed expr.

Returns

pass – The pass.

Return type

FunctionPass

tvm.relay.transform.DeadCodeElimination(inline_once=False)

Remove expressions that do not have any users (dead code).

Parameters

inline_once (Optional[Bool]) – Whether to inline binding that occurs only once.

Returns

ret – The registered pass that eliminates the dead code in a Relay program.

Return type

tvm.relay.Pass

tvm.relay.transform.EliminateCommonSubexpr(fskip=None)

Eliminate common subexpressions.

Parameters

fskip (Callable) – The callback function that decides whether an expression should be skipped.

Returns

ret – The registered pass that eliminates common subexpressions.

Return type

tvm.relay.Pass

tvm.relay.transform.EtaExpand(expand_constructor=False, expand_global_var=False)

Add abstraction over a constructor or global variable bound to a function

Parameters
  • expand_constructor (bool) – Whether to expand constructors.

  • expand_global_var (bool) – Whether to expand global variables.

Returns

ret – The registered pass that eta expands an expression.

Return type

tvm.relay.Pass

tvm.relay.transform.FastMath()

Converts the expensive non linear functions to their fast but approximate counterparts.

Returns

ret – The registered pass to perform fast math operations.

Return type

tvm.relay.Pass

tvm.relay.transform.FoldConstant()

Fold the constant expressions in a Relay program.

Returns

ret – The registered pass for constant folding.

Return type

tvm.relay.Pass

tvm.relay.transform.FoldScaleAxis()

Fold the scaling of axis into weights of conv2d/dense. This pass will invoke both forward and backward scale folding.

Returns

ret – The registered pass to fold expressions.

Return type

tvm.relay.Pass

Note

Internally, we will call backward_fold_scale_axis before using forward_fold_scale_axis as backward folding targets the common conv->bn pattern.

tvm.relay.transform.ForwardFoldScaleAxis()

Fold the scaling of axis into weights of conv2d/dense.

Returns

ret – The registered pass to forward fold expressions.

Return type

tvm.relay.Pass

Note

It is recommended to call backward_fold_scale_axis before using forward_fold_scale_axis, as backward folding targets the common conv->bn pattern.

class tvm.relay.transform.FunctionPass

A pass that works on each tvm.relay.Function in a module. A function pass class should be created through function_pass.

tvm.relay.transform.FuseOps(fuse_opt_level=-1)

Fuse operators in an expr to a larger operator according to some rules.

Parameters

fuse_opt_level (int) – The level of fuse optimization. -1 indicates that the level will be inferred from pass context.

Returns

ret – The registered pass for operator fusion.

Return type

tvm.relay.Pass

tvm.relay.transform.InferType()

Infer the type of an expr.

Returns

ret – The registered type inference pass.

Return type

tvm.relay.Pass

tvm.relay.transform.Inline()

Perform inlining on the given Relay IR module. The global functions that are marked as inline should be always inlined. A cost model will be needed in the future to decide if it is profitable to inline the function.

Returns

ret – The registered pass that performs inlining for a Relay IR module.

Return type

tvm.relay.Pass

tvm.relay.transform.LambdaLift()

Lift the closure to global function.

Returns

ret – The registered pass that lifts the lambda function.

Return type

tvm.relay.Pass

tvm.relay.transform.LazyGradientInit()

Reduces memory usage of gradient tensors

Returns

ret – A pass which delays and/or reduces memory allocation, by lazily allocating 0 or one filled tensors.

Return type

tvm.relay.Pass

tvm.relay.transform.Legalize(legalize_map_attr_name='FTVMLegalize')

Legalizes an expression with another expression. This pass can be used to replace an expr with another expr for target dependent optimizations. For example, one expr, though semnatically equivalent to the other, can have better performance on a target. This pass can be used to legalize the expr in a target-dependent manner.

Parameters

legalize_map_attr_name (str) – The Op’s attr name which corresponds to the legalize rule function.

Returns

ret – The registered pass that rewrites an expr.

Return type

tvm.relay.Pass

tvm.relay.transform.MergeCompilerRegions()

Merge together compiler regions.

Returns

ret – The registered pass that merges compiler regions.

Return type

tvm.relay.Pass

tvm.relay.transform.MergeComposite(pattern_table)

Merge multiple operators into a single composite relay function.

Parameters

pattern_table (list(tuple)) – A list of (pattern_name, pattern) tuples. The order of the patterns in the list will determine the order of priority in which they are matched.

Returns

ret – The registered pass that merges operators into a single composite relay function.

Return type

tvm.relay.Pass

class tvm.relay.transform.ModulePass

A pass that works on tvm.IRModule. Users don’t need to interact with this class directly. Instead, a module pass should be created through module_pass, because the design of the module_pass API is flexible enough to handle the creation of a module pass in different manners. In addition, all members of a module pass can be accessed from the base class. The same rule applies to FunctionPass as well.

tvm.relay.transform.PartialEvaluate()

Evaluate the static fragment of the code.

Note

This transformation could be either Module -> Module or Expr -> Expr. It will directly transform the input expression to a new one if the target expression is provided. Otherwise, it will rely on the pass manager to carry out transformation.

Returns

ret – The registered pass that performs partial evaluation on an expression.

Return type

tvm.relay.Pass

tvm.relay.transform.PartitionGraph()

Partition a Relay program into regions that can be executed on different backends.

Returns

ret – The registered pass that partitions the Relay program.

Return type

tvm.relay.Pass

class tvm.relay.transform.Pass

The base class of all passes. All methods here are just simple wrappers that are implemented in the backend. They are defined for users to conveniently interact with the base class.

Attributes

info

Get the pass meta.

property info

Get the pass meta.

class tvm.relay.transform.PassContext(opt_level=2, fallback_device=cpu(0), required_pass=None, disabled_pass=None, trace=None)

The basis where a Relay optimization/analysis runs on. Each pass context contains a number of auxiliary information that is used to help an optimization pass. Such information includes the error reporter to record the errors of during the optimization, etc.

opt_levelOptional[int]

The optimization level of this pass.

fallback_deviceOptional[Union[int, str, TVMContext]]

The fallback device type. It is also used as the default device for operators that are not annotated during heterogeneous execution.

required_passOptional[Union[List[str], Set[str], Tuple[str]]]

The list of passes that are required by a certain pass.

disabled_passOptional[Union[List[str], Set[str], Tuple[str]]]

The list of passes that are disabled.

Methods

current()

Return the current pass context.

static current()

Return the current pass context.

class tvm.relay.transform.PassInfo(opt_level, name, required=None)

The class contains the meta data required by a pass. It is the container of information needed by running an optimization or analysis. This class can be extended by adding new members when more meta data is needed.

Parameters
  • opt_level (int) – The optimization level of this pass.

  • name (str) – The pass name.

  • required (List[str]) – The list of passes that are required by a certain pass.

tvm.relay.transform.PrintIR(show_meta_data=True)

Print the IR for a module to help debugging.

Parameters

show_meta_data (bool) – A boolean flag to indicate if meta data should be printed.

Returns

ret – The registered pass that prints the module IR.

Return type

tvm.relay.Pass

tvm.relay.transform.RemoveUnusedFunctions(entry_functions=None)

Remove unused global relay functions in a relay module.

Parameters

entry_functions (list[string]) – The set of entry functions to start from.

Returns

ret – The registered pass to remove unused functions.

Return type

tvm.relay.Pass

tvm.relay.transform.RewriteAnnotatedOps(fallback_device)

Rewrite the annotated program where annotation operators, e.g. on_deivce, mark which device an expression should be scheduled to. This pass helps heterogeneous execution where different operators may need to be allocated on various devices.

Parameters

fallback_device (int) – The fallback device type. It is also used as the default device for operators with no annotated device.

Returns

ret – The registered pass that rewrites an expression with annotated on_device operators.

Return type

tvm.relay.Pass

class tvm.relay.transform.Sequential(passes=None, opt_level=2, name='sequential', required=None)

A pass that works on a sequence of pass objects. Multiple passes can be executed sequentially using this class.

Some typical usage of the sequential pass are: 1. Users provide a list of passes for optimization. 2. Only an optimization level is provided so that the backend system has to glob all passes at this level and below to perform the optimizations.

Note that users can also provide a series of passes that they don’t want to apply when running a sequential pass. Pass dependency will be resolved in the backend as well.

Parameters
  • passes (Optional[List[Pass]]) – A sequence of passes candidate for optimization.

  • opt_level (Optional[int]) – The optimization level of this sequential pass.

  • name (Optional[str]) – The name of the sequential pass.

  • required (Optional[List[str]]) – The list of passes that the sequential pass is dependent on.

tvm.relay.transform.SimplifyInference()

Simplify the data-flow graph for inference phase. An simplified expression which is semantically equal to the input expression will be returned.

Returns

ret – The registered pass to perform operator simplification.

Return type

tvm.relay.Pass

tvm.relay.transform.ToANormalForm()

Turn Graph Normal Form expression into A Normal Form Expression. The scope of the root expression is the global scope. The scope of any non root expression is the least common ancestor of all it’s scope. Values are ordered by post-DFS order in each scope.

Returns

ret – The registered pass that transforms an expression into A Normal Form.

Return type

Union[tvm.relay.Pass, tvm.relay.Expr]

tvm.relay.transform.ToCPS(expr, mod=None)

Turn expression into continuation passing style(CPS).

Every intermediate compute will be passed to a continuation.

Returns

result – The registered pass that transforms an expression into CPS.

Return type

tvm.relay.Pass

tvm.relay.transform.ToGraphNormalForm()

Turn a Relay program in A Normal Form into Graph Normal Form

Returns

ret – The registered pass that transforms an expression into Graph Normal Form.

Return type

tvm.relay.Pass

tvm.relay.transform.build_config(opt_level=2, fallback_device=cpu(0), required_pass=None, disabled_pass=None, trace=None)

Configure the build behavior by setting config variables.

Parameters
  • opt_level (int, optional) –

    Optimization level. The optimization pass name and level are as the following:

    OPT_PASS_LEVEL = {
        "SimplifyInference": 0,
        "OpFusion": 1,
        "FoldConstant": 2,
        "FoldScaleAxis": 3,
        "AlterOpLayout": 3,
        "CanonicalizeOps": 3,
        "CanonicalizeCast": 3,
        "EliminateCommonSubexpr": 3,
        "CombineParallelConv2D": 4,
        "CombineParallelDense": 4,
        "FastMath": 4
    }
    

  • fallback_device (int, str, or tvmContext, optional) – The fallback device. It is also used as the default device for operators without specified device during heterogeneous execution.

  • required_pass (set of str, optional) – Optimization passes that are required regardless of optimization level.

  • disabled_pass (set of str, optional) – Optimization passes to be disabled during optimization.

  • trace (Callable[[IRModule, PassInfo, bool], None]) – A tracing function for debugging or introspection.

Returns

pass_context – The pass context for optimizations.

Return type

PassContext

tvm.relay.transform.function_pass(pass_func=None, opt_level=None, name=None, required=None)

Decorate a function pass.

This function returns a callback when pass_func is provided. Otherwise, it returns the created function pass using the given optimization function.

Parameters
  • pass_func (Optional[Callable[(Function, Module, PassContext) -> Function]]) – The transformation function or class.

  • opt_level (int) – The optimization level of this module pass.

  • name (Optional[str]) – The name of the function pass. The name could be empty. In this case, the name of the optimization function will be used as the pass name.

  • required (Optional[List[str]]) – The list of passes that the module pass is dependent on.

Returns

create_function_pass – A decorator will be returned if pass_func is not provided, otherwise return the decorated result. The returned decorator has two behaviors depending on the input: A new FunctionPass will be returned when we decorate a pass function. A new FunctionPass class will be returned when we decorate a class type.

Return type

Union[Callable, FunctionPass]

Examples

The following code block decorates a function pass class.

@relay.transform.function_pass(opt_level=1)
class TestReplaceFunc:
    def __init__(self, new_func):
        self.new_func = new_func

    def transform_function(self, func, mod, ctx):
        # just for demo purposes
        # transform func to new_func
        return self.new_func

x = relay.var("x", shape=(10, 20))
f1 = relay.Function([x], x)
f2 = relay.Function([x], relay.log(x))
# fpass is now a special pass that replaces every
# function to f1
fpass = TestReplaceFunc(f1)
# now every function in input_mod is replaced by f1
res_mod = fpass(input_mod)

The following code creates a function pass by decorating a user defined transform function.

@relay.transform.function_pass(opt_level=2)
def transform(func, mod, ctx):
    # my transformations here.
    return func

function_pass = transform
assert isinstance(function_pass, transform.FunctionPass)
assert function_pass.info.opt_level == 2

# Given a module m, the optimization could be invoked as the follwoing:
updated_mod = function_pass(m)
# Now constant folding should have been applied to every function in
# the provided module m. And the updated module will be returned.
tvm.relay.transform.gradient(expr, mod=None, mode='higher_order')

Transform the input function, returning a function that calculate the original result, paired with gradient of the input.

Parameters
  • expr (tvm.relay.Expr) – The input expression, which is a Function or a GlobalVar.

  • mod (Optional[tvm.IRModule]) –

  • mode (Optional[String]) – The mode of the automatic differentiation algorithm. ‘first_order’ only works on first order code, but will not produce reference nor closure. ‘higher_order’ works on all code using reference and closure.

Returns

expr – The transformed expression.

Return type

tvm.relay.Expr

tvm.relay.transform.module_pass(pass_func=None, opt_level=None, name=None, required=None)

Decorate a module pass.

This function returns a callback when pass_func is provided. Otherwise, it serves a decorator function.

pass_func can also be a class type with a method transform_module. This function will create a decorated ModulePass using transform_module as the pass function.

Parameters
  • pass_func (Optional[Callable[(Module, PassContext) ->Module]]) – The transformation function or class.

  • opt_level (int) – The optimization level of this module pass.

  • name (Optional[str]) – The name of the module pass. The name could be empty. In this case, the name of the optimization function will be used as the pass name.

  • required (Optional[List[str]]) – The list of passes that the module pass is dependent on.

Returns

create_module_pass – A decorator will be returned if pass_func is not provided, otherwise return the decorated result. The returned decorator has two behaviors depending on the input: A new ModulePass will be returned when we decorate a pass function. A new ModulePass class will be returned when we decorate a class type.

Return type

Union[Callable, ModulePass]

Examples

The following code block decorates a module pass class.

@relay.transform.module_pass
class CustomPipeline:
    def __init__(self, enable_fold):
        self.enable_fold = enable_fold
        self.cse = relay.transform.EliminateCommonSubexpr()
        self.const_fold = relay.transform.FoldConstant()

    def transform_module(self, mod, ctx):
        mod = self.cse(mod, ctx)
        if self.enable_fold:
            mod = self.const_fold(mod, ctx)
        return mod

# create an instance of customized pipeline
pipeline = CustomPipeline(enable_fold=False)
assert isinstance(pipeline, transform.ModulePass)
# run the pipeline.
output_module = pipeline(input_module)

The following code creates a module pass by decorating a user defined transform function.

@relay.transform.module_pass(opt_level=2)
def transform(mod, ctx):
    tp = relay.TensorType((10,), "float32")
    x = relay.var("x", tp)
    gv = relay.GlobalVar("var")
    func = relay.Function([x], relay.abs(x))
    new_mod = tvm.IRModule({gv: func})
    new_mod.update(mod)
    return new_mod

module_pass = transform
assert isinstance(module_pass, transform.ModulePass)
assert module_pass.info.opt_level == 2

# Given a module m, the optimization could be invoked as the follwoing:
updated_mod = module_pass(m)
# Now a function abs should be added to the module m.
tvm.relay.transform.to_cps(func, mod=None)

Turn expression into CPS expression.

Every intermediate compute will be passed to a continuation.

Parameters
  • func (tvm.relay.Function) – The input function.

  • mod (Optional[tvm.IRModule]) – The global module.

Returns

result – The output function.

Return type

tvm.relay.Function

tvm.relay.transform.un_cps(func)

Turn an cps function into a Function without the continuation argument.

Note that this will not give the exact same interface as before cps:

If the input/output is higher order, they will still be in cps form.

Parameters

func (tvm.relay.Function) – The input function

Returns

result – The output function

Return type

tvm.relay.Function