Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug] Long lowering time after #13217 #13508

Closed
comaniac opened this issue Nov 29, 2022 · 10 comments
Closed

[Bug] Long lowering time after #13217 #13508

comaniac opened this issue Nov 29, 2022 · 10 comments
Labels
needs-triage PRs or issues that need to be investigated by maintainers to find the right assignees to address it type: bug

Comments

@comaniac
Copy link
Contributor

comaniac commented Nov 29, 2022

Expected behavior

The lowering time of the given case should be around 10 seconds.

Actual behavior

The lowering time is more than 550 seconds.

Environment

Any environment with commit commit 101e3a4 (#13217) or later.

Steps to reproduce

The script:

import time

import tvm
from tvm import topi

class Timer:
    def __init__(self, msg):
        self.msg = msg
        print(f"{msg}...", flush=True)

    def __enter__(self):
        self.start = time.time()

    def __exit__(self, *args):
        print(f"{self.msg}...{time.time() - self.start:.2f}s", flush=True)

def resize2d_dx_compute(inp, dy):
    """compute definition for resize2d_dx op"""
    size = (64, 32)
    layout = "NCHW"
    method = "cubic"
    coord_trans = "half_pixel"
    rounding_method = ""
    cubic_alpha = -0.75
    cubic_exclude = 0
    out_dtype = "float32"

    out = topi.image.resize2d(
        inp,
        (None, None, None, None),
        size,
        layout,
        method,
        coord_trans,
        rounding_method,
        bicubic_alpha=cubic_alpha,
        bicubic_exclude=cubic_exclude,
        out_dtype=out_dtype,
    )
    grads = tvm.te.gradient(out, [inp], head=dy)
    return grads

inp = tvm.te.placeholder((32, 3, 32, 32), name="inp")
dy = tvm.te.placeholder((32, 3, 64, 32), name="dy")
with Timer("te.gradient"):
    grads = resize2d_dx_compute(inp, dy)

# This problem is platform-independent.
with Timer("schedule"):
    sch = topi.x86.injective.schedule_injective(grads)

with Timer("lower"):
    print(tvm.lower(sch, [inp, dy, grads[0]], simple_mode=True))
  1. Switch to a commit before 101e3a4 ([TIR][Transform] Optional data-flow analysis in RemoveNoOp #13217) and run the script.
  2. Checkout the commit 101e3a4 ([TIR][Transform] Optional data-flow analysis in RemoveNoOp #13217) and run again.

Here are also the lowered IR without and with this commit:

Without this commit:

@main = primfn(inp_1: handle, dy_1: handle, resize.inp.grad_1: handle) -> ()
  attr = {"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}
  buffers = {inp: Buffer(inp_2: Pointer(float32), float32, [32, 3, 32, 32], []),
             dy: Buffer(dy_2: Pointer(float32), float32, [32, 3, 64, 32], []),
             resize.inp.grad: Buffer(resize.inp.grad_2: Pointer(float32), float32, [32, 3, 32, 32], [])}
  buffer_map = {inp_1: inp, dy_1: dy, resize.inp.grad_1: resize.inp.grad} {
  for (ax0.ax1.fused: int32, 0, 96) "parallel" {
    for (ax2: int32, 0, 32) {
      for (ax3.outer: int32, 0, 2) {
        resize.inp.grad_3: Buffer(resize.inp.grad_2, float32, [98304], [])[ramp((((ax0.ax1.fused*1024) + (ax2*32)) + (ax3.outer*16)), 1, 16)] = broadcast(0f32, 16)
        for (n0_n0_k2.shifted.shifted: int32, 0, 64) {
          for (n1_n1_k3.shifted.shifted: int32, 0, 32) {
            for (ax3.inner.s: int32, 0, 16) {
              let cse_var_3: float32 = cast(float32, n1_n1_k3.shifted.shifted)
              let cse_var_2: int32 = ((ax3.outer*16) + ax3.inner.s)
              let cse_var_1: float32 = (((cast(float32, n0_n0_k2.shifted.shifted) + 0.5f32)*0.5f32) - 0.5f32)
              if ((((((ax2 - max(min((cast(int32, @tir.floor(cse_var_1, dtype=float32)) - 1), 31), 0)) == 0) && (((((cse_var_2 - max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) - 1), 31), 0)) == 0) || ((cse_var_2 - max(min(cast(int32, @tir.floor(cse_var_3, dtype=float32)), 31), 0)) == 0)) || ((cse_var_2 - max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 1), 31), 0)) == 0)) || ((cse_var_2 - max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 2), 31), 0)) == 0))) || (((ax2 - max(min(cast(int32, @tir.floor(cse_var_1, dtype=float32)), 31), 0)) == 0) && (((((cse_var_2 - max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) - 1), 31), 0)) == 0) || ((cse_var_2 - max(min(cast(int32, @tir.floor(cse_var_3, dtype=float32)), 31), 0)) == 0)) || ((cse_var_2 - max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 1), 31), 0)) == 0)) || ((cse_var_2 - max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 2), 31), 0)) == 0)))) || (((ax2 - max(min((cast(int32, @tir.floor(cse_var_1, dtype=float32)) + 1), 31), 0)) == 0) && (((((cse_var_2 - max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) - 1), 31), 0)) == 0) || ((cse_var_2 - max(min(cast(int32, @tir.floor(cse_var_3, dtype=float32)), 31), 0)) == 0)) || ((cse_var_2 - max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 1), 31), 0)) == 0)) || ((cse_var_2 - max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 2), 31), 0)) == 0)))) || (((ax2 - max(min((cast(int32, @tir.floor(cse_var_1, dtype=float32)) + 2), 31), 0)) == 0) && (((((cse_var_2 - max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) - 1), 31), 0)) == 0) || ((cse_var_2 - max(min(cast(int32, @tir.floor(cse_var_3, dtype=float32)), 31), 0)) == 0)) || ((cse_var_2 - max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 1), 31), 0)) == 0)) || ((cse_var_2 - max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 2), 31), 0)) == 0)))) {
                let cse_var_4: int32 = ((((ax0.ax1.fused*1024) + (ax2*32)) + (ax3.outer*16)) + ax3.inner.s)
                resize.inp.grad_3[cse_var_4] = (resize.inp.grad_3[cse_var_4] + (dy_3: Buffer(dy_2, float32, [196608], [])[(((ax0.ax1.fused*2048) + (n0_n0_k2.shifted.shifted*32)) + n1_n1_k3.shifted.shifted)]*(select((((((((ax2 == max(min((cast(int32, @tir.floor(cse_var_1, dtype=float32)) - 1), 31), 0)) && (cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) - 1), 31), 0))) || ((ax2 == max(min((cast(int32, @tir.floor(cse_var_1, dtype=float32)) - 1), 31), 0)) && (cse_var_2 == max(min(cast(int32, @tir.floor(cse_var_3, dtype=float32)), 31), 0)))) || ((ax2 == max(min((cast(int32, @tir.floor(cse_var_1, dtype=float32)) - 1), 31), 0)) && (cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 1), 31), 0)))) || ((ax2 == max(min((cast(int32, @tir.floor(cse_var_1, dtype=float32)) - 1), 31), 0)) && (cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 2), 31), 0)))) || (((((ax2 == max(min(cast(int32, @tir.floor(cse_var_1, dtype=float32)), 31), 0)) && (cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) - 1), 31), 0))) || ((ax2 == max(min(cast(int32, @tir.floor(cse_var_1, dtype=float32)), 31), 0)) && (cse_var_2 == max(min(cast(int32, @tir.floor(cse_var_3, dtype=float32)), 31), 0)))) || ((ax2 == max(min(cast(int32, @tir.floor(cse_var_1, dtype=float32)), 31), 0)) && (cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 1), 31), 0)))) || ((ax2 == max(min(cast(int32, @tir.floor(cse_var_1, dtype=float32)), 31), 0)) && (cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 2), 31), 0))))) || (((((ax2 == max(min((cast(int32, @tir.floor(cse_var_1, dtype=float32)) + 1), 31), 0)) && (cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) - 1), 31), 0))) || ((ax2 == max(min((cast(int32, @tir.floor(cse_var_1, dtype=float32)) + 1), 31), 0)) && (cse_var_2 == max(min(cast(int32, @tir.floor(cse_var_3, dtype=float32)), 31), 0)))) || ((ax2 == max(min((cast(int32, @tir.floor(cse_var_1, dtype=float32)) + 1), 31), 0)) && (cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 1), 31), 0)))) || ((ax2 == max(min((cast(int32, @tir.floor(cse_var_1, dtype=float32)) + 1), 31), 0)) && (cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 2), 31), 0))))), (select(((((((ax2 == max(min((cast(int32, @tir.floor(cse_var_1, dtype=float32)) - 1), 31), 0)) && (cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) - 1), 31), 0))) || ((ax2 == max(min((cast(int32, @tir.floor(cse_var_1, dtype=float32)) - 1), 31), 0)) && (cse_var_2 == max(min(cast(int32, @tir.floor(cse_var_3, dtype=float32)), 31), 0)))) || ((ax2 == max(min((cast(int32, @tir.floor(cse_var_1, dtype=float32)) - 1), 31), 0)) && (cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 1), 31), 0)))) || ((ax2 == max(min((cast(int32, @tir.floor(cse_var_1, dtype=float32)) - 1), 31), 0)) && (cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 2), 31), 0)))) || (((((ax2 == max(min(cast(int32, @tir.floor(cse_var_1, dtype=float32)), 31), 0)) && (cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) - 1), 31), 0))) || ((ax2 == max(min(cast(int32, @tir.floor(cse_var_1, dtype=float32)), 31), 0)) && (cse_var_2 == max(min(cast(int32, @tir.floor(cse_var_3, dtype=float32)), 31), 0)))) || ((ax2 == max(min(cast(int32, @tir.floor(cse_var_1, dtype=float32)), 31), 0)) && (cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 1), 31), 0)))) || ((ax2 == max(min(cast(int32, @tir.floor(cse_var_1, dtype=float32)), 31), 0)) && (cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 2), 31), 0))))), (select((((((ax2 == max(min((cast(int32, @tir.floor(cse_var_1, dtype=float32)) - 1), 31), 0)) && (cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) - 1), 31), 0))) || ((ax2 == max(min((cast(int32, @tir.floor(cse_var_1, dtype=float32)) - 1), 31), 0)) && (cse_var_2 == max(min(cast(int32, @tir.floor(cse_var_3, dtype=float32)), 31), 0)))) || ((ax2 == max(min((cast(int32, @tir.floor(cse_var_1, dtype=float32)) - 1), 31), 0)) && (cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 1), 31), 0)))) || ((ax2 == max(min((cast(int32, @tir.floor(cse_var_1, dtype=float32)) - 1), 31), 0)) && (cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 2), 31), 0)))), ((select((((cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) - 1), 31), 0)) || (cse_var_2 == max(min(cast(int32, @tir.floor(cse_var_3, dtype=float32)), 31), 0))) || (cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 1), 31), 0))), (select(((cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) - 1), 31), 0)) || (cse_var_2 == max(min(cast(int32, @tir.floor(cse_var_3, dtype=float32)), 31), 0))), (select((cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) - 1), 31), 0)), (-0.75f32*(((((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32))) - (2f32*((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32))))) + (cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))), 0f32) + select((cse_var_2 == max(min(cast(int32, @tir.floor(cse_var_3, dtype=float32)), 31), 0)), (((1.25f32*(((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))) - (2.25f32*((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32))))) + 1f32), 0f32)), 0f32) + select((cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 1), 31), 0)), (((-1.25f32*(((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))) + (1.5f32*((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32))))) - (-0.75f32*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))), 0f32)), 0f32) + select((cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 2), 31), 0)), ((0.75f32*(((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))) + (-0.75f32*((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32))))), 0f32))*(-0.75f32*(((((cse_var_1 - @tir.floor(cse_var_1, dtype=float32))*(cse_var_1 - @tir.floor(cse_var_1, dtype=float32)))*(cse_var_1 - @tir.floor(cse_var_1, dtype=float32))) - (2f32*((cse_var_1 - @tir.floor(cse_var_1, dtype=float32))*(cse_var_1 - @tir.floor(cse_var_1, dtype=float32))))) + (cse_var_1 - @tir.floor(cse_var_1, dtype=float32))))), 0f32) + select((((((ax2 == max(min(cast(int32, @tir.floor(cse_var_1, dtype=float32)), 31), 0)) && (cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) - 1), 31), 0))) || ((ax2 == max(min(cast(int32, @tir.floor(cse_var_1, dtype=float32)), 31), 0)) && (cse_var_2 == max(min(cast(int32, @tir.floor(cse_var_3, dtype=float32)), 31), 0)))) || ((ax2 == max(min(cast(int32, @tir.floor(cse_var_1, dtype=float32)), 31), 0)) && (cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 1), 31), 0)))) || ((ax2 == max(min(cast(int32, @tir.floor(cse_var_1, dtype=float32)), 31), 0)) && (cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 2), 31), 0)))), ((select((((cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) - 1), 31), 0)) || (cse_var_2 == max(min(cast(int32, @tir.floor(cse_var_3, dtype=float32)), 31), 0))) || (cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 1), 31), 0))), (select(((cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) - 1), 31), 0)) || (cse_var_2 == max(min(cast(int32, @tir.floor(cse_var_3, dtype=float32)), 31), 0))), (select((cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) - 1), 31), 0)), (-0.75f32*(((((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32))) - (2f32*((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32))))) + (cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))), 0f32) + select((cse_var_2 == max(min(cast(int32, @tir.floor(cse_var_3, dtype=float32)), 31), 0)), (((1.25f32*(((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))) - (2.25f32*((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32))))) + 1f32), 0f32)), 0f32) + select((cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 1), 31), 0)), (((-1.25f32*(((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))) + (1.5f32*((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32))))) - (-0.75f32*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))), 0f32)), 0f32) + select((cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 2), 31), 0)), ((0.75f32*(((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))) + (-0.75f32*((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32))))), 0f32))*(((1.25f32*(((cse_var_1 - @tir.floor(cse_var_1, dtype=float32))*(cse_var_1 - @tir.floor(cse_var_1, dtype=float32)))*(cse_var_1 - @tir.floor(cse_var_1, dtype=float32)))) - (2.25f32*((cse_var_1 - @tir.floor(cse_var_1, dtype=float32))*(cse_var_1 - @tir.floor(cse_var_1, dtype=float32))))) + 1f32)), 0f32)), 0f32) + select((((((ax2 == max(min((cast(int32, @tir.floor(cse_var_1, dtype=float32)) + 1), 31), 0)) && (cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) - 1), 31), 0))) || ((ax2 == max(min((cast(int32, @tir.floor(cse_var_1, dtype=float32)) + 1), 31), 0)) && (cse_var_2 == max(min(cast(int32, @tir.floor(cse_var_3, dtype=float32)), 31), 0)))) || ((ax2 == max(min((cast(int32, @tir.floor(cse_var_1, dtype=float32)) + 1), 31), 0)) && (cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 1), 31), 0)))) || ((ax2 == max(min((cast(int32, @tir.floor(cse_var_1, dtype=float32)) + 1), 31), 0)) && (cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 2), 31), 0)))), ((select((((cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) - 1), 31), 0)) || (cse_var_2 == max(min(cast(int32, @tir.floor(cse_var_3, dtype=float32)), 31), 0))) || (cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 1), 31), 0))), (select(((cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) - 1), 31), 0)) || (cse_var_2 == max(min(cast(int32, @tir.floor(cse_var_3, dtype=float32)), 31), 0))), (select((cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) - 1), 31), 0)), (-0.75f32*(((((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32))) - (2f32*((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32))))) + (cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))), 0f32) + select((cse_var_2 == max(min(cast(int32, @tir.floor(cse_var_3, dtype=float32)), 31), 0)), (((1.25f32*(((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))) - (2.25f32*((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32))))) + 1f32), 0f32)), 0f32) + select((cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 1), 31), 0)), (((-1.25f32*(((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))) + (1.5f32*((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32))))) - (-0.75f32*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))), 0f32)), 0f32) + select((cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 2), 31), 0)), ((0.75f32*(((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))) + (-0.75f32*((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32))))), 0f32))*(((-1.25f32*(((cse_var_1 - @tir.floor(cse_var_1, dtype=float32))*(cse_var_1 - @tir.floor(cse_var_1, dtype=float32)))*(cse_var_1 - @tir.floor(cse_var_1, dtype=float32)))) + (1.5f32*((cse_var_1 - @tir.floor(cse_var_1, dtype=float32))*(cse_var_1 - @tir.floor(cse_var_1, dtype=float32))))) - (-0.75f32*(cse_var_1 - @tir.floor(cse_var_1, dtype=float32))))), 0f32)), 0f32) + select((((((ax2 == max(min((cast(int32, @tir.floor(cse_var_1, dtype=float32)) + 2), 31), 0)) && (cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) - 1), 31), 0))) || ((ax2 == max(min((cast(int32, @tir.floor(cse_var_1, dtype=float32)) + 2), 31), 0)) && (cse_var_2 == max(min(cast(int32, @tir.floor(cse_var_3, dtype=float32)), 31), 0)))) || ((ax2 == max(min((cast(int32, @tir.floor(cse_var_1, dtype=float32)) + 2), 31), 0)) && (cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 1), 31), 0)))) || ((ax2 == max(min((cast(int32, @tir.floor(cse_var_1, dtype=float32)) + 2), 31), 0)) && (cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 2), 31), 0)))), ((select((((cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) - 1), 31), 0)) || (cse_var_2 == max(min(cast(int32, @tir.floor(cse_var_3, dtype=float32)), 31), 0))) || (cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 1), 31), 0))), (select(((cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) - 1), 31), 0)) || (cse_var_2 == max(min(cast(int32, @tir.floor(cse_var_3, dtype=float32)), 31), 0))), (select((cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) - 1), 31), 0)), (-0.75f32*(((((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32))) - (2f32*((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32))))) + (cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))), 0f32) + select((cse_var_2 == max(min(cast(int32, @tir.floor(cse_var_3, dtype=float32)), 31), 0)), (((1.25f32*(((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))) - (2.25f32*((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32))))) + 1f32), 0f32)), 0f32) + select((cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 1), 31), 0)), (((-1.25f32*(((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))) + (1.5f32*((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32))))) - (-0.75f32*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))), 0f32)), 0f32) + select((cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 2), 31), 0)), ((0.75f32*(((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))) + (-0.75f32*((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32))))), 0f32))*((0.75f32*(((cse_var_1 - @tir.floor(cse_var_1, dtype=float32))*(cse_var_1 - @tir.floor(cse_var_1, dtype=float32)))*(cse_var_1 - @tir.floor(cse_var_1, dtype=float32)))) + (-0.75f32*((cse_var_1 - @tir.floor(cse_var_1, dtype=float32))*(cse_var_1 - @tir.floor(cse_var_1, dtype=float32)))))), 0f32))))
              }
            }
          }
        }
      }
    }
  }
}

