Tuning High Performance Convolution on NVIDIA GPUs

Author: Lianmin Zheng

This is an advanced tutorial for writing high performance tunable template for NVIDIA GPU. By running auto-tuner on this template, we can outperform the vendor provided library CuDNN in many cases.

Install dependencies

To use autotvm package in tvm, we need to install some extra dependencies. (change “3” to “2” if you use python2):

pip3 install --user psutil xgboost tornado

To make tvm run faster in tuning, it is recommended to use cython as FFI of tvm. In the root directory of tvm, execute

pip3 install --user cython
sudo make cython3

Now return to python code. Import packages.

import logging
import sys
import numpy as np

import tvm
import topi
from topi.testing import conv2d_nchw_python

from tvm import autotvm

Step 1: Define the search space

There are plenty of useful schedule primitives in tvm. You can also find some tutorials that describe them in more details, such as (1). How to optimize convolution on GPU (2). Optimizing DepthwiseConv on NVIDIA GPU

However, their implementations are manually tuned for some special input shapes. In this section, we build a large enough space to cover the techniques used in these tutorials. Then we rely on the efficient auto-tuner to search through this space and pick some good configurations.

If you are familiar with writing cuda schedule, you can find the following template is very general. Actually this template can be easily modified to tune other operators such as depthwise convolution and gemm. In order to fully understand this template, you should be familiar with the schedule primitives and auto tuning API. You can refer to the above tutorials and autotvm tutorial

It is worth noting that the search space for a conv2d operator can be very large (at the level of 10^9 for some input shapes)

@autotvm.template
def conv2d_no_batching(N, H, W, CO, CI, KH, KW, stride, padding):
    assert N == 1, "Only consider batch_size = 1 in this template"

    data = tvm.placeholder((N, CI, H, W), name='data')
    kernel = tvm.placeholder((CO, CI, KH, KW), name='kernel')
    conv = topi.nn.conv2d_nchw(data, kernel, stride, padding, dilation=1, out_dtype='float32')
    s = tvm.create_schedule([conv.op])

    ##### space definition begin #####
    n, f, y, x = s[conv].op.axis
    rc, ry, rx = s[conv].op.reduce_axis

    cfg = autotvm.get_config()
    cfg.define_split("tile_f", f, num_outputs=4)
    cfg.define_split("tile_y", y, num_outputs=4)
    cfg.define_split("tile_x", x, num_outputs=4)
    cfg.define_split("tile_rc", rc, num_outputs=3)
    cfg.define_split("tile_ry", ry, num_outputs=3)
    cfg.define_split("tile_rx", rx, num_outputs=3)
    cfg.define_knob("auto_unroll_max_step", [0, 512, 1500])
    cfg.define_knob("unroll_explicit", [0, 1])
    ##### space definition end #####

    # inline padding
    pad_data = s[conv].op.input_tensors[0]
    s[pad_data].compute_inline()
    data, raw_data = pad_data, data

    output = conv
    OL = s.cache_write(conv, 'local')

    # create cache stage
    AA = s.cache_read(data, 'shared', [OL])
    WW = s.cache_read(kernel, 'shared', [OL])
    AL = s.cache_read(AA, 'local', [OL])
    WL = s.cache_read(WW, 'local', [OL])

    # tile and bind spatial axes
    n, f, y, x = s[output].op.axis
    bf, vf, tf, fi = cfg["tile_f"].apply(s, output, f)
    by, vy, ty, yi = cfg["tile_y"].apply(s, output, y)
    bx, vx, tx, xi = cfg["tile_x"].apply(s, output, x)
    kernel_scope = n  # this is the scope to attach global config inside this kernel

    s[output].bind(bf, tvm.thread_axis("blockIdx.z"))
    s[output].bind(by, tvm.thread_axis("blockIdx.y"))
    s[output].bind(bx, tvm.thread_axis("blockIdx.x"))
    s[output].bind(vf, tvm.thread_axis("vthread"))
    s[output].bind(vy, tvm.thread_axis("vthread"))
    s[output].bind(vx, tvm.thread_axis("vthread"))
    s[output].bind(tf, tvm.thread_axis("threadIdx.z"))
    s[output].bind(ty, tvm.thread_axis("threadIdx.y"))
    s[output].bind(tx, tvm.thread_axis("threadIdx.x"))
    s[output].reorder(n, bf, by, bx, vf, vy, vx, tf, ty, tx, fi, yi, xi)
    s[OL].compute_at(s[output], tx)

    # tile reduction axes
    n, f, y, x = s[OL].op.axis
    rc, ry, rx = s[OL].op.reduce_axis
    rco, rcm, rci = cfg['tile_rc'].apply(s, OL, rc)
    ryo, rym, ryi = cfg['tile_rx'].apply(s, OL, ry)
    rxo, rxm, rxi = cfg['tile_ry'].apply(s, OL, rx)
    s[OL].reorder(rco, ryo, rxo, rcm, rym, rxm, rci, ryi, rxi, n, f, y, x)

    s[AA].compute_at(s[OL], rxo)
    s[WW].compute_at(s[OL], rxo)
    s[AL].compute_at(s[OL], rxm)
    s[WL].compute_at(s[OL], rxm)

    # cooperative fetching
    for load in [AA, WW]:
        n, f, y, x = s[load].op.axis
        fused = s[load].fuse(n, f, y, x)
        tz, fused = s[load].split(fused, nparts=cfg["tile_f"].size[2])
        ty, fused = s[load].split(fused, nparts=cfg["tile_y"].size[2])
        tx, fused = s[load].split(fused, nparts=cfg["tile_x"].size[2])
        s[load].bind(tz, tvm.thread_axis("threadIdx.z"))
        s[load].bind(ty, tvm.thread_axis("threadIdx.y"))
        s[load].bind(tx, tvm.thread_axis("threadIdx.x"))

    # tune unroll
    s[output].pragma(kernel_scope, 'auto_unroll_max_step', cfg['auto_unroll_max_step'].val)
    s[output].pragma(kernel_scope, 'unroll_explicit', cfg['unroll_explicit'].val)

    return s, [raw_data, kernel, conv]

