Skip to content
This repository has been archived by the owner on Nov 25, 2022. It is now read-only.

Commit

Permalink
[TIR] Disallow unused rhs vars in GetAutoTensorizeMapping (apache#12225)
Browse files Browse the repository at this point in the history
  • Loading branch information
vinx13 authored and xinetzone committed Nov 25, 2022
1 parent ae22a12 commit fcce475
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 0 deletions.
5 changes: 5 additions & 0 deletions src/tir/schedule/analysis/analysis.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2460,6 +2460,7 @@ class AutoTensorizeMappingProposer {
}

// Step 3: Fuse LHS iters mapped to the same RHS iter
std::unordered_set<Var, ObjectPtrHash, ObjectPtrEqual> used_rhs_vars;
for (size_t i = 0; i < extractor_->lhs_iters_.size(); ++i) {
const Var& lhs_iter_var = extractor_->lhs_iters_[i]->var;
const VarSet& rhs_candidates = lhs_feasible_vars_[lhs_iter_var];
Expand All @@ -2472,12 +2473,16 @@ class AutoTensorizeMappingProposer {
PrimExpr updated_fused_lhs =
fused_lhs * lhs_iter_extents.at(lhs_iter_var) + index_map_src[i];
fused_lhs_iters.Set(rhs_var, updated_fused_lhs);
used_rhs_vars.insert(rhs_var);
} else {
// non-unique mapping is not supported
return {};
}
}
for (const auto& iter : extractor_->rhs_iters_) {
if (!used_rhs_vars.count(iter->var)) {
return {};
}
index_map_tgt.push_back(analyzer_->Simplify(fused_lhs_iters[iter->var]));
}
// At most one mapping is supported.
Expand Down
24 changes: 24 additions & 0 deletions tests/python/unittest/test_tir_schedule_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,9 @@ def check_index_map(workload, block_name, intrin_name, expected_index_map):
block = s.get_block(block_name)
desc_func = TensorIntrin.get(intrin_name).desc
info = get_auto_tensorize_mapping_info(s, block, desc_func)
if expected_index_map is None:
assert info is None
return
assert len(info.mappings) == 1
assert IndexMap.from_func(expected_index_map).is_equivalent_to(info.mappings[0])

Expand Down Expand Up @@ -304,5 +307,26 @@ def test_get_auto_tensorize_mapping_info_batch_matmul(b, m, n, k):
)


@pytest.mark.parametrize(
"n,m,k,expected",
[
(
512,
512,
512,
lambda n, m, k: (
n,
m,
k,
),
),
(1, 32, 32, None),
],
)
def test_get_auto_tensorize_mapping_info_matmul(n, m, k, expected):
matmul = create_prim_func(te_workload.matmul(n, m, k, in_dtype="float16", out_dtype="float32"))
check_index_map(matmul, "C", WMMA_SYNC_16x16x16_f16f16f32_INTRIN, expected)


if __name__ == "__main__":
tvm.testing.main()

0 comments on commit fcce475

Please sign in to comment.