tvm.autotvm

The auto-tuning module of tvm

This module includes:

  • Tuning space definition API
  • Efficient auto-tuners
  • Tuning result and database support
  • Distributed measurement to scale up tuning

tvm.autotvm.measure

User facing API for specifying how to measure the generated code

class tvm.autotvm.measure.MeasureInput

Stores all the necessary inputs for a measurement.

Parameters:
class tvm.autotvm.measure.MeasureResult

Stores all the results of a measurement

Parameters:
  • costs (Array of float or Array of Exception) – If no error occurs during measurement, it is an array of measured running times. If an error occurs during measurement, it is an array of the exception objections.
  • error_no (int) – Denote error type, defined by MeasureErrorNo
  • all_cost (float) – All cost of this measure, including rpc, compilation, test runs
  • timestamp (float) – The absolute time stamp when we finish measurement.
tvm.autotvm.measure.measure_option(builder, runner)

Set options for measure. To measure a config, we will build it and run it. So we have to set options for these two steps. They have their own options on timeout, parallel, etc.

Parameters:
  • builder (Builder) – Specify how to build programs
  • runner (Runner) – Specify how to run programs

Examples

# example setting for using local devices >>> measure_option = autotvm.measure_option( >>> builder=autotvm.LocalBuilder(), # use all local cpu cores for compilation >>> runner=autotvm.LocalRunner( # measure them sequentially >>> number=10, >>> timeout=5) >>> )

# example setting for using remote devices >>> measure_option = autotvm.measure_option( >>> builder=autotvm.LocalBuilder(), # use all local cpu cores for compilation >>> runner=autotvm.RPCRunner( >>> ‘rasp3b’, ‘locahost’, 9190, # device key, host and port of the rpc tracker >>> number=4, >>> timeout=4) # timeout of a run on the device. RPC request waiting time is excluded. >>>)

Note

To make measurement results accurate, you should pick the correct value for the argument number and repeat in Runner(). Using min_repeat_ms can dynamically adjusts number, so it is recommended. The typical value for NVIDIA GPU is 100 ms.

tvm.autotvm.measure.create_measure_batch(task, option)

Get a standard measure_batch function.

Parameters:
  • task (tvm.autotvm.task.Task) – The tuning task
  • option (dict) – The option for measuring generated code. You should use the return value of function measure_option for this argument.
Returns:

measure_batch – a callback function to measure a batch of configs

Return type:

callable

class tvm.autotvm.measure.measure_methods.LocalBuilder(timeout=10, n_parallel=None, build_func='default')

Run compilation on local machine

Parameters:
  • timeout (float) – The timeout of a compilation
  • n_parallel (int) – The number of tasks run in parallel. “None” will use all cpu cores
  • build_func (callable or str) – If is ‘default’, use default build function If is ‘ndk’, use function for android ndk If is callable, use it as custom build function
class tvm.autotvm.measure.measure_methods.RPCRunner(key, host, port, priority=1, timeout=10, n_parallel=None, number=4, repeat=3, min_repeat_ms=0, cooldown_interval=0.1, check_correctness=False)

Run generated code on remove devices. This function will ask a RPC Tracker to get device for measurement.

Parameters:
  • timeout (float) – The timeout of a compilation
  • n_parallel (int) – The number of tasks run in parallel. “None” will use all cpu cores
  • key (str) – The key of the device registered in the tracker
  • host (str) – The host address of RPC Tracker
  • port (int) – The port of RPC Tracker
  • number (int, optional) – Number of times to do measurement for tasking average
  • repeat (int, optional) – Number of times to repeat the measurement. In total, the generated code will be run (1 + number x repeat) times, where the first one is warm up. The returned result contains repeat costs,
  • min_repeat_ms (float, optional) – Minimum duration of a timer measurement in milliseconds. When the run time of a measurement trial falls below this time, the number parameter will be automatically increased. Set this to improve the accuracy of perf measurement, e.g., when timers are not precise enough to capture short-running tasks. This parameter is also critical when devices need a certain minimum running time to “warm up,” such as GPUs that need time to reach a performance power state.
  • cooldown_interval (float, optional) – The cool down interval between two measurements.
  • check_correctness (bool, optional) – Whether check correctness after measurement. This will use llvm cpu target to call your template and get the reference output. This can work for TOPI templates, but may not work for your custom template.
class tvm.autotvm.measure.measure_methods.LocalRunner(timeout=10, number=4, repeat=3, min_repeat_ms=0, cooldown_interval=0.1, check_correctness=False)

Run generated code on local devices.

Parameters:
  • timeout (float) – The timeout of a compilation
  • number (int, optional) – Number of times to do measurement for tasking average
  • repeat (int, optional) – Number of times to repeat the measurement. In total, the generated code will be run (1 + number x repeat) times, where the first one is warm up. The returned result contains repeat costs, each of which is the average of number test run.
  • min_repeat_ms (float, optional) – Minimum duration of a timer measurement in milliseconds. When the run time of a measurement trial falls below this time, the number parameter will be automatically increased. Set this to improve the accuracy of perf measurement, e.g., when timers are not precise enough to capture short-running tasks. This parameter is also critical when devices need a certain minimum running time to “warm up,” such as GPUs that need time to reach a performance power state.
  • cooldown_interval (float, optional) – The cool down interval between two measurements.
  • check_correctness (bool, optional) – Whether check correctness after measurement. This will use llvm cpu target to call your template and get the reference output. This can work for TOPI templates, but may not work for your custom template.

Note