Step 2: Search through the space

We pick the last layer on resnet as test case. Since our space is very large, XGBoostTuner is most suitable for our case. Here we only do 20 trials for demonstration. In practice, making 1000 trials usually can find some good kernels for this template

# logging config (for printing tuning log to screen)
logging.getLogger('autotvm').setLevel(logging.DEBUG)
logging.getLogger('autotvm').addHandler(logging.StreamHandler(sys.stdout))

# the last layer in resnet
N, H, W, CO, CI, KH, KW, strides, padding = 1, 7, 7, 512, 512, 3, 3, (1, 1), (1, 1)
task = autotvm.task.create(conv2d_no_batching,
                           args=(N, H, W, CO, CI, KH, KW, strides, padding),
                           target='cuda')
print(task.config_space)

# Use local gpu, measure 10 times for every config to reduce variance
# The timeout of compiling a program is 10 seconds, the timeout for running is 4 seconds
measure_option = autotvm.measure_option(
    builder=autotvm.LocalBuilder(),
    runner=autotvm.LocalRunner(repeat=3, min_repeat_ms=100, timeout=4)
)

# Begin tuning, log records to file `conv2d.log`
# During tuning we will also try many invalid configs, so you are expected to
# see many error reports. As long as you can see non-zero GFLOPS, it is okay.
tuner = autotvm.tuner.XGBTuner(task)
tuner.tune(n_trial=20,
           measure_option=measure_option,
           callbacks=[autotvm.callback.log_to_file('conv2d.log')])

Out:

