Skip to content

Commit

Permalink
add VNNI unittest
Browse files Browse the repository at this point in the history
  • Loading branch information
masahi committed Apr 6, 2022
1 parent 6cc8009 commit f88c31e
Showing 1 changed file with 38 additions and 2 deletions.
40 changes: 38 additions & 2 deletions tests/python/unittest/test_tir_schedule_tensorize.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,10 @@
import pytest
import tvm
import tvm.testing
from tvm import tir
from tvm import tir, te
from tvm.script import tir as T
from tvm.tir.schedule.testing import verify_trace_roundtrip
from tvm.tir.tensor_intrin.vnni import INTRIN_NAME as VNNI_INTRIN

# fmt: off
# pylint: disable=no-member,invalid-name,unused-variable,line-too-long,redefined-outer-name,unexpected-keyword-arg,too-many-nested-blocks
Expand Down Expand Up @@ -531,5 +532,40 @@ def test_tensorize_with_annotation():
verify_trace_roundtrip(sch=s, mod=func)


def test_tensorize_vnni():
n, m, k = 128, 128, 128
X = te.placeholder((m, k), name="X", dtype="uint8")
packed_W = te.placeholder((n // 16, k // 4, 16, 4), name="packedW", dtype="int8")

ak = te.reduce_axis((0, k), name="k")
matmul = te.compute(
(m, n),
lambda i, j: te.sum(
X[i, ak].astype("int32")
* packed_W[
tvm.tir.indexdiv(j, 16), tvm.tir.indexdiv(ak, 4), j % 16, ak % 4
].astype("int32"),
axis=ak,
),
name="compute",
)

func = te.create_prim_func([X, packed_W, matmul])

sch = tir.Schedule(func, debug_mask="all")
block = sch.get_block("compute")
_, j, k = sch.get_loops(block)

_, ji = sch.split(j, factors=[None, 16])
ko, ki = sch.split(k, factors=[None, 4])
sch.reorder(ko, ji, ki)

sch.decompose_reduction(block, ko)
sch.tensorize(ji, VNNI_INTRIN)

verify_trace_roundtrip(sch=sch, mod=func)


if __name__ == "__main__":
sys.exit(pytest.main([__file__] + sys.argv[1:]))
# sys.exit(pytest.main([__file__] + sys.argv[1:]))
test_tensorize_vnni()

0 comments on commit f88c31e

Please sign in to comment.