With this commit:

@main = primfn(inp_1: handle, dy_1: handle, resize.inp.grad_1: handle) -> ()
  attr = {"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}
  buffers = {inp: Buffer(inp_2: Pointer(float32), float32, [32, 3, 32, 32], []),
             dy: Buffer(dy_2: Pointer(float32), float32, [32, 3, 64, 32], []),
             resize.inp.grad: Buffer(resize.inp.grad_2: Pointer(float32), float32, [32, 3, 32, 32], [])}
  buffer_map = {inp_1: inp, dy_1: dy, resize.inp.grad_1: resize.inp.grad} {
  for (ax0.ax1.fused: int32, 0, 96) "parallel" {
    for (ax2: int32, 0, 32) {
      for (ax3.outer: int32, 0, 2) {
        resize.inp.grad_3: Buffer(resize.inp.grad_2, float32, [98304], [])[ramp((((ax0.ax1.fused*1024) + (ax2*32)) + (ax3.outer*16)), 1, 16)] = broadcast(0f32, 16)
        for (n0_n0_k2.shifted.shifted: int32, 0, 64) {
          for (n1_n1_k3.shifted.shifted: int32, 0, 32) {
            for (ax3.inner.s: int32, 0, 16) {
              let cse_var_3: float32 = cast(float32, n1_n1_k3.shifted.shifted)
              let cse_var_2: int32 = ((ax3.outer*16) + ax3.inner.s)
              let cse_var_1: float32 = (((cast(float32, n0_n0_k2.shifted.shifted) + 0.5f32)*0.5f32) - 0.5f32)
              if ((((((ax2 - max(min((cast(int32, @tir.floor(cse_var_1, dtype=float32)) - 1), 31), 0)) == 0) && (((((cse_var_2 - max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) - 1), 31), 0)) == 0) || ((cse_var_2 - max(min(cast(int32, @tir.floor(cse_var_3, dtype=float32)), 31), 0)) == 0)) || ((cse_var_2 - max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 1), 31), 0)) == 0)) || ((cse_var_2 - max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 2), 31), 0)) == 0))) || (((ax2 - max(min(cast(int32, @tir.floor(cse_var_1, dtype=float32)), 31), 0)) == 0) && (((((cse_var_2 - max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) - 1), 31), 0)) == 0) || ((cse_var_2 - max(min(cast(int32, @tir.floor(cse_var_3, dtype=float32)), 31), 0)) == 0)) || ((cse_var_2 - max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 1), 31), 0)) == 0)) || ((cse_var_2 - max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 2), 31), 0)) == 0)))) || (((ax2 - max(min((cast(int32, @tir.floor(cse_var_1, dtype=float32)) + 1), 31), 0)) == 0) && (((((cse_var_2 - max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) - 1), 31), 0)) == 0) || ((cse_var_2 - max(min(cast(int32, @tir.floor(cse_var_3, dtype=float32)), 31), 0)) == 0)) || ((cse_var_2 - max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 1), 31), 0)) == 0)) || ((cse_var_2 - max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 2), 31), 0)) == 0)))) || (((ax2 - max(min((cast(int32, @tir.floor(cse_var_1, dtype=float32)) + 2), 31), 0)) == 0) && (((((cse_var_2 - max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) - 1), 31), 0)) == 0) || ((cse_var_2 - max(min(cast(int32, @tir.floor(cse_var_3, dtype=float32)), 31), 0)) == 0)) || ((cse_var_2 - max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 1), 31), 0)) == 0)) || ((cse_var_2 - max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 2), 31), 0)) == 0)))) {
                let cse_var_4: int32 = ((((ax0.ax1.fused*1024) + (ax2*32)) + (ax3.outer*16)) + ax3.inner.s)
                resize.inp.grad_3[cse_var_4] = (resize.inp.grad_3[cse_var_4] + (dy_3: Buffer(dy_2, float32, [196608], [])[(((ax0.ax1.fused*2048) + (n0_n0_k2.shifted.shifted*32)) + n1_n1_k3.shifted.shifted)]*(select((((((((ax2 == max(min((cast(int32, @tir.floor(cse_var_1, dtype=float32)) - 1), 31), 0)) && (cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) - 1), 31), 0))) || ((ax2 == max(min((cast(int32, @tir.floor(cse_var_1, dtype=float32)) - 1), 31), 0)) && (cse_var_2 == max(min(cast(int32, @tir.floor(cse_var_3, dtype=float32)), 31), 0)))) || ((ax2 == max(min((cast(int32, @tir.floor(cse_var_1, dtype=float32)) - 1), 31), 0)) && (cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 1), 31), 0)))) || ((ax2 == max(min((cast(int32, @tir.floor(cse_var_1, dtype=float32)) - 1), 31), 0)) && (cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 2), 31), 0)))) || (((((ax2 == max(min(cast(int32, @tir.floor(cse_var_1, dtype=float32)), 31), 0)) && (cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) - 1), 31), 0))) || ((ax2 == max(min(cast(int32, @tir.floor(cse_var_1, dtype=float32)), 31), 0)) && (cse_var_2 == max(min(cast(int32, @tir.floor(cse_var_3, dtype=float32)), 31), 0)))) || ((ax2 == max(min(cast(int32, @tir.floor(cse_var_1, dtype=float32)), 31), 0)) && (cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 1), 31), 0)))) || ((ax2 == max(min(cast(int32, @tir.floor(cse_var_1, dtype=float32)), 31), 0)) && (cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 2), 31), 0))))) || (((((ax2 == max(min((cast(int32, @tir.floor(cse_var_1, dtype=float32)) + 1), 31), 0)) && (cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) - 1), 31), 0))) || ((ax2 == max(min((cast(int32, @tir.floor(cse_var_1, dtype=float32)) + 1), 31), 0)) && (cse_var_2 == max(min(cast(int32, @tir.floor(cse_var_3, dtype=float32)), 31), 0)))) || ((ax2 == max(min((cast(int32, @tir.floor(cse_var_1, dtype=float32)) + 1), 31), 0)) && (cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 1), 31), 0)))) || ((ax2 == max(min((cast(int32, @tir.floor(cse_var_1, dtype=float32)) + 1), 31), 0)) && (cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 2), 31), 0))))), (select(((((((ax2 == max(min((cast(int32, @tir.floor(cse_var_1, dtype=float32)) - 1), 31), 0)) && (cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) - 1), 31), 0))) || ((ax2 == max(min((cast(int32, @tir.floor(cse_var_1, dtype=float32)) - 1), 31), 0)) && (cse_var_2 == max(min(cast(int32, @tir.floor(cse_var_3, dtype=float32)), 31), 0)))) || ((ax2 == max(min((cast(int32, @tir.floor(cse_var_1, dtype=float32)) - 1), 31), 0)) && (cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 1), 31), 0)))) || ((ax2 == max(min((cast(int32, @tir.floor(cse_var_1, dtype=float32)) - 1), 31), 0)) && (cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 2), 31), 0)))) || (((((ax2 == max(min(cast(int32, @tir.floor(cse_var_1, dtype=float32)), 31), 0)) && (cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) - 1), 31), 0))) || ((ax2 == max(min(cast(int32, @tir.floor(cse_var_1, dtype=float32)), 31), 0)) && (cse_var_2 == max(min(cast(int32, @tir.floor(cse_var_3, dtype=float32)), 31), 0)))) || ((ax2 == max(min(cast(int32, @tir.floor(cse_var_1, dtype=float32)), 31), 0)) && (cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 1), 31), 0)))) || ((ax2 == max(min(cast(int32, @tir.floor(cse_var_1, dtype=float32)), 31), 0)) && (cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 2), 31), 0))))), (select((((((ax2 == max(min((cast(int32, @tir.floor(cse_var_1, dtype=float32)) - 1), 31), 0)) && (cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) - 1), 31), 0))) || ((ax2 == max(min((cast(int32, @tir.floor(cse_var_1, dtype=float32)) - 1), 31), 0)) && (cse_var_2 == max(min(cast(int32, @tir.floor(cse_var_3, dtype=float32)), 31), 0)))) || ((ax2 == max(min((cast(int32, @tir.floor(cse_var_1, dtype=float32)) - 1), 31), 0)) && (cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 1), 31), 0)))) || ((ax2 == max(min((cast(int32, @tir.floor(cse_var_1, dtype=float32)) - 1), 31), 0)) && (cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 2), 31), 0)))), ((select((((cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) - 1), 31), 0)) || (cse_var_2 == max(min(cast(int32, @tir.floor(cse_var_3, dtype=float32)), 31), 0))) || (cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 1), 31), 0))), (select(((cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) - 1), 31), 0)) || (cse_var_2 == max(min(cast(int32, @tir.floor(cse_var_3, dtype=float32)), 31), 0))), (select((cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) - 1), 31), 0)), (-0.75f32*(((((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32))) - (2f32*((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32))))) + (cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))), 0f32) + select((cse_var_2 == max(min(cast(int32, @tir.floor(cse_var_3, dtype=float32)), 31), 0)), (((1.25f32*(((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))) - (2.25f32*((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32))))) + 1f32), 0f32)), 0f32) + select((cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 1), 31), 0)), (((-1.25f32*(((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))) + (1.5f32*((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32))))) - (-0.75f32*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))), 0f32)), 0f32) + select((cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 2), 31), 0)), ((0.75f32*(((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))) + (-0.75f32*((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32))))), 0f32))*(-0.75f32*(((((cse_var_1 - @tir.floor(cse_var_1, dtype=float32))*(cse_var_1 - @tir.floor(cse_var_1, dtype=float32)))*(cse_var_1 - @tir.floor(cse_var_1, dtype=float32))) - (2f32*((cse_var_1 - @tir.floor(cse_var_1, dtype=float32))*(cse_var_1 - @tir.floor(cse_var_1, dtype=float32))))) + (cse_var_1 - @tir.floor(cse_var_1, dtype=float32))))), 0f32) + select((((((ax2 == max(min(cast(int32, @tir.floor(cse_var_1, dtype=float32)), 31), 0)) && (cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) - 1), 31), 0))) || ((ax2 == max(min(cast(int32, @tir.floor(cse_var_1, dtype=float32)), 31), 0)) && (cse_var_2 == max(min(cast(int32, @tir.floor(cse_var_3, dtype=float32)), 31), 0)))) || ((ax2 == max(min(cast(int32, @tir.floor(cse_var_1, dtype=float32)), 31), 0)) && (cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 1), 31), 0)))) || ((ax2 == max(min(cast(int32, @tir.floor(cse_var_1, dtype=float32)), 31), 0)) && (cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 2), 31), 0)))), ((select((((cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) - 1), 31), 0)) || (cse_var_2 == max(min(cast(int32, @tir.floor(cse_var_3, dtype=float32)), 31), 0))) || (cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 1), 31), 0))), (select(((cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) - 1), 31), 0)) || (cse_var_2 == max(min(cast(int32, @tir.floor(cse_var_3, dtype=float32)), 31), 0))), (select((cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) - 1), 31), 0)), (-0.75f32*(((((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32))) - (2f32*((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32))))) + (cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))), 0f32) + select((cse_var_2 == max(min(cast(int32, @tir.floor(cse_var_3, dtype=float32)), 31), 0)), (((1.25f32*(((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))) - (2.25f32*((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32))))) + 1f32), 0f32)), 0f32) + select((cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 1), 31), 0)), (((-1.25f32*(((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))) + (1.5f32*((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32))))) - (-0.75f32*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))), 0f32)), 0f32) + select((cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 2), 31), 0)), ((0.75f32*(((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))) + (-0.75f32*((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32))))), 0f32))*(((1.25f32*(((cse_var_1 - @tir.floor(cse_var_1, dtype=float32))*(cse_var_1 - @tir.floor(cse_var_1, dtype=float32)))*(cse_var_1 - @tir.floor(cse_var_1, dtype=float32)))) - (2.25f32*((cse_var_1 - @tir.floor(cse_var_1, dtype=float32))*(cse_var_1 - @tir.floor(cse_var_1, dtype=float32))))) + 1f32)), 0f32)), 0f32) + select((((((ax2 == max(min((cast(int32, @tir.floor(cse_var_1, dtype=float32)) + 1), 31), 0)) && (cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) - 1), 31), 0))) || ((ax2 == max(min((cast(int32, @tir.floor(cse_var_1, dtype=float32)) + 1), 31), 0)) && (cse_var_2 == max(min(cast(int32, @tir.floor(cse_var_3, dtype=float32)), 31), 0)))) || ((ax2 == max(min((cast(int32, @tir.floor(cse_var_1, dtype=float32)) + 1), 31), 0)) && (cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 1), 31), 0)))) || ((ax2 == max(min((cast(int32, @tir.floor(cse_var_1, dtype=float32)) + 1), 31), 0)) && (cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 2), 31), 0)))), ((select((((cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) - 1), 31), 0)) || (cse_var_2 == max(min(cast(int32, @tir.floor(cse_var_3, dtype=float32)), 31), 0))) || (cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 1), 31), 0))), (select(((cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) - 1), 31), 0)) || (cse_var_2 == max(min(cast(int32, @tir.floor(cse_var_3, dtype=float32)), 31), 0))), (select((cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) - 1), 31), 0)), (-0.75f32*(((((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32))) - (2f32*((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32))))) + (cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))), 0f32) + select((cse_var_2 == max(min(cast(int32, @tir.floor(cse_var_3, dtype=float32)), 31), 0)), (((1.25f32*(((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))) - (2.25f32*((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32))))) + 1f32), 0f32)), 0f32) + select((cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 1), 31), 0)), (((-1.25f32*(((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))) + (1.5f32*((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32))))) - (-0.75f32*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))), 0f32)), 0f32) + select((cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 2), 31), 0)), ((0.75f32*(((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))) + (-0.75f32*((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32))))), 0f32))*(((-1.25f32*(((cse_var_1 - @tir.floor(cse_var_1, dtype=float32))*(cse_var_1 - @tir.floor(cse_var_1, dtype=float32)))*(cse_var_1 - @tir.floor(cse_var_1, dtype=float32)))) + (1.5f32*((cse_var_1 - @tir.floor(cse_var_1, dtype=float32))*(cse_var_1 - @tir.floor(cse_var_1, dtype=float32))))) - (-0.75f32*(cse_var_1 - @tir.floor(cse_var_1, dtype=float32))))), 0f32)), 0f32) + select((((((ax2 == max(min((cast(int32, @tir.floor(cse_var_1, dtype=float32)) + 2), 31), 0)) && (cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) - 1), 31), 0))) || ((ax2 == max(min((cast(int32, @tir.floor(cse_var_1, dtype=float32)) + 2), 31), 0)) && (cse_var_2 == max(min(cast(int32, @tir.floor(cse_var_3, dtype=float32)), 31), 0)))) || ((ax2 == max(min((cast(int32, @tir.floor(cse_var_1, dtype=float32)) + 2), 31), 0)) && (cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 1), 31), 0)))) || ((ax2 == max(min((cast(int32, @tir.floor(cse_var_1, dtype=float32)) + 2), 31), 0)) && (cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 2), 31), 0)))), ((select((((cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) - 1), 31), 0)) || (cse_var_2 == max(min(cast(int32, @tir.floor(cse_var_3, dtype=float32)), 31), 0))) || (cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 1), 31), 0))), (select(((cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) - 1), 31), 0)) || (cse_var_2 == max(min(cast(int32, @tir.floor(cse_var_3, dtype=float32)), 31), 0))), (select((cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) - 1), 31), 0)), (-0.75f32*(((((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32))) - (2f32*((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32))))) + (cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))), 0f32) + select((cse_var_2 == max(min(cast(int32, @tir.floor(cse_var_3, dtype=float32)), 31), 0)), (((1.25f32*(((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))) - (2.25f32*((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32))))) + 1f32), 0f32)), 0f32) + select((cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 1), 31), 0)), (((-1.25f32*(((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))) + (1.5f32*((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32))))) - (-0.75f32*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))), 0f32)), 0f32) + select((cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 2), 31), 0)), ((0.75f32*(((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))) + (-0.75f32*((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32))))), 0f32))*((0.75f32*(((cse_var_1 - @tir.floor(cse_var_1, dtype=float32))*(cse_var_1 - @tir.floor(cse_var_1, dtype=float32)))*(cse_var_1 - @tir.floor(cse_var_1, dtype=float32)))) + (-0.75f32*((cse_var_1 - @tir.floor(cse_var_1, dtype=float32))*(cse_var_1 - @tir.floor(cse_var_1, dtype=float32)))))), 0f32))))
              }
            }
          }
        }
      }
    }
  }
}

