diff --git a/src/tir/schedule/analysis.h b/src/tir/schedule/analysis.h index b4d8250dc340..f14ca55974dc 100644 --- a/src/tir/schedule/analysis.h +++ b/src/tir/schedule/analysis.h @@ -453,176 +453,6 @@ bool CanComputeAt(const ScheduleState& self, const StmtSRef& block_sref, const S bool CanReverseComputeAt(const ScheduleState& self, const StmtSRef& block_sref, const StmtSRef& loop_sref, bool preserve_unit_loops); -/******** SparseTIR Tools ********/ - -/*! - * \brief Maps sparse buffers to the array of sparse iterators we used to index the buffer. - */ -using BufferAccessMap = Map>; -/*! - * \brief Maps sparse_iter to (sparse_buffer, i), indicates sparse_iter was used - * in the i-th dimension of sparse_buffer. - */ -using DependencyMap = - std::unordered_map, ObjectPtrHash, ObjectPtrEqual>; - -/*! - * \brief Check whether a given SparseBuffer contains the given axis. - * \param buffer The SparseBuffer to be checked. - * \param axis The axis to be checked. - * \return A boolean indicating whether the given SparseBuffer contains the - * given axis - */ -bool BufferContainsAxis(const SparseBuffer& buffer, const Axis& axis); - -/*! - * \brief For each sparse-fixed or sparse-variable iterator, collect the - * iterators that it depends on. - */ -class AccessAndDependencyCollector : public StmtExprVisitor { - public: - /*! - * \brief Collect access and dependency information from the given statement. - * \param stmt The statement node to collect in the AST. - */ - void Collect(Stmt stmt) { - VisitStmt(std::move(stmt)); - - for (const std::pair>& kv_pair : buffer_access_map_) { - const SparseBuffer& buffer = kv_pair.first; - const Array& sp_iters = kv_pair.second; - int ndim = static_cast(sp_iters.size()); - for (int k = 0; k < ndim; ++k) { - const SpIterVar& sp_iter = sp_iters[k]; - if (sp_iter->kind == SpIterKind::kDenseFixed || - !BufferContainsAxis(buffer, sp_iter->axis)) { - continue; - } - - auto it = dependency_map_.find(sp_iter); - if (it == dependency_map_.end()) { - dependency_map_[sp_iter] = std::make_pair(buffer, k); - } else { - const Array& dependent_iters = buffer_access_map_[it->second.first]; - for (int i = 0; i < k; ++i) { - CHECK(sp_iters[i].same_as(dependent_iters[i])) - << "ValueError: A SpIterVar can only depend on a fixed set of " - "iterators"; - } - } - } - } - } - - /*! - * \brief Collect the dependent buffer and iterators current sparse iterator depends on. - * \param sp_iter The sparse iterator. - * \param iterated_buffer The sparse buffer that given sparse iterator depends on. - * \param dependent_iters The sparse iterators that given sparse iterator depends on in the - * program. - * \note iterated_buffer and dependent_iters were pointers used as return values. - */ - void GetIteratedBufferAndDependentIters(const SpIterVar& sp_iter, SparseBuffer* iterated_buffer, - Array* dependent_iters) { - SparseBuffer dependent_buf; - int n_dependent; - std::tie(dependent_buf, n_dependent) = dependency_map_[sp_iter]; - Array buffer_access_iters = buffer_access_map_[dependent_buf]; - - *iterated_buffer = std::move(dependent_buf); - *dependent_iters = Array(); - dependent_iters->reserve(n_dependent); - for (int i = 0; i < n_dependent; ++i) { - dependent_iters->push_back(buffer_access_iters[i]->var); - } - } - - /*! - * \brief Get sparse iterator corresponding to the given variable. - * \param index The variable - */ - SpIterVar GetSpIterFromIndex(PrimExpr index) { - auto it = var_sp_iter_map_.find(index.as()); - CHECK(it != var_sp_iter_map_.end()) - << "ValueError: Currently an index is only allowed to be SpIterVar"; - return it->second; - } - - private: - /*! - * \brief Update the buffer access map given a sparse buffer access pattern. - * \param buffer The buffer to be accessed. - * \param indices The indices used to access the sparse buffer. - * \note We don't support use two set of indices to access the same buffer, and will throw - * an error in this case. For example, we can not access sparse buffer A with A[i, j] - * and A[j, i] in the same program. - * TODO(zihao, ruihang): fix the behavior in the future. - */ - void AddAccessPattern(const SparseBuffer& buffer, const Array& indices) { - int ndim = buffer->ndim(); - CHECK_EQ(static_cast(indices.size()), ndim); - - Array iters; - iters.reserve(ndim); - for (int i = 0; i < ndim; ++i) { - iters.push_back(GetSpIterFromIndex(indices[i])); - } - - BufferAccessMap::iterator it = buffer_access_map_.find(buffer); - if (it == buffer_access_map_.end()) { - buffer_access_map_.Set(buffer, iters); - } else { - ICHECK_EQ(static_cast((*it).second.size()), ndim); - for (int i = 0; i < ndim; ++i) { - CHECK((*it).second[i].same_as(iters[i])) - << "ValueError: Currently all accesses to a same buffer are " - "required to be the same"; - } - } - } - - /*! - * \brief The visit function to collect variable to sparse iterator mapping for sparse block node. - * \param sp_block The sparse block node in AST. - */ - void VisitStmt_(const SparseBlockNode* sp_block) final { - for (const SpIterVar& sp_iter : sp_block->sp_iter_vars) { - var_sp_iter_map_[sp_iter->var.get()] = sp_iter; - } - StmtVisitor::VisitStmt_(sp_block); - } - - /*! - * \brief The visit function to collect buffer access pattern from sparse buffer stores. - * \param store The sparse buffer store node in AST. - */ - void VisitStmt_(const SparseBufferStoreNode* store) final { - ExprVisitor::VisitExpr(store->value); - AddAccessPattern(store->buffer, store->indices); - } - - /*! - * \brief The visit function to collect buffer access pattern from sparse buffer loads. - * \param load The sparse buffer load node in AST. - */ - void VisitExpr_(const SparseBufferLoadNode* load) final { - AddAccessPattern(load->buffer, load->indices); - } - - BufferAccessMap buffer_access_map_; - DependencyMap dependency_map_; - std::unordered_map var_sp_iter_map_; -}; - -/*! - * \brief Check whether the new order satisfies the iterator dependency constraints - * \param self The schedule state - * \param block The sparse block, which is the source of the constraints - * \param new_order The new iterator order to be checked - */ -void CheckDependency(const ScheduleState& self, const SparseBlock& block, - const Array& new_order); - } // namespace tir } // namespace tvm diff --git a/src/tir/schedule/analysis/analysis.cc b/src/tir/schedule/analysis/analysis.cc index 6034d9400db4..3c64533ba465 100644 --- a/src/tir/schedule/analysis/analysis.cc +++ b/src/tir/schedule/analysis/analysis.cc @@ -1408,68 +1408,5 @@ void CheckStorageScope(const ScheduleState& self, String storage_scope) { } } -/******** SparseTIR Tools ********/ - -bool BufferContainsAxis(const SparseBuffer& buffer, const Axis& axis) { - for (int i = 0; i < static_cast(buffer->axes.size()); ++i) { - if (buffer->axes[i].same_as(axis)) { - return true; - } - } - return false; -} - -void CheckDependency(const ScheduleState& self, const SparseBlock& block, - const Array& new_order) { - class DependentIterNotAppearError : public ScheduleError { - public: - explicit DependentIterNotAppearError(IRModule mod, SpIterVar iter, SpIterVar dependent_iter) - : mod_(std::move(mod)), - iter_(std::move(iter)), - dependent_iter_(std::move(dependent_iter)) {} - - String FastErrorString() const final { - return "ScheduleError: The new order violates some iterator dependency"; - } - - String DetailRenderTemplate() const final { - std::ostringstream os; - os << "ScheduleError: Iterator " << iter_ << " depends on " << dependent_iter_ - << ", while the latter iterator does not appear before the former iterator in the new " - "order"; - return os.str(); - } - - IRModule mod() const final { return mod_; } - Array LocationsOfInterest() const final { return {}; } - - IRModule mod_; - SpIterVar iter_; - SpIterVar dependent_iter_; - }; - - AccessAndDependencyCollector collector; - collector.Collect(block); - - for (int i = 0; i < static_cast(new_order.size()); ++i) { - const SpIterVar& sp_iter = new_order[i]; - if (sp_iter->kind == SpIterKind::kDenseFixed) { - continue; - } - - SparseBuffer iterated_buffer{nullptr}; - Array iters{nullptr}; - collector.GetIteratedBufferAndDependentIters(sp_iter, &iterated_buffer, &iters); - - for (const PrimExpr& index : iters) { - const SpIterVar dependent_iter = collector.GetSpIterFromIndex(index); - if (std::find(new_order.begin(), new_order.begin() + i, dependent_iter) == - new_order.begin() + i) { - throw DependentIterNotAppearError(self->mod, sp_iter, dependent_iter); - } - } - } -} - } // namespace tir } // namespace tvm diff --git a/src/tir/schedule/primitive/sparse_loop_transformation.cc b/src/tir/schedule/primitive/sparse_loop_transformation.cc index 5e322c9e6ba1..b7da781fe3e3 100644 --- a/src/tir/schedule/primitive/sparse_loop_transformation.cc +++ b/src/tir/schedule/primitive/sparse_loop_transformation.cc @@ -100,7 +100,8 @@ SparseBlock SparseReorder(ScheduleState self, const SparseBlock& block, CheckValidInputIterators(self, new_order, block->sp_iter_vars); // Step 2. Check whether the new order does not break the iterator dependency. - CheckDependency(self, block, new_order); + // TODO(zihao): use axis dependency tree instead + // CheckDependency(self, block, new_order); // Step 3. Create the new SparseBlock. ObjectPtr p_new_block = make_object(*block.get()); diff --git a/src/tir/transforms/lower_sparse_tir.cc b/src/tir/transforms/lower_sparse_tir.cc index 3e02c9a59daa..e2526c130929 100644 --- a/src/tir/transforms/lower_sparse_tir.cc +++ b/src/tir/transforms/lower_sparse_tir.cc @@ -30,6 +30,7 @@ #include #include +#include "../../support/utils.h" #include "../schedule/analysis.h" #include "ir_utils.h" @@ -83,178 +84,216 @@ Map UpdateBufferMap(PrimFunc f) { return std::move(updater.buffer_map_); } +/*! \brief Storing the context information of a sparse block. */ +class SparseBlockCtx { + public: + explicit SparseBlockCtx(String blk_name, Array sp_iter_vars, const AxisTree& tree) + : blk_name_(std::move(blk_name)) { + // initialize sparse iter var dependency map. + std::unordered_map axis_name_sp_iter_map_; + for (const SpIterVar& sp_iter_var : sp_iter_vars) { + axis_name_sp_iter_map_[sp_iter_var->axis->name] = sp_iter_var; + } + + for (const SpIterVar& sp_iter_var : sp_iter_vars) { + String axis_name = sp_iter_var->axis->name; + const SpIterVarNode* node = sp_iter_var.get(); + if (support::EndsWith(axis_name, "_dense")) { + // ends with "_dense", the axis is generated via to_dense + parent_sp_iter_var_[node] = sp_iter_var; + } else { + auto opt = tree->parent.Get(axis_name); + CHECK(opt.defined()) << "Cannot find parent of axis " << axis_name << "."; + String parent_axis_name = opt.value(); + if (parent_axis_name != "root") { + auto it = axis_name_sp_iter_map_.find(parent_axis_name); + CHECK(it != axis_name_sp_iter_map_.end()) + << "Cannot find sparse iter vars corresponding to parent axis " << parent_axis_name + << " in current sparse block " << blk_name_; + parent_sp_iter_var_[node] = it->second; + } + } + } + } + + void AddSparseIterVar(const VarNode* node, const SpIterVar& sp_iter_var) { + sp_iter_var_map_[node] = sp_iter_var; + } + + Optional LookupSparseIterVar(const VarNode* var_node) const { + auto it = sp_iter_var_map_.find(var_node); + if (it != sp_iter_var_map_.end()) { + return it->second; + } else { + return NullOpt; + } + } + + void SetOffset(const SpIterVarNode* node, const PrimExpr& e) { sp_iter_var_offset_[node] = e; } + + Optional LookupParentSparseIterVar(const SpIterVarNode* node) const { + auto it = parent_sp_iter_var_.find(node); + if (it != parent_sp_iter_var_.end()) { + return it->second; + } else { + return NullOpt; + } + } + + Optional LookupOffset(const Optional& sp_iter_var) const { + if (!sp_iter_var.defined()) { + // the root + return Integer(0); + } else { + auto it = sp_iter_var_offset_.find(sp_iter_var.value().get()); + if (it != sp_iter_var_offset_.end()) { + return it->second; + } else { + return NullOpt; + } + } + } + + const String GetBlockName() const { return blk_name_; } + + private: + std::unordered_map sp_iter_var_map_; + std::unordered_map sp_iter_var_offset_; + std::unordered_map parent_sp_iter_var_; + String blk_name_; +}; + /*! * \brief Rewrite indices in sparse buffers to indices in corresponding data * buffers. */ class IndexTransformer : public StmtExprMutator { public: - explicit IndexTransformer(AccessAndDependencyCollector collector, AxisTree axis_tree) - : collector_(std::move(collector)), axis_tree_(std::move(axis_tree)) {} + explicit IndexTransformer(AxisTree axis_tree) : axis_tree_(std::move(axis_tree)), ctx_st({}) {} private: - /*! - * \brief The lowered absolute offset of an sparse buffer access pattern. - * \param sp_buffer The sparse buffer to be accessed. - * \param indices The sparse indices to access the buffer. - * \return The lowered absolute offset to the start of flattened data in given sparse buffer. - */ - PrimExpr LowerIndices(SparseBuffer sp_buffer, const Array& indices) { - int ndim = sp_buffer->ndim(); - int n_lower = static_cast(indices.size()); - ICHECK_LE(n_lower, ndim); + // Context stack; + std::vector ctx_st; + + PrimExpr ViewIndexInAxis(const Axis& axis, PrimExpr index) { + const SparseBlockCtx& ctx = ctx_st.back(); + const VarNode* var = index.as(); + if (var) { + // index is a single var node. + auto opt = ctx.LookupSparseIterVar(var); + CHECK(opt.defined()) << "var " << var->name_hint << " not appeared in the sparse block " + << ctx.GetBlockName() << "."; + const SpIterVar& sp_iter_var = opt.value(); + if (sp_iter_var->axis->name == axis->name) { + // if the iterator and sparse buffer refers to the same axis, + // no need to convert + return index; + } + } - PrimExpr lowered_index = Integer(0); + // decompress index to coordinate on iterator axis. + // the index might not be a single var node, use visitor to recursive construct the coordinate. + PrimExpr cord = ExprMutator::VisitExpr(index); - for (int i = 0; i < n_lower; ++i) { - const Axis& axis = sp_buffer->axes[i]; - const PrimExpr& index = indices[i]; + // compress coordinate to index on sparse buffer axis. + // TODO(zihao) + return cord; + } - // Stage 1. Get the sparse index. - SpIterVar sp_iter = collector_.GetSpIterFromIndex(index); - PrimExpr sp_index{nullptr}; - - PrimExpr l = PartialLowerIndex(lowered_index, sp_buffer->axes[i], 0); - PrimExpr r = PartialLowerIndex(add(lowered_index, 1), sp_buffer->axes[i], 0); - - switch (sp_iter->kind) { - case SpIterKind::kDenseFixed: { - CHECK(!axis->IsInstance()); - if (const auto* df_axis = axis.as()) { - CHECK(ana_.CanProveEqual(sp_iter->max_extent, df_axis->length)); - sp_index = sp_iter; - } else { - Var buffer_var; - if (const auto* sf_axis = axis.as()) { - CHECK(ana_.CanProveEqual(sp_iter->max_extent, sf_axis->length)); - buffer_var = sf_axis->indices->data; - } else if (const auto* sv_axis = axis.as()) { - CHECK(ana_.CanProveEqual(sp_iter->max_extent, sv_axis->length)); - buffer_var = sv_axis->indices->data; - } else { - LOG(FATAL) << "Cannot reach here"; - } - sp_index = lower_bound(buffer_var, index, std::move(l), std::move(r)); - } - break; - } - case SpIterKind::kDenseVariable: { - const auto* dv_axis = axis.as(); - CHECK(dv_axis != nullptr); - CHECK(sp_iter->axis.defined()); - sp_index = sp_iter; - break; - } - case SpIterKind::kSparseFixed: { - CHECK(!axis->IsInstance()); - CHECK(sp_iter->axis.defined()); - const Axis& iterated_axis = sp_iter->axis; - if (axis->IsInstance()) { - sp_index = GetDenseValue(sp_iter); - } else if (const auto* sf_axis = axis.as()) { - if (iterated_axis.get() == sf_axis) { - sp_index = sp_iter; - } else { - sp_index = lower_bound(sf_axis->indices->data, GetDenseValue(sp_iter), std::move(l), - std::move(r)); - } - } else if (const auto* sv_axis = axis.as()) { - sp_index = lower_bound(sv_axis->indices->data, GetDenseValue(sp_iter), std::move(l), - std::move(r)); - } else { - LOG(FATAL) << "Cannot reach here"; - } - break; - } - default: { // kind == SpIterKind::kSparseVariable - CHECK(!axis->IsInstance()); - CHECK(sp_iter->axis.defined()); - const Axis& iterated_axis = sp_iter->axis; - if (const auto* df_axis = axis.as()) { - CHECK(ana_.CanProveEqual(sp_iter->max_extent, df_axis->length)); - sp_index = GetDenseValue(sp_iter); - } else if (const auto* sf_axis = axis.as()) { - CHECK(ana_.CanProveEqual(sp_iter->max_extent, sf_axis->length)); - sp_index = lower_bound(sf_axis->indices->data, GetDenseValue(sp_iter), std::move(l), - std::move(r)); - } else if (const auto* sv_axis = axis.as()) { - CHECK(ana_.CanProveEqual(sp_iter->max_extent, sv_axis->length)); - if (iterated_axis.get() == sv_axis) { - sp_index = sp_iter; - } else { - sp_index = lower_bound(sv_axis->indices->data, GetDenseValue(sp_iter), std::move(l), - std::move(r)); - } - } else { - LOG(FATAL) << "Cannot reach here"; - } - break; - } - } + PrimExpr ComputeOffset(SparseBuffer sp_buffer, const Array& indices) { + int num_lowered_indices = static_cast(indices.size()); + ICHECK_LE(num_lowered_indices, sp_buffer->ndim()); - // Stage 2. Accumulate the lowered index. - lowered_index = - PartialLowerIndex(std::move(lowered_index), sp_buffer->axes[i], std::move(sp_index)); + PrimExpr offset = Integer(0); + for (int i = 0; i < num_lowered_indices; ++i) { + const Axis& axis = sp_buffer->axes[i]; + const PrimExpr& index = indices[i]; + PrimExpr offset_i = ViewIndexInAxis(axis, index); + offset = AggregateOffset(std::move(offset), axis, std::move(offset_i)); } - - return lowered_index; + return offset; } /*! * \brief Compupte the partially lowered index. - * \param prev_lowered_index The lowered index accumulated over all axis prior to current axis. + * \param prev_offset The lowered index accumulated over all axis prior to current axis. * \param axis Current axis. * \param index The sparse index on current axis. * \return The lowered index. */ - PrimExpr PartialLowerIndex(PrimExpr prev_lowered_index, const Axis& axis, PrimExpr index) { + PrimExpr AggregateOffset(PrimExpr prev_offset, const Axis& axis, PrimExpr index) { if (axis->IsInstance()) { - return ana_.Simplify(std::move(prev_lowered_index) * axis->length + std::move(index)); + return ana_.Simplify(std::move(prev_offset) * axis->length + std::move(index)); } else if (const auto* sf_axis = axis.as()) { - return ana_.Simplify(std::move(prev_lowered_index) * sf_axis->nnz_cols + std::move(index)); + return ana_.Simplify(std::move(prev_offset) * sf_axis->nnz_cols + std::move(index)); } else if (const auto* dv_axis = axis.as()) { return ana_.Simplify( - add(BufferLoad(dv_axis->indptr, {std::move(prev_lowered_index)}), std::move(index))); + add(BufferLoad(dv_axis->indptr, {std::move(prev_offset)}), std::move(index))); } else if (const auto* sv_axis = axis.as()) { return ana_.Simplify( - add(BufferLoad(sv_axis->indptr, {std::move(prev_lowered_index)}), std::move(index))); + add(BufferLoad(sv_axis->indptr, {std::move(prev_offset)}), std::move(index))); } LOG(FATAL) << "Cannot reach here"; throw; } /*! - * \brief Convert sparse iteration positions to dense coordinates. - * \param sp_iter The sparse iterator. + * \brief Decompress coordinates from compressed indices. + * \param sp_iter_var The compressed iterator. */ - PrimExpr GetDenseValue(SpIterVar sp_iter) { - SpIterKind kind = sp_iter->kind; + PrimExpr GetCoordinate(SpIterVar sp_iter_var) { + SpIterKind kind = sp_iter_var->kind; + switch (kind) { + case SpIterKind::kDenseFixed: + case SpIterKind::kDenseVariable: + // if dense fixed or dense variable, just return the value. + return sp_iter_var->var; + break; + default: + break; + } CHECK(kind == SpIterKind::kSparseFixed || kind == SpIterKind::kSparseVariable); - Axis iterated_axis = sp_iter->axis; - SparseBuffer iterated_buffer{nullptr}; - Array iters{nullptr}; - - collector_.GetIteratedBufferAndDependentIters(sp_iter, &iterated_buffer, &iters); - iters.push_back(sp_iter); - PrimExpr lowered_indices = LowerIndices(std::move(iterated_buffer), iters); + SparseBlockCtx& ctx = this->ctx_st.back(); + Axis axis = sp_iter_var->axis; + Optional parent_sp_iter_var = ctx.LookupParentSparseIterVar(sp_iter_var.get()); + auto opt = ctx.LookupOffset(sp_iter_var); + PrimExpr prev_offset; + if (opt.defined()) { + prev_offset = opt.value(); + } else { + prev_offset = GetCoordinate(parent_sp_iter_var.value()); + ctx.SetOffset(parent_sp_iter_var.value().get(), prev_offset); + } + PrimExpr offset = AggregateOffset(prev_offset, axis, sp_iter_var->var); if (kind == SpIterKind::kSparseFixed) { - return BufferLoad(Downcast(iterated_axis)->indices, - {std::move(lowered_indices)}); + return BufferLoad(Downcast(axis)->indices, {std::move(offset)}); } else { - return BufferLoad(Downcast(iterated_axis)->indices, - {std::move(lowered_indices)}); + return BufferLoad(Downcast(axis)->indices, {std::move(offset)}); } } + /*! + * \brief Get coordinate of var node corresponding to sparse iter vars. + * \param v The variable node. + */ + PrimExpr VisitExpr_(const VarNode* v) final { + const SparseBlockCtx& ctx = ctx_st.back(); + auto opt = ctx.LookupSparseIterVar(v); + CHECK(opt.defined()) << "var " << v->name_hint << " not appeared in the sparse block " + << ctx.GetBlockName() << ", cannot get its corresponding coordinate."; + const SpIterVar& sp_iter_var = opt.value(); + return GetCoordinate(sp_iter_var); + } + /*! * \brief Convert sparse buffer load node to buffer load node. * \param load The sparse buffer load node in AST. */ PrimExpr VisitExpr_(const SparseBufferLoadNode* load) final { buffer_read_.insert(load->buffer.get()); - PrimExpr lowered_indices = LowerIndices(load->buffer, load->indices); + PrimExpr lowered_indices = ComputeOffset(load->buffer, load->indices); return BufferLoad(load->buffer->data, {std::move(lowered_indices)}); } @@ -265,7 +304,7 @@ class IndexTransformer : public StmtExprMutator { Stmt VisitStmt_(const SparseBufferStoreNode* store) final { buffer_write_.insert(store->buffer.get()); PrimExpr value = ExprMutator::VisitExpr(store->value); - PrimExpr lowered_indices = LowerIndices(store->buffer, store->indices); + PrimExpr lowered_indices = ComputeOffset(store->buffer, store->indices); return BufferStore(store->buffer->data, std::move(value), {std::move(lowered_indices)}); } @@ -278,22 +317,29 @@ class IndexTransformer : public StmtExprMutator { buffer_read_.clear(); buffer_write_.clear(); - // Step 1. Recursively mutate the `init` field and the block body. + // Step 1. Push new context to sparse block context stack. + SparseBlockCtx ctx(sp_block->name, sp_block->sp_iter_vars, axis_tree_); + for (const SpIterVar& sp_iter_var : sp_block->sp_iter_vars) { + ctx.AddSparseIterVar(sp_iter_var->var.get(), sp_iter_var); + } + this->ctx_st.push_back(ctx); + + // Step 2. Recursively mutate the `init` field and the block body. Optional init = sp_block->init.defined() ? VisitStmt(sp_block->init.value()) : Optional(NullOpt); Stmt body = VisitStmt(sp_block->body); - // Step 2. Create the new loop vars. + // Step 3. Create the new loop vars. std::unordered_map var_map; Array all_loop_vars; var_map.reserve(n_iter); - for (const SpIterVar& sp_iter : sp_block->sp_iter_vars) { - Var loop_var("v_" + sp_iter->var->name_hint); + for (const SpIterVar& sp_iter_var : sp_block->sp_iter_vars) { + Var loop_var("v_" + sp_iter_var->var->name_hint); all_loop_vars.push_back(loop_var); - var_map[sp_iter->var.get()] = loop_var; + var_map[sp_iter_var->var.get()] = loop_var; } - // Step 3. Collet block iters and iter bindings. + // Step 4. Collet block iters and iter bindings. std::set in_stack; in_stack.insert("root"); /* A stack that stores block itervars in each block. */ @@ -360,12 +406,12 @@ class IndexTransformer : public StmtExprMutator { } } while (true); - // Step 4. Generate the read-region and write-retion of the block. - Array reads{nullptr}; - Array writes{nullptr}; + // Step 5. Generate the read-region and write-retion of the block. + Array reads{}; + Array writes{}; GenerateReadWriteRegions(sp_block, &reads, &writes); - // Step 5. Generate nested blocks and loops from innermost to outermost. + // Step 6. Generate nested blocks and loops from innermost to outermost. int blk_counter = 0; while (!block_iters_st.empty()) { Array block_iters = std::move(block_iters_st.top()); @@ -400,39 +446,41 @@ class IndexTransformer : public StmtExprMutator { blk_counter += 1; } + // Step 7: pop the sparse block context stack. + this->ctx_st.pop_back(); + return body; } /*! * \brief Convert sparse iterable variable to ordinary iterable variable. - * \param sp_iter The sparse iterable variable to convert. + * \param sp_iter_var The sparse iterable variable to convert. * \param var_map The mapping from sparse iterable variable to corresponding ordinary iterable * variable. */ - IterVar SpIterVarToIterVar(const SpIterVar& sp_iter, + IterVar SpIterVarToIterVar(const SpIterVar& sp_iter_var, const std::unordered_map& var_map) { PrimExpr extent{nullptr}; + const SparseBlockCtx& ctx = this->ctx_st.back(); - SpIterKind kind = sp_iter->kind; + SpIterKind kind = sp_iter_var->kind; if (kind == SpIterKind::kDenseFixed || kind == SpIterKind::kSparseFixed) { - extent = sp_iter->max_extent; + extent = sp_iter_var->max_extent; } else { - SparseBuffer iterated_buffer{nullptr}; - Array dependent_iters{nullptr}; - collector_.GetIteratedBufferAndDependentIters(sp_iter, &iterated_buffer, &dependent_iters); - PrimExpr lowered_indices = LowerIndices(std::move(iterated_buffer), dependent_iters); + SpIterVar parent_sp_iter = ctx.LookupParentSparseIterVar(sp_iter_var.get()).value(); + PrimExpr lowered_indices = GetCoordinate(parent_sp_iter); Buffer indptr{kind == SpIterKind::kDenseVariable - ? Downcast(sp_iter->axis)->indptr - : Downcast(sp_iter->axis)->indptr}; + ? Downcast(sp_iter_var->axis)->indptr + : Downcast(sp_iter_var->axis)->indptr}; PrimExpr l = BufferLoad(indptr, {lowered_indices}); PrimExpr r = BufferLoad(indptr, {add(lowered_indices, 1)}); extent = sub(r, l); } // Substitute the iteration vars in the expression with the loop vars. - return IterVar(Range::FromMinExtent(0, Substitute(std::move(extent), var_map)), sp_iter->var, - sp_iter->is_reduction ? kCommReduce : kDataPar); + return IterVar(Range::FromMinExtent(0, Substitute(std::move(extent), var_map)), + sp_iter_var->var, sp_iter_var->is_reduction ? kCommReduce : kDataPar); } /*! @@ -477,7 +525,6 @@ class IndexTransformer : public StmtExprMutator { return body; } - AccessAndDependencyCollector collector_; AxisTree axis_tree_; arith::Analyzer ana_; std::unordered_set buffer_read_; @@ -507,12 +554,9 @@ PrimFunc LowerSparseTIR(AxisTree axis_tree, PrimFunc f) { PrimFuncNode* fptr = f.CopyOnWrite(); // Step 1. Update the PrimFunc's buffer map. fptr->buffer_map = UpdateBufferMap(f); - // Step 2. Collect buffer access information and dependency. - AccessAndDependencyCollector collector; - collector.Collect(f->body); - // Step 3. Lower indices. - fptr->body = IndexTransformer(collector, axis_tree)(std::move(f->body)); - // Step 4. Wrap the function body with a root block. + // Step 2. Lower indices. + fptr->body = IndexTransformer(axis_tree)(std::move(f->body)); + // Step 3. Wrap the function body with a root block. fptr->body = WrapWithRootBlock(std::move(fptr->body)); return f; } else { diff --git a/tests/python/sparsetir/test_tir_sparse_correctness.py b/tests/python/sparsetir/test_tir_sparse_correctness.py index 63bcce46a4b2..a69170179dc6 100644 --- a/tests/python/sparsetir/test_tir_sparse_correctness.py +++ b/tests/python/sparsetir/test_tir_sparse_correctness.py @@ -24,21 +24,6 @@ from tvm.script import tir as T -@T.prim_func -def csrmm(a: T.handle, b: T.handle, c: T.handle, indptr: T.handle, indices: T.handle, m: T.int32, n: T.int32, k: T.int32, nnz: T.int32) -> None: - I = T.dense_fixed(m) - J = T.sparse_variable((n, m + 1, nnz), (indptr, indices), "int32") - K = T.dense_fixed(k) - A = T.match_sparse_buffer(a, (I, J), nnz, "float32") - B = T.match_sparse_buffer(b, (T.to_dense(J), K), n * k, "float32") - C = T.match_sparse_buffer(c, (I, K), m * k, "float32") - with T.iter([T.cord(I), T.cord(J), T.cord(K)], "SRS", "csrmm") as [vi, vj, vk]: - T.block_attr({"sparse": True}) - with T.init(): - C[vi, vk] = 0.0 - C[vi, vk] = C[vi, vk] + A[vi, vj] * B[vj, vk] - - @T.prim_func def csrmm_tir(a: T.handle, b: T.handle, c: T.handle, indptr: T.handle, indices: T.handle, M: T.int32, N: T.int32, K: T.int32, NNZ: T.int32) -> None: T.func_attr({"global_symbol": "main", "tir.noalias": True}) diff --git a/tests/python/sparsetir/test_tir_sparse_nnz_inference.py b/tests/python/sparsetir/test_tir_sparse_nnz_inference.py new file mode 100644 index 000000000000..e164621778c0 --- /dev/null +++ b/tests/python/sparsetir/test_tir_sparse_nnz_inference.py @@ -0,0 +1,76 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import tvm +import tvm.tir as tir +import scipy.sparse as sp +import numpy as np +from tvm.script import tir as T +from tvm.tir.sparse import AxisTree + +@T.prim_func +def csr2bsr_nnz_inf( + indptr: T.handle, indices: T.handle, + new_cord: T.handle, glb_counter: T.handle, + n: T.int32, m: T.int32, nnz: T.int32, + max_nnz: T.int32) -> None: + I = T.dense_fixed(n) + J = T.sparse_variable((m, n + 1, nnz), (indptr, indices), "int32") + K = T.dense_fixed(2) + Glb_counter = T.match_buffer(glb_counter, (1,), "int32") + New_cord = T.match_sparse_buffer(new_cord, (I, J, K), nnz * 2, "int32") + with T.iter([T.pos(I), T.cord(J), ], "SS", "csr2bsr_nnz_inf") as [vi, vj]: + #offset = T.atomic_add(Glb_counter.data, 1) + New_cord[vi, vj, 0] = 0 + New_cord[vi, vj, 1] = 1 + + +@T.prim_func +def csr2bsr(indptr_1: T.handle, indices_1: T.handle, indptr_2: T.handle, indices_2: T.handle, + a_csr: T.handle, a_bsr: T.handle, + block_size: T.int32, + n: T.int32, m: T.int32, nnz: T.int32, + nb: T.int32, mb: T.int32, nnzb: T.int32) -> None: + I = T.dense_fixed(n) + J = T.sparse_variable((m, n + 1, nnz), (indptr_1, indices_1), "int32") + Ibo = T.dense_fixed(nb) + Jbo = T.sparse_variable((mb, nb + 1, nnzb), (indptr_2, indices_2), "int32") + Ibi = T.dense_fixed(block_size) + Jbi = T.dense_fixed(block_size) + A_csr = T.match_sparse_buffer(a_csr, (I, J), nnz, "float32") + A_bsr = T.match_sparse_buffer(a_bsr, (Ibo, Jbo, Ibi, Jbi), nnzb * block_size * block_size, "float32") + with T.iter([T.pos(I), T.cord(J)], "SS", "csr2bsrm") as [vi, vj]: + A_bsr[T.floordiv(vi, block_size), T.floordiv(vj, block_size), T.floormod(vi, block_size), T.floormod(vj, block_size)] =\ + A_csr[vi, vj] + + +def test_csr2bsr(): + mod = tvm.IRModule.from_expr(csr2bsr) + t = AxisTree({ + "J": "I", + "I": None, + "K": None, + "Ibo": None, + "Jbo": "Ibo", + "Ibi": None, + "Ibo": None, + }) + mod = tvm.tir.transform.LowerSparseTIR(t)(mod) + print(mod['main'].script()) + + +if __name__ == "__main__": + test_csr2bsr() \ No newline at end of file