This is a “fake” local mode. We start a silent rpc tracker and rpc server for the user. In this way we reuse timeout/isolation mechanism in RPC infrastructure.

tvm.autotvm.tuner

A tuner takes a task as input. It proposes some promising ConfigEntity in the ConfigSpace and measure them on the real hardware. Then it proposed the next batch of ConfigEntity according to the measure results. This tuning loop is repeated.

class tvm.autotvm.tuner.Tuner(task, **kwargs)

Base class for tuners

Parameters:task (autotvm.task.Task) – Tuning Task
has_next()

Whether has next untried config in the space

Returns:has_next
Return type:bool
load_history(data_set)

load history data for transfer learning

Parameters:data_set (Array of (MeasureInput, MeasureResult) pair) – Previous tuning records
next_batch(batch_size)

get the next batch of configs to be measure on real hardware

Parameters:batch_size (int) – The size of the batch
Returns:
Return type:a batch of configs
reset()

reset the status of tuner

tune(n_trial, measure_option, early_stopping=None, callbacks=())

Begin tuning

Parameters:
  • n_trial (int) – Maximum number of configs to try (measure on real hardware)
  • measure_option (dict) – The options for how to measure generated code. You should use the return value ot autotvm.measure_option for this argument.
  • early_stopping (int, optional) – Early stop the tuning when not finding better configs in this number of trials
  • callbacks (List of callable) – A list of callback functions. The signature of callback function is (Tuner, List of MeasureInput, List of MeasureResult) with no return value. These callback functions will be called on every measurement pair. See autotvm/tuner/callback.py for some examples.
update(inputs, results)

Update parameters of the tuner according to measurement results

Parameters:
  • inputs (Array of autotvm.measure.MeasureInput) – The input for measurement
  • results (Array of autotvm.measure.MeasureResult) – result for measurement
class tvm.autotvm.tuner.RandomTuner(task)

Enumerate the search space in a random order

has_next()

Whether has next untried config in the space

Returns:has_next
Return type:bool
load_history(data_set)

load history data for transfer learning

Parameters:data_set (Array of (MeasureInput, MeasureResult) pair) – Previous tuning records
next_batch(batch_size)

get the next batch of configs to be measure on real hardware

Parameters:batch_size (int) – The size of the batch
Returns:
Return type:a batch of configs
reset()

reset the status of tuner

tune(n_trial, measure_option, early_stopping=None, callbacks=())

Begin tuning

Parameters:
  • n_trial (int) – Maximum number of configs to try (measure on real hardware)
  • measure_option (dict) – The options for how to measure generated code. You should use the return value ot autotvm.measure_option for this argument.
  • early_stopping (int, optional) – Early stop the tuning when not finding better configs in this number of trials
  • callbacks (List of callable) – A list of callback functions. The signature of callback function is (Tuner, List of MeasureInput, List of MeasureResult) with no return value. These callback functions will be called on every measurement pair. See autotvm/tuner/callback.py for some examples.
update(inputs, results)

Update parameters of the tuner according to measurement results

Parameters:
  • inputs (Array of autotvm.measure.MeasureInput) – The input for measurement
  • results (Array of autotvm.measure.MeasureResult) – result for measurement
class tvm.autotvm.tuner.GridSearchTuner(task)

Enumerate the search space in a grid search order

has_next()

Whether has next untried config in the space

Returns:has_next
Return type:bool
load_history(data_set)

load history data for transfer learning

Parameters:data_set (Array of (MeasureInput, MeasureResult) pair) – Previous tuning records
next_batch(batch_size)

get the next batch of configs to be measure on real hardware

Parameters:batch_size (int) – The size of the batch
Returns:
Return type:a batch of configs
reset()

reset the status of tuner

tune(n_trial, measure_option, early_stopping=None, callbacks=())

Begin tuning

Parameters:
  • n_trial (int) – Maximum number of configs to try (measure on real hardware)
  • measure_option (dict) – The options for how to measure generated code. You should use the return value ot autotvm.measure_option for this argument.
  • early_stopping (int, optional) – Early stop the tuning when not finding better configs in this number of trials
  • callbacks (List of callable) – A list of callback functions. The signature of callback function is (Tuner, List of MeasureInput, List of MeasureResult) with no return value. These callback functions will be called on every measurement pair. See autotvm/tuner/callback.py for some examples.
update(inputs, results)

Update parameters of the tuner according to measurement results

Parameters:
  • inputs (Array of autotvm.measure.MeasureInput) – The input for measurement
  • results (Array of autotvm.measure.MeasureResult) – result for measurement
class tvm.autotvm.tuner.GATuner(task, pop_size=100, elite_num=3, mutation_prob=0.1)

Tuner with genetic algorithm. This tuner does not have a cost model so it always run measurement on real machines. This tuner expands the ConfigEntity as gene.

Parameters:
  • pop_size (int) – number of genes in one generation
  • elite_num (int) – number of elite to keep
  • mutation_prob (float) – probability of mutation of a knob in a gene
has_next()

Whether has next untried config in the space

Returns:has_next
Return type:bool
load_history(data_set)

load history data for transfer learning

Parameters:data_set (Array of (MeasureInput, MeasureResult) pair) – Previous tuning records
next_batch(batch_size)

get the next batch of configs to be measure on real hardware

Parameters:batch_size (int) – The size of the batch
Returns:
Return type:a batch of configs
reset()

reset the status of tuner

tune(n_trial, measure_option, early_stopping=None, callbacks=())

Begin tuning