The IRs are pretty much identical, so it may be due to the change of lowering passes.

cc @Lunderberg @masahi

Triage

  • needs-triage
@comaniac comaniac added type: bug needs-triage PRs or issues that need to be investigated by maintainers to find the right assignees to address it labels Nov 29, 2022
@masahi
Copy link
Member

masahi commented Nov 29, 2022

hmm strange, the new flag use_dataflow_analysis in RemoveNoOp is set to false by default, so I thought it shouldn't affect the default lowering in any way.

@Lunderberg
Copy link
Contributor

Thank you, and I'm seeing the same behavior with this example. Using f5a102c, the lowering step runs in 0.18s, while using 101e3a4 (just after PR#13217) the lowering step runs in 11.78s, the same 50x difference in performance that you're seeing. This definitely shouldn't be the case, as @masahi pointed out, since the additional analysis is disabled by default. I'm investigating into it.

@Lunderberg
Copy link
Contributor

Looks like the performance degredation is from RemoveNoOp. Even though the data-flow is disabled by default, the analyzer of IRMutatorWithAnalyzer still collects scoped information. Simplifications done by that analyzer don't show up in the output TIR, unless they are used to prove a statement to be a no-op (e.g. by having negative loop extent), but would impact the performance required.

It looks like a quick fix may be to disable the arith::RewriteSimplifier::kApplyConstraintsToBooleanBranches, which is currently enabled for the analyzer in RemoveNoOp, which restored the performance in this test case. Can you check if it also improves the performance on your side by removing kApplyConstraintsToBooleanBranches from this line?

