From 54efdc6b3c73268ecaf6e7ec3659cfe31da94977 Mon Sep 17 00:00:00 2001 From: Y Date: Wed, 28 Apr 2021 01:09:55 +0800 Subject: [PATCH] [Tensorize] Fix compute reusing (#7920) --- src/te/operation/tensorize.cc | 2 +- .../python/unittest/test_te_schedule_tensorize.py | 14 ++++++++++++++ 2 files changed, 15 insertions(+), 1 deletion(-) diff --git a/src/te/operation/tensorize.cc b/src/te/operation/tensorize.cc index ea713220eddd..0aa279fb9246 100644 --- a/src/te/operation/tensorize.cc +++ b/src/te/operation/tensorize.cc @@ -327,7 +327,7 @@ void VerifyTensorizeBody(const ComputeOpNode* self, const Stage& stage, ana.Bind(compute_intrin_iter_space); for (size_t i = 0; i < body.size(); ++i) { - PrimExpr lhs = ana.Simplify(body[i]); + PrimExpr lhs = ana.Simplify(Substitute(body[i], value_map)); // run substitution because the intrin body could depend on outer loop vars. PrimExpr rhs = ana.Simplify(Substitute(intrin_compute->body[i], value_map)); if (lhs.dtype() != rhs.dtype()) { diff --git a/tests/python/unittest/test_te_schedule_tensorize.py b/tests/python/unittest/test_te_schedule_tensorize.py index fdafdb74fc0b..e2c2f7f7e0e5 100644 --- a/tests/python/unittest/test_te_schedule_tensorize.py +++ b/tests/python/unittest/test_te_schedule_tensorize.py @@ -146,8 +146,22 @@ def check_cache_write(m, factor): stmt = tvm.te.schedule.ScheduleOps(s, dom_map) tvm.lower(s, [x, y, z]) + def check_compute_reuse(): + x, y, z = add(32) + + def _intrin_vadd(): + def _intrin_func(ins, outs): + return tvm.tir.call_packed("vadd", ins[0], ins[1], outs[0]) + + return tvm.te.decl_tensor_intrin(z.op, _intrin_func) + + s = tvm.te.create_schedule(z.op) + s[z].tensorize(z.op.axis[0], _intrin_vadd()) + tvm.lower(s, [x, y, z]) + check(128, 16) check_cache_write(129, 16) + check_compute_reuse() def test_tensorize_matmul():