Parameters:
  • n_trial (int) – Maximum number of configs to try (measure on real hardware)
  • measure_option (dict) – The options for how to measure generated code. You should use the return value ot autotvm.measure_option for this argument.
  • early_stopping (int, optional) – Early stop the tuning when not finding better configs in this number of trials
  • callbacks (List of callable) – A list of callback functions. The signature of callback function is (Tuner, List of MeasureInput, List of MeasureResult) with no return value. These callback functions will be called on every measurement pair. See autotvm/tuner/callback.py for some examples.
update(inputs, results)

Update parameters of the tuner according to measurement results

Parameters:
  • inputs (Array of autotvm.measure.MeasureInput) – The input for measurement
  • results (Array of autotvm.measure.MeasureResult) – result for measurement
class tvm.autotvm.tuner.XGBTuner(task, plan_size=64, feature_type='itervar', loss_type='rank', num_threads=None, optimizer='sa', diversity_filter_ratio=None, log_interval=50)

Tuner that uses xgboost as cost model

Parameters:
  • task (Task) – The tuning task
  • plan_size (int) – The size of a plan. After plan_size trials, the tuner will refit a new cost model and do planing for the next plan_size trials.
  • feature_type (str, optional) –

    If is ‘itervar’, use features extracted from IterVar (loop variable). If is ‘knob’, use flatten ConfigEntity directly. If is ‘curve’, use sampled curve feature (relation feature).

    Note on choosing feature type: For single task tuning, ‘itervar’ and ‘knob’ are good.

    ’itervar’ is more accurate but ‘knob’ is much faster. There are some constraints on ‘itervar’, if you meet problems with feature extraction when using ‘itervar’, you can swith to ‘knob’.
    For cross-shape tuning (e.g. many convolutions with different shapes),
    ’itervar’ and ‘curve’ has better transferability, ‘knob’ is faster.

    For cross-device or cross-operator tuning, you can use ‘curve’ only.

  • loss_type (str) –
    If is ‘reg’, use regression loss to train cost model.
    The cost model predicts the normalized flops.
    If is ‘rank’, use pairwise rank loss to train cost model.
    The cost model predicts relative rank score.
  • num_threads (int, optional) – The number of threads. optimizer: str or ModelOptimizer, optional If is ‘sa’, use a default simulated annealing optimizer. Otherwise it should be a ModelOptimizer object.
  • diversity_filter_ratio (int or float, optional) – If is not None, the tuner will first select top-(plan_size * diversity_filter_ratio) candidates according to the cost model and then pick batch_size of them according to the diversity metric.
  • log_interval (int, optional) – The verbose level. If is 0, output nothing. Otherwise, output debug information every verbose iterations.
has_next()

Whether has next untried config in the space

Returns:has_next
Return type:bool
load_history(data_set)

load history data for transfer learning

Parameters:data_set (Array of (MeasureInput, MeasureResult) pair) – Previous tuning records
next_batch(batch_size)

get the next batch of configs to be measure on real hardware

Parameters:batch_size (int) – The size of the batch
Returns:
Return type:a batch of configs
reset()

reset the status of tuner

tune(*args, **kwargs)

Begin tuning

Parameters:
  • n_trial (int) – Maximum number of configs to try (measure on real hardware)
  • measure_option (dict) – The options for how to measure generated code. You should use the return value ot autotvm.measure_option for this argument.
  • early_stopping (int, optional) – Early stop the tuning when not finding better configs in this number of trials
  • callbacks (List of callable) – A list of callback functions. The signature of callback function is (Tuner, List of MeasureInput, List of MeasureResult) with no return value. These callback functions will be called on every measurement pair. See autotvm/tuner/callback.py for some examples.
update(inputs, results)

Update parameters of the tuner according to measurement results

Parameters:
  • inputs (Array of autotvm.measure.MeasureInput) – The input for measurement
  • results (Array of autotvm.measure.MeasureResult) – result for measurement

Namespace of callback utilities of AutoTVM

class tvm.autotvm.tuner.callback.Monitor

A monitor to collect statistic during tuning

trial_scores()

get scores (currently is flops) of all trials

trial_timestamps()

get wall clock time stamp of all trials

tvm.autotvm.tuner.callback.log_to_database(db)

Save the tuning records to a database object.

Parameters:db (Database) – The database
tvm.autotvm.tuner.callback.log_to_file(file_out, protocol='json')

Log the tuning records into file. The rows of the log are stored in the format of autotvm.record.encode.

Parameters:
  • file_out (File or str) – The file to log to.
  • protocol (str, optional) – The log protocol. Can be ‘json’ or ‘pickle’
Returns:

callback – Callback function to do the logging.

Return type:

callable

tvm.autotvm.tuner.callback.progress_bar(total, prefix='')

Display progress bar for tuning

Parameters:
  • total (int) – The total number of trials
  • prefix (str) – The prefix of output message

tvm.autotvm.task

Task is a tunable composition of template functions.

Tuner takes a tunable task and optimizes the joint configuration space of all the template functions in the task. This module defines the task data structure, as well as a collection(zoo) of typical tasks of interest.

Definition of task function.

Task can be constructed from tuple of func, args, and kwargs. func is a state-less function, or a string that registers the standard task.

exception tvm.autotvm.task.task.FlopCalculationError

Error happens when estimating FLOP for a compute op

class tvm.autotvm.task.task.Task(name, args)

A Tunable Task

Parameters:
  • name (str) – The name of the task.
  • args (Tuple) – Positional argument of func
instantiate(config)

Instantiate this task function (template) with a config. Returns corresponding schedule.

