diff --git a/python/tvm/tl/language.py b/python/tvm/tl/language.py index 6e3e78686cff..2ec9da78b8cd 100644 --- a/python/tvm/tl/language.py +++ b/python/tvm/tl/language.py @@ -136,13 +136,13 @@ def Kernel(*blocks: List[tir.PrimExpr], threads: Union[int, List[int], Tuple] = return _ffi_api.KernelLaunch(blocks, threads, attrs) -def use_swizzle(panel_size: int, order: str = "row"): +def use_swizzle(panel_size: int, order: str = "row", enable: bool = True): device_func = ( "rasterization2DRow" if order == "row" else "rasterization2DColumn" ) return T.attr( None, "threadblock_swizzle_pattern", f"tl::{device_func}<{panel_size}>" - ) + ) if enable else None def alloc_shared(shape, dtype, scope="shared.dyn"): diff --git a/src/tl/tl_templates/threadblock_swizzle.h b/src/tl/tl_templates/threadblock_swizzle.h index ed8b3935eb8f..4b08c2ebd65d 100644 --- a/src/tl/tl_templates/threadblock_swizzle.h +++ b/src/tl/tl_templates/threadblock_swizzle.h @@ -6,33 +6,33 @@ namespace tl { template TL_DEVICE dim3 rasterization2DRow() { - const int block_idx = blockIdx.x + blockIdx.y * gridDim.x; - const int grid_size = gridDim.x * gridDim.y; - const int panel_size = panel_width * gridDim.x; - const int panel_offset = block_idx % panel_size; - const int panel_idx = block_idx / panel_size; - const int total_panel = cutlass::ceil_div(grid_size, panel_size); - const int stride = + const unsigned int block_idx = blockIdx.x + blockIdx.y * gridDim.x; + const unsigned int grid_size = gridDim.x * gridDim.y; + const unsigned int panel_size = panel_width * gridDim.x; + const unsigned int panel_offset = block_idx % panel_size; + const unsigned int panel_idx = block_idx / panel_size; + const unsigned int total_panel = cutlass::ceil_div(grid_size, panel_size); + const unsigned int stride = panel_idx + 1 < total_panel ? panel_width : (grid_size - panel_idx * panel_size) / gridDim.x; - const int col_idx = + const unsigned int col_idx = (panel_idx & 1) ? gridDim.x - 1 - panel_offset / stride : panel_offset / stride; - const int row_idx = panel_offset % stride + panel_idx * panel_width; + const unsigned int row_idx = panel_offset % stride + panel_idx * panel_width; return {col_idx, row_idx, blockIdx.z}; } template TL_DEVICE dim3 rasterization2DColumn() { - const int block_idx = blockIdx.x + blockIdx.y * gridDim.x; - const int grid_size = gridDim.x * gridDim.y; - const int panel_size = panel_width * gridDim.y; - const int panel_offset = block_idx % panel_size; - const int panel_idx = block_idx / panel_size; - const int total_panel = cutlass::ceil_div(grid_size, panel_size); - const int stride = + const unsigned int block_idx = blockIdx.x + blockIdx.y * gridDim.x; + const unsigned int grid_size = gridDim.x * gridDim.y; + const unsigned int panel_size = panel_width * gridDim.y; + const unsigned int panel_offset = block_idx % panel_size; + const unsigned int panel_idx = block_idx / panel_size; + const unsigned int total_panel = cutlass::ceil_div(grid_size, panel_size); + const unsigned int stride = panel_idx + 1 < total_panel ? panel_width : (grid_size - panel_idx * panel_size) / gridDim.y; - const int row_idx = + const unsigned int row_idx = (panel_idx & 1) ? gridDim.y - 1 - panel_offset / stride : panel_offset / stride; - const int col_idx = panel_offset % stride + panel_idx * panel_width; + const unsigned int col_idx = panel_offset % stride + panel_idx * panel_width; return {col_idx, row_idx, blockIdx.z}; }