I'm continuing to investigate, to see if this should be disabled, or if something else is wrong with simplifications. The lowered TIR has a lot of expressions that I would expect to be simplified. For example, that first @tir.floor in the if condition is tir.floor((((cast(float32, n0_n0_k2.shifted.shifted) + 0.5f32)*0.5f32) - 0.5f32), dtype=float32)). which is equivalent to floordiv(n0_n0_k2.shifted.shifted - 1, 2).

@comaniac
Copy link
Contributor Author

@Lunderberg your analysis makes a lot of sense. After removing kApplyConstraintsToBooleanBranches, the lowering time of the example became 13 seconds, which looks much more reasonable to me. It would be great if you could fix it by disabling the analyzer along with the flag.

In addition, just my two cents about the simplification, correct me if I'm wrong, low-level compilers (e.g., nvcc, llvm) should be capable of simplifying this expression by themselves, so you might not see any performance improvement even you apply this simplification at TIR level.

@tqchen
Copy link
Member

tqchen commented Nov 29, 2022

Ideally we should keep simplifier light weight. In this case, disabling kApplyConstraintsToBooleanBranches makes sense

Lunderberg added a commit to Lunderberg/tvm that referenced this issue Nov 30, 2022
During a `tir.Simplify` pass, these extensions were conditionally
enabled based on the `PassContext`.  Prior to this commit, they were
enabled by default in the `tir.RemoveNoOp` pass, as the simplified
expressions were only used to prove/disprove a no-op, and did not
appear in the output TIR.  However, this caused performance issues for
some nested boolean expressions.