ConfigSpace (len=10454400, space_map=
   0 tile_f: Split(policy=all, product=512, num_outputs=4) len=220
   1 tile_y: Split(policy=all, product=7, num_outputs=4) len=4
   2 tile_x: Split(policy=all, product=7, num_outputs=4) len=4
   3 tile_rc: Split(policy=all, product=512, num_outputs=3) len=55
   4 tile_ry: Split(policy=all, product=3, num_outputs=3) len=3
   5 tile_rx: Split(policy=all, product=3, num_outputs=3) len=3
   6 auto_unroll_max_step: OtherOption([0, 512, 1500]) len=3
   7 unroll_explicit: OtherOption([0, 1]) len=2
)
Get devices for measurement successfully!
No: 1   GFLOPS: 0.00/0.00       result: MeasureResult(costs=(InstantiationError('Traceback (most recent call last):\n  [bt] (1) /workspace/build/libtvm.so(TVMFuncCall+0x61) [0x7f09c4952f01]\n  [bt] (0) /workspace/build/libtvm.so(+0x8e5a7b) [0x7f09c494fa7b]\n  File "/workspace/docs/../python/tvm/_ffi/_ctypes/function.py", line 55, in cfun\n    rv = local_pyfunc(*pyargs)\n  File "/workspace/docs/../python/tvm/autotvm/measure/measure_methods.py", line 579, in verify_pass\n    raise InstantiationError("Skipped because of invalid gpu kernel")\ntvm.autotvm.task.space.InstantiationError: Skipped because of invalid gpu kernel',),), error_no=1, all_cost=0.735872745513916, timestamp=1553483904.6039436)   [('tile_f', [1, 8, 8, 8]), ('tile_y', [1, 7, 1, 1]), ('tile_x', [1, 7, 1, 1]), ('tile_rc', [1, 256, 2]), ('tile_ry', [1, 1, 3]), ('tile_rx', [3, 1, 1]), ('auto_unroll_max_step', 1500), ('unroll_explicit', 1)],,None,9163817
No: 2   GFLOPS: 2.71/2.71       result: MeasureResult(costs=(0.08541304375,), error_no=0, all_cost=4.892483949661255, timestamp=1553483909.061465)      [('tile_f', [2, 32, 8, 1]), ('tile_y', [1, 1, 1, 7]), ('tile_x', [7, 1, 1, 1]), ('tile_rc', [64, 4, 2]), ('tile_ry', [1, 1, 3]), ('tile_rx', [3, 1, 1]), ('auto_unroll_max_step', 0), ('unroll_explicit', 0)],,None,430132
No: 3   GFLOPS: 50.39/50.39     result: MeasureResult(costs=(0.004594036181818182,), error_no=0, all_cost=3.148590326309204, timestamp=1553483910.01365)        [('tile_f', [8, 1, 32, 2]), ('tile_y', [1, 7, 1, 1]), ('tile_x', [1, 7, 1, 1]), ('tile_rc', [128, 4, 1]), ('tile_ry', [1, 3, 1]), ('tile_rx', [3, 1, 1]), ('auto_unroll_max_step', 512), ('unroll_explicit', 1)],,None,7171430
No: 4   GFLOPS: 0.00/50.39      result: MeasureResult(costs=(InstantiationError('Traceback (most recent call last):\n  [bt] (1) /workspace/build/libtvm.so(TVMFuncCall+0x61) [0x7f09c4952f01]\n  [bt] (0) /workspace/build/libtvm.so(+0x8e5a7b) [0x7f09c494fa7b]\n  File "/workspace/docs/../python/tvm/_ffi/_ctypes/function.py", line 55, in cfun\n    rv = local_pyfunc(*pyargs)\n  File "/workspace/docs/../python/tvm/autotvm/measure/measure_methods.py", line 579, in verify_pass\n    raise InstantiationError("Skipped because of invalid gpu kernel")\ntvm.autotvm.task.space.InstantiationError: Skipped because of invalid gpu kernel',),), error_no=1, all_cost=0.3991889953613281, timestamp=1553483907.1609626)  [('tile_f', [4, 8, 2, 8]), ('tile_y', [7, 1, 1, 1]), ('tile_x', [1, 1, 7, 1]), ('tile_rc', [2, 2, 128]), ('tile_ry', [1, 3, 1]), ('tile_rx', [3, 1, 1]), ('auto_unroll_max_step', 0), ('unroll_explicit', 0)],,None,371506
No: 5   GFLOPS: 61.13/61.13     result: MeasureResult(costs=(0.0037870489736842104,), error_no=0, all_cost=4.7902915477752686, timestamp=1553483915.8739607)    [('tile_f', [256, 1, 2, 1]), ('tile_y', [1, 1, 7, 1]), ('tile_x', [1, 1, 1, 7]), ('tile_rc', [32, 8, 2]), ('tile_ry', [1, 3, 1]), ('tile_rx', [3, 1, 1]), ('auto_unroll_max_step', 512), ('unroll_explicit', 0)],,None,1984850
No: 6   GFLOPS: 1.24/61.13      result: MeasureResult(costs=(0.18687951025,), error_no=0, all_cost=5.949897527694702, timestamp=1553483918.8701487)     [('tile_f', [64, 2, 1, 4]), ('tile_y', [7, 1, 1, 1]), ('tile_x', [7, 1, 1, 1]), ('tile_rc', [4, 32, 4]), ('tile_ry', [3, 1, 1]), ('tile_rx', [1, 3, 1]), ('auto_unroll_max_step', 0), ('unroll_explicit', 1)],,None,5892581
No: 7   GFLOPS: 145.13/145.13   result: MeasureResult(costs=(0.0015950836666666667,), error_no=0, all_cost=3.728163957595825, timestamp=1553483919.8190858)     [('tile_f', [8, 1, 64, 1]), ('tile_y', [1, 1, 1, 7]), ('tile_x', [1, 1, 7, 1]), ('tile_rc', [256, 2, 1]), ('tile_ry', [1, 1, 3]), ('tile_rx', [3, 1, 1]), ('auto_unroll_max_step', 512), ('unroll_explicit', 0)],,None,2135585
No: 8   GFLOPS: 2.77/145.13     result: MeasureResult(costs=(0.08359992375,), error_no=0, all_cost=4.9966466426849365, timestamp=1553483921.4892585)    [('tile_f', [128, 4, 1, 1]), ('tile_y', [1, 7, 1, 1]), ('tile_x', [1, 1, 7, 1]), ('tile_rc', [4, 8, 16]), ('tile_ry', [3, 1, 1]), ('tile_rx', [1, 1, 3]), ('auto_unroll_max_step', 0), ('unroll_explicit', 0)],,None,1293822
No: 9   GFLOPS: 6.96/145.13     result: MeasureResult(costs=(0.03326607125,), error_no=0, all_cost=9.728787422180176, timestamp=1553483931.8199544)     [('tile_f', [16, 16, 2, 1]), ('tile_y', [1, 7, 1, 1]), ('tile_x', [7, 1, 1, 1]), ('tile_rc', [32, 1, 16]), ('tile_ry', [3, 1, 1]), ('tile_rx', [1, 1, 3]), ('auto_unroll_max_step', 1500), ('unroll_explicit', 0)],,None,4766314
No: 10  GFLOPS: 0.00/145.13     result: MeasureResult(costs=(InstantiationError('Traceback (most recent call last):\n  [bt] (1) /workspace/build/libtvm.so(TVMFuncCall+0x61) [0x7f09c4952f01]\n  [bt] (0) /workspace/build/libtvm.so(+0x8e5a7b) [0x7f09c494fa7b]\n  File "/workspace/docs/../python/tvm/_ffi/_ctypes/function.py", line 55, in cfun\n    rv = local_pyfunc(*pyargs)\n  File "/workspace/docs/../python/tvm/autotvm/measure/measure_methods.py", line 579, in verify_pass\n    raise InstantiationError("Skipped because of invalid gpu kernel")\ntvm.autotvm.task.space.InstantiationError: Skipped because of invalid gpu kernel',),), error_no=1, all_cost=0.4989957809448242, timestamp=1553483930.662188)   [('tile_f', [8, 4, 16, 1]), ('tile_y', [1, 1, 1, 7]), ('tile_x', [7, 1, 1, 1]), ('tile_rc', [2, 8, 32]), ('tile_ry', [3, 1, 1]), ('tile_rx', [1, 1, 3]), ('auto_unroll_max_step', 512), ('unroll_explicit', 0)],,None,3056056
No: 11  GFLOPS: 0.00/145.13     result: MeasureResult(costs=(InstantiationError('Traceback (most recent call last):\n  [bt] (1) /workspace/build/libtvm.so(TVMFuncCall+0x61) [0x7f09c4952f01]\n  [bt] (0) /workspace/build/libtvm.so(+0x8e5a7b) [0x7f09c494fa7b]\n  File "/workspace/docs/../python/tvm/_ffi/_ctypes/function.py", line 55, in cfun\n    rv = local_pyfunc(*pyargs)\n  File "/workspace/docs/../python/tvm/autotvm/measure/measure_methods.py", line 579, in verify_pass\n    raise InstantiationError("Skipped because of invalid gpu kernel")\ntvm.autotvm.task.space.InstantiationError: Skipped because of invalid gpu kernel',),), error_no=1, all_cost=0.47098588943481445, timestamp=1553483930.6622982) [('tile_f', [4, 32, 2, 2]), ('tile_y', [7, 1, 1, 1]), ('tile_x', [7, 1, 1, 1]), ('tile_rc', [4, 2, 64]), ('tile_ry', [3, 1, 1]), ('tile_rx', [3, 1, 1]), ('auto_unroll_max_step', 0), ('unroll_explicit', 0)],,None,161989
No: 12  GFLOPS: 0.00/145.13     result: MeasureResult(costs=(InstantiationError('Traceback (most recent call last):\n  [bt] (1) /workspace/build/libtvm.so(TVMFuncCall+0x61) [0x7f09c4952f01]\n  [bt] (0) /workspace/build/libtvm.so(+0x8e5a7b) [0x7f09c494fa7b]\n  File "/workspace/docs/../python/tvm/_ffi/_ctypes/function.py", line 55, in cfun\n    rv = local_pyfunc(*pyargs)\n  File "/workspace/docs/../python/tvm/autotvm/measure/measure_methods.py", line 579, in verify_pass\n    raise InstantiationError("Skipped because of invalid gpu kernel")\ntvm.autotvm.task.space.InstantiationError: Skipped because of invalid gpu kernel',),), error_no=1, all_cost=0.17226195335388184, timestamp=1553483930.662431)  [('tile_f', [4, 2, 8, 8]), ('tile_y', [1, 1, 7, 1]), ('tile_x', [1, 7, 1, 1]), ('tile_rc', [2, 256, 1]), ('tile_ry', [3, 1, 1]), ('tile_rx', [3, 1, 1]), ('auto_unroll_max_step', 0), ('unroll_explicit', 0)],,None,29635
No: 13  GFLOPS: 0.00/145.13     result: MeasureResult(costs=(InstantiationError('Traceback (most recent call last):\n  [bt] (1) /workspace/build/libtvm.so(TVMFuncCall+0x61) [0x7f09c4952f01]\n  [bt] (0) /workspace/build/libtvm.so(+0x8e5a7b) [0x7f09c494fa7b]\n  File "/workspace/docs/../python/tvm/_ffi/_ctypes/function.py", line 55, in cfun\n    rv = local_pyfunc(*pyargs)\n  File "/workspace/docs/../python/tvm/autotvm/measure/measure_methods.py", line 579, in verify_pass\n    raise InstantiationError("Skipped because of invalid gpu kernel")\ntvm.autotvm.task.space.InstantiationError: Skipped because of invalid gpu kernel',),), error_no=1, all_cost=0.15493106842041016, timestamp=1553483933.1147625) [('tile_f', [4, 4, 8, 4]), ('tile_y', [1, 7, 1, 1]), ('tile_x', [7, 1, 1, 1]), ('tile_rc', [1, 32, 16]), ('tile_ry', [1, 3, 1]), ('tile_rx', [1, 1, 3]), ('auto_unroll_max_step', 0), ('unroll_explicit', 1)],,None,6720023
No: 14  GFLOPS: 0.00/145.13     result: MeasureResult(costs=(InstantiationError('Traceback (most recent call last):\n  [bt] (1) /workspace/build/libtvm.so(TVMFuncCall+0x61) [0x7f09c4952f01]\n  [bt] (0) /workspace/build/libtvm.so(+0x8e5a7b) [0x7f09c494fa7b]\n  File "/workspace/docs/../python/tvm/_ffi/_ctypes/function.py", line 55, in cfun\n    rv = local_pyfunc(*pyargs)\n  File "/workspace/docs/../python/tvm/autotvm/measure/measure_methods.py", line 579, in verify_pass\n    raise InstantiationError("Skipped because of invalid gpu kernel")\ntvm.autotvm.task.space.InstantiationError: Skipped because of invalid gpu kernel',),), error_no=1, all_cost=0.39814043045043945, timestamp=1553483933.3197)    [('tile_f', [2, 16, 1, 16]), ('tile_y', [1, 1, 7, 1]), ('tile_x', [1, 1, 1, 7]), ('tile_rc', [32, 1, 16]), ('tile_ry', [1, 3, 1]), ('tile_rx', [1, 1, 3]), ('auto_unroll_max_step', 0), ('unroll_explicit', 0)],,None,1478128
No: 15  GFLOPS: 0.00/145.13     result: MeasureResult(costs=(InstantiationError('Traceback (most recent call last):\n  [bt] (1) /workspace/build/libtvm.so(TVMFuncCall+0x61) [0x7f09c4952f01]\n  [bt] (0) /workspace/build/libtvm.so(+0x8e5a7b) [0x7f09c494fa7b]\n  File "/workspace/docs/../python/tvm/_ffi/_ctypes/function.py", line 55, in cfun\n    rv = local_pyfunc(*pyargs)\n  File "/workspace/docs/../python/tvm/autotvm/measure/measure_methods.py", line 579, in verify_pass\n    raise InstantiationError("Skipped because of invalid gpu kernel")\ntvm.autotvm.task.space.InstantiationError: Skipped because of invalid gpu kernel',),), error_no=1, all_cost=0.5038576126098633, timestamp=1553483933.5677843)  [('tile_f', [2, 2, 64, 2]), ('tile_y', [1, 7, 1, 1]), ('tile_x', [1, 7, 1, 1]), ('tile_rc', [32, 2, 8]), ('tile_ry', [1, 1, 3]), ('tile_rx', [3, 1, 1]), ('auto_unroll_max_step', 1500), ('unroll_explicit', 1)],,None,9198955
No: 16  GFLOPS: 0.00/145.13     result: MeasureResult(costs=(InstantiationError('Traceback (most recent call last):\n  [bt] (1) /workspace/build/libtvm.so(TVMFuncCall+0x61) [0x7f09c4952f01]\n  [bt] (0) /workspace/build/libtvm.so(+0x8e5a7b) [0x7f09c494fa7b]\n  File "/workspace/docs/../python/tvm/_ffi/_ctypes/function.py", line 55, in cfun\n    rv = local_pyfunc(*pyargs)\n  File "/workspace/docs/../python/tvm/autotvm/measure/measure_methods.py", line 579, in verify_pass\n    raise InstantiationError("Skipped because of invalid gpu kernel")\ntvm.autotvm.task.space.InstantiationError: Skipped because of invalid gpu kernel',),), error_no=1, all_cost=0.12298727035522461, timestamp=1553483933.607738)  [('tile_f', [4, 32, 4, 1]), ('tile_y', [7, 1, 1, 1]), ('tile_x', [1, 1, 1, 7]), ('tile_rc', [4, 128, 1]), ('tile_ry', [1, 3, 1]), ('tile_rx', [3, 1, 1]), ('auto_unroll_max_step', 0), ('unroll_explicit', 0)],,None,220904
No: 17  GFLOPS: 0.00/145.13     result: MeasureResult(costs=(InstantiationError('Traceback (most recent call last):\n  [bt] (1) /workspace/build/libtvm.so(TVMFuncCall+0x61) [0x7f09c4952f01]\n  [bt] (0) /workspace/build/libtvm.so(+0x8e5a7b) [0x7f09c494fa7b]\n  File "/workspace/docs/../python/tvm/_ffi/_ctypes/function.py", line 55, in cfun\n    rv = local_pyfunc(*pyargs)\n  File "/workspace/docs/../python/tvm/autotvm/measure/measure_methods.py", line 579, in verify_pass\n    raise InstantiationError("Skipped because of invalid gpu kernel")\ntvm.autotvm.task.space.InstantiationError: Skipped because of invalid gpu kernel',),), error_no=1, all_cost=0.38558197021484375, timestamp=1553483935.0005279) [('tile_f', [4, 1, 128, 1]), ('tile_y', [1, 7, 1, 1]), ('tile_x', [1, 7, 1, 1]), ('tile_rc', [1, 1, 512]), ('tile_ry', [3, 1, 1]), ('tile_rx', [1, 1, 3]), ('auto_unroll_max_step', 512), ('unroll_explicit', 1)],,None,8322429
No: 18  GFLOPS: 0.00/145.13     result: MeasureResult(costs=(InstantiationError('Traceback (most recent call last):\n  [bt] (1) /workspace/build/libtvm.so(TVMFuncCall+0x61) [0x7f09c4952f01]\n  [bt] (0) /workspace/build/libtvm.so(+0x8e5a7b) [0x7f09c494fa7b]\n  File "/workspace/docs/../python/tvm/_ffi/_ctypes/function.py", line 55, in cfun\n    rv = local_pyfunc(*pyargs)\n  File "/workspace/docs/../python/tvm/autotvm/measure/measure_methods.py", line 579, in verify_pass\n    raise InstantiationError("Skipped because of invalid gpu kernel")\ntvm.autotvm.task.space.InstantiationError: Skipped because of invalid gpu kernel',),), error_no=1, all_cost=0.4398198127746582, timestamp=1553483935.4740965)  [('tile_f', [4, 1, 4, 32]), ('tile_y', [1, 1, 7, 1]), ('tile_x', [1, 7, 1, 1]), ('tile_rc', [4, 128, 1]), ('tile_ry', [3, 1, 1]), ('tile_rx', [3, 1, 1]), ('auto_unroll_max_step', 1500), ('unroll_explicit', 0)],,None,3510954
No: 19  GFLOPS: 0.00/145.13     result: MeasureResult(costs=(InstantiationError('Traceback (most recent call last):\n  [bt] (1) /workspace/build/libtvm.so(TVMFuncCall+0x61) [0x7f09c4952f01]\n  [bt] (0) /workspace/build/libtvm.so(+0x8e5a7b) [0x7f09c494fa7b]\n  File "/workspace/docs/../python/tvm/_ffi/_ctypes/function.py", line 55, in cfun\n    rv = local_pyfunc(*pyargs)\n  File "/workspace/docs/../python/tvm/autotvm/measure/measure_methods.py", line 579, in verify_pass\n    raise InstantiationError("Skipped because of invalid gpu kernel")\ntvm.autotvm.task.space.InstantiationError: Skipped because of invalid gpu kernel',),), error_no=1, all_cost=0.23189949989318848, timestamp=1553483935.4742439) [('tile_f', [8, 2, 8, 4]), ('tile_y', [1, 7, 1, 1]), ('tile_x', [7, 1, 1, 1]), ('tile_rc', [2, 8, 32]), ('tile_ry', [3, 1, 1]), ('tile_rx', [1, 1, 3]), ('auto_unroll_max_step', 0), ('unroll_explicit', 1)],,None,6540502
No: 20  GFLOPS: 0.00/145.13     result: MeasureResult(costs=(InstantiationError('Traceback (most recent call last):\n  [bt] (1) /workspace/build/libtvm.so(TVMFuncCall+0x61) [0x7f09c4952f01]\n  [bt] (0) /workspace/build/libtvm.so(+0x8e5a7b) [0x7f09c494fa7b]\n  File "/workspace/docs/../python/tvm/_ffi/_ctypes/function.py", line 55, in cfun\n    rv = local_pyfunc(*pyargs)\n  File "/workspace/docs/../python/tvm/autotvm/measure/measure_methods.py", line 579, in verify_pass\n    raise InstantiationError("Skipped because of invalid gpu kernel")\ntvm.autotvm.task.space.InstantiationError: Skipped because of invalid gpu kernel',),), error_no=1, all_cost=0.35700488090515137, timestamp=1553483935.519801)  [('tile_f', [1, 1, 4, 128]), ('tile_y', [1, 1, 7, 1]), ('tile_x', [1, 1, 1, 7]), ('tile_rc', [8, 8, 8]), ('tile_ry', [3, 1, 1]), ('tile_rx', [3, 1, 1]), ('auto_unroll_max_step', 0), ('unroll_explicit', 0)],,None,108895