Parameters:config (template.ConfigEntity) – parameter config for this template
Returns:
  • sch (tvm.schedule.Schedule) – The tvm schedule
  • arg_bufs (Array of tvm.tensor.Tensor) – The input/output buffers
tvm.autotvm.task.task.args_to_workload(x, topi_compute_func=None)

Convert argument list to hashable workload tuple. This function will convert list to tuple, tvm node to python value and flatten tvm.tensor.Tensor to a tuple

Parameters:
  • x (primitive hashable types or tensor.Tensor) – The original value
  • topi_compute_func (topi compute function) – The function name will be added as first element of the workload tuple
Returns:

ret – The hashable value

Return type:

hashable

tvm.autotvm.task.task.compute_flop(sch)

Calculate number of FLOP (floating number operations) of the compute ops in a schedule

Parameters:sch (tvm.schedule.Schedule) – schedule
Returns:flop – number of FLOP in this schedule
Return type:int
tvm.autotvm.task.task.create(func_name, args, target, target_host=None, template_key=None)

Create a tuning task and initialize its search space

Parameters:
  • func_name (str or callable) – The task function
  • args (List) – Positional arguments
  • target (Target) – The compilation target
  • target_host (Target, optional) – The compilation target for host side
Returns:

tsk – a task object

Return type:

Task

tvm.autotvm.task.task.get_config()

Get current config object

Returns:cfg – The current config
Return type:ConfigSpace or ConfigEntity
tvm.autotvm.task.task.register(name, func=None, override=False)

Register a task function.

Parameters:
  • name (str) – The name to identify the task.
  • func (callable) – The function to be registered.
  • override (bool) – Whether override existing registration.
Returns:

func – The registered function

Return type:

callable

tvm.autotvm.task.task.template(func)

Decorate a function as a tunable schedule template

Parameters:func (callable) – A callable template function. Its argument should be hashable values. Its return value should be a Tuple(Schedule, Array of Tensor)
Returns:func – The decorated function
Return type:callable

Examples

The following code is a tunable template for a blocked matrix multiplication

@autotvm.template
def matmul(N, L, M, dtype):
    A = tvm.placeholder((N, L), name='A', dtype=dtype)
    B = tvm.placeholder((L, M), name='B', dtype=dtype)

    k = tvm.reduce_axis((0, L), name='k')
    C = tvm.compute((N, M), lambda i, j: tvm.sum(A[i, k] * B[k, j], axis=k), name='C')
    s = tvm.create_schedule(C.op)

    # schedule
    y, x = s[C].op.axis
    k = s[C].op.reduce_axis[0]

    ##### define space begin #####
    cfg = autotvm.get_config()
    cfg.define_split("tile_y", y, num_outputs=2)
    cfg.define_split("tile_x", x, num_outputs=2)
    ##### define space end #####

    # schedule according to config
    yo, yi = cfg["tile_y"].apply(s, C, y)
    xo, xi = cfg["tile_x"].apply(s, C, x)

    s[C].reorder(yo, xo, k, yi, xi)

    return s, [A, B, C]

Template configuration space.

Each template function can be parametrized by a ConfigSpace. The space is declared when we invoke the template function with ConfigSpace. During evaluation, we pass in a ConfigEntity, which contains a specific entity in the space. This entity contains deterministic parameters.

class tvm.autotvm.task.space.AnnotateEntity(anns)

An annotation operation with detailed parameters that can apply to axes

Parameters:anns (Array of string) – The annotations of axes
apply(sch, op, axes, axis_lens=None, max_unroll=None, vec_size=None, cfg=None, source=None)

Apply annotation to an array of axes

Parameters:
  • sch (tvm.schedule.Schedule) – The tvm schedule
  • op (tvm.tensor.Operation) – The stage to be applied
  • axes (Array of tvm.schedule.IterVar) – axis to split
  • axis_lens (Array of int, optional) – the length of axes
  • max_unroll (int, optional) – maximum unroll step
  • vec_size (Array of int, optional) – valid vector lanes for vectorization
  • cfg (ConfigEntity, optional) – cfg for recording error
  • source (Array of Array tensor, optional) – source tensor for attaching cache
Returns:

axes – The transformed axes

Return type:

list of tvm.schedule.IterVar

class tvm.autotvm.task.space.AnnotateSpace(axes, policy, **kwargs)

The parameter space for annotating an array of axes

static get_num_output(axes, policy, **kwargs)

get number of output axes after this transform

Returns:n – number of output axes
Return type:int
class tvm.autotvm.task.space.Axis(space, index)
index

Alias for field number 1

space

Alias for field number 0

class tvm.autotvm.task.space.ConfigEntity(index, code_hash, template_key, entity_map, constraints)

A configuration with detailed parameters

Parameters:
  • index (int) – index of this config in space
  • code_hash (str) – hash of schedule code
  • template_key (str) – The specific template key
  • entity_map (dict) – map name to transform entity
  • constraints (list) – List of constraints
static from_json_dict(json_dict)

Build a ConfigEntity from json serializable dictionary

Parameters:json_dict (dict) – Json serializable dictionary. This should be the return value of to_json_dict.
Returns:config – The corresponding config object
Return type:ConfigEntity
get_flatten_feature()

flatten entities to a numerical one-dimensional feature vector

Returns:fea – one dimensional float32 array
Return type:np.array
get_other_option()
Returns:other_option – other tunable parameters (tunable parameters defined by cfg.define_knob)
Return type:dict
to_json_dict()

convert to a json serializable dictionary

Returns:json_dict – a json serializable dictionary
Return type:dict
class tvm.autotvm.task.space.ConfigSpace