This PR disables the analyzer extensions for the analyzer used by
`tir.RemoveNoOp`.  The extensions are still used internally by
`ControlFlowGraph`, including during the data-flow analysis used if
`tir.transform.RemoveNoOpConfig.use_dataflow_analysis` is enabled, so
the opt-in data-dependent no-op removals are unaffected.

Related to issue apache#13508.
@Lunderberg
Copy link
Contributor

@comaniac @tqchen I've submitted #13524, which disables the use of simplifier extensions by RemoveNoOp. My main concern was that it would prevent some of the planned simplifications in #13299, but all the test cases can either by handled without extensions, or have data-flow analysis enabled which uses all the extensions.

In addition, just my two cents about the simplification, correct me if I'm wrong, low-level compilers (e.g., nvcc, llvm) should be capable of simplifying this expression by themselves, so you might not see any performance improvement even you apply this simplification at TIR level.

I was curious on this, and did some benchmarks after modifying resize.py to only use integer fractions to compute the indices and linear/cubic interpolations weights, which ended up having about a 1000x improvement in execution speed on the LLVM backend. Apart from the floats, the main difference in the TIR was that tir.VectorizeLoop could identify an opportunity to vectorize the innermost loop for integer indices, but couldn't do so for floating-point indices.

image

Since there was such a benefit, I'm going to clean up and PR those changes to topi.image.resize. (It also has a 10x improvement in the time required to lower the schedule, so that's also a plus.)

