Skip to content

Commit

Permalink
Add docstring for sparse tir lowering (apache#21)
Browse files Browse the repository at this point in the history
* add docstring

* upd
  • Loading branch information
yzh119 authored Nov 19, 2021
1 parent 332df3e commit 9f6a2cd
Show file tree
Hide file tree
Showing 4 changed files with 234 additions and 102 deletions.
4 changes: 2 additions & 2 deletions include/tvm/tir/stmt.h
Original file line number Diff line number Diff line change
Expand Up @@ -1288,7 +1288,7 @@ class SparseBlockNode : public StmtNode {
/*! \brief The sparse data structures */
Array<ObjectRef> sp_structs;
/*! \brief The mapping from sparse data structures to the PrimFunc parameters */
Map<ObjectRef, Array<Var>> sp_struct2param_map;
Map<ObjectRef, Array<Var>> sp_struct_param_map;
/*! \brief The name of the block */
String name;
/*! \brief The body of the block */
Expand All @@ -1299,7 +1299,7 @@ class SparseBlockNode : public StmtNode {
void VisitAttrs(AttrVisitor* v) {
v->Visit("sp_iter_vars", &sp_iter_vars);
v->Visit("sp_structs", &sp_structs);
v->Visit("sp_struct2param_map", &sp_struct2param_map);
v->Visit("sp_struct_param_map", &sp_struct_param_map);
v->Visit("name", &name);
v->Visit("body", &body);
v->Visit("init", &init);
Expand Down
2 changes: 1 addition & 1 deletion src/printer/tvmscript_printer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1278,7 +1278,7 @@ Doc TVMScriptPrinter::PrintSparseStructDefinitions(const SparseBlockNode* sp_blo
std::vector<Doc> sp_buf_docs;

for (const ObjectRef& obj : sp_block->sp_structs) {
Array<Var> params = sp_block->sp_struct2param_map.Get(obj).value();
Array<Var> params = sp_block->sp_struct_param_map.Get(obj).value();

Doc doc;
doc << Print(obj) << " = " << tir_prefix_ << ".";
Expand Down
6 changes: 3 additions & 3 deletions src/tir/ir/stmt.cc
Original file line number Diff line number Diff line change
Expand Up @@ -974,7 +974,7 @@ SparseBlock::SparseBlock(Array<SpIterVar> sp_iter_vars, Array<ObjectRef> sp_stru
CHECK_EQ(sp_structs.size(), sp_struct_params.size())
<< "ValueError: The length of `sp_struct_params` is expected to be equal to the length "
"`sp_structs`, which is the number of sparse data structures";
Map<ObjectRef, Array<Var>> sp_struct2param_map;
Map<ObjectRef, Array<Var>> sp_struct_param_map;
for (int i = 0; i < static_cast<int>(sp_structs.size()); ++i) {
ObjectRef obj = sp_structs[i];
Array<Var> params = sp_struct_params[i];
Expand All @@ -998,13 +998,13 @@ SparseBlock::SparseBlock(Array<SpIterVar> sp_iter_vars, Array<ObjectRef> sp_stru
LOG(FATAL) << "ValueError: " << obj->_type_key << " is not a sparse data structure";
}

sp_struct2param_map.Set(obj, params);
sp_struct_param_map.Set(obj, params);
}

ObjectPtr<SparseBlockNode> node = make_object<SparseBlockNode>();
node->sp_iter_vars = std::move(sp_iter_vars);
node->sp_structs = std::move(sp_structs);
node->sp_struct2param_map = std::move(sp_struct2param_map);
node->sp_struct_param_map = std::move(sp_struct_param_map);
node->name = std::move(name);
node->body = std::move(body);
node->init = std::move(init);
Expand Down
Loading

0 comments on commit 9f6a2cd

Please sign in to comment.