The configuration space of a schedule. Pass it as config in template to collect transformation space and build transform graph of axes

add_flop(flop)

Add float operation statistics for this tuning task

Parameters:flop (int or float) – number of float operations
static axis(var)

get a virtual axis (axis placeholder)

Parameters:var (int or tvm.schedule.IterVar) –

If is int, return an axis whose length is the provided argument. If is IterVar, return an axis whose length is extracted from the

IterVar’s extent domain.
define_annotate(name, axes, policy, **kwargs)

Define a new tunable knob which annotates a list of axes

Parameters:
  • name (str) – name to index the entity of this space
  • axes (Array of tvm.schedule.IterVar) – axes to annotate
  • policy (str) – name of policy If is ‘unroll’, unroll the axes. If is ‘try_unroll’, try to unroll the axes. If is ‘try_unroll_vec’, try to unroll or vectorize the axes. If is ‘bind_gpu’, bind the first few axes to gpu threads. If is ‘locate_cache’, choose n axes to attach shared/local cache.
  • kwargs (dict) – extra arguments for policy
define_knob(name, candidate)

Define a tunable knob with a list of candidates

Parameters:
  • name (str) – name key of that option
  • candidate (list) – list of candidates
define_reorder(name, axes, policy, **kwargs)

Define a new tunable knob which reorders a list of axes

Parameters:
  • name (str) – name to index the entity of this space
  • axes (Array of tvm.schedule.IterVar) – axes to reorder
  • policy (str) – name of policy If is ‘identity’, do an identity permutation. If is ‘all’, try all permutations. If is ‘interval_all’, try all permutations of an interval of axes. If is ‘candidate’, try listed candidate. If is ‘interleave’, interleave chains of spatial axes and chains of reduction axes.
  • kwargs (dict) – extra arguments for policy
define_split(name, axis, policy='all', **kwargs)

Define a new tunable knob which splits an axis into a list of axes

Parameters:
  • name (str) – name to index the entity of this space
  • axis (tvm.schedule.IterVar) – axis to split
  • policy (str) – name of policy. If is ‘all’, the tuner will try all divisible factors. If is ‘candidate’, try listed candidate.
  • kwargs (dict) – extra arguments for policy see examples below for how to use filter

Examples

>>> # use custom candidates
>>> cfg.define_split('tile_x', x, policy='candidate', candidate=[[1, 4, 4], [4, 1, 4]])
>>> # use a filter that only accepts the split scheme whose inner most tile is less then 4
>>> cfg.define_split('tile_y', y, policy='all', filter=lambda x: x.size[-1] <= 4)
get(index)

Get a config entity with detailed parameters from this space

Parameters:index (int) – index in the space
raise_error(msg)

register error in config Using this to actively detect error when scheudling. Otherwise these error will occur during runtime, which will cost more time.

Parameters:msg (str) –
static reduce_axis(var)

get a virtual axis (axis placeholder)

Parameters:var (int or tvm.schedule.IterVar) –

If is int, return an axis whose length is the provided argument. If is IterVar, return an axis whose length is extracted from the

IterVar’s extent domain.
valid()

Check whether the config meets all the constraints Note: This check should be called after instantiation of task,

because the ConfigEntity/ConfigSpace collects errors during instantiation
Returns:valid – whether the config meets all the constraints
Return type:bool
class tvm.autotvm.task.space.FallbackConfigEntity

The config entity created to support fallback

fallback_split(name, constraints)

Fallback a split knob

Parameters:
  • name (str) – name of the knob
  • constraints (List of int) – The maximum tile size for every dimension. Value -1 means no constraint.

Examples

If you use cfg.define_split(‘tile_0’, 128, num_outputs=3), Then cfg.fallback_split(‘tile_0’, [-1, 8, 4]) will give you cfg[‘tile_0’].size = [4, 8, 4]

If you use cfg.define_split(‘tile_0’, 49, num_outputs=3), Then cfg.fallback_split(‘tile_0’, [-1, 8, 4]) will give you cfg[‘tile_0’].size = [7, 7, 1]

fallback_with_reference_log(ref_log)

A data driven fallback mechanism. We use tuned parameters from TopHub as reference data. For an unseen shape, we find the most similar tuned one from TopHub and mimic its parameters.

Parameters:ref_log (List of (MeasureInput, MeasureResult)) – The reference log
exception tvm.autotvm.task.space.InstantiationError

Actively detected error in instantiating a template with a config, raised by cfg.raise_error e.g. too many unrolling, too many threads in a block

class tvm.autotvm.task.space.OtherOptionEntity(val)

The parameter entity for general option, with a detailed value

class tvm.autotvm.task.space.OtherOptionSpace(axes, policy, **kwargs)

The parameter space for general option

static get_num_output(axes, policy, **kwargs)

get number of output axes after this transform

Returns:n – number of output axes
Return type:int
class tvm.autotvm.task.space.ReorderEntity(perm)

A reorder operation with detailed parameters that can apply to axes

Parameters:perm (Array of int) – define the permutation
apply(sch, op, axes)

Apply reorder to an array of axes

Parameters:
Returns:

axes – The transformed axes.

Return type:

list of Axis

class tvm.autotvm.task.space.ReorderSpace(axes, policy, **kwargs)

The parameter space for ordering an array of axes

static get_num_output(axes, policy, **kwargs)

get number of output axes after this transform

Returns:n – number of output axes
Return type:int
class tvm.autotvm.task.space.SplitEntity(size)

A split operation with detailed parameters that can apply to an axis

Parameters:size (Array of int) –