@comaniac
Copy link
Contributor Author

Thanks for the fix and investigation. Apparently the LLVM backend doesn't aware of this transform. If possible, could you also benchmark on GPU to test nvcc?

Based on the benchmark results, I agree that we should include this transform once the we make it reasonably light weight.

vinx13 pushed a commit that referenced this issue Nov 30, 2022
During a `tir.Simplify` pass, these extensions were conditionally
enabled based on the `PassContext`.  Prior to this commit, they were
enabled by default in the `tir.RemoveNoOp` pass, as the simplified
expressions were only used to prove/disprove a no-op, and did not
appear in the output TIR.  However, this caused performance issues for
some nested boolean expressions.

This PR disables the analyzer extensions for the analyzer used by
`tir.RemoveNoOp`.  The extensions are still used internally by
`ControlFlowGraph`, including during the data-flow analysis used if
`tir.transform.RemoveNoOpConfig.use_dataflow_analysis` is enabled, so
the opt-in data-dependent no-op removals are unaffected.

Related to issue #13508.
@Lunderberg
Copy link
Contributor

Lunderberg commented Nov 30, 2022

Testing on the GPU, with both cuda and vulkan backends (nvidia-driver-470 on ubuntu 21.04), it shows a pretty similar effect. It isn't quite as dramatic, only 50x slower instead of 1000x, but it's still quite a large effect. Both GPU tests were done with the same compute definition, but with topi.cuda.injective.schedule_injective instead of topi.x86.injective.schedule_injective

