Introduction to TOPI

Author: Ehsan M. Kermani

This is an introductory tutorial to TVM Operator Inventory (TOPI). TOPI provides numpy-style generic operations and schedules with higher abstractions than TVM. In this tutorial, we will see how TOPI can save us from writing boilerplates code in TVM.

from __future__ import absolute_import, print_function

import tvm
import topi
import numpy as np

Basic example

Let’s revisit the sum of rows operation (equivalent to B = numpy.sum(A, axis=1)’) To compute the sum of rows of a two dimensional TVM tensor A, we should specify the symbolic operation as well as schedule as follows

n = tvm.var("n")
m = tvm.var("m")
A = tvm.placeholder((n, m), name='A')
k = tvm.reduce_axis((0, m), "k")
B = tvm.compute((n,), lambda i: tvm.sum(A[i, k], axis=k), name="B")
s = tvm.create_schedule(B.op)

and to examine the IR code in human readable format, we can do

print(tvm.lower(s, [A], simple_mode=True))

Out:

// attr [B] storage_scope = "global"
allocate B[float32 * n]
produce B {
  for (i, 0, n) {
    B[i] = 0.000000f
    for (k, 0, m) {
      B[i] = (B[i] + A[((i*m) + k)])
    }
  }
}

However, for such a common operation we had to define the reduce axis ourselves as well as explicit computation with :code: tvm.compute. Imagine for more complicated operations how much details we need to provide. Fortunately, we can replace those two lines with simple topi.sum much like :code`numpy.sum`

C = topi.sum(A, axis=1)
ts = tvm.create_schedule(C.op)
print(tvm.lower(ts, [A], simple_mode=True))

Out:

// attr [A_red] storage_scope = "global"
allocate A_red[float32 * n]
produce A_red {
  for (ax0, 0, n) {
    A_red[ax0] = 0.000000f
    for (k1, 0, m) {
      A_red[ax0] = (A_red[ax0] + A[((ax0*m) + k1)])
    }
  }
}

Numpy-style operator overloading

We can add two tensors using topi.broadcast_add that have correct (broadcastable with specific) shapes. Even shorter, TOPI provides operator overloading for such common operations. For example,

x, y = 100, 10
a = tvm.placeholder((x, y, y), name="a")
b = tvm.placeholder((y, y), name="b")
c = a + b  # same as topi.broadcast_add
d = a * b  # same as topi.broadcast_mul

Overloaded with the same syntax, TOPI handles broadcasting a primitive (int, float) to a tensor d - 3.14.

Generic schedules and fusing operations

Up to now, we have seen an example of how TOPI can save us from writing explicit computations in lower level API. But it doesn’t stop here. Still we did the scheduling as before. TOPI also provides higher level scheduling recipes depending on a given context. For example, for CUDA, we can schedule the following series of operations ending with topi.sum using only topi.generic.schedule_reduce

e = topi.elemwise_sum([c, d])
f = e / 2.0
g = topi.sum(f)
with tvm.target.cuda():
    sg = topi.generic.schedule_reduce(g)
    print(tvm.lower(sg, [a, b], simple_mode=True))

Out:

// attr [tensor_red] storage_scope = "global"
allocate tensor_red[float32 * 1]
produce tensor_red {
  // attr [iter_var(threadIdx.x, Range(min=0, extent=512), threadIdx.x)] thread_extent = 512
  // attr [tensor_red.rf] storage_scope = "local"
  allocate tensor_red.rf[float32 * 1]
  // attr [reduce_temp0] storage_scope = "local"
  allocate reduce_temp0[float32 * 1]
  produce tensor_red.rf {
    tensor_red.rf[0] = 0.000000f
    for (k0.k1.fused.k2.fused.outer, 0, 20) {
      if ((threadIdx.x < (10000 - (k0.k1.fused.k2.fused.outer*512)))) {
        tensor_red.rf[0] = (tensor_red.rf[0] + (((a[((((threadIdx.x + (k0.k1.fused.k2.fused.outer*512))/100)*100) + (((((threadIdx.x + (k0.k1.fused.k2.fused.outer*512))/10) % 10)*10) + ((threadIdx.x + (k0.k1.fused.k2.fused.outer*2)) % 10)))] + b[(((((threadIdx.x + (k0.k1.fused.k2.fused.outer*512))/10) % 10)*10) + ((threadIdx.x + (k0.k1.fused.k2.fused.outer*2)) % 10))]) + (a[((((threadIdx.x + (k0.k1.fused.k2.fused.outer*512))/100)*100) + (((((threadIdx.x + (k0.k1.fused.k2.fused.outer*512))/10) % 10)*10) + ((threadIdx.x + (k0.k1.fused.k2.fused.outer*2)) % 10)))]*b[(((((threadIdx.x + (k0.k1.fused.k2.fused.outer*512))/10) % 10)*10) + ((threadIdx.x + (k0.k1.fused.k2.fused.outer*2)) % 10))]))*0.500000f))
      }
    }
  }
  // attr [comm_reducer(result=[(x + y)], lhs=[x], rhs=[y], identity_element=[0.000000f])] reduce_scope = reinterpret((uint64)0)
  tvm_thread_allreduce((uint32)1, tensor_red.rf[0], (uint1)1, reduce_temp0, threadIdx.x)
  if ((threadIdx.x == 0)) {
    tensor_red[0] = reduce_temp0[0]
  }
}

As you can see, scheduled stages of computation have been accumulated and we can examine them by

print(sg.stages)

Out:

[stage(a, 0x65604da0), stage(b, 0xaf799c80), stage(tensor, 0x3bdc3bd0), stage(tensor, 0x4def62f0), stage(tensor, 0xadc8c900), stage(tensor, 0xaa0d8650), stage(tensor_red.rf, 0x6560f4f0), stage(tensor_red, 0x640c8450)]

We can test the correctness by comparing with numpy result as follows

func = tvm.build(sg, [a, b, g], 'cuda')
ctx = tvm.gpu(0)
a_np = np.random.uniform(size=(x, y, y)).astype(a.dtype)
b_np = np.random.uniform(size=(y, y)).astype(b.dtype)
g_np = np.sum(np.add(a_np + b_np, a_np * b_np) / 2.0)
a_nd = tvm.nd.array(a_np, ctx)
b_nd = tvm.nd.array(b_np, ctx)
g_nd = tvm.nd.array(np.zeros(g_np.shape, dtype=g_np.dtype), ctx)
func(a_nd, b_nd, g_nd)
tvm.testing.assert_allclose(g_nd.asnumpy(), g_np, rtol=1e-5)

TOPI also provides common neural nets operations such as _softmax_ with optimized schedule

tarray = tvm.placeholder((512, 512), name="tarray")
softmax_topi = topi.nn.softmax(tarray)
with tvm.target.create("cuda"):
    sst = topi.generic.schedule_softmax(softmax_topi)
    print(tvm.lower(sst, [tarray], simple_mode=True))

Out:

// attr [compute] storage_scope = "global"
allocate compute[float32 * 512]
// attr [compute] storage_scope = "global"
allocate compute[float32 * 512]
// attr [compute] storage_scope = "global"
allocate compute[float32 * 262144]
produce compute {
  // attr [iter_var(blockIdx.x, , blockIdx.x)] thread_extent = 512
  compute[blockIdx.x] = -340282346638528859811704183484516925440.000000f
  for (k, 0, 512) {
    compute[blockIdx.x] = max(compute[blockIdx.x], tarray[((blockIdx.x*512) + k)])
  }
}
produce compute {
  // attr [iter_var(blockIdx.x, , blockIdx.x)] thread_extent = 512
  // attr [compute.rf] storage_scope = "local"
  allocate compute.rf[float32 * 1]
  // attr [reduce_temp0] storage_scope = "local"
  allocate reduce_temp0[float32 * 1]
  // attr [iter_var(threadIdx.x, Range(min=0, extent=64), threadIdx.x)] thread_extent = 64
  produce compute.rf {
    compute.rf[0] = 0.000000f
    for (k.outer, 0, 8) {
      compute.rf[0] = (compute.rf[0] + exp((tarray[(((blockIdx.x*512) + threadIdx.x) + (k.outer*64))] - compute[blockIdx.x])))
    }
  }
  // attr [comm_reducer(result=[(x + y)], lhs=[x], rhs=[y], identity_element=[0.000000f])] reduce_scope = reinterpret((uint64)0)
  tvm_thread_allreduce((uint32)1, compute.rf[0], (uint1)1, reduce_temp0, threadIdx.x)
  if ((threadIdx.x == 0)) {
    compute[blockIdx.x] = reduce_temp0[0]
  }
}
produce compute {
  // attr [iter_var(blockIdx.x, , blockIdx.x)] thread_extent = 512
  // attr [iter_var(threadIdx.x, Range(min=0, extent=64), threadIdx.x)] thread_extent = 64
  for (i1.inner, 0, 8) {
    compute[((((blockIdx.x*64) + threadIdx.x)*8) + i1.inner)] = (exp((tarray[((((blockIdx.x*64) + threadIdx.x)*8) + i1.inner)] - compute[blockIdx.x]))/compute[blockIdx.x])
  }
}

Fusing convolutions

We can fuse topi.nn.conv2d and topi.nn.relu together.

Note

TOPI functions are all generic functions. They have different implementations for different backends to optimize for performance. For each backend, it is necessary to call them under a target scope for both compute declaration and schedule. TVM will choose the right function to call with the target information.

data = tvm.placeholder((1, 3, 224, 224))
kernel = tvm.placeholder((10, 3, 5, 5))

with tvm.target.create("cuda"):
    conv = topi.nn.conv2d(data, kernel, strides=1, padding=2, dilation=1)
    out = topi.nn.relu(conv)
    sconv = topi.generic.nn.schedule_conv2d_nchw(out)
    print(tvm.lower(sconv, [data, kernel], simple_mode=True))

Out:

// attr [compute] storage_scope = "global"
allocate compute[float32 * 501760]
produce compute {
  // attr [iter_var(blockIdx.z, , blockIdx.z)] thread_extent = 5
  // attr [compute] storage_scope = "local"
  allocate compute[float32 * 16]
  // attr [pad_temp.shared] storage_scope = "shared"
  allocate pad_temp.shared[float32 * 128]
  // attr [placeholder.shared] storage_scope = "shared"
  allocate placeholder.shared[float32 * 2]
  // attr [iter_var(blockIdx.y, , blockIdx.y)] thread_extent = 28
  // attr [iter_var(blockIdx.x, , blockIdx.x)] thread_extent = 14
  // attr [iter_var(threadIdx.z, , threadIdx.z)] thread_extent = 1
  // attr [iter_var(threadIdx.y, , threadIdx.y)] thread_extent = 1
  // attr [iter_var(threadIdx.x, , threadIdx.x)] thread_extent = 16
  produce compute {
    compute[0] = 0.000000f
    compute[1] = 0.000000f
    compute[2] = 0.000000f
    compute[3] = 0.000000f
    compute[4] = 0.000000f
    compute[5] = 0.000000f
    compute[6] = 0.000000f
    compute[7] = 0.000000f
    compute[8] = 0.000000f
    compute[9] = 0.000000f
    compute[10] = 0.000000f
    compute[11] = 0.000000f
    compute[12] = 0.000000f
    compute[13] = 0.000000f
    compute[14] = 0.000000f
    compute[15] = 0.000000f
    for (rc.outer, 0, 3) {
      for (ry.outer, 0, 5) {
        produce pad_temp.shared {
          // attr [iter_var(threadIdx.z, , threadIdx.z)] thread_extent = 1
          // attr [iter_var(threadIdx.y, , threadIdx.y)] thread_extent = 1
          // attr [iter_var(threadIdx.x, , threadIdx.x)] thread_extent = 16
          pad_temp.shared[(((threadIdx.x/2)*16) + ((threadIdx.x*8) % 16))] = tvm_if_then_else(((((((2 - (threadIdx.x/2)) - ry.outer) <= (blockIdx.y*8)) && ((blockIdx.y*8) < ((226 - (threadIdx.x/2)) - ry.outer))) && ((2 - ((threadIdx.x*8) % 16)) <= (blockIdx.x*16))) && ((blockIdx.x*16) < (226 - ((threadIdx.x*8) % 16)))), placeholder[((((((((blockIdx.y*112) + blockIdx.x) + (rc.outer*3136)) + (ry.outer*14)) + ((threadIdx.x/2)*14))*16) + ((threadIdx.x*8) % 16)) + -450)], 0.000000f)
          pad_temp.shared[((threadIdx.x*8) + 1)] = tvm_if_then_else(((((((2 - (((threadIdx.x*8) + 1)/16)) - ry.outer) <= (blockIdx.y*8)) && ((blockIdx.y*8) < ((226 - (((threadIdx.x*8) + 1)/16)) - ry.outer))) && ((2 - (((threadIdx.x*8) + 1) % 16)) <= (blockIdx.x*16))) && ((blockIdx.x*16) < (226 - (((threadIdx.x*8) + 1) % 16)))), placeholder[((((((((blockIdx.y*112) + blockIdx.x) + (rc.outer*3136)) + (ry.outer*14)) + ((((threadIdx.x*8) + 1)/16)*14))*16) + (((threadIdx.x*8) + 1) % 16)) + -450)], 0.000000f)
          pad_temp.shared[((threadIdx.x*8) + 2)] = tvm_if_then_else(((((((2 - (((threadIdx.x*8) + 2)/16)) - ry.outer) <= (blockIdx.y*8)) && ((blockIdx.y*8) < ((226 - (((threadIdx.x*8) + 2)/16)) - ry.outer))) && ((2 - (((threadIdx.x*8) + 2) % 16)) <= (blockIdx.x*16))) && ((blockIdx.x*16) < (226 - (((threadIdx.x*8) + 2) % 16)))), placeholder[((((((((blockIdx.y*112) + blockIdx.x) + (rc.outer*3136)) + (ry.outer*14)) + ((((threadIdx.x*8) + 2)/16)*14))*16) + (((threadIdx.x*8) + 2) % 16)) + -450)], 0.000000f)
          pad_temp.shared[((threadIdx.x*8) + 3)] = tvm_if_then_else(((((((2 - (((threadIdx.x*8) + 3)/16)) - ry.outer) <= (blockIdx.y*8)) && ((blockIdx.y*8) < ((226 - (((threadIdx.x*8) + 3)/16)) - ry.outer))) && ((2 - (((threadIdx.x*8) + 3) % 16)) <= (blockIdx.x*16))) && ((blockIdx.x*16) < (226 - (((threadIdx.x*8) + 3) % 16)))), placeholder[((((((((blockIdx.y*112) + blockIdx.x) + (rc.outer*3136)) + (ry.outer*14)) + ((((threadIdx.x*8) + 3)/16)*14))*16) + (((threadIdx.x*8) + 3) % 16)) + -450)], 0.000000f)
          pad_temp.shared[((threadIdx.x*8) + 4)] = tvm_if_then_else(((((((2 - (((threadIdx.x*8) + 4)/16)) - ry.outer) <= (blockIdx.y*8)) && ((blockIdx.y*8) < ((226 - (((threadIdx.x*8) + 4)/16)) - ry.outer))) && ((2 - (((threadIdx.x*8) + 4) % 16)) <= (blockIdx.x*16))) && ((blockIdx.x*16) < (226 - (((threadIdx.x*8) + 4) % 16)))), placeholder[((((((((blockIdx.y*112) + blockIdx.x) + (rc.outer*3136)) + (ry.outer*14)) + ((((threadIdx.x*8) + 4)/16)*14))*16) + (((threadIdx.x*8) + 4) % 16)) + -450)], 0.000000f)
          pad_temp.shared[((threadIdx.x*8) + 5)] = tvm_if_then_else(((((((2 - (((threadIdx.x*8) + 5)/16)) - ry.outer) <= (blockIdx.y*8)) && ((blockIdx.y*8) < ((226 - (((threadIdx.x*8) + 5)/16)) - ry.outer))) && ((2 - (((threadIdx.x*8) + 5) % 16)) <= (blockIdx.x*16))) && ((blockIdx.x*16) < (226 - (((threadIdx.x*8) + 5) % 16)))), placeholder[((((((((blockIdx.y*112) + blockIdx.x) + (rc.outer*3136)) + (ry.outer*14)) + ((((threadIdx.x*8) + 5)/16)*14))*16) + (((threadIdx.x*8) + 5) % 16)) + -450)], 0.000000f)
          pad_temp.shared[((threadIdx.x*8) + 6)] = tvm_if_then_else(((((((2 - (((threadIdx.x*8) + 6)/16)) - ry.outer) <= (blockIdx.y*8)) && ((blockIdx.y*8) < ((226 - (((threadIdx.x*8) + 6)/16)) - ry.outer))) && ((2 - (((threadIdx.x*8) + 6) % 16)) <= (blockIdx.x*16))) && ((blockIdx.x*16) < (226 - (((threadIdx.x*8) + 6) % 16)))), placeholder[((((((((blockIdx.y*112) + blockIdx.x) + (rc.outer*3136)) + (ry.outer*14)) + ((((threadIdx.x*8) + 6)/16)*14))*16) + (((threadIdx.x*8) + 6) % 16)) + -450)], 0.000000f)
          pad_temp.shared[((threadIdx.x*8) + 7)] = tvm_if_then_else(((((((2 - (((threadIdx.x*8) + 7)/16)) - ry.outer) <= (blockIdx.y*8)) && ((blockIdx.y*8) < ((226 - (((threadIdx.x*8) + 7)/16)) - ry.outer))) && ((2 - (((threadIdx.x*8) + 7) % 16)) <= (blockIdx.x*16))) && ((blockIdx.x*16) < (226 - (((threadIdx.x*8) + 7) % 16)))), placeholder[((((((((blockIdx.y*112) + blockIdx.x) + (rc.outer*3136)) + (ry.outer*14)) + ((((threadIdx.x*8) + 7)/16)*14))*16) + (((threadIdx.x*8) + 7) % 16)) + -450)], 0.000000f)
        }
        produce placeholder.shared {
          // attr [iter_var(threadIdx.z, , threadIdx.z)] thread_extent = 1
          // attr [iter_var(threadIdx.y, , threadIdx.y)] thread_extent = 1
          // attr [iter_var(threadIdx.x, , threadIdx.x)] thread_extent = 16
          if (likely((threadIdx.x < 2))) {
            if (likely(((blockIdx.z*2) < (10 - threadIdx.x)))) {
              placeholder.shared[threadIdx.x] = placeholder[((((((blockIdx.z*6) + rc.outer)*5) + ry.outer) + (threadIdx.x*15))*5)]
            }
          }
        }
        compute[0] = (compute[0] + (pad_temp.shared[threadIdx.x]*placeholder.shared[0]))
        compute[1] = (compute[1] + (pad_temp.shared[(threadIdx.x + 16)]*placeholder.shared[0]))
        compute[2] = (compute[2] + (pad_temp.shared[(threadIdx.x + 32)]*placeholder.shared[0]))
        compute[3] = (compute[3] + (pad_temp.shared[(threadIdx.x + 48)]*placeholder.shared[0]))
        compute[4] = (compute[4] + (pad_temp.shared[(threadIdx.x + 64)]*placeholder.shared[0]))
        compute[5] = (compute[5] + (pad_temp.shared[(threadIdx.x + 80)]*placeholder.shared[0]))
        compute[6] = (compute[6] + (pad_temp.shared[(threadIdx.x + 96)]*placeholder.shared[0]))
        compute[7] = (compute[7] + (pad_temp.shared[(threadIdx.x + 112)]*placeholder.shared[0]))
        compute[8] = (compute[8] + (pad_temp.shared[threadIdx.x]*placeholder.shared[1]))
        compute[9] = (compute[9] + (pad_temp.shared[(threadIdx.x + 16)]*placeholder.shared[1]))
        compute[10] = (compute[10] + (pad_temp.shared[(threadIdx.x + 32)]*placeholder.shared[1]))
        compute[11] = (compute[11] + (pad_temp.shared[(threadIdx.x + 48)]*placeholder.shared[1]))
        compute[12] = (compute[12] + (pad_temp.shared[(threadIdx.x + 64)]*placeholder.shared[1]))
        compute[13] = (compute[13] + (pad_temp.shared[(threadIdx.x + 80)]*placeholder.shared[1]))
        compute[14] = (compute[14] + (pad_temp.shared[(threadIdx.x + 96)]*placeholder.shared[1]))
        compute[15] = (compute[15] + (pad_temp.shared[(threadIdx.x + 112)]*placeholder.shared[1]))
        produce pad_temp.shared {
          // attr [iter_var(threadIdx.z, , threadIdx.z)] thread_extent = 1
          // attr [iter_var(threadIdx.y, , threadIdx.y)] thread_extent = 1
          // attr [iter_var(threadIdx.x, , threadIdx.x)] thread_extent = 16
          pad_temp.shared[(((threadIdx.x/2)*16) + ((threadIdx.x*8) % 16))] = tvm_if_then_else(((((((2 - (threadIdx.x/2)) - ry.outer) <= (blockIdx.y*8)) && ((blockIdx.y*8) < ((226 - (threadIdx.x/2)) - ry.outer))) && ((1 - ((threadIdx.x*8) % 16)) <= (blockIdx.x*16))) && ((blockIdx.x*16) < (225 - ((threadIdx.x*8) % 16)))), placeholder[((((((((blockIdx.y*112) + blockIdx.x) + (rc.outer*3136)) + (ry.outer*14)) + ((threadIdx.x/2)*14))*16) + ((threadIdx.x*8) % 16)) + -449)], 0.000000f)
          pad_temp.shared[((threadIdx.x*8) + 1)] = tvm_if_then_else(((((((2 - (((threadIdx.x*8) + 1)/16)) - ry.outer) <= (blockIdx.y*8)) && ((blockIdx.y*8) < ((226 - (((threadIdx.x*8) + 1)/16)) - ry.outer))) && ((1 - (((threadIdx.x*8) + 1) % 16)) <= (blockIdx.x*16))) && ((blockIdx.x*16) < (225 - (((threadIdx.x*8) + 1) % 16)))), placeholder[((((((((blockIdx.y*112) + blockIdx.x) + (rc.outer*3136)) + (ry.outer*14)) + ((((threadIdx.x*8) + 1)/16)*14))*16) + (((threadIdx.x*8) + 1) % 16)) + -449)], 0.000000f)
          pad_temp.shared[((threadIdx.x*8) + 2)] = tvm_if_then_else(((((((2 - (((threadIdx.x*8) + 2)/16)) - ry.outer) <= (blockIdx.y*8)) && ((blockIdx.y*8) < ((226 - (((threadIdx.x*8) + 2)/16)) - ry.outer))) && ((1 - (((threadIdx.x*8) + 2) % 16)) <= (blockIdx.x*16))) && ((blockIdx.x*16) < (225 - (((threadIdx.x*8) + 2) % 16)))), placeholder[((((((((blockIdx.y*112) + blockIdx.x) + (rc.outer*3136)) + (ry.outer*14)) + ((((threadIdx.x*8) + 2)/16)*14))*16) + (((threadIdx.x*8) + 2) % 16)) + -449)], 0.000000f)
          pad_temp.shared[((threadIdx.x*8) + 3)] = tvm_if_then_else(((((((2 - (((threadIdx.x*8) + 3)/16)) - ry.outer) <= (blockIdx.y*8)) && ((blockIdx.y*8) < ((226 - (((threadIdx.x*8) + 3)/16)) - ry.outer))) && ((1 - (((threadIdx.x*8) + 3) % 16)) <= (blockIdx.x*16))) && ((blockIdx.x*16) < (225 - (((threadIdx.x*8) + 3) % 16)))), placeholder[((((((((blockIdx.y*112) + blockIdx.x) + (rc.outer*3136)) + (ry.outer*14)) + ((((threadIdx.x*8) + 3)/16)*14))*16) + (((threadIdx.x*8) + 3) % 16)) + -449)], 0.000000f)
          pad_temp.shared[((threadIdx.x*8) + 4)] = tvm_if_then_else(((((((2 - (((threadIdx.x*8) + 4)/16)) - ry.outer) <= (blockIdx.y*8)) && ((blockIdx.y*8) < ((226 - (((threadIdx.x*8) + 4)/16)) - ry.outer))) && ((1 - (((threadIdx.x*8) + 4) % 16)) <= (blockIdx.x*16))) && ((blockIdx.x*16) < (225 - (((threadIdx.x*8) + 4) % 16)))), placeholder[((((((((blockIdx.y*112) + blockIdx.x) + (rc.outer*3136)) + (ry.outer*14)) + ((((threadIdx.x*8) + 4)/16)*14))*16) + (((threadIdx.x*8) + 4) % 16)) + -449)], 0.000000f)
          pad_temp.shared[((threadIdx.x*8) + 5)] = tvm_if_then_else(((((((2 - (((threadIdx.x*8) + 5)/16)) - ry.outer) <= (blockIdx.y*8)) && ((blockIdx.y*8) < ((226 - (((threadIdx.x*8) + 5)/16)) - ry.outer))) && ((1 - (((threadIdx.x*8) + 5) % 16)) <= (blockIdx.x*16))) && ((blockIdx.x*16) < (225 - (((threadIdx.x*8) + 5) % 16)))), placeholder[((((((((blockIdx.y*112) + blockIdx.x) + (rc.outer*3136)) + (ry.outer*14)) + ((((threadIdx.x*8) + 5)/16)*14))*16) + (((threadIdx.x*8) + 5) % 16)) + -449)], 0.000000f)
          pad_temp.shared[((threadIdx.x*8) + 6)] = tvm_if_then_else(((((((2 - (((threadIdx.x*8) + 6)/16)) - ry.outer) <= (blockIdx.y*8)) && ((blockIdx.y*8) < ((226 - (((threadIdx.x*8) + 6)/16)) - ry.outer))) && ((1 - (((threadIdx.x*8) + 6) % 16)) <= (blockIdx.x*16))) && ((blockIdx.x*16) < (225 - (((threadIdx.x*8) + 6) % 16)))), placeholder[((((((((blockIdx.y*112) + blockIdx.x) + (rc.outer*3136)) + (ry.outer*14)) + ((((threadIdx.x*8) + 6)/16)*14))*16) + (((threadIdx.x*8) + 6) % 16)) + -449)], 0.000000f)
          pad_temp.shared[((threadIdx.x*8) + 7)] = tvm_if_then_else(((((((2 - (((threadIdx.x*8) + 7)/16)) - ry.outer) <= (blockIdx.y*8)) && ((blockIdx.y*8) < ((226 - (((threadIdx.x*8) + 7)/16)) - ry.outer))) && ((1 - (((threadIdx.x*8) + 7) % 16)) <= (blockIdx.x*16))) && ((blockIdx.x*16) < (225 - (((threadIdx.x*8) + 7) % 16)))), placeholder[((((((((blockIdx.y*112) + blockIdx.x) + (rc.outer*3136)) + (ry.outer*14)) + ((((threadIdx.x*8) + 7)/16)*14))*16) + (((threadIdx.x*8) + 7) % 16)) + -449)], 0.000000f)
        }
        produce placeholder.shared {
          // attr [iter_var(threadIdx.z, , threadIdx.z)] thread_extent = 1
          // attr [iter_var(threadIdx.y, , threadIdx.y)] thread_extent = 1
          // attr [iter_var(threadIdx.x, , threadIdx.x)] thread_extent = 16
          if (likely((threadIdx.x < 2))) {
            if (likely(((blockIdx.z*2) < (10 - threadIdx.x)))) {
              placeholder.shared[threadIdx.x] = placeholder[(((((((blockIdx.z*6) + rc.outer)*5) + ry.outer) + (threadIdx.x*15))*5) + 1)]
            }
          }
        }
        compute[0] = (compute[0] + (pad_temp.shared[threadIdx.x]*placeholder.shared[0]))
        compute[1] = (compute[1] + (pad_temp.shared[(threadIdx.x + 16)]*placeholder.shared[0]))
        compute[2] = (compute[2] + (pad_temp.shared[(threadIdx.x + 32)]*placeholder.shared[0]))
        compute[3] = (compute[3] + (pad_temp.shared[(threadIdx.x + 48)]*placeholder.shared[0]))
        compute[4] = (compute[4] + (pad_temp.shared[(threadIdx.x + 64)]*placeholder.shared[0]))
        compute[5] = (compute[5] + (pad_temp.shared[(threadIdx.x + 80)]*placeholder.shared[0]))
        compute[6] = (compute[6] + (pad_temp.shared[(threadIdx.x + 96)]*placeholder.shared[0]))
        compute[7] = (compute[7] + (pad_temp.shared[(threadIdx.x + 112)]*placeholder.shared[0]))
        compute[8] = (compute[8] + (pad_temp.shared[threadIdx.x]*placeholder.shared[1]))
        compute[9] = (compute[9] + (pad_temp.shared[(threadIdx.x + 16)]*placeholder.shared[1]))
        compute[10] = (compute[10] + (pad_temp.shared[(threadIdx.x + 32)]*placeholder.shared[1]))
        compute[11] = (compute[11] + (pad_temp.shared[(threadIdx.x + 48)]*placeholder.shared[1]))
        compute[12] = (compute[12] + (pad_temp.shared[(threadIdx.x + 64)]*placeholder.shared[1]))
        compute[13] = (compute[13] + (pad_temp.shared[(threadIdx.x + 80)]*placeholder.shared[1]))
        compute[14] = (compute[14] + (pad_temp.shared[(threadIdx.x + 96)]*placeholder.shared[1]))
        compute[15] = (compute[15] + (pad_temp.shared[(threadIdx.x + 112)]*placeholder.shared[1]))
        produce pad_temp.shared {
          // attr [iter_var(threadIdx.z, , threadIdx.z)] thread_extent = 1
          // attr [iter_var(threadIdx.y, , threadIdx.y)] thread_extent = 1
          // attr [iter_var(threadIdx.x, , threadIdx.x)] thread_extent = 16
          pad_temp.shared[(((threadIdx.x/2)*16) + ((threadIdx.x*8) % 16))] = tvm_if_then_else(((((((2 - (threadIdx.x/2)) - ry.outer) <= (blockIdx.y*8)) && ((blockIdx.y*8) < ((226 - (threadIdx.x/2)) - ry.outer))) && ((0 - ((threadIdx.x*8) % 16)) <= (blockIdx.x*16))) && ((blockIdx.x*16) < (224 - ((threadIdx.x*8) % 16)))), placeholder[((((((((blockIdx.y*112) + blockIdx.x) + (rc.outer*3136)) + (ry.outer*14)) + ((threadIdx.x/2)*14))*16) + ((threadIdx.x*8) % 16)) + -448)], 0.000000f)
          pad_temp.shared[((threadIdx.x*8) + 1)] = tvm_if_then_else(((((((2 - (((threadIdx.x*8) + 1)/16)) - ry.outer) <= (blockIdx.y*8)) && ((blockIdx.y*8) < ((226 - (((threadIdx.x*8) + 1)/16)) - ry.outer))) && ((0 - (((threadIdx.x*8) + 1) % 16)) <= (blockIdx.x*16))) && ((blockIdx.x*16) < (224 - (((threadIdx.x*8) + 1) % 16)))), placeholder[((((((((blockIdx.y*112) + blockIdx.x) + (rc.outer*3136)) + (ry.outer*14)) + ((((threadIdx.x*8) + 1)/16)*14))*16) + (((threadIdx.x*8) + 1) % 16)) + -448)], 0.000000f)
          pad_temp.shared[((threadIdx.x*8) + 2)] = tvm_if_then_else(((((((2 - (((threadIdx.x*8) + 2)/16)) - ry.outer) <= (blockIdx.y*8)) && ((blockIdx.y*8) < ((226 - (((threadIdx.x*8) + 2)/16)) - ry.outer))) && ((0 - (((threadIdx.x*8) + 2) % 16)) <= (blockIdx.x*16))) && ((blockIdx.x*16) < (224 - (((threadIdx.x*8) + 2) % 16)))), placeholder[((((((((blockIdx.y*112) + blockIdx.x) + (rc.outer*3136)) + (ry.outer*14)) + ((((threadIdx.x*8) + 2)/16)*14))*16) + (((threadIdx.x*8) + 2) % 16)) + -448)], 0.000000f)
          pad_temp.shared[((threadIdx.x*8) + 3)] = tvm_if_then_else(((((((2 - (((threadIdx.x*8) + 3)/16)) - ry.outer) <= (blockIdx.y*8)) && ((blockIdx.y*8) < ((226 - (((threadIdx.x*8) + 3)/16)) - ry.outer))) && ((0 - (((threadIdx.x*8) + 3) % 16)) <= (blockIdx.x*16))) && ((blockIdx.x*16) < (224 - (((threadIdx.x*8) + 3) % 16)))), placeholder[((((((((blockIdx.y*112) + blockIdx.x) + (rc.outer*3136)) + (ry.outer*14)) + ((((threadIdx.x*8) + 3)/16)*14))*16) + (((threadIdx.x*8) + 3) % 16)) + -448)], 0.000000f)
          pad_temp.shared[((threadIdx.x*8) + 4)] = tvm_if_then_else(((((((2 - (((threadIdx.x*8) + 4)/16)) - ry.outer) <= (blockIdx.y*8)) && ((blockIdx.y*8) < ((226 - (((threadIdx.x*8) + 4)/16)) - ry.outer))) && ((0 - (((threadIdx.x*8) + 4) % 16)) <= (blockIdx.x*16))) && ((blockIdx.x*16) < (224 - (((threadIdx.x*8) + 4) % 16)))), placeholder[((((((((blockIdx.y*112) + blockIdx.x) + (rc.outer*3136)) + (ry.outer*14)) + ((((threadIdx.x*8) + 4)/16)*14))*16) + (((threadIdx.x*8) + 4) % 16)) + -448)], 0.000000f)
          pad_temp.shared[((threadIdx.x*8) + 5)] = tvm_if_then_else(((((((2 - (((threadIdx.x*8) + 5)/16)) - ry.outer) <= (blockIdx.y*8)) && ((blockIdx.y*8) < ((226 - (((threadIdx.x*8) + 5)/16)) - ry.outer))) && ((0 - (((threadIdx.x*8) + 5) % 16)) <= (blockIdx.x*16))) && ((blockIdx.x*16) < (224 - (((threadIdx.x*8) + 5) % 16)))), placeholder[((((((((blockIdx.y*112) + blockIdx.x) + (rc.outer*3136)) + (ry.outer*14)) + ((((threadIdx.x*8) + 5)/16)*14))*16) + (((threadIdx.x*8) + 5) % 16)) + -448)], 0.000000f)
          pad_temp.shared[((threadIdx.x*8) + 6)] = tvm_if_then_else(((((((2 - (((threadIdx.x*8) + 6)/16)) - ry.outer) <= (blockIdx.y*8)) && ((blockIdx.y*8) < ((226 - (((threadIdx.x*8) + 6)/16)) - ry.outer))) && ((0 - (((threadIdx.x*8) + 6) % 16)) <= (blockIdx.x*16))) && ((blockIdx.x*16) < (224 - (((threadIdx.x*8) + 6) % 16)))), placeholder[((((((((blockIdx.y*112) + blockIdx.x) + (rc.outer*3136)) + (ry.outer*14)) + ((((threadIdx.x*8) + 6)/16)*14))*16) + (((threadIdx.x*8) + 6) % 16)) + -448)], 0.000000f)
          pad_temp.shared[((threadIdx.x*8) + 7)] = tvm_if_then_else(((((((2 - (((threadIdx.x*8) + 7)/16)) - ry.outer) <= (blockIdx.y*8)) && ((blockIdx.y*8) < ((226 - (((threadIdx.x*8) + 7)/16)) - ry.outer))) && ((0 - (((threadIdx.x*8) + 7) % 16)) <= (blockIdx.x*16))) && ((blockIdx.x*16) < (224 - (((threadIdx.x*8) + 7) % 16)))), placeholder[((((((((blockIdx.y*112) + blockIdx.x) + (rc.outer*3136)) + (ry.outer*14)) + ((((threadIdx.x*8) + 7)/16)*14))*16) + (((threadIdx.x*8) + 7) % 16)) + -448)], 0.000000f)
        }
        produce placeholder.shared {
          // attr [iter_var(threadIdx.z, , threadIdx.z)] thread_extent = 1
          // attr [iter_var(threadIdx.y, , threadIdx.y)] thread_extent = 1
          // attr [iter_var(threadIdx.x, , threadIdx.x)] thread_extent = 16
          if (likely((threadIdx.x < 2))) {
            if (likely(((blockIdx.z*2) < (10 - threadIdx.x)))) {
              placeholder.shared[threadIdx.x] = placeholder[(((((((blockIdx.z*6) + rc.outer)*5) + ry.outer) + (threadIdx.x*15))*5) + 2)]
            }
          }
        }
        compute[0] = (compute[0] + (pad_temp.shared[threadIdx.x]*placeholder.shared[0]))
        compute[1] = (compute[1] + (pad_temp.shared[(threadIdx.x + 16)]*placeholder.shared[0]))
        compute[2] = (compute[2] + (pad_temp.shared[(threadIdx.x + 32)]*placeholder.shared[0]))
        compute[3] = (compute[3] + (pad_temp.shared[(threadIdx.x + 48)]*placeholder.shared[0]))
        compute[4] = (compute[4] + (pad_temp.shared[(threadIdx.x + 64)]*placeholder.shared[0]))
        compute[5] = (compute[5] + (pad_temp.shared[(threadIdx.x + 80)]*placeholder.shared[0]))
        compute[6] = (compute[6] + (pad_temp.shared[(threadIdx.x + 96)]*placeholder.shared[0]))
        compute[7] = (compute[7] + (pad_temp.shared[(threadIdx.x + 112)]*placeholder.shared[0]))
        compute[8] = (compute[8] + (pad_temp.shared[threadIdx.x]*placeholder.shared[1]))
        compute[9] = (compute[9] + (pad_temp.shared[(threadIdx.x + 16)]*placeholder.shared[1]))
        compute[10] = (compute[10] + (pad_temp.shared[(threadIdx.x + 32)]*placeholder.shared[1]))
        compute[11] = (compute[11] + (pad_temp.shared[(threadIdx.x + 48)]*placeholder.shared[1]))
        compute[12] = (compute[12] + (pad_temp.shared[(threadIdx.x + 64)]*placeholder.shared[1]))
        compute[13] = (compute[13] + (pad_temp.shared[(threadIdx.x + 80)]*placeholder.shared[1]))
        compute[14] = (compute[14] + (pad_temp.shared[(threadIdx.x + 96)]*placeholder.shared[1]))
        compute[15] = (compute[15] + (pad_temp.shared[(threadIdx.x + 112)]*placeholder.shared[1]))
        produce pad_temp.shared {
          // attr [iter_var(threadIdx.z, , threadIdx.z)] thread_extent = 1
          // attr [iter_var(threadIdx.y, , threadIdx.y)] thread_extent = 1
          // attr [iter_var(threadIdx.x, , threadIdx.x)] thread_extent = 16
          pad_temp.shared[(((threadIdx.x/2)*16) + ((threadIdx.x*8) % 16))] = tvm_if_then_else(((((((2 - (threadIdx.x/2)) - ry.outer) <= (blockIdx.y*8)) && ((blockIdx.y*8) < ((226 - (threadIdx.x/2)) - ry.outer))) && ((-1 - ((threadIdx.x*8) % 16)) <= (blockIdx.x*16))) && ((blockIdx.x*16) < (223 - ((threadIdx.x*8) % 16)))), placeholder[((((((((blockIdx.y*112) + blockIdx.x) + (rc.outer*3136)) + (ry.outer*14)) + ((threadIdx.x/2)*14))*16) + ((threadIdx.x*8) % 16)) + -447)], 0.000000f)
          pad_temp.shared[((threadIdx.x*8) + 1)] = tvm_if_then_else(((((((2 - (((threadIdx.x*8) + 1)/16)) - ry.outer) <= (blockIdx.y*8)) && ((blockIdx.y*8) < ((226 - (((threadIdx.x*8) + 1)/16)) - ry.outer))) && ((-1 - (((threadIdx.x*8) + 1) % 16)) <= (blockIdx.x*16))) && ((blockIdx.x*16) < (223 - (((threadIdx.x*8) + 1) % 16)))), placeholder[((((((((blockIdx.y*112) + blockIdx.x) + (rc.outer*3136)) + (ry.outer*14)) + ((((threadIdx.x*8) + 1)/16)*14))*16) + (((threadIdx.x*8) + 1) % 16)) + -447)], 0.000000f)
          pad_temp.shared[((threadIdx.x*8) + 2)] = tvm_if_then_else(((((((2 - (((threadIdx.x*8) + 2)/16)) - ry.outer) <= (blockIdx.y*8)) && ((blockIdx.y*8) < ((226 - (((threadIdx.x*8) + 2)/16)) - ry.outer))) && ((-1 - (((threadIdx.x*8) + 2) % 16)) <= (blockIdx.x*16))) && ((blockIdx.x*16) < (223 - (((threadIdx.x*8) + 2) % 16)))), placeholder[((((((((blockIdx.y*112) + blockIdx.x) + (rc.outer*3136)) + (ry.outer*14)) + ((((threadIdx.x*8) + 2)/16)*14))*16) + (((threadIdx.x*8) + 2) % 16)) + -447)], 0.000000f)
          pad_temp.shared[((threadIdx.x*8) + 3)] = tvm_if_then_else(((((((2 - (((threadIdx.x*8) + 3)/16)) - ry.outer) <= (blockIdx.y*8)) && ((blockIdx.y*8) < ((226 - (((threadIdx.x*8) + 3)/16)) - ry.outer))) && ((-1 - (((threadIdx.x*8) + 3) % 16)) <= (blockIdx.x*16))) && ((blockIdx.x*16) < (223 - (((threadIdx.x*8) + 3) % 16)))), placeholder[((((((((blockIdx.y*112) + blockIdx.x) + (rc.outer*3136)) + (ry.outer*14)) + ((((threadIdx.x*8) + 3)/16)*14))*16) + (((threadIdx.x*8) + 3) % 16)) + -447)], 0.000000f)
          pad_temp.shared[((threadIdx.x*8) + 4)] = tvm_if_then_else(((((((2 - (((threadIdx.x*8) + 4)/16)) - ry.outer) <= (blockIdx.y*8)) && ((blockIdx.y*8) < ((226 - (((threadIdx.x*8) + 4)/16)) - ry.outer))) && ((-1 - (((threadIdx.x*8) + 4) % 16)) <= (blockIdx.x*16))) && ((blockIdx.x*16) < (223 - (((threadIdx.x*8) + 4) % 16)))), placeholder[((((((((blockIdx.y*112) + blockIdx.x) + (rc.outer*3136)) + (ry.outer*14)) + ((((threadIdx.x*8) + 4)/16)*14))*16) + (((threadIdx.x*8) + 4) % 16)) + -447)], 0.000000f)
          pad_temp.shared[((threadIdx.x*8) + 5)] = tvm_if_then_else(((((((2 - (((threadIdx.x*8) + 5)/16)) - ry.outer) <= (blockIdx.y*8)) && ((blockIdx.y*8) < ((226 - (((threadIdx.x*8) + 5)/16)) - ry.outer))) && ((-1 - (((threadIdx.x*8) + 5) % 16)) <= (blockIdx.x*16))) && ((blockIdx.x*16) < (223 - (((threadIdx.x*8) + 5) % 16)))), placeholder[((((((((blockIdx.y*112) + blockIdx.x) + (rc.outer*3136)) + (ry.outer*14)) + ((((threadIdx.x*8) + 5)/16)*14))*16) + (((threadIdx.x*8) + 5) % 16)) + -447)], 0.000000f)
          pad_temp.shared[((threadIdx.x*8) + 6)] = tvm_if_then_else(((((((2 - (((threadIdx.x*8) + 6)/16)) - ry.outer) <= (blockIdx.y*8)) && ((blockIdx.y*8) < ((226 - (((threadIdx.x*8) + 6)/16)) - ry.outer))) && ((-1 - (((threadIdx.x*8) + 6) % 16)) <= (blockIdx.x*16))) && ((blockIdx.x*16) < (223 - (((threadIdx.x*8) + 6) % 16)))), placeholder[((((((((blockIdx.y*112) + blockIdx.x) + (rc.outer*3136)) + (ry.outer*14)) + ((((threadIdx.x*8) + 6)/16)*14))*16) + (((threadIdx.x*8) + 6) % 16)) + -447)], 0.000000f)
          pad_temp.shared[((threadIdx.x*8) + 7)] = tvm_if_then_else(((((((2 - (((threadIdx.x*8) + 7)/16)) - ry.outer) <= (blockIdx.y*8)) && ((blockIdx.y*8) < ((226 - (((threadIdx.x*8) + 7)/16)) - ry.outer))) && ((-1 - (((threadIdx.x*8) + 7) % 16)) <= (blockIdx.x*16))) && ((blockIdx.x*16) < (223 - (((threadIdx.x*8) + 7) % 16)))), placeholder[((((((((blockIdx.y*112) + blockIdx.x) + (rc.outer*3136)) + (ry.outer*14)) + ((((threadIdx.x*8) + 7)/16)*14))*16) + (((threadIdx.x*8) + 7) % 16)) + -447)], 0.000000f)
        }
        produce placeholder.shared {
          // attr [iter_var(threadIdx.z, , threadIdx.z)] thread_extent = 1
          // attr [iter_var(threadIdx.y, , threadIdx.y)] thread_extent = 1
          // attr [iter_var(threadIdx.x, , threadIdx.x)] thread_extent = 16
          if (likely((threadIdx.x < 2))) {
            if (likely(((blockIdx.z*2) < (10 - threadIdx.x)))) {
              placeholder.shared[threadIdx.x] = placeholder[(((((((blockIdx.z*6) + rc.outer)*5) + ry.outer) + (threadIdx.x*15))*5) + 3)]
            }
          }
        }
        compute[0] = (compute[0] + (pad_temp.shared[threadIdx.x]*placeholder.shared[0]))
        compute[1] = (compute[1] + (pad_temp.shared[(threadIdx.x + 16)]*placeholder.shared[0]))
        compute[2] = (compute[2] + (pad_temp.shared[(threadIdx.x + 32)]*placeholder.shared[0]))
        compute[3] = (compute[3] + (pad_temp.shared[(threadIdx.x + 48)]*placeholder.shared[0]))
        compute[4] = (compute[4] + (pad_temp.shared[(threadIdx.x + 64)]*placeholder.shared[0]))
        compute[5] = (compute[5] + (pad_temp.shared[(threadIdx.x + 80)]*placeholder.shared[0]))
        compute[6] = (compute[6] + (pad_temp.shared[(threadIdx.x + 96)]*placeholder.shared[0]))
        compute[7] = (compute[7] + (pad_temp.shared[(threadIdx.x + 112)]*placeholder.shared[0]))
        compute[8] = (compute[8] + (pad_temp.shared[threadIdx.x]*placeholder.shared[1]))
        compute[9] = (compute[9] + (pad_temp.shared[(threadIdx.x + 16)]*placeholder.shared[1]))
        compute[10] = (compute[10] + (pad_temp.shared[(threadIdx.x + 32)]*placeholder.shared[1]))
        compute[11] = (compute[11] + (pad_temp.shared[(threadIdx.x + 48)]*placeholder.shared[1]))
        compute[12] = (compute[12] + (pad_temp.shared[(threadIdx.x + 64)]*placeholder.shared[1]))
        compute[13] = (compute[13] + (pad_temp.shared[(threadIdx.x + 80)]*placeholder.shared[1]))
        compute[14] = (compute[14] + (pad_temp.shared[(threadIdx.x + 96)]*placeholder.shared[1]))
        compute[15] = (compute[15] + (pad_temp.shared[(threadIdx.x + 112)]*placeholder.shared[1]))
        produce pad_temp.shared {
          // attr [iter_var(threadIdx.z, , threadIdx.z)] thread_extent = 1
          // attr [iter_var(threadIdx.y, , threadIdx.y)] thread_extent = 1
          // attr [iter_var(threadIdx.x, , threadIdx.x)] thread_extent = 16
          pad_temp.shared[(((threadIdx.x/2)*16) + ((threadIdx.x*8) % 16))] = tvm_if_then_else(((((((2 - (threadIdx.x/2)) - ry.outer) <= (blockIdx.y*8)) && ((blockIdx.y*8) < ((226 - (threadIdx.x/2)) - ry.outer))) && ((-2 - ((threadIdx.x*8) % 16)) <= (blockIdx.x*16))) && ((blockIdx.x*16) < (222 - ((threadIdx.x*8) % 16)))), placeholder[((((((((blockIdx.y*112) + blockIdx.x) + (rc.outer*3136)) + (ry.outer*14)) + ((threadIdx.x/2)*14))*16) + ((threadIdx.x*8) % 16)) + -446)], 0.000000f)
          pad_temp.shared[((threadIdx.x*8) + 1)] = tvm_if_then_else(((((((2 - (((threadIdx.x*8) + 1)/16)) - ry.outer) <= (blockIdx.y*8)) && ((blockIdx.y*8) < ((226 - (((threadIdx.x*8) + 1)/16)) - ry.outer))) && ((-2 - (((threadIdx.x*8) + 1) % 16)) <= (blockIdx.x*16))) && ((blockIdx.x*16) < (222 - (((threadIdx.x*8) + 1) % 16)))), placeholder[((((((((blockIdx.y*112) + blockIdx.x) + (rc.outer*3136)) + (ry.outer*14)) + ((((threadIdx.x*8) + 1)/16)*14))*16) + (((threadIdx.x*8) + 1) % 16)) + -446)], 0.000000f)
          pad_temp.shared[((threadIdx.x*8) + 2)] = tvm_if_then_else(((((((2 - (((threadIdx.x*8) + 2)/16)) - ry.outer) <= (blockIdx.y*8)) && ((blockIdx.y*8) < ((226 - (((threadIdx.x*8) + 2)/16)) - ry.outer))) && ((-2 - (((threadIdx.x*8) + 2) % 16)) <= (blockIdx.x*16))) && ((blockIdx.x*16) < (222 - (((threadIdx.x*8) + 2) % 16)))), placeholder[((((((((blockIdx.y*112) + blockIdx.x) + (rc.outer*3136)) + (ry.outer*14)) + ((((threadIdx.x*8) + 2)/16)*14))*16) + (((threadIdx.x*8) + 2) % 16)) + -446)], 0.000000f)
          pad_temp.shared[((threadIdx.x*8) + 3)] = tvm_if_then_else(((((((2 - (((threadIdx.x*8) + 3)/16)) - ry.outer) <= (blockIdx.y*8)) && ((blockIdx.y*8) < ((226 - (((threadIdx.x*8) + 3)/16)) - ry.outer))) && ((-2 - (((threadIdx.x*8) + 3) % 16)) <= (blockIdx.x*16))) && ((blockIdx.x*16) < (222 - (((threadIdx.x*8) + 3) % 16)))), placeholder[((((((((blockIdx.y*112) + blockIdx.x) + (rc.outer*3136)) + (ry.outer*14)) + ((((threadIdx.x*8) + 3)/16)*14))*16) + (((threadIdx.x*8) + 3) % 16)) + -446)], 0.000000f)
          pad_temp.shared[((threadIdx.x*8) + 4)] = tvm_if_then_else(((((((2 - (((threadIdx.x*8) + 4)/16)) - ry.outer) <= (blockIdx.y*8)) && ((blockIdx.y*8) < ((226 - (((threadIdx.x*8) + 4)/16)) - ry.outer))) && ((-2 - (((threadIdx.x*8) + 4) % 16)) <= (blockIdx.x*16))) && ((blockIdx.x*16) < (222 - (((threadIdx.x*8) + 4) % 16)))), placeholder[((((((((blockIdx.y*112) + blockIdx.x) + (rc.outer*3136)) + (ry.outer*14)) + ((((threadIdx.x*8) + 4)/16)*14))*16) + (((threadIdx.x*8) + 4) % 16)) + -446)], 0.000000f)
          pad_temp.shared[((threadIdx.x*8) + 5)] = tvm_if_then_else(((((((2 - (((threadIdx.x*8) + 5)/16)) - ry.outer) <= (blockIdx.y*8)) && ((blockIdx.y*8) < ((226 - (((threadIdx.x*8) + 5)/16)) - ry.outer))) && ((-2 - (((threadIdx.x*8) + 5) % 16)) <= (blockIdx.x*16))) && ((blockIdx.x*16) < (222 - (((threadIdx.x*8) + 5) % 16)))), placeholder[((((((((blockIdx.y*112) + blockIdx.x) + (rc.outer*3136)) + (ry.outer*14)) + ((((threadIdx.x*8) + 5)/16)*14))*16) + (((threadIdx.x*8) + 5) % 16)) + -446)], 0.000000f)
          pad_temp.shared[((threadIdx.x*8) + 6)] = tvm_if_then_else(((((((2 - (((threadIdx.x*8) + 6)/16)) - ry.outer) <= (blockIdx.y*8)) && ((blockIdx.y*8) < ((226 - (((threadIdx.x*8) + 6)/16)) - ry.outer))) && ((-2 - (((threadIdx.x*8) + 6) % 16)) <= (blockIdx.x*16))) && ((blockIdx.x*16) < (222 - (((threadIdx.x*8) + 6) % 16)))), placeholder[((((((((blockIdx.y*112) + blockIdx.x) + (rc.outer*3136)) + (ry.outer*14)) + ((((threadIdx.x*8) + 6)/16)*14))*16) + (((threadIdx.x*8) + 6) % 16)) + -446)], 0.000000f)
          pad_temp.shared[((threadIdx.x*8) + 7)] = tvm_if_then_else(((((((2 - (((threadIdx.x*8) + 7)/16)) - ry.outer) <= (blockIdx.y*8)) && ((blockIdx.y*8) < ((226 - (((threadIdx.x*8) + 7)/16)) - ry.outer))) && ((-2 - (((threadIdx.x*8) + 7) % 16)) <= (blockIdx.x*16))) && ((blockIdx.x*16) < (222 - (((threadIdx.x*8) + 7) % 16)))), placeholder[((((((((blockIdx.y*112) + blockIdx.x) + (rc.outer*3136)) + (ry.outer*14)) + ((((threadIdx.x*8) + 7)/16)*14))*16) + (((threadIdx.x*8) + 7) % 16)) + -446)], 0.000000f)
        }
        produce placeholder.shared {
          // attr [iter_var(threadIdx.z, , threadIdx.z)] thread_extent = 1
          // attr [iter_var(threadIdx.y, , threadIdx.y)] thread_extent = 1
          // attr [iter_var(threadIdx.x, , threadIdx.x)] thread_extent = 16
          if (likely((threadIdx.x < 2))) {
            if (likely(((blockIdx.z*2) < (10 - threadIdx.x)))) {
              placeholder.shared[threadIdx.x] = placeholder[(((((((blockIdx.z*6) + rc.outer)*5) + ry.outer) + (threadIdx.x*15))*5) + 4)]
            }
          }
        }
        compute[0] = (compute[0] + (pad_temp.shared[threadIdx.x]*placeholder.shared[0]))
        compute[1] = (compute[1] + (pad_temp.shared[(threadIdx.x + 16)]*placeholder.shared[0]))
        compute[2] = (compute[2] + (pad_temp.shared[(threadIdx.x + 32)]*placeholder.shared[0]))
        compute[3] = (compute[3] + (pad_temp.shared[(threadIdx.x + 48)]*placeholder.shared[0]))
        compute[4] = (compute[4] + (pad_temp.shared[(threadIdx.x + 64)]*placeholder.shared[0]))
        compute[5] = (compute[5] + (pad_temp.shared[(threadIdx.x + 80)]*placeholder.shared[0]))
        compute[6] = (compute[6] + (pad_temp.shared[(threadIdx.x + 96)]*placeholder.shared[0]))
        compute[7] = (compute[7] + (pad_temp.shared[(threadIdx.x + 112)]*placeholder.shared[0]))
        compute[8] = (compute[8] + (pad_temp.shared[threadIdx.x]*placeholder.shared[1]))
        compute[9] = (compute[9] + (pad_temp.shared[(threadIdx.x + 16)]*placeholder.shared[1]))
        compute[10] = (compute[10] + (pad_temp.shared[(threadIdx.x + 32)]*placeholder.shared[1]))
        compute[11] = (compute[11] + (pad_temp.shared[(threadIdx.x + 48)]*placeholder.shared[1]))
        compute[12] = (compute[12] + (pad_temp.shared[(threadIdx.x + 64)]*placeholder.shared[1]))
        compute[13] = (compute[13] + (pad_temp.shared[(threadIdx.x + 80)]*placeholder.shared[1]))
        compute[14] = (compute[14] + (pad_temp.shared[(threadIdx.x + 96)]*placeholder.shared[1]))
        compute[15] = (compute[15] + (pad_temp.shared[(threadIdx.x + 112)]*placeholder.shared[1]))
      }
    }
  }
  compute[((((((blockIdx.z*56) + blockIdx.y)*112) + blockIdx.x)*16) + threadIdx.x)] = max(compute[0], 0.000000f)
  compute[(((((((blockIdx.z*56) + blockIdx.y)*112) + blockIdx.x)*16) + threadIdx.x) + 224)] = max(compute[1], 0.000000f)
  compute[(((((((blockIdx.z*56) + blockIdx.y)*112) + blockIdx.x)*16) + threadIdx.x) + 448)] = max(compute[2], 0.000000f)
  compute[(((((((blockIdx.z*56) + blockIdx.y)*112) + blockIdx.x)*16) + threadIdx.x) + 672)] = max(compute[3], 0.000000f)
  compute[(((((((blockIdx.z*56) + blockIdx.y)*112) + blockIdx.x)*16) + threadIdx.x) + 896)] = max(compute[4], 0.000000f)
  compute[(((((((blockIdx.z*56) + blockIdx.y)*112) + blockIdx.x)*16) + threadIdx.x) + 1120)] = max(compute[5], 0.000000f)
  compute[(((((((blockIdx.z*56) + blockIdx.y)*112) + blockIdx.x)*16) + threadIdx.x) + 1344)] = max(compute[6], 0.000000f)
  compute[(((((((blockIdx.z*56) + blockIdx.y)*112) + blockIdx.x)*16) + threadIdx.x) + 1568)] = max(compute[7], 0.000000f)
  compute[(((((((blockIdx.z*56) + blockIdx.y)*112) + blockIdx.x)*16) + threadIdx.x) + 50176)] = max(compute[8], 0.000000f)
  compute[(((((((blockIdx.z*56) + blockIdx.y)*112) + blockIdx.x)*16) + threadIdx.x) + 50400)] = max(compute[9], 0.000000f)
  compute[(((((((blockIdx.z*56) + blockIdx.y)*112) + blockIdx.x)*16) + threadIdx.x) + 50624)] = max(compute[10], 0.000000f)
  compute[(((((((blockIdx.z*56) + blockIdx.y)*112) + blockIdx.x)*16) + threadIdx.x) + 50848)] = max(compute[11], 0.000000f)
  compute[(((((((blockIdx.z*56) + blockIdx.y)*112) + blockIdx.x)*16) + threadIdx.x) + 51072)] = max(compute[12], 0.000000f)
  compute[(((((((blockIdx.z*56) + blockIdx.y)*112) + blockIdx.x)*16) + threadIdx.x) + 51296)] = max(compute[13], 0.000000f)
  compute[(((((((blockIdx.z*56) + blockIdx.y)*112) + blockIdx.x)*16) + threadIdx.x) + 51520)] = max(compute[14], 0.000000f)
  compute[(((((((blockIdx.z*56) + blockIdx.y)*112) + blockIdx.x)*16) + threadIdx.x) + 51744)] = max(compute[15], 0.000000f)
}

Summary

In this tutorial, we have seen

  • How to use TOPI API for common operations with numpy-style operators.
  • How TOPI facilitates generic schedules and operator fusion for a context, to generate optimized kernel codes.

Total running time of the script: ( 0 minutes 0.638 seconds)

Gallery generated by Sphinx-Gallery