diff --git a/src/tir/transforms/lower_sparse_tir.cc b/src/tir/transforms/lower_sparse_tir.cc index 2b913f6281fd..1d91f6a4969c 100644 --- a/src/tir/transforms/lower_sparse_tir.cc +++ b/src/tir/transforms/lower_sparse_tir.cc @@ -27,7 +27,6 @@ #include #include -#include #include #include "../../support/utils.h" @@ -521,84 +520,84 @@ class IndexTransformer : public StmtExprMutator { var_map[sp_iter_var->var.get()] = loop_var; } - // Step 4. Collet block iters and iter bindings. - std::set in_stack; + // Step 4. Collect block iters and iter bindings. + /* Whether the axis appears in the stack. */ + std::unordered_set in_stack; /* A stack that stores block itervars in each block. */ - std::stack> block_iters_st; + std::vector> block_iters_st; /* A stack that stores itervar bindings in each block. */ - std::stack> iter_bindings_st; + std::vector> iter_bindings_st; /* A stack that stores generated loop vars in each block. */ - std::stack> loop_vars_st; + std::vector> loop_vars_st; /* A stack that stores whether to place init block in each block. */ - std::stack place_init_st; + std::vector place_init_st; /* An indicator that records whether init block has been set. */ bool init_set = false; - do { - /* Block itervars of current block. */ - Array block_iters; - /* Itervar bindings of current block. */ - Array iter_bindings; - /* Axis names of current block. */ - Array blk_axes; - /* Generated loop vars of current block. */ - Array loop_vars; - /* An indicator that records whether there is reduction axis in current block. */ - bool has_reduction_var = false; - for (int i = 0; i < n_iter; ++i) { - SpIterVar sp_it_var = sp_block->sp_iter_vars[i]; - Axis axis = sp_it_var->axis; - - /* Add itervar to current block when - * - it's not used yet (not in stack) and - * - it's parent axis was used in outer blocks or - * - it's an iterator to a fixed axis. - */ - auto parent = axis->GetParentAxis(); - bool emit_iter_var = true; - if (in_stack.find(axis.get()) != - in_stack.end()) { // the iter var has already been emitted. - emit_iter_var = false; + /* Block itervars of current block. */ + Array block_iters; + /* Itervar bindings of current block. */ + Array iter_bindings; + /* Generated loop vars of current block. */ + Array loop_vars; + /* Whether the axis appears in the cuurent block. */ + std::unordered_set in_block; + /* An indicator that records whether there is reduction axis in current block. */ + bool has_reduction_var = false; + + auto UpdateStack = [&]() { + block_iters_st.emplace_back(std::move(block_iters)); + iter_bindings_st.emplace_back(std::move(iter_bindings)); + loop_vars_st.emplace_back(std::move(loop_vars)); + if (init_set) { + place_init_st.emplace_back(false); + } else { + place_init_st.emplace_back(has_reduction_var); + init_set |= has_reduction_var; + } + }; + + for (int i = 0; i < n_iter; ++i) { + SpIterVar sp_it_var = sp_block->sp_iter_vars[i]; + Axis axis = sp_it_var->axis; + auto parent = axis->GetParentAxis(); + bool create_new_blk = false; + bool is_fixed_axis = axis->kind() == AxisKind::kDenseFixed || axis->kind() == AxisKind::kSparseFixed; + if (!is_fixed_axis && parent.defined()) { + const AxisNode* parent_node = parent.value().get(); + if (in_block.find(parent_node) != in_block.end()) { + /* parent node is in the current block, need to create new block. */ + create_new_blk = true; + } else if (in_stack.find(parent_node) != in_stack.end()) { + /* parent node is in the previous blocks in the stack, no need to create new block. */ + create_new_blk = false; } else { - if (parent.defined()) { // has parent - if (in_stack.find(parent.value().get()) == in_stack.end()) { // parent not emitted yet - if (axis->kind() == AxisKind::kDenseVariable || - axis->kind() == AxisKind::kSparseVariable) { // is not fixed axis. - emit_iter_var = false; - } - } - } + CHECK(false) << "The parent axis of " << axis->GetName() << " should appear before " << axis->GetName() << " when defining a sparse block."; } - // LOG(INFO) << axis->name << " " << (parent.defined() ? parent.value()->name : "no-parent") - // << " " << emit_iter_var; - if (emit_iter_var) { - loop_vars.push_back(all_loop_vars[i]); - blk_axes.push_back(axis); - block_iters.push_back(SpIterVarToIterVar(sp_it_var, var_map)); - iter_bindings.push_back(all_loop_vars[i]); - has_reduction_var |= sp_it_var->is_reduction; + } + if (create_new_blk) { + /* update in stack set. */ + for (const AxisNode* node : in_block) { + in_stack.insert(node); } + /* Update stack. */ + UpdateStack(); + /* Reset block states. */ + loop_vars = {}; + block_iters = {}; + iter_bindings = {}; + has_reduction_var = false; + in_block.clear(); } - /* Tag axes in current block as "in-stack". */ - for (const Axis&& axis : blk_axes) { - in_stack.insert(axis.get()); - } + loop_vars.push_back(all_loop_vars[i]); + block_iters.push_back(SpIterVarToIterVar(sp_it_var, var_map)); + iter_bindings.push_back(all_loop_vars[i]); + has_reduction_var |= sp_it_var->is_reduction; + in_block.insert(axis.get()); + } - /* Update stack. */ - if (!block_iters.empty()) { - block_iters_st.push(std::move(block_iters)); - iter_bindings_st.push(std::move(iter_bindings)); - loop_vars_st.push(std::move(loop_vars)); - if (init_set) { - place_init_st.push(false); - } else { - place_init_st.push(has_reduction_var); - init_set |= has_reduction_var; - } - } else { - break; - } - } while (true); + // Update the last block. + UpdateStack(); // Step 5. Generate the read-region and write-retion of the block. Array reads{}; @@ -608,14 +607,14 @@ class IndexTransformer : public StmtExprMutator { // Step 6. Generate nested blocks and loops from innermost to outermost. int blk_counter = 0; while (!block_iters_st.empty()) { - Array block_iters = std::move(block_iters_st.top()); - Array iter_bindings = std::move(iter_bindings_st.top()); - Array loop_vars = std::move(loop_vars_st.top()); - bool place_init = place_init_st.top(); - block_iters_st.pop(); - iter_bindings_st.pop(); - loop_vars_st.pop(); - place_init_st.pop(); + Array block_iters = std::move(block_iters_st.back()); + Array iter_bindings = std::move(iter_bindings_st.back()); + Array loop_vars = std::move(loop_vars_st.back()); + bool place_init = place_init_st.back(); + block_iters_st.pop_back(); + iter_bindings_st.pop_back(); + loop_vars_st.pop_back(); + place_init_st.pop_back(); Map mapping; mapping.Set("sparse", Bool(true)); diff --git a/tests/python/sparsetir/test_tir_sparse_lower.py b/tests/python/sparsetir/test_tir_sparse_lower.py index 1388e9044dfc..086b8094be7b 100644 --- a/tests/python/sparsetir/test_tir_sparse_lower.py +++ b/tests/python/sparsetir/test_tir_sparse_lower.py @@ -40,7 +40,7 @@ def csrmm( A = T.match_sparse_buffer(a, (I, J), "float32") B = T.match_sparse_buffer(b, (T.dense(J), K), "float32") C = T.match_sparse_buffer(c, (I, K), "float32") - with T.iter([I, J, K], "SRS", "csrmm") as [vi, vj, vk]: + with T.iter([I, K, J], "SSR", "csrmm") as [vi, vk, vj]: with T.init(): C[vi, vk] = 0.0 C[vi, vk] = C[vi, vk] + A[vi, vj] * B[vj, vk] @@ -180,12 +180,12 @@ def bsrmm( B = T.match_sparse_buffer(b, (T.dense(J), BJ, F), "float32") C = T.match_sparse_buffer(c, (I, BI, F), "float32") - with T.iter([I, J, BI, BJ, F], "SRSRS", "bsrmm") as [ + with T.iter([I, BI, BJ, F, J], "SSRSR", "bsrmm") as [ vi, - vj, vbi, vbj, vf, + vj, ]: with T.init(): C[vi, vbi, vf] = 0.0 @@ -314,7 +314,6 @@ def lowered_csr_element_wise(a: T.handle, b: T.handle, indptr: T.handle, indices def test_csrmm(): mod = tvm.IRModule.from_expr(csrmm) mod = tvm.tir.transform.LowerSparseTIR()(mod) - print(mod["main"].script()) tvm.ir.assert_structural_equal(mod["main"], lowered_csrmm, True) A = sp.random(512, 512, dtype="float32", density=0.0125, format="csr") @@ -338,14 +337,12 @@ def test_csrmm(): def test_csrmm_dense_iter(): mod = tvm.IRModule.from_expr(csrmm_dense_iter) mod = tvm.tir.transform.LowerSparseTIR()(mod) - print(mod["main"].script()) # tvm.ir.assert_structural_equal(mod["main"], lowered_csrmm, True) def test_segment_reduce(): mod = tvm.IRModule.from_expr(segment_reduce) mod = tvm.tir.transform.LowerSparseTIR()(mod) - print(mod["main"].script()) def test_csr_reduce(): @@ -412,7 +409,6 @@ def test_bsrmm(): def test_ellpack_mm(): mod = tvm.IRModule.from_expr(ellpack_mm) mod = tvm.tir.transform.LowerSparseTIR()(mod) - print(mod["main"].script()) tvm.ir.assert_structural_equal(mod["main"], lowered_ellpack_mm, True) nnz_cols = 4