image

The specific fix here (the int vs float indexing) wasn't on the transformation side, but a change to the topi operator. The nice thing is that it can be a lot more general, and can convert floating point numbers to integer ratios (e.g. the -0.75 in the example into Fraction(-3, 4)) before they get too folded to be recognized. The downside is that it isn't as general of a solution.

Lunderberg added a commit to Lunderberg/tvm that referenced this issue Dec 1, 2022
Prior to this commit, floating point expressions were used to map
between different-sized pixel arrays.  These floating point
expressions are less aggressively optimized by `RewriteSimplifier`,
which can prevent some optimizations

This was first noticed during investigation into issue apache#13508.
Benchmarks of `topi.image.resize` showed 1000x and 50x performance
improvements using the LLVM and CUDA backends, respectively, by using
integer expressions instead of floating point.  This performance
improvement is partly driven by enabling
`tir.transform.VectorizeLoops` to recognize vectorizable indices,
where the round-trip through floating point previously prevented that
optimization.
Lunderberg added a commit to Lunderberg/tvm that referenced this issue Dec 1, 2022
Prior to this commit, floating point expressions were used to map
between different-sized pixel arrays.  These floating point
expressions are less aggressively optimized by `RewriteSimplifier`,
which can prevent some optimizations

This was first noticed during investigation into issue apache#13508.
Benchmarks of `topi.image.resize` showed 1000x and 50x performance
improvements using the LLVM and CUDA backends, respectively, by using
integer expressions instead of floating point.  This performance
improvement is partly driven by enabling
`tir.transform.VectorizeLoops` to recognize vectorizable indices,
where the round-trip through floating point previously prevented that
optimization.
Lunderberg added a commit to Lunderberg/tvm that referenced this issue Dec 1, 2022
Prior to this commit, floating point expressions were used to map
between different-sized pixel arrays.  These floating point
expressions are less aggressively optimized by `RewriteSimplifier`,
which can prevent some optimizations

This was first noticed during investigation into issue apache#13508.
Benchmarks of `topi.image.resize` showed 1000x and 50x performance
improvements using the LLVM and CUDA backends, respectively, by using
integer expressions instead of floating point.  This performance
improvement is partly driven by enabling
`tir.transform.VectorizeLoops` to recognize vectorizable indices,
where the round-trip through floating point previously prevented that
optimization.
Lunderberg added a commit to Lunderberg/tvm that referenced this issue Dec 1, 2022
Prior to this commit, floating point expressions were used to map
between different-sized pixel arrays.  These floating point
expressions are less aggressively optimized by `RewriteSimplifier`,
which can prevent some optimizations