Finally we can inspect the best config from log file, check correctness, and measure running time.

# inspect the best config
dispatch_context = autotvm.apply_history_best("conv2d.log")
best_config = dispatch_context.query(task.target, task.workload)
print("\nBest config:")
print(best_config)

# apply history best from log file
with autotvm.apply_history_best('conv2d.log'):
    with tvm.target.create("cuda"):
        s, arg_bufs = conv2d_no_batching(N, H, W, CO, CI, KH, KW, strides, padding)
        func = tvm.build(s, arg_bufs)

# check correctness
a_np = np.random.uniform(size=(N, CI, H, W)).astype(np.float32)
w_np = np.random.uniform(size=(CO, CI, KH, KW)).astype(np.float32)
c_np = conv2d_nchw_python(a_np, w_np, strides, padding)

ctx = tvm.gpu()
a_tvm = tvm.nd.array(a_np, ctx=ctx)
w_tvm = tvm.nd.array(w_np, ctx=ctx)
c_tvm = tvm.nd.empty(c_np.shape, ctx=ctx)
func(a_tvm, w_tvm, c_tvm)

tvm.testing.assert_allclose(c_np, c_tvm.asnumpy(), rtol=1e-2)

# Evaluate running time. Here we choose a large repeat number (400) to reduce the noise
# and the overhead of kernel launch. You can also use nvprof to validate the result.
evaluator = func.time_evaluator(func.entry_name, ctx, number=400)
print('Time cost of this operator: %f' % evaluator(a_tvm, w_tvm, c_tvm).mean)

Out:

Best config:
[('tile_f', [8, 1, 64, 1]), ('tile_y', [1, 1, 1, 7]), ('tile_x', [1, 1, 7, 1]), ('tile_rc', [256, 2, 1]), ('tile_ry', [1, 1, 3]), ('tile_rx', [3, 1, 1]), ('auto_unroll_max_step', 512), ('unroll_explicit', 0)],,None,2135585
Time cost of this operator: 0.001603

Total running time of the script: ( 0 minutes 41.909 seconds)

Gallery generated by Sphinx-Gallery