diff --git a/include/tvm/tir/stmt.h b/include/tvm/tir/stmt.h index 105dd7009355..106eef20ba93 100644 --- a/include/tvm/tir/stmt.h +++ b/include/tvm/tir/stmt.h @@ -1288,7 +1288,7 @@ class SparseBlockNode : public StmtNode { /*! \brief The sparse data structures */ Array sp_structs; /*! \brief The mapping from sparse data structures to the PrimFunc parameters */ - Map> sp_struct2param_map; + Map> sp_struct_param_map; /*! \brief The name of the block */ String name; /*! \brief The body of the block */ @@ -1299,7 +1299,7 @@ class SparseBlockNode : public StmtNode { void VisitAttrs(AttrVisitor* v) { v->Visit("sp_iter_vars", &sp_iter_vars); v->Visit("sp_structs", &sp_structs); - v->Visit("sp_struct2param_map", &sp_struct2param_map); + v->Visit("sp_struct_param_map", &sp_struct_param_map); v->Visit("name", &name); v->Visit("body", &body); v->Visit("init", &init); diff --git a/src/printer/tvmscript_printer.cc b/src/printer/tvmscript_printer.cc index 819e9cc5a8d7..caeee77e3889 100644 --- a/src/printer/tvmscript_printer.cc +++ b/src/printer/tvmscript_printer.cc @@ -1278,7 +1278,7 @@ Doc TVMScriptPrinter::PrintSparseStructDefinitions(const SparseBlockNode* sp_blo std::vector sp_buf_docs; for (const ObjectRef& obj : sp_block->sp_structs) { - Array params = sp_block->sp_struct2param_map.Get(obj).value(); + Array params = sp_block->sp_struct_param_map.Get(obj).value(); Doc doc; doc << Print(obj) << " = " << tir_prefix_ << "."; diff --git a/src/tir/ir/stmt.cc b/src/tir/ir/stmt.cc index c8dbfdffe30b..a5957a475cd6 100644 --- a/src/tir/ir/stmt.cc +++ b/src/tir/ir/stmt.cc @@ -974,7 +974,7 @@ SparseBlock::SparseBlock(Array sp_iter_vars, Array sp_stru CHECK_EQ(sp_structs.size(), sp_struct_params.size()) << "ValueError: The length of `sp_struct_params` is expected to be equal to the length " "`sp_structs`, which is the number of sparse data structures"; - Map> sp_struct2param_map; + Map> sp_struct_param_map; for (int i = 0; i < static_cast(sp_structs.size()); ++i) { ObjectRef obj = sp_structs[i]; Array params = sp_struct_params[i]; @@ -998,13 +998,13 @@ SparseBlock::SparseBlock(Array sp_iter_vars, Array sp_stru LOG(FATAL) << "ValueError: " << obj->_type_key << " is not a sparse data structure"; } - sp_struct2param_map.Set(obj, params); + sp_struct_param_map.Set(obj, params); } ObjectPtr node = make_object(); node->sp_iter_vars = std::move(sp_iter_vars); node->sp_structs = std::move(sp_structs); - node->sp_struct2param_map = std::move(sp_struct2param_map); + node->sp_struct_param_map = std::move(sp_struct_param_map); node->name = std::move(name); node->body = std::move(body); node->init = std::move(init); diff --git a/src/tir/transforms/lower_sparse_tir.cc b/src/tir/transforms/lower_sparse_tir.cc index 79a293713a32..e2ac8f7bd987 100644 --- a/src/tir/transforms/lower_sparse_tir.cc +++ b/src/tir/transforms/lower_sparse_tir.cc @@ -33,25 +33,40 @@ namespace tvm { namespace tir { +/*! + * \brief Get the mapping from Var to corresponding Buffer's. + * \param f The primitive function to visit. + * \return The map. + */ Map UpdateBufferMap(PrimFunc f) { struct BufferMapUpdater : public StmtVisitor { explicit BufferMapUpdater(Map buffer_map) : buffer_map_(std::move(buffer_map)) {} + /*! + * \brief Visit function to collect var to buffer mapping in a sparse block. + * \param sp_block The sparse block to collect. + */ void VisitStmt_(const SparseBlockNode* sp_block) { - for (const auto& it : sp_block->sp_struct2param_map) { - if (const auto* dv_axis = it.first.as()) { - ICHECK_EQ(it.second.size(), 1); - buffer_map_.Set(it.second[0], dv_axis->indptr); - } else if (const auto* sf_axis = it.first.as()) { - ICHECK_EQ(it.second.size(), 1); - buffer_map_.Set(it.second[0], sf_axis->indices); - } else if (const auto* sv_axis = it.first.as()) { - ICHECK_EQ(it.second.size(), 2); - buffer_map_.Set(it.second[0], sv_axis->indptr); - buffer_map_.Set(it.second[1], sv_axis->indices); - } else if (const auto* sp_buffer = it.first.as()) { - ICHECK_EQ(it.second.size(), 1); - buffer_map_.Set(it.second[0], sp_buffer->data); + for (const auto& it : sp_block->sp_struct_param_map) { + const ObjectRef& sp_struct = it.first; + const Array& params = it.second; + if (const auto* dv_axis = sp_struct.as()) { + // collect indptr buffer of dense variable axis. + ICHECK_EQ(params.size(), 1); + buffer_map_.Set(params[0], dv_axis->indptr); + } else if (const auto* sf_axis = sp_struct.as()) { + // collect indices buffer of sparse fixed axis. + ICHECK_EQ(params.size(), 1); + buffer_map_.Set(params[0], sf_axis->indices); + } else if (const auto* sv_axis = sp_struct.as()) { + // collect indptr and indices buffer of sparse variable axis. + ICHECK_EQ(params.size(), 2); + buffer_map_.Set(params[0], sv_axis->indptr); + buffer_map_.Set(params[1], sv_axis->indices); + } else if (const auto* sp_buffer = sp_struct.as()) { + // collect data buffer for sparse buffers. + ICHECK_EQ(params.size(), 1); + buffer_map_.Set(params[0], sp_buffer->data); } } return; @@ -67,9 +82,10 @@ Map UpdateBufferMap(PrimFunc f) { /*! * \brief Check whether a given SparseBuffer contains the given axis. - * \brief buffer The SparseBuffer to be checked - * \brief axis The axis to be checked - * \return A boolean indicating whether the given SparseBuffer contains the given axis + * \param buffer The SparseBuffer to be checked. + * \param axis The axis to be checked. + * \return A boolean indicating whether the given SparseBuffer contains the + * given axis */ bool BufferContainsAxis(const SparseBuffer& buffer, const Axis& axis) { for (int i = 0; i < static_cast(buffer->axes.size()); ++i) { @@ -80,24 +96,36 @@ bool BufferContainsAxis(const SparseBuffer& buffer, const Axis& axis) { return false; } +/*! + * \brief Maps sparse buffers to the array of sparse iterators we used to index the buffer. + */ using BufferAccessMap = Map>; +/*! + * \brief Maps sparse_iter to (sparse_buffer, i), indicates sparse_iter was used + * in the i-th dimension of sparse_buffer. + */ using DependencyMap = std::unordered_map, ObjectPtrHash, ObjectPtrEqual>; /*! - * \brief For each sparse-fixed or sparse-variable iterator, collect the iterators that it depends - * on. + * \brief For each sparse-fixed or sparse-variable iterator, collect the + * iterators that it depends on. */ class AccessAndDependencyCollector : public StmtExprVisitor { public: + /*! + * \brief Collect access and dependency information from the given statement. + * \param stmt The statement node to collect in the AST. + */ void Collect(Stmt stmt) { VisitStmt(std::move(stmt)); for (const std::pair>& kv_pair : buffer_access_map_) { const SparseBuffer& buffer = kv_pair.first; - int ndim = static_cast(kv_pair.second.size()); + const Array& sp_iters = kv_pair.second; + int ndim = static_cast(sp_iters.size()); for (int k = 0; k < ndim; ++k) { - const SpIterVar& sp_iter = kv_pair.second[k]; + const SpIterVar& sp_iter = sp_iters[k]; if (sp_iter->kind == SpIterKind::kDenseFixed || !BufferContainsAxis(buffer, sp_iter->axis)) { continue; @@ -109,21 +137,31 @@ class AccessAndDependencyCollector : public StmtExprVisitor { } else { const Array& dependent_iters = buffer_access_map_[it->second.first]; for (int i = 0; i < k; ++i) { - CHECK(kv_pair.second[i].same_as(dependent_iters[i])) - << "ValueError: A SpIterVar can only depend on a fixed set of iterators"; + CHECK(sp_iters[i].same_as(dependent_iters[i])) + << "ValueError: A SpIterVar can only depend on a fixed set of " + "iterators"; } } } } } + /*! + * \brief Collect the dependent buffer and iterators current sparse iterator depends on. + * \param sp_iter The sparse iterator. + * \param iterated_buffer The sparse buffer that given sparse iterator depends on. + * \param dependent_iters The sparse iterators that given sparse iterator depends on in the + * program. + * \note iterated_buffer and dependent_iters were pointers used as return values. + */ void GetIteratedBufferAndDependentIters(const SpIterVar& sp_iter, SparseBuffer* iterated_buffer, Array* dependent_iters) { - std::pair dependent_pair = dependency_map_[sp_iter]; - Array buffer_access_iters = buffer_access_map_[dependent_pair.first]; - int n_dependent = dependent_pair.second; + SparseBuffer dependent_buf; + int n_dependent; + std::tie(dependent_buf, n_dependent) = dependency_map_[sp_iter]; + Array buffer_access_iters = buffer_access_map_[dependent_buf]; - *iterated_buffer = std::move(dependent_pair.first); + *iterated_buffer = std::move(dependent_buf); *dependent_iters = Array(); dependent_iters->reserve(n_dependent); for (int i = 0; i < n_dependent; ++i) { @@ -131,14 +169,27 @@ class AccessAndDependencyCollector : public StmtExprVisitor { } } + /*! + * \brief Get sparse iterator corresponding to the given variable. + * \param index The variable + */ SpIterVar GetSpIterFromIndex(PrimExpr index) { - auto it = var2sp_iter_map_.find(index.as()); - CHECK(it != var2sp_iter_map_.end()) + auto it = var_sp_iter_map_.find(index.as()); + CHECK(it != var_sp_iter_map_.end()) << "ValueError: Currently an index is only allowed to be SpIterVar"; return it->second; } private: + /*! + * \brief Update the buffer access map given a sparse buffer access pattern. + * \param buffer The buffer to be accessed. + * \param indices The indices used to access the sparse buffer. + * \note We don't support use two set of indices to access the same buffer, and will throw + * an error in this case. For example, we can not access sparse buffer A with A[i, j] + * and A[j, i] in the same program. + * TODO(zihao, ruihang): fix the behavior in the future. + */ void AddAccessPattern(const SparseBuffer& buffer, const Array& indices) { int ndim = buffer->ndim(); CHECK_EQ(static_cast(indices.size()), ndim); @@ -156,38 +207,61 @@ class AccessAndDependencyCollector : public StmtExprVisitor { ICHECK_EQ(static_cast((*it).second.size()), ndim); for (int i = 0; i < ndim; ++i) { CHECK((*it).second[i].same_as(iters[i])) - << "ValueError: Currently all accesses to a same buffer are required to be the same"; + << "ValueError: Currently all accesses to a same buffer are " + "required to be the same"; } } } + /*! + * \brief The visit function to collect variable to sparse iterator mapping for sparse block node. + * \param sp_block The sparse block node in AST. + */ void VisitStmt_(const SparseBlockNode* sp_block) final { for (const SpIterVar& sp_iter : sp_block->sp_iter_vars) { - var2sp_iter_map_[sp_iter->var.get()] = sp_iter; + var_sp_iter_map_[sp_iter->var.get()] = sp_iter; } StmtVisitor::VisitStmt_(sp_block); } + /*! + * \brief The visit function to collect buffer access pattern from sparse buffer stores. + * \param store The sparse buffer store node in AST. + */ void VisitStmt_(const SparseBufferStoreNode* store) final { ExprVisitor::VisitExpr(store->value); AddAccessPattern(store->buffer, store->indices); } + /*! + * \brief The visit function to collect buffer access pattern from sparse buffer loads. + * \param load The sparse buffer load node in AST. + */ void VisitExpr_(const SparseBufferLoadNode* load) final { AddAccessPattern(load->buffer, load->indices); } BufferAccessMap buffer_access_map_; DependencyMap dependency_map_; - std::unordered_map var2sp_iter_map_; + std::unordered_map var_sp_iter_map_; }; +/*! + * \brief Rewrite indices in sparse buffers to indices in corresponding data + * buffers. + */ class IndexTransformer : public StmtExprMutator { public: explicit IndexTransformer(AccessAndDependencyCollector collector) : collector_(std::move(collector)) {} private: + /*! + * \brief The lowered absolute offset of an sparse buffer access pattern. + * \param sp_buffer The sparse buffer to be accessed. + * \param indices The sparse indices to access the buffer. + * \return The lowered absolute offset to the start of flattened data in given sparse buffer. + */ PrimExpr LowerIndices(SparseBuffer sp_buffer, const Array& indices) { int ndim = sp_buffer->ndim(); int n_lower = static_cast(indices.size()); @@ -203,88 +277,100 @@ class IndexTransformer : public StmtExprMutator { SpIterVar sp_iter = collector_.GetSpIterFromIndex(index); PrimExpr sp_index{nullptr}; - PrimExpr l = AccumulateLowerIndex(lowered_index, sp_buffer, i, 0); - PrimExpr r = AccumulateLowerIndex(add(lowered_index, 1), sp_buffer, i, 0); + PrimExpr l = PartialLowerIndex(lowered_index, sp_buffer->axes[i], 0); + PrimExpr r = PartialLowerIndex(add(lowered_index, 1), sp_buffer->axes[i], 0); - SpIterKind kind = sp_iter->kind; - if (kind == SpIterKind::kDenseFixed) { - CHECK(!axis->IsInstance()); - if (const auto* df_axis = axis.as()) { - CHECK(ana_.CanProveEqual(sp_iter->max_extent, df_axis->length)); + switch (sp_iter->kind) { + case SpIterKind::kDenseFixed: { + CHECK(!axis->IsInstance()); + if (const auto* df_axis = axis.as()) { + CHECK(ana_.CanProveEqual(sp_iter->max_extent, df_axis->length)); + sp_index = sp_iter; + } else { + Var buffer_var; + if (const auto* sf_axis = axis.as()) { + CHECK(ana_.CanProveEqual(sp_iter->max_extent, sf_axis->length)); + buffer_var = sf_axis->indices->data; + } else if (const auto* sv_axis = axis.as()) { + CHECK(ana_.CanProveEqual(sp_iter->max_extent, sv_axis->length)); + buffer_var = sv_axis->indices->data; + } else { + LOG(FATAL) << "Cannot reach here"; + } + sp_index = lower_bound(buffer_var, index, std::move(l), std::move(r)); + } + break; + } + case SpIterKind::kDenseVariable: { + const auto* dv_axis = axis.as(); + CHECK(dv_axis != nullptr); + CHECK(sp_iter->axis.defined()); sp_index = sp_iter; - } else { - Var buffer_var; - if (const auto* sf_axis = axis.as()) { - CHECK(ana_.CanProveEqual(sp_iter->max_extent, sf_axis->length)); - buffer_var = sf_axis->indices->data; + break; + } + case SpIterKind::kSparseFixed: { + CHECK(!axis->IsInstance()); + CHECK(sp_iter->axis.defined()); + const Axis& iterated_axis = sp_iter->axis; + if (axis->IsInstance()) { + sp_index = GetDenseValue(sp_iter); + } else if (const auto* sf_axis = axis.as()) { + if (iterated_axis.get() == sf_axis) { + sp_index = sp_iter; + } else { + sp_index = lower_bound(sf_axis->indices->data, GetDenseValue(sp_iter), std::move(l), + std::move(r)); + } } else if (const auto* sv_axis = axis.as()) { - CHECK(ana_.CanProveEqual(sp_iter->max_extent, sv_axis->length)); - buffer_var = sv_axis->indices->data; + sp_index = lower_bound(sv_axis->indices->data, GetDenseValue(sp_iter), std::move(l), + std::move(r)); } else { LOG(FATAL) << "Cannot reach here"; } - sp_index = lower_bound(buffer_var, index, std::move(l), std::move(r)); + break; } - } else if (kind == SpIterKind::kDenseVariable) { - const auto* dv_axis = axis.as(); - CHECK(dv_axis != nullptr); - CHECK(sp_iter->axis.defined()); - sp_index = sp_iter; - } else if (kind == SpIterKind::kSparseFixed) { - CHECK(!axis->IsInstance()); - CHECK(sp_iter->axis.defined()); - const Axis& iterated_axis = sp_iter->axis; - if (axis->IsInstance()) { - sp_index = GetDenseValue(sp_iter); - } else if (const auto* sf_axis = axis.as()) { - if (iterated_axis.get() == sf_axis) { - sp_index = sp_iter; - } else { + default: { // kind == SpIterKind::kSparseVariable + CHECK(!axis->IsInstance()); + CHECK(sp_iter->axis.defined()); + const Axis& iterated_axis = sp_iter->axis; + if (const auto* df_axis = axis.as()) { + CHECK(ana_.CanProveEqual(sp_iter->max_extent, df_axis->length)); + sp_index = GetDenseValue(sp_iter); + } else if (const auto* sf_axis = axis.as()) { + CHECK(ana_.CanProveEqual(sp_iter->max_extent, sf_axis->length)); sp_index = lower_bound(sf_axis->indices->data, GetDenseValue(sp_iter), std::move(l), std::move(r)); - } - } else if (const auto* sv_axis = axis.as()) { - sp_index = lower_bound(sv_axis->indices->data, GetDenseValue(sp_iter), std::move(l), - std::move(r)); - } else { - LOG(FATAL) << "Cannot reach here"; - } - } else { - CHECK(kind == SpIterKind::kSparseVariable); - CHECK(!axis->IsInstance()); - CHECK(sp_iter->axis.defined()); - const Axis& iterated_axis = sp_iter->axis; - if (const auto* df_axis = axis.as()) { - CHECK(ana_.CanProveEqual(sp_iter->max_extent, df_axis->length)); - sp_index = GetDenseValue(sp_iter); - } else if (const auto* sf_axis = axis.as()) { - CHECK(ana_.CanProveEqual(sp_iter->max_extent, sf_axis->length)); - sp_index = lower_bound(sf_axis->indices->data, GetDenseValue(sp_iter), std::move(l), - std::move(r)); - } else if (const auto* sv_axis = axis.as()) { - CHECK(ana_.CanProveEqual(sp_iter->max_extent, sv_axis->length)); - if (iterated_axis.get() == sv_axis) { - sp_index = sp_iter; + } else if (const auto* sv_axis = axis.as()) { + CHECK(ana_.CanProveEqual(sp_iter->max_extent, sv_axis->length)); + if (iterated_axis.get() == sv_axis) { + sp_index = sp_iter; + } else { + sp_index = lower_bound(sv_axis->indices->data, GetDenseValue(sp_iter), std::move(l), + std::move(r)); + } } else { - sp_index = lower_bound(sv_axis->indices->data, GetDenseValue(sp_iter), std::move(l), - std::move(r)); + LOG(FATAL) << "Cannot reach here"; } - } else { - LOG(FATAL) << "Cannot reach here"; + break; } } // Stage 2. Accumulate the lowered index. lowered_index = - AccumulateLowerIndex(std::move(lowered_index), sp_buffer, i, std::move(sp_index)); + PartialLowerIndex(std::move(lowered_index), sp_buffer->axes[i], std::move(sp_index)); } return lowered_index; } - PrimExpr AccumulateLowerIndex(PrimExpr prev_lowered_index, const SparseBuffer& sp_buffer, int dim, - PrimExpr index) { - const Axis& axis = sp_buffer->axes[dim]; + /*! + * \brief Compupte the partially lowered index. + * \param prev_lowered_index The lowered index accumulated over all axis prior to current axis. + * \param axis Current axis. + * \param index The sparse index on current axis. + * \return The lowered index. + */ + PrimExpr PartialLowerIndex(PrimExpr prev_lowered_index, const Axis& axis, PrimExpr index) { if (axis->IsInstance()) { return ana_.Simplify(std::move(prev_lowered_index) * axis->length + std::move(index)); } else if (const auto* sf_axis = axis.as()) { @@ -300,6 +386,10 @@ class IndexTransformer : public StmtExprMutator { throw; } + /*! + * \brief Convert sparse iteration positions to dense coordinates. + * \param sp_iter The sparse iterator. + */ PrimExpr GetDenseValue(SpIterVar sp_iter) { SpIterKind kind = sp_iter->kind; CHECK(kind == SpIterKind::kSparseFixed || kind == SpIterKind::kSparseVariable); @@ -321,12 +411,20 @@ class IndexTransformer : public StmtExprMutator { } } + /*! + * \brief Convert sparse buffer load node to buffer load node. + * \param load The sparse buffer load node in AST. + */ PrimExpr VisitExpr_(const SparseBufferLoadNode* load) final { buffer_read_.insert(load->buffer.get()); PrimExpr lowered_indices = LowerIndices(load->buffer, load->indices); return BufferLoad(load->buffer->data, {std::move(lowered_indices)}); } + /*! + * \brief Convert sparse buffer store node to buffer store node. + * \param store The sparse buffer store node in AST. + */ Stmt VisitStmt_(const SparseBufferStoreNode* store) final { buffer_write_.insert(store->buffer.get()); PrimExpr value = ExprMutator::VisitExpr(store->value); @@ -334,6 +432,10 @@ class IndexTransformer : public StmtExprMutator { return BufferStore(store->buffer->data, std::move(value), {std::move(lowered_indices)}); } + /*! + * \brief Rewrite sparse block to ordinary block. + * \param sp_block The sparse block to be rewritten. + */ Stmt VisitStmt_(const SparseBlockNode* sp_block) { int n_iter = static_cast(sp_block->sp_iter_vars.size()); buffer_read_.clear(); @@ -361,7 +463,7 @@ class IndexTransformer : public StmtExprMutator { block_iters.reserve(n_iter); iter_bindings.reserve(n_iter); for (int i = 0; i < n_iter; ++i) { - block_iters.push_back(SpIterVar2IterVar(sp_block->sp_iter_vars[i], var_map)); + block_iters.push_back(SpIterVarToIterVar(sp_block->sp_iter_vars[i], var_map)); iter_bindings.push_back(loop_vars[i]); } @@ -381,8 +483,14 @@ class IndexTransformer : public StmtExprMutator { return loop; } - IterVar SpIterVar2IterVar(const SpIterVar& sp_iter, - const std::unordered_map& var_map) { + /*! + * \brief Convert sparse iterable variable to ordinary iterable variable. + * \param sp_iter The sparse iterable variable to convert. + * \param var_map The mapping from sparse iterable variable to corresponding ordinary iterable + * variable. + */ + IterVar SpIterVarToIterVar(const SpIterVar& sp_iter, + const std::unordered_map& var_map) { PrimExpr extent{nullptr}; SpIterKind kind = sp_iter->kind; @@ -407,6 +515,12 @@ class IndexTransformer : public StmtExprMutator { sp_iter->is_reduction ? kCommReduce : kDataPar); } + /*! + * \brief generate read and write regions for sparse blocks. + * \param sp_block the sparse blocks + * \param reads pointer of array to read buffer regions. + * \param writes pointer of array to write buffer regions. + */ void GenerateReadWriteRegions(const SparseBlockNode* sp_block, Array* reads, Array* writes) { for (const ObjectRef& obj : sp_block->sp_structs) { @@ -428,6 +542,11 @@ class IndexTransformer : public StmtExprMutator { } } + /*! + * \brief generated nested for loops for sparse block. + * \param block_iters The iterators defined in sparse blocks. + * \param loop_vars The loop variables binded with block iterators. + */ Stmt GenerateLoops(Stmt body, const Array& block_iters, const Array& loop_vars) { int n_iter = static_cast(block_iters.size()); for (int i = n_iter - 1; i >= 0; --i) { @@ -443,12 +562,22 @@ class IndexTransformer : public StmtExprMutator { std::unordered_set buffer_write_; }; +/*! + * \brief Wrap the body statement with an empty root block. + * \param body The body statements to wrap with. + * \return The wrapped block. + */ Stmt WrapWithRootBlock(Stmt body) { Block root_block({}, {}, {}, "root", std::move(body)); body = BlockRealize({}, const_true(), std::move(root_block)); return Stmt(body); } +/*! + * \brief Rewrite the given primitive function + * \param f The Sparse-TIR primitive function to lower. + * \return lowered primitive function in TIR. + */ PrimFunc LowerSparseTIR(PrimFunc f) { // Only apply this pass to TIR that is not from TE schedules if (!IsFromLegacyTESchedule(f)) { @@ -470,6 +599,9 @@ PrimFunc LowerSparseTIR(PrimFunc f) { namespace transform { +/*! + * \brief The lowering pass from TIR to Sparse TIR. + */ Pass LowerSparseTIR() { auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) { return LowerSparseTIR(std::move(f));