the size of every axis after split e.g. an axis of extent 128, we split it into 3 axes, a possible

size is [4, 4, 8] (4x4x8 = 128)
apply(sch, op, axis)

Apply split to an axis

Parameters:
Returns:

axes – The transformed axes.

Return type:

list of Axis

class tvm.autotvm.task.space.SplitSpace(axes, policy, **kwargs)

Split an axis for several times

static get_num_output(axes, policy, **kwargs)

get number of output axes after this transform

Returns:n – number of output axes
Return type:int
class tvm.autotvm.task.space.TransformSpace

Base class for transform space TransformSpace is the node in the computation graph of axes

Note

We can regard our schedule code as a transformation graph of axes. Starting from raw axes in the definition of tvm.compute, we can transform these axes by some operators. The operator includes ‘split’, ‘reorder’ and ‘annotate’. Each operator has some tunable parameters (e.g. the split factor). Then the tuning process is just to find good parameters of these op.

So the all the combinations of the parameters of these op forms our search space.

Naming convention: We call the set of all possible values as XXXSpace. (XXX can be Split, Reorder, Config …) We call a specific entity in a space as XXXEntity.

static get_num_output()

get number of output axes after this transform

Returns:n – number of output axes
Return type:int
class tvm.autotvm.task.space.VirtualAxis(var, name=None)

Axis placeholder in template

Parameters:
  • var (int or tvm.schedule.IterVar) –

    If is int, return a virtual axis whose length is the provided argument. If is IterVar, return a virtual axis whose length is extracted from

    the IterVar’s extent domain.
  • name (str) –
static get_num_output(var, name=None)

get number of output axes after this transform

Returns:n – number of output axes
Return type:int
tvm.autotvm.task.space.get_factors(n)

return all factors of an integer

Parameters:n (int) – integer to factorize
Returns:factors – List of all factors
Return type:list

Template dispatcher module.

A dispatcher is a function that can contains multiple behaviors. Its specific behavior is can be controlled by DispatchContext.

DispatchContext is used in two ways, usually via different implementation of the DispatchContext base class.

  • During search, we can use it to pass the current proposal from tuner.
  • During evaluation, we can use it to set pick the best policy.
class tvm.autotvm.task.dispatcher.ApplyConfig(config)

Apply a deterministic config entity for all queries.

Parameters:config (ConfigSpace or ConfigEntity) – The specific configuration we care about.
update(target, workload, cfg)

Override update

class tvm.autotvm.task.dispatcher.ApplyGraphBest(records)

Load the graph level tuning optimal schedules.

The input records should be in the ascending order of node index for target operator. Usually this can be obtained with graph tuner.

This context maintains an internal counter to indicate the current node index.

update(target, workload, cfg)

Update context with a specific config.

Parameters:
  • target (Target) – The current target
  • workload (Workload) – The current workload.
  • cfg (ConfigSpace) – The specific configuration.

Note

This interface is for cases when TVM decides to replace an operator in the graph. For example, AlterOpLayout pass (enables when opt_level = 3) replaces NCHW convolution with NCHW[x]c implementation on x86 CPUs. Thus in TOPI, we first query schedule using original NCHW workload, then update the dispatcher with the new NCHW[x]c workload. So that later on, NCHW[x]c convolution can get schedule from the dispatcher using its own workload directly.

@conv2d_alter_layout.register("cpu")
def _alter_conv2d_layout(attrs, inputs, tinfo):
    workload = get_conv2d_workload(...)
    dispatch_ctx = autotvm.task.DispatchContext.current
    target = tvm.target.current_target()
    config = dispatch_ctx.query(target, workload)

    # Get conv2d_NCHWc workload from config
    # new_workload = ...
    # new_inputs = ...
    # new_attrs = ...

    # Store altered operator's config
    dispatch_ctx.update(target, new_workload, config)
    return sym.contrib.conv2d_NCHWc(*new_inputs, **new_attrs)

We directly store config back because conv2d_NCHW and conv2d_NCHWc share the same schedule parameters. One can construct a new ConfigEntity if this is not the case.

class tvm.autotvm.task.dispatcher.ApplyHistoryBest(records)

Apply the history best config

Parameters:records (str or iterator of (MeasureInput, MeasureResult)) –

Collection of tuning records. If is str, then it should be the filename of a records log file.

Each row of this file is an encoded record pair.

Otherwise, it is an iterator.

load(records)

Load records to this dispatch context

Parameters:records (str or iterator of (MeasureInput, MeasureResult)) –

Collection of tuning records. If is str, then it should be the filename of a records log file.

Each row of this file is an encoded record pair.

Otherwise, it is an iterator.

update(target, workload, cfg)

Update context with a specific config.

Parameters:
  • target (Target) – The current target
  • workload (Workload) – The current workload.
  • cfg (ConfigSpace) – The specific configuration.

Note

This interface is for cases when TVM decides to replace an operator in the graph. For example, AlterOpLayout pass (enables when opt_level = 3) replaces NCHW convolution with NCHW[x]c implementation on x86 CPUs. Thus in TOPI, we first query schedule using original NCHW workload, then update the dispatcher with the new NCHW[x]c workload. So that later on, NCHW[x]c convolution can get schedule from the dispatcher using its own workload directly.

@conv2d_alter_layout.register("cpu")
def _alter_conv2d_layout(attrs, inputs, tinfo):
    workload = get_conv2d_workload(...)
    dispatch_ctx = autotvm.task.DispatchContext.current
    target = tvm.target.current_target()
    config = dispatch_ctx.query(target, workload)

    # Get conv2d_NCHWc workload from config
    # new_workload = ...
    # new_inputs = ...
    # new_attrs = ...

    # Store altered operator's config
    dispatch_ctx.update(target, new_workload, config)
    return sym.contrib.conv2d_NCHWc(*new_inputs, **new_attrs)