This was first noticed during investigation into issue apache#13508.
Benchmarks of `topi.image.resize` showed 1000x and 50x performance
improvements using the LLVM and CUDA backends, respectively, by using
integer expressions instead of floating point.  This performance
improvement is partly driven by enabling
`tir.transform.VectorizeLoops` to recognize vectorizable indices,
where the round-trip through floating point previously prevented that
optimization.
Lunderberg added a commit to Lunderberg/tvm that referenced this issue Dec 1, 2022
Prior to this commit, floating point expressions were used to map
between different-sized pixel arrays.  These floating point
expressions are less aggressively optimized by `RewriteSimplifier`,
which can prevent some optimizations

This was first noticed during investigation into issue apache#13508.
Benchmarks of `topi.image.resize` showed 1000x and 50x performance
improvements using the LLVM and CUDA backends, respectively, by using
integer expressions instead of floating point.  This performance
improvement is partly driven by enabling
`tir.transform.VectorizeLoops` to recognize vectorizable indices,
where the round-trip through floating point previously prevented that
optimization.
Lunderberg added a commit to Lunderberg/tvm that referenced this issue Dec 1, 2022
Prior to this commit, floating point expressions were used to map
between different-sized pixel arrays.  These floating point
expressions are less aggressively optimized by `RewriteSimplifier`,
which can prevent some optimizations

This was first noticed during investigation into issue apache#13508.
Benchmarks of `topi.image.resize` showed 1000x and 50x performance
improvements using the LLVM and CUDA backends, respectively, by using
integer expressions instead of floating point.  This performance
improvement is partly driven by enabling
`tir.transform.VectorizeLoops` to recognize vectorizable indices,
where the round-trip through floating point previously prevented that
optimization.
@masahi
Copy link
Member

masahi commented Dec 8, 2022

Can we close this?

@comaniac comaniac closed this as completed Dec 8, 2022
@comaniac
Copy link
Contributor Author

comaniac commented Dec 8, 2022

Yeah this particular issue has been resolved. Closed.

Lunderberg added a commit to Lunderberg/tvm that referenced this issue Jan 4, 2023
Prior to this commit, floating point expressions were used to map
between different-sized pixel arrays.  These floating point
expressions are less aggressively optimized by `RewriteSimplifier`,
which can prevent some optimizations

This was first noticed during investigation into issue apache#13508.
Benchmarks of `topi.image.resize` showed 1000x and 50x performance
improvements using the LLVM and CUDA backends, respectively, by using
integer expressions instead of floating point.  This performance
improvement is partly driven by enabling
`tir.transform.VectorizeLoops` to recognize vectorizable indices,
where the round-trip through floating point previously prevented that
optimization.
Lunderberg added a commit to Lunderberg/tvm that referenced this issue Jan 11, 2023
Prior to this commit, floating point expressions were used to map
between different-sized pixel arrays.  These floating point
expressions are less aggressively optimized by `RewriteSimplifier`,
which can prevent some optimizations

This was first noticed during investigation into issue apache#13508.
Benchmarks of `topi.image.resize` showed 1000x and 50x performance
improvements using the LLVM and CUDA backends, respectively, by using
integer expressions instead of floating point.  This performance
improvement is partly driven by enabling
`tir.transform.VectorizeLoops` to recognize vectorizable indices,
where the round-trip through floating point previously prevented that
optimization.
Lunderberg added a commit to Lunderberg/tvm that referenced this issue Jan 24, 2023
Prior to this commit, floating point expressions were used to map
between different-sized pixel arrays.  These floating point
expressions are less aggressively optimized by `RewriteSimplifier`,
which can prevent some optimizations

This was first noticed during investigation into issue apache#13508.
Benchmarks of `topi.image.resize` showed 1000x and 50x performance
improvements using the LLVM and CUDA backends, respectively, by using
integer expressions instead of floating point.  This performance
improvement is partly driven by enabling
`tir.transform.VectorizeLoops` to recognize vectorizable indices,
where the round-trip through floating point previously prevented that
optimization.
Lunderberg added a commit to Lunderberg/tvm that referenced this issue Jan 25, 2023
Prior to this commit, floating point expressions were used to map
between different-sized pixel arrays.  These floating point
expressions are less aggressively optimized by `RewriteSimplifier`,
which can prevent some optimizations

This was first noticed during investigation into issue apache#13508.
Benchmarks of `topi.image.resize` showed 1000x and 50x performance
improvements using the LLVM and CUDA backends, respectively, by using
integer expressions instead of floating point.  This performance
improvement is partly driven by enabling
`tir.transform.VectorizeLoops` to recognize vectorizable indices,
where the round-trip through floating point previously prevented that
optimization.
Lunderberg added a commit to Lunderberg/tvm that referenced this issue Apr 5, 2023
Prior to this commit, floating point expressions were used to map
between different-sized pixel arrays.  These floating point
expressions are less aggressively optimized by `RewriteSimplifier`,
which can prevent some optimizations

This was first noticed during investigation into issue apache#13508.
Benchmarks of `topi.image.resize` showed 1000x and 50x performance
improvements using the LLVM and CUDA backends, respectively, by using
integer expressions instead of floating point.  This performance
improvement is partly driven by enabling
`tir.transform.VectorizeLoops` to recognize vectorizable indices,
where the round-trip through floating point previously prevented that
optimization.
Lunderberg added a commit to Lunderberg/tvm that referenced this issue Sep 11, 2024
Prior to this commit, floating point expressions were used to map
between different-sized pixel arrays.  These floating point
expressions are less aggressively optimized by `RewriteSimplifier`,
which can prevent some optimizations

This was first noticed during investigation into issue apache#13508.
Benchmarks of `topi.image.resize` showed 1000x and 50x performance
improvements using the LLVM and CUDA backends, respectively, by using
integer expressions instead of floating point.  This performance
improvement is partly driven by enabling
`tir.transform.VectorizeLoops` to recognize vectorizable indices,
where the round-trip through floating point previously prevented that
optimization.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
needs-triage PRs or issues that need to be investigated by maintainers to find the right assignees to address it type: bug
Projects
None yet
Development

No branches or pull requests

4 participants