Skip to content

Commit

Permalink
[TOPI] Fix SME conv2d schedule import and intrin argument (#17040)
Browse files Browse the repository at this point in the history
Fixes a merge conflict between #16981 and #17003.

Change-Id: Ifcc983ef0b8c00250568a048fd682933adfdcde4
  • Loading branch information
lhutton1 authored May 29, 2024
1 parent d9240e4 commit 8bdd54b
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions python/tvm/topi/arm_cpu/conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -729,7 +729,7 @@ def schedule_conv2d_NHWC_hybrid_TIR(sch: tvm.tir.Schedule):
# pylint: disable=import-outside-toplevel
from tvm.topi.arm_cpu.pstate_attributes import SMEAttributes
from tvm.tir.tensor_intrin.arm_cpu import (
ARM_SME_2SVLx2SVL_TRANSPOSE_INTERLEAVE,
ARM_SME_2SVLx2SVL_FP32_TRANSPOSE_INTERLEAVE,
ARM_SME_2SVLx2SVL_GEMM_INTERLEAVED_MOPA,
ARM_SME_INIT,
get_sme_gemm_interleaved_mopa_2svlx2svl_intrin,
Expand All @@ -743,7 +743,7 @@ def schedule_conv2d_NHWC_hybrid_TIR(sch: tvm.tir.Schedule):
ko, ki = sch.split(k, factors=(None, tile_K), disable_predication=True)
sch.parallel(b)
sch.reorder(b, ko, mo, ki, mi)
sch.tensorize(ki, ARM_SME_2SVLx2SVL_TRANSPOSE_INTERLEAVE)
sch.tensorize(ki, ARM_SME_2SVLx2SVL_FP32_TRANSPOSE_INTERLEAVE)

# Split and reorder the loops of the GeMM for tensorization
b, m, n, k = sch.get_loops(gemm_block)
Expand All @@ -760,7 +760,7 @@ def schedule_conv2d_NHWC_hybrid_TIR(sch: tvm.tir.Schedule):
sme_gemm_interleaved_intrin_name = ARM_SME_2SVLx2SVL_GEMM_INTERLEAVED_MOPA + f"_{K_padded}"
tvm.tir.TensorIntrin.register(
sme_gemm_interleaved_intrin_name,
*get_sme_gemm_interleaved_mopa_2svlx2svl_intrin(K_padded),
*get_sme_gemm_interleaved_mopa_2svlx2svl_intrin(K_padded, dtype),
override=True,
)
sch.tensorize(mi, sme_gemm_interleaved_intrin_name)
Expand Down

0 comments on commit 8bdd54b

Please sign in to comment.