We directly store config back because conv2d_NCHW and conv2d_NCHWc share the same schedule parameters. One can construct a new ConfigEntity if this is not the case.

class tvm.autotvm.task.dispatcher.DispatchContext

Base class of dispatch context.

DispatchContext enables the target and workload specific dispatch mechanism for templates.

query(target, workload)

Query the context to get the specific config for a template. If cannot find the result inside this context, this function will query it from the upper contexts.

Parameters:
  • target (Target) – The current target
  • workload (Workload) – The current workload.
Returns:

cfg – The specific configuration.

Return type:

ConfigSpace

update(target, workload, cfg)

Update context with a specific config.

Parameters:
  • target (Target) – The current target
  • workload (Workload) – The current workload.
  • cfg (ConfigSpace) – The specific configuration.

Note

This interface is for cases when TVM decides to replace an operator in the graph. For example, AlterOpLayout pass (enables when opt_level = 3) replaces NCHW convolution with NCHW[x]c implementation on x86 CPUs. Thus in TOPI, we first query schedule using original NCHW workload, then update the dispatcher with the new NCHW[x]c workload. So that later on, NCHW[x]c convolution can get schedule from the dispatcher using its own workload directly.

@conv2d_alter_layout.register("cpu")
def _alter_conv2d_layout(attrs, inputs, tinfo):
    workload = get_conv2d_workload(...)
    dispatch_ctx = autotvm.task.DispatchContext.current
    target = tvm.target.current_target()
    config = dispatch_ctx.query(target, workload)

    # Get conv2d_NCHWc workload from config
    # new_workload = ...
    # new_inputs = ...
    # new_attrs = ...

    # Store altered operator's config
    dispatch_ctx.update(target, new_workload, config)
    return sym.contrib.conv2d_NCHWc(*new_inputs, **new_attrs)

We directly store config back because conv2d_NCHW and conv2d_NCHWc share the same schedule parameters. One can construct a new ConfigEntity if this is not the case.

class tvm.autotvm.task.dispatcher.FallbackContext

A fallback dispatch context.

Any tunable template can be called under this context. This is the root context.

clear_cache(target, workload)

Clear fallback cache. Pass the same argument as _query_inside to this function to clean the cache.

Parameters:
  • target (Target) – The current target
  • workload (Workload) – The current workload.
update(target, workload, cfg)

Update context with a specific config.

Parameters:
  • target (Target) – The current target
  • workload (Workload) – The current workload.
  • cfg (ConfigSpace) – The specific configuration.

Note

This interface is for cases when TVM decides to replace an operator in the graph. For example, AlterOpLayout pass (enables when opt_level = 3) replaces NCHW convolution with NCHW[x]c implementation on x86 CPUs. Thus in TOPI, we first query schedule using original NCHW workload, then update the dispatcher with the new NCHW[x]c workload. So that later on, NCHW[x]c convolution can get schedule from the dispatcher using its own workload directly.

@conv2d_alter_layout.register("cpu")
def _alter_conv2d_layout(attrs, inputs, tinfo):
    workload = get_conv2d_workload(...)
    dispatch_ctx = autotvm.task.DispatchContext.current
    target = tvm.target.current_target()
    config = dispatch_ctx.query(target, workload)

    # Get conv2d_NCHWc workload from config
    # new_workload = ...
    # new_inputs = ...
    # new_attrs = ...

    # Store altered operator's config
    dispatch_ctx.update(target, new_workload, config)
    return sym.contrib.conv2d_NCHWc(*new_inputs, **new_attrs)

We directly store config back because conv2d_NCHW and conv2d_NCHWc share the same schedule parameters. One can construct a new ConfigEntity if this is not the case.

tvm.autotvm.task.dispatcher.clear_fallback_cache(target, workload)

Clear fallback cache. Pass the same argument as _query_inside to this function to clean the cache.

Parameters:
  • target (Target) – The current target
  • workload (Workload) – The current workload.

Note

This is used in alter_op_layout to clear the bad cache created before call topi compute function

tvm.autotvm.task.dispatcher.dispatcher(fworkload)

Wrap a workload dispatcher function.

Parameters:fworkload (function) – The workload extraction function from arguments.
Returns:fdispatcher – A wrapped dispatcher function, which will dispatch based on DispatchContext and the current workload.
Return type:function

Decorators for registering tunable templates to TOPI.

These decorators can make your simple implementation be able to use different configurations for different workloads. Here we directly use all arguments to the TOPI call as “workload”, so make sure all the arguments (except tvm.Tensor) in you calls are hashable. For tvm.Tensor, we will serialize it to a hashable tuple.

See tvm/topi/python/topi/arm_cpu/depthwise_conv2d.py for example usage.

tvm.autotvm.task.topi_integration.register_topi_compute(topi_compute, target_keys, template_keys, func=None)

Register a tunable template for a topi compute function.

After the registration. This topi compute will become a configuration dispatcher. It uses all its argument as workload and dispatches configurations according to the input workload.

It also stores this “workload” to its final ComputeOp, which can be used to reconstruct “workload” in the following topi_schedule call.

Parameters:
  • topi_compute (GenericFunc) – The topi compute function that will be overloaded
  • target_keys (str or list of str) – The compilation target. The same as the argument of GenericFunc.register.
  • template_keys (str or list of str) – The template key. We might have several strategies for a single operator (e.g. direct, im2col, winograd). The template key is used to identity the algorithm strategy. Every operator must have a “direct” template, which is used by default.
  • func (None or callable) – If it is None, return a decorator. If is callable, decorate this function.
Returns:

decorator – A decorator

Return type:

callable

Examples

See tvm/topi/python/topi/arm_cpu/depthwise_conv2d.py for example usage.

tvm.autotvm.task.topi_integration.register_topi_schedule(topi_schedule, target_keys, template_keys, func=None)

Register a tunable template for a topi schedule function.

After the registration. This topi schedule will become a configuration dispatcher. It dispatches configurations according to the input workload.

Note that this function will try to find “workload” from all the ComputeOp in the input. You can attach “workload” to your compute op by using register_topi_compute.

Parameters:
  • topi_schedule (GenericFunc) – The topi schedule function that will be overloaded
  • target_keys (str or list of str) – The compilation target
  • template_keys (str or list of str) – The template key. We might have several strategies for a single operator (e.g. direct, im2col, winograd). The template key is used to identity the algorithm strategy. Every operator must have a “direct” template, which is used by default.
  • func (None or callable) – If it is None, return a decorator. If is callable, decorate this function.
Returns:

decorator – A decorator

Return type:

callable

Examples

See tvm/topi/python/topi/arm_cpu/depthwise_conv2d.py for example usage.

Decorator and utilities for the integration with TOPI and NNVM

class tvm.autotvm.task.nnvm_integration.TaskExtractEnv

Global environment for extracting tuning tasks from nnvm graph

static get()

Get the single instance of TaskExtractEnv

Returns:env – The single instance of TaskExtractEnv
Return type:TaskExtractEnv
get_tasks()

Get collected tasks

Returns:tasks – A list of tasks extracted from the nnvm graph
Return type:List of tuple(name, args)
reset(wanted_topi_funcs)

Reset task collections

Parameters:wanted_topi_funcs (List of function) – The topi function to be extracted
tvm.autotvm.task.nnvm_integration.deserialize_args(args)

The inverse function of serialize_args.

Parameters:args (list of hashable or Tensor) –
tvm.autotvm.task.nnvm_integration.extract_from_graph(graph, shape, dtype, target, symbols, target_host=None)

Extract tuning tasks from a nnvm graph.

This function collects tuning tasks by building the graph with a “tracing” target and tracing all the calls to topi.

Parameters:
  • graph (Graph) – The graph to tune
  • shape (dict of str to tuple) – The input shape to the graph
  • dtype (str or dict of str to str) – The input types to the graph
  • target (tvm.target.Target) – The compilation target
  • symbols (Array of nnvm.symbol) – Array of nnvm symbols want to be tuned
  • target_host (tvm.target.Target) – The host compilation target
Returns:

task – collected tasks

Return type:

Array of autotvm.task.Task

tvm.autotvm.task.nnvm_integration.extract_from_multiple_graph(graphs, shapes, dtypes, target, symbols, target_host=None)

Extract tuning tasks from multiple nnvm graphs.

This function is the multiple graph version of extract_from_graph

Parameters:
  • graphs (List of Graph) – The list of graphs to tune
  • shapes (List of dict of str to tuple) – The input shape to the graph
  • dtypes (List of str or dict of str to str) – The input types to the graph
  • target (tvm.target.Target) – The compilation target
  • symbols (Array of nnvm.symbol) – Array of nnvm symbols want to be tuned
  • target_host (tvm.target.Target) – The host compilation target
Returns:

task – collected tasks

Return type:

Array of autotvm.task.Task

tvm.autotvm.task.nnvm_integration.serialize_args(args)

serialize arguments of a topi function to a hashable tuple.

Parameters:args (list of hashable or Tensor) –

tvm.autotvm.record

Tuning record and serialization format

tvm.autotvm.record.decode(row, protocol='json')

Decode encoded record string to python object

Parameters:
  • row (str) – a row in the logger file
  • protocol (str) – log protocol, json or pickle
Returns:

  • input (autotvm.tuner.MeasureInput)
  • result (autotvm.tuner.MeasureResult)

tvm.autotvm.record.encode(inp, result, protocol='json')

encode (MeasureInput, MeasureResult) pair to a string

Parameters:
  • inp (autotvm.tuner.MeasureInput) –
  • result (autotvm.tuner.MeasureResult) – pair of input/result
  • protocol (str) – log protocol, json or pickle
Returns:

row – a row in the logger file

Return type:

str

tvm.autotvm.record.load_from_file(filename)

Generator: load records from file. This is a generator that yields the records.

Parameters:

filename (str) –

Yields:
  • input (autotvm.tuner.MeasureInput)
  • result (autotvm.tuner.MeasureResult)
tvm.autotvm.record.measure_str_key(inp, include_config=True)

get unique str key for MeasureInput

Parameters:
  • inp (MeasureInput) – input for the measure
  • include_config (bool, optional) – whether includes config in the str key
Returns:

key – The str representation of key

Return type:

str

tvm.autotvm.record.pick_best(in_file, out_file)

Pick best entries from a file and store it to another file. This distill the useful log entries from a large log file.

Parameters:
  • in_file (str) – The filename of input
  • out_file (str or file) – The filename of output
tvm.autotvm.record.split_workload(in_file, clean=True)

Split a log file into separate files, each of which contains only a single workload This function can also delete duplicated records in log file

Parameters:
  • in_file (str) – input filename
  • clean (bool) – whether delete duplicated items