From 437d88c3d08ff67e82290f4547cce70e47e980e4 Mon Sep 17 00:00:00 2001 From: Zihao Ye Date: Wed, 15 Dec 2021 09:42:50 -0800 Subject: [PATCH] Syntax simplification (#34) --- include/tvm/tir/sparse.h | 64 ++----- include/tvm/tir/transform.h | 2 +- python/tvm/script/tir/special_stmt.py | 35 +++- python/tvm/tir/sparse.py | 48 +++-- python/tvm/tir/transform/transform.py | 10 +- src/printer/tvmscript_printer.cc | 19 +- src/tir/ir/sparse.cc | 75 +++----- .../primitive/sparse_loop_transformation.cc | 2 +- src/tir/transforms/lower_sparse_tir.cc | 173 ++++++++---------- .../sparsetir/test_tir_sparse_buffer.py | 10 +- .../python/sparsetir/test_tir_sparse_lower.py | 138 +++++++------- .../test_tir_sparse_nnz_inference.py | 31 +--- .../test_tir_sparse_script_roundtrip.py | 45 +++-- 13 files changed, 279 insertions(+), 373 deletions(-) diff --git a/include/tvm/tir/sparse.h b/include/tvm/tir/sparse.h index cd0fca704871..6043dce377de 100644 --- a/include/tvm/tir/sparse.h +++ b/include/tvm/tir/sparse.h @@ -40,6 +40,8 @@ enum class AxisKind : int { kSparseVariable = 3 }; +class Axis; + /*! * \brief Base type for axis in sparse formats. */ @@ -73,6 +75,7 @@ class AxisNode : public Object { String GetName() const { return name; } PrimExpr GetLength() const { return length; } DataType GetIndexType() const { return length->dtype; } + virtual Optional GetParentAxis() const = 0; virtual AxisKind kind() const = 0; virtual PrimExpr nnz() const = 0; @@ -137,6 +140,8 @@ class DenseFixedAxisNode : public DenseAxisNode { PrimExpr nnz() const final { return length; } + Optional GetParentAxis() const final { return NullOpt; } + static constexpr const char* _type_key = "tir.sparse.DenseFixedAxis"; TVM_DECLARE_BASE_OBJECT_INFO(DenseFixedAxisNode, DenseAxisNode); }; @@ -238,6 +243,7 @@ class DenseVariableAxisNode : public DenseAxisNode { public: Buffer indptr; PrimExpr nnz_; + Axis parent_; void VisitAttrs(AttrVisitor* v) { DenseAxisNode::VisitAttrs(v); @@ -257,6 +263,8 @@ class DenseVariableAxisNode : public DenseAxisNode { PrimExpr nnz() const final { return nnz_; } + Optional GetParentAxis() const final { return parent_; } + static constexpr const char* _type_key = "tir.sparse.DenseVariableAxis"; TVM_DECLARE_FINAL_OBJECT_INFO(DenseVariableAxisNode, DenseAxisNode); }; @@ -267,7 +275,8 @@ class DenseVariableAxisNode : public DenseAxisNode { */ class DenseVariableAxis : public DenseAxis { public: - TVM_DLL explicit DenseVariableAxis(String name, PrimExpr length, PrimExpr nnz, Buffer indptr); + TVM_DLL explicit DenseVariableAxis(String name, Axis parent, PrimExpr length, PrimExpr nnz, + Buffer indptr); TVM_DEFINE_OBJECT_REF_METHODS(DenseVariableAxis, DenseAxis, DenseVariableAxisNode); }; @@ -280,6 +289,7 @@ class SparseFixedAxisNode : public SparseAxisNode { Buffer indices; /* fixed number of non-zero columns of current sparse axis. */ PrimExpr nnz_cols; + Axis parent_; void VisitAttrs(AttrVisitor* v) { SparseAxisNode::VisitAttrs(v); @@ -302,6 +312,8 @@ class SparseFixedAxisNode : public SparseAxisNode { AxisKind kind() const final { return AxisKind::kSparseFixed; } + Optional GetParentAxis() const final { return parent_; } + static constexpr const char* _type_key = "tir.sparse.SparseFixedAxis"; TVM_DECLARE_FINAL_OBJECT_INFO(SparseFixedAxisNode, SparseAxisNode); }; @@ -312,7 +324,8 @@ class SparseFixedAxisNode : public SparseAxisNode { */ class SparseFixedAxis : public SparseAxis { public: - TVM_DLL explicit SparseFixedAxis(String name, PrimExpr length, Buffer indices, PrimExpr nnz_cols); + TVM_DLL explicit SparseFixedAxis(String name, Axis parent, PrimExpr length, Buffer indices, + PrimExpr nnz_cols); TVM_DEFINE_OBJECT_REF_METHODS(SparseFixedAxis, SparseAxis, SparseFixedAxisNode); }; @@ -324,6 +337,7 @@ class SparseVariableAxisNode : public SparseAxisNode { public: Buffer indptr; Buffer indices; + Axis parent_; void VisitAttrs(AttrVisitor* v) { SparseAxisNode::VisitAttrs(v); @@ -346,6 +360,8 @@ class SparseVariableAxisNode : public SparseAxisNode { AxisKind kind() const final { return AxisKind::kSparseVariable; } + Optional GetParentAxis() const final { return parent_; } + static constexpr const char* _type_key = "tir.sparse.SparseVariableAxis"; TVM_DECLARE_FINAL_OBJECT_INFO(SparseVariableAxisNode, SparseAxisNode); }; @@ -356,52 +372,12 @@ class SparseVariableAxisNode : public SparseAxisNode { */ class SparseVariableAxis : public SparseAxis { public: - TVM_DLL explicit SparseVariableAxis(String name, PrimExpr length, Buffer indptr, Buffer indices); + TVM_DLL explicit SparseVariableAxis(String name, Axis parent, PrimExpr length, Buffer indptr, + Buffer indices); TVM_DEFINE_OBJECT_REF_METHODS(SparseVariableAxis, SparseAxis, SparseVariableAxisNode); }; -/*! - * \brief Axis Dependency Tree. - */ -class AxisTreeNode : public Object { - public: - // unordered map that stores the parent relationship between axes. - Map parent; - // unordered map that stores the children relationship between axes. - Map> children; - - void VisitAttrs(AttrVisitor* v) { - v->Visit("parent", &parent); - v->Visit("children", &children); - } - - bool SEqualReduce(const AxisTreeNode* other, SEqualReducer equal) const { - return equal(parent, other->parent) && equal(children, other->children); - } - - void SHashReduce(SHashReducer hash_reduce) const { - hash_reduce(parent); - hash_reduce(children); - } - - static constexpr const char* _type_key = "tir.sparse.AxisTree"; - static constexpr const bool _type_has_method_sequal_reduce = true; - static constexpr const bool _type_has_method_shash_reduce = true; - TVM_DECLARE_FINAL_OBJECT_INFO(AxisTreeNode, Object); -}; - -/*! - * \brief Managed reference to AxisRefNode. - * \sa AxisTreeNode - */ -class AxisTree : public ObjectRef { - public: - TVM_DLL AxisTree(Array axis_names, Array> axis_parent_names); - - TVM_DEFINE_OBJECT_REF_METHODS(AxisTree, ObjectRef, AxisTreeNode); -}; - /*! * \brief Class of sparse buffer. */ diff --git a/include/tvm/tir/transform.h b/include/tvm/tir/transform.h index 02815fe9e217..e552ae117445 100644 --- a/include/tvm/tir/transform.h +++ b/include/tvm/tir/transform.h @@ -489,7 +489,7 @@ TVM_DLL Pass ConvertForLoopsToSerial(); * \param axis_tree The axis dependency tree. * \return The pass. */ -TVM_DLL Pass LowerSparseTIR(AxisTree axis_tree); +TVM_DLL Pass LowerSparseTIR(); } // namespace transform } // namespace tir diff --git a/python/tvm/script/tir/special_stmt.py b/python/tvm/script/tir/special_stmt.py index b1cfdf37c907..1acfa85767e1 100644 --- a/python/tvm/script/tir/special_stmt.py +++ b/python/tvm/script/tir/special_stmt.py @@ -920,6 +920,7 @@ class DenseVariable(SpecialStmt): def __init__(self): def dense_variable( + parent_axis: Axis, shape: Tuple[PrimExpr, PrimExpr], indptr_var: tvm.tir.Var, idtype: str = "int32", @@ -931,11 +932,12 @@ def dense_variable( f"`dense_variable` expected assign to only one var, but got {names}", span ) - length, indptr_len, nnz = shape + length, nnz = shape + indptr_len = parent_axis.length + 1 indptr_buf = tvm.tir.decl_buffer( (indptr_len,), dtype=idtype, name=names[0] + "_indptr", span=span ) - axis = DenseVariableAxis(names[0], length, nnz, indptr_buf) + axis = DenseVariableAxis(names[0], parent_axis, length, nnz, indptr_buf) self.context.sp_struct.append(axis) self.context.sp_struct_params.append([indptr_var]) self.context.update_symbol(names[0], axis, self.node) @@ -950,7 +952,8 @@ class SparseFixed(SpecialStmt): def __init__(self): def sparse_fixed( - shape: Tuple[PrimExpr, PrimExpr, PrimExpr], + parent_axis: Axis, + shape: Tuple[PrimExpr, PrimExpr], indices_var: tvm.tir.Var, idtype: str = "int32", span: Optional[Span] = None, @@ -961,11 +964,12 @@ def sparse_fixed( f"`sparse_fixed` expected assign to only one var, but got {names}", span ) - length, nnz, nnz_cols = shape + length, nnz_cols = shape + nnz = parent_axis.nnz * nnz_cols indices_buf = tvm.tir.decl_buffer( (nnz,), dtype=idtype, name=names[0] + "_indices", span=span ) - axis = SparseFixedAxis(names[0], length, indices_buf, nnz_cols) + axis = SparseFixedAxis(names[0], parent_axis, length, indices_buf, nnz_cols) self.context.sp_struct.append(axis) self.context.sp_struct_params.append([indices_var]) self.context.update_symbol(names[0], axis, self.node) @@ -980,7 +984,8 @@ class SparseVariable(SpecialStmt): def __init__(self): def sparse_variable( - shape: Tuple[PrimExpr, PrimExpr, PrimExpr], + parent_axis: Axis, + shape: Tuple[PrimExpr, PrimExpr], data: Tuple[tvm.tir.Var, tvm.tir.Var], idtype: str = "int32", span: Optional[Span] = None, @@ -991,7 +996,8 @@ def sparse_variable( f"`sparse_variable` expected assign to only one var, but got {names}", span ) - length, indptr_len, nnz = shape + length, nnz = shape + indptr_len = parent_axis.nnz + 1 indptr_var, indices_var = data indptr_buf = tvm.tir.decl_buffer( (indptr_len,), dtype=idtype, name=names[0] + "_indptr", span=span @@ -999,7 +1005,7 @@ def sparse_variable( indices_buf = tvm.tir.decl_buffer( (nnz,), dtype=idtype, name=names[0] + "_indices", span=span ) - axis = SparseVariableAxis(names[0], length, indptr_buf, indices_buf) + axis = SparseVariableAxis(names[0], parent_axis, length, indptr_buf, indices_buf) self.context.sp_struct.append(axis) self.context.sp_struct_params.append([indptr_var, indices_var]) self.context.update_symbol(names[0], axis, self.node) @@ -1017,10 +1023,19 @@ def __init__(self): def match_sparse_buffer( param: tvm.tir.Var, axes: List[Axis], - nnz: PrimExpr, dtype: str = "float32", span: Optional[Span] = None, ): + def infer_nnz(axes: List[Axis]) -> PrimExpr: + """Inference the number of non-zero elements in a sparse buffer.""" + ret = axes[0].nnz + for axis in axes[1:]: + if isinstance(axis, DenseFixedAxis): + ret = ret * axis.nnz + else: + ret = axis.nnz + return ret + if not isinstance(self.node, ast.Assign) or not len(self.node.lhs) == 1: self.context.report_error( "`match_sparse_buffer` must be assigned to a single sparse buffer, " @@ -1035,7 +1050,7 @@ def match_sparse_buffer( ) if param in self.context.func_params: - data = tvm.tir.decl_buffer(nnz, dtype, buffer_name + "_data", span=span) + data = tvm.tir.decl_buffer(infer_nnz(axes), dtype, buffer_name + "_data", span=span) buffer = tvm.tir.sparse.SparseBuffer(axes, data, buffer_name) self.context.sp_struct.append(buffer) self.context.sp_struct_params.append([param]) diff --git a/python/tvm/tir/sparse.py b/python/tvm/tir/sparse.py index 9b136d037412..d4df74f9685c 100644 --- a/python/tvm/tir/sparse.py +++ b/python/tvm/tir/sparse.py @@ -42,6 +42,10 @@ def length(self): @property def idtype(self): return _ffi_api.GetAxisIndexType(self) + + @property + def nnz(self): + return _ffi_api.GetNNZ(self) @tvm._ffi.register_object("tir.sparse.DenseAxis") @@ -117,6 +121,9 @@ class DenseVariableAxis(DenseAxis): ---------- name : str The name of the axis + + parent : Axis + The parent axis length : PrimExpr The length of the axis @@ -126,13 +133,14 @@ class DenseVariableAxis(DenseAxis): """ name: str + parent: Axis length: PrimExpr nnz: PrimExpr indptr: Buffer - def __init__(self, name, length, nnz, indptr): + def __init__(self, name, parent, length, nnz, indptr): self.__init_handle_by_constructor__( - _ffi_api.DenseVariableAxis, name, length, nnz, indptr # type: ignore + _ffi_api.DenseVariableAxis, name, parent, length, nnz, indptr # type: ignore ) @@ -145,6 +153,9 @@ class SparseFixedAxis(DenseAxis): name : str The name of the axis + parent : Axis + The parent axis + length : PrimExpr The length of the axis @@ -156,13 +167,14 @@ class SparseFixedAxis(DenseAxis): """ name: str + parent: Axis length: PrimExpr indices: Buffer nnz_cols: PrimExpr - def __init__(self, name, length, indices, nnz_cols): + def __init__(self, name, parent, length, indices, nnz_cols): self.__init_handle_by_constructor__( - _ffi_api.SparseFixedAxis, name, length, indices, nnz_cols # type: ignore + _ffi_api.SparseFixedAxis, name, parent, length, indices, nnz_cols # type: ignore ) @@ -174,6 +186,9 @@ class SparseVariableAxis(DenseAxis): ---------- name : str The name of the axis + + parent : Axis + The parent axis length : PrimExpr The length of the axis @@ -186,33 +201,14 @@ class SparseVariableAxis(DenseAxis): """ name: str + parent: Axis length: PrimExpr indptr: Buffer indices: Buffer - def __init__(self, name, length, indptr, indices): - self.__init_handle_by_constructor__( - _ffi_api.SparseVariableAxis, name, length, indptr, indices # type: ignore - ) - - -@tvm._ffi.register_object("tir.sparse.AxisTree") -class AxisTree(Object): - """AxisTree node - - Parameters - ---------- - axis_parent_map: Dict - A dictionary that maps axis name to parent axis name, value is None if there is not parent axis. - """ - - axis_parent_map: Dict[str, Optional[str]] - - def __init__(self, axis_parent_map) -> None: - keys = list(axis_parent_map.keys()) - values = list(axis_parent_map.values()) + def __init__(self, name, parent, length, indptr, indices): self.__init_handle_by_constructor__( - _ffi_api.AxisTree, keys, values # type:ignore + _ffi_api.SparseVariableAxis, name, parent, length, indptr, indices # type: ignore ) diff --git a/python/tvm/tir/transform/transform.py b/python/tvm/tir/transform/transform.py index 96ae275dc156..0073503e4fc3 100644 --- a/python/tvm/tir/transform/transform.py +++ b/python/tvm/tir/transform/transform.py @@ -19,7 +19,6 @@ from typing import Optional from . import _ffi_api from . import function_pass as _fpass -from ..sparse import AxisTree def Apply(ftransform): @@ -752,17 +751,12 @@ def ConvertForLoopsToSerial(): return _ffi_api.ConvertForLoopsToSerial() # type: ignore -def LowerSparseTIR(axis_tree: AxisTree): +def LowerSparseTIR(): """Lower SparseTIR to TIR - Parameters - ---------- - axis_tree : AxisTree - The axis dependency tree. - Returns ------- fpass : tvm.transform.Pass The result pass """ - return _ffi_api.LowerSparseTIR(axis_tree) # type: ignore + return _ffi_api.LowerSparseTIR() # type: ignore diff --git a/src/printer/tvmscript_printer.cc b/src/printer/tvmscript_printer.cc index afdc0e6a34f1..85b25156aa19 100644 --- a/src/printer/tvmscript_printer.cc +++ b/src/printer/tvmscript_printer.cc @@ -1321,7 +1321,7 @@ Doc TVMScriptPrinter::PrintSparseBlockName(const SparseBlockNode* op) { const SpIterVar& sp_iter = op->sp_iter_vars[i]; const Axis& axis = sp_iter->axis; Doc iter_doc; - + std::string axis_repr = sp_iter->axis->name; if (axis->is_derived_axis) { if (const DenseFromSparseAxisNode* dfs_axis = axis.as()) { @@ -1399,7 +1399,7 @@ Doc TVMScriptPrinter::PrintSparseStructDefinitions(const SparseBlockNode* sp_blo } doc << "match_sparse_buffer(" << Print(params[0]) << ", (" << axes_doc << "), " - << Print(sp_buffer->data->shape[0]) << ", " << PrintDType(sp_buffer->data->dtype) << ")"; + << PrintDType(sp_buffer->data->dtype) << ")"; sp_buf_docs.push_back(doc); continue; } @@ -1409,20 +1409,19 @@ Doc TVMScriptPrinter::PrintSparseStructDefinitions(const SparseBlockNode* sp_blo doc << "dense_fixed(" << Print(df_axis->length) << ")"; } else if (const auto* dv_axis = obj.as()) { ICHECK_EQ(params.size(), 1); - doc << "dense_variable((" << Print(dv_axis->length) << ", " - << Print(dv_axis->indptr->shape[0]) << "), " << Print(params[0]) << ", " + doc << "dense_variable(" << dv_axis->parent_->name << ", (" << Print(dv_axis->length) << ", " + << Print(dv_axis->nnz()) << "), " << Print(params[0]) << ", " << PrintDType(dv_axis->indptr->dtype) << ")"; } else if (const auto* sf_axis = obj.as()) { ICHECK_EQ(params.size(), 1); - doc << "sparse_fixed((" << Print(sf_axis->length) << ", " << Print(sf_axis->indices->shape[0]) - << ", " << Print(sf_axis->nnz_cols) << "), " << Print(params[0]) << ", " + doc << "sparse_fixed(" << sf_axis->parent_->name << ", (" << Print(sf_axis->length) << ", " + << Print(sf_axis->nnz_cols) << "), " << Print(params[0]) << ", " << PrintDType(sf_axis->indices->dtype) << ")"; } else if (const auto* sv_axis = obj.as()) { ICHECK_EQ(params.size(), 2); - doc << "sparse_variable((" << Print(sv_axis->length) << ", " - << Print(sv_axis->indptr->shape[0]) << ", " << Print(sv_axis->indices->shape[0]) << "), (" - << Print(params[0]) << ", " << Print(params[1]) << "), " - << PrintDType(sv_axis->indptr->dtype) << ")"; + doc << "sparse_variable(" << sv_axis->parent_->name << ", (" << Print(sv_axis->length) << ", " + << Print(sv_axis->nnz()) << "), (" << Print(params[0]) << ", " << Print(params[1]) + << "), " << PrintDType(sv_axis->indptr->dtype) << ")"; } else { ICHECK(false) << "Cannot reach here"; } diff --git a/src/tir/ir/sparse.cc b/src/tir/ir/sparse.cc index 18e7cf8b4f2a..7c6225f59cdb 100644 --- a/src/tir/ir/sparse.cc +++ b/src/tir/ir/sparse.cc @@ -43,6 +43,8 @@ TVM_REGISTER_GLOBAL("tir.sparse.GetAxisIndexType").set_body_typed([](Axis axis) return DLDataType2String(axis->GetIndexType()); }); +TVM_REGISTER_GLOBAL("tir.sparse.GetNNZ").set_body_typed([](Axis axis) { return axis->nnz(); }); + /******** DenseFixedAxis ********/ /*! \brief Default constructor of DenseFixedAxis */ @@ -68,9 +70,11 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) /******** DenseVariableAxis ********/ /*! \brief Default constuctor of DenseVariableAxis */ -DenseVariableAxis::DenseVariableAxis(String name, PrimExpr length, PrimExpr nnz, Buffer indptr) { +DenseVariableAxis::DenseVariableAxis(String name, Axis parent, PrimExpr length, PrimExpr nnz, + Buffer indptr) { ObjectPtr node = make_object(); node->name = std::move(name); + node->parent_ = std::move(parent); node->length = std::move(length); node->nnz_ = std::move(nnz); node->indptr = std::move(indptr); @@ -80,8 +84,9 @@ DenseVariableAxis::DenseVariableAxis(String name, PrimExpr length, PrimExpr nnz, TVM_REGISTER_NODE_TYPE(DenseVariableAxisNode); TVM_REGISTER_GLOBAL("tir.sparse.DenseVariableAxis") - .set_body_typed([](String name, PrimExpr length, PrimExpr nnz, Buffer indptr) { - return DenseVariableAxis(std::move(name), std::move(length), std::move(nnz), std::move(indptr)); + .set_body_typed([](String name, Axis parent, PrimExpr length, PrimExpr nnz, Buffer indptr) { + return DenseVariableAxis(std::move(name), std::move(parent), std::move(length), + std::move(nnz), std::move(indptr)); }); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) @@ -161,9 +166,11 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) /******** SparseFixedAxis ********/ /*! \brief Default constructor of SparseFixedAxis */ -SparseFixedAxis::SparseFixedAxis(String name, PrimExpr length, Buffer indices, PrimExpr nnz_cols) { +SparseFixedAxis::SparseFixedAxis(String name, Axis parent, PrimExpr length, Buffer indices, + PrimExpr nnz_cols) { ObjectPtr node = make_object(); node->name = std::move(name); + node->parent_ = std::move(parent); node->length = std::move(length); node->indices = std::move(indices); node->nnz_cols = std::move(nnz_cols); @@ -173,24 +180,27 @@ SparseFixedAxis::SparseFixedAxis(String name, PrimExpr length, Buffer indices, P TVM_REGISTER_NODE_TYPE(SparseFixedAxisNode); TVM_REGISTER_GLOBAL("tir.sparse.SparseFixedAxis") - .set_body_typed([](String name, PrimExpr length, Buffer indices, PrimExpr nnz_cols) { - return SparseFixedAxis(std::move(name), std::move(length), std::move(indices), std::move(nnz_cols)); + .set_body_typed([](String name, Axis parent, PrimExpr length, Buffer indices, + PrimExpr nnz_cols) { + return SparseFixedAxis(std::move(name), std::move(parent), std::move(length), + std::move(indices), std::move(nnz_cols)); }); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { auto* op = static_cast(node.get()); - p->stream << "sparse_fixed(" << op->name << ", " << op->length << ", " << op->nnz_cols << ", " - << op->indices->name << ")"; + p->stream << "sparse_fixed(" << op->name << ", " << op->GetParentAxis().value()->name << ", " + << op->length << ", " << op->nnz_cols << ", " << op->indices->name << ")"; }); /******** SparseVariableAxis ********/ /*! \brief Default constructor of SparseVariableAxis */ -SparseVariableAxis::SparseVariableAxis(String name, PrimExpr length, Buffer indptr, +SparseVariableAxis::SparseVariableAxis(String name, Axis parent, PrimExpr length, Buffer indptr, Buffer indices) { ObjectPtr node = make_object(); node->name = std::move(name); + node->parent_ = std::move(parent); node->length = std::move(length); node->indptr = std::move(indptr); node->indices = std::move(indices); @@ -200,8 +210,9 @@ SparseVariableAxis::SparseVariableAxis(String name, PrimExpr length, Buffer indp TVM_REGISTER_NODE_TYPE(SparseVariableAxisNode); TVM_REGISTER_GLOBAL("tir.sparse.SparseVariableAxis") - .set_body_typed([](String name, PrimExpr length, Buffer indptr, Buffer indices) { - return SparseVariableAxis(std::move(name), std::move(length), std::move(indptr), std::move(indices)); + .set_body_typed([](String name, Axis parent, PrimExpr length, Buffer indptr, Buffer indices) { + return SparseVariableAxis(std::move(name), std::move(parent), std::move(length), + std::move(indptr), std::move(indices)); }); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) @@ -211,48 +222,6 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) << ", " << op->indices->name << ")"; }); -/******** AxisTree ********/ - -/*! \brief Default constructor of AxisTree */ -AxisTree::AxisTree(Array axis_names, Array> axis_parent_names) { - CHECK_EQ(axis_names.size(), axis_parent_names.size()) - << "ValueError: The axis_names array should have the same length as " - "axis_parent_names " - "array."; - ObjectPtr node = make_object(); - Map parent; - Map> children; - for (size_t i = 0; i < axis_names.size(); i++) { - // update parent map & children map - String axis_name = axis_names[i]; - String parent_name("root"); - if (axis_parent_names[i].defined()) { - parent_name = axis_parent_names[i].value(); - } - parent.Set(axis_name, parent_name); - - auto it = children.find(parent_name); - if (it != children.end()) { - Array value = (*it).second; - value.push_back(axis_name); - children.Set(parent_name, std::move(value)); - } else { - Array value{axis_name}; - children.Set(parent_name, std::move(value)); - } - } - node->parent = std::move(parent); - node->children = std::move(children); - data_ = std::move(node); -} - -TVM_REGISTER_NODE_TYPE(AxisTreeNode); - -TVM_REGISTER_GLOBAL("tir.sparse.AxisTree") - .set_body_typed([](Array axis_names, Array> axis_parent_names) { - return AxisTree(std::move(axis_names), std::move(axis_parent_names)); - }); - /******** SparseBuffer ********/ /*! \brief Default constructor of SparseBuffer */ diff --git a/src/tir/schedule/primitive/sparse_loop_transformation.cc b/src/tir/schedule/primitive/sparse_loop_transformation.cc index b7da781fe3e3..bfb57bd4aa88 100644 --- a/src/tir/schedule/primitive/sparse_loop_transformation.cc +++ b/src/tir/schedule/primitive/sparse_loop_transformation.cc @@ -100,7 +100,7 @@ SparseBlock SparseReorder(ScheduleState self, const SparseBlock& block, CheckValidInputIterators(self, new_order, block->sp_iter_vars); // Step 2. Check whether the new order does not break the iterator dependency. - // TODO(zihao): use axis dependency tree instead + // TODO(zihao): rewrite this part. // CheckDependency(self, block, new_order); // Step 3. Create the new SparseBlock. diff --git a/src/tir/transforms/lower_sparse_tir.cc b/src/tir/transforms/lower_sparse_tir.cc index a9de88264546..2b913f6281fd 100644 --- a/src/tir/transforms/lower_sparse_tir.cc +++ b/src/tir/transforms/lower_sparse_tir.cc @@ -92,7 +92,7 @@ Map UpdateBufferMap(PrimFunc f) { * \param ana_ The analyzer used for simplifying expressions. TODO(zihao): make it more cleaner. * \return The lowered index. */ -PrimExpr AggregateOffset(PrimExpr prev_offset, const Axis& axis, PrimExpr index, +PrimExpr AggregateOffset(PrimExpr prev_offset, Axis axis, PrimExpr index, arith::Analyzer* ana_ = nullptr) { PrimExpr new_offset; switch (axis->kind()) { @@ -132,45 +132,18 @@ class SparseBlockCtx { explicit Scope(Scope&& other) : sp_iter_var_map_(std::move(other.sp_iter_var_map_)), offset_(std::move(other.offset_)), - parent_(std::move(parent_)), - blk_name_(std::move(blk_name_)), + axis_to_sp_iter_var_(std::move(other.axis_to_sp_iter_var_)), + blk_name_(std::move(other.blk_name_)), ana_(std::move(other.ana_)) {} // default constructor - explicit Scope(String blk_name, Array sp_iter_vars, AxisTree tree, arith::Analyzer* ana) + explicit Scope(String blk_name, Array sp_iter_vars, arith::Analyzer* ana) : blk_name_(std::move(blk_name)), ana_(ana) { - std::unordered_map axis_name_sp_iter_map_; // initialize sparse iter var dependency map. for (const SpIterVar& sp_iter_var : sp_iter_vars) { - axis_name_sp_iter_map_[sp_iter_var->axis->name] = sp_iter_var; + axis_to_sp_iter_var_[sp_iter_var->axis.get()] = sp_iter_var; sp_iter_var_map_[sp_iter_var->var.get()] = sp_iter_var; } - - // collect parents. - for (const SpIterVar& sp_iter_var : sp_iter_vars) { - String axis_name = sp_iter_var->axis->name; - const SpIterVarNode* node = sp_iter_var.get(); - if (sp_iter_var->axis->is_derived_axis) { - // The axis is a derived axis. - parent_[node] = nullptr; - } else { - auto opt = tree->parent.Get(axis_name); - CHECK(opt.defined()) << "Cannot find parent of axis " << axis_name << "."; - String parent_axis_name = opt.value(); - if (parent_axis_name != "root") { - auto it = axis_name_sp_iter_map_.find(parent_axis_name); - CHECK(it != axis_name_sp_iter_map_.end()) - << "Cannot find sparse iter vars corresponding to parent axis " << parent_axis_name - << " in current sparse block " << blk_name; - parent_[node] = (it->second).get(); - } else { - parent_[node] = nullptr; - } - } - } - - // init offset_ - offset_[nullptr] = Integer(0); } /*! @@ -193,7 +166,7 @@ class SparseBlockCtx { * \param sp_iter_var The compressed iterator. * \return A PrimExpr representing the coordinate. */ - PrimExpr GetCoordinate(const SpIterVarNode* sp_iter_var) { + PrimExpr GetCoordinate(SpIterVar sp_iter_var) { const Axis& axis = sp_iter_var->axis; AxisKind kind = axis->kind(); if (kind == AxisKind::kDenseFixed || kind == AxisKind::kDenseVariable) { @@ -214,14 +187,18 @@ class SparseBlockCtx { * \param sp_iter_var The sparse iter var to lookup. * \return A PrimExpr representing the offset. */ - PrimExpr GetOffset(const SpIterVarNode* sp_iter_var) { - auto it = offset_.find(sp_iter_var); + PrimExpr GetOffset(Optional sp_iter_var) { + if (!sp_iter_var.defined()) { + return Integer(0); + } + SpIterVar sp_iter_var_ = sp_iter_var.value(); + auto it = offset_.find(sp_iter_var_.get()); if (it != offset_.end()) { return it->second; } else { - PrimExpr prev_off = GetOffset(parent_[sp_iter_var]); - PrimExpr new_off = AggregateOffset(prev_off, sp_iter_var->axis, sp_iter_var->var, ana_); - offset_[sp_iter_var] = new_off; + PrimExpr prev_off = GetOffset(GetParentSpIterVar(sp_iter_var_)); + PrimExpr new_off = AggregateOffset(prev_off, sp_iter_var_->axis, sp_iter_var_->var, ana_); + offset_[sp_iter_var_.get()] = new_off; return new_off; } } @@ -232,13 +209,24 @@ class SparseBlockCtx { * \return A tuple of PrimExpr, the first elements refers to the start position, and the second * elements refers the end position. */ - std::tuple GetIndicesRange(const SpIterVarNode* sp_iter_var) { - PrimExpr prev_off = GetOffset(parent_[sp_iter_var]); + std::tuple GetIndicesRange(SpIterVar sp_iter_var) { + PrimExpr prev_off = GetOffset(GetParentSpIterVar(sp_iter_var)); const Axis& axis = sp_iter_var->axis; return {AggregateOffset(prev_off, axis, Integer(0), ana_), AggregateOffset(add(prev_off, 1), axis, Integer(0), ana_)}; } + Optional GetParentSpIterVar(SpIterVar sp_iter_var) { + Axis axis = std::move(sp_iter_var->axis); + auto parent = axis->GetParentAxis(); + if (parent.defined()) { + Axis parent_ = parent.value(); + return axis_to_sp_iter_var_[parent_.get()]; + } else { + return NullOpt; + } + } + /*! * \brief Get the current block name. */ @@ -247,17 +235,17 @@ class SparseBlockCtx { private: std::unordered_map sp_iter_var_map_; std::unordered_map offset_; - std::unordered_map parent_; + std::unordered_map axis_to_sp_iter_var_; String blk_name_; arith::Analyzer* ana_; }; /*! \brief default constructor */ - explicit SparseBlockCtx(AxisTree tree, arith::Analyzer* ana) : tree_(std::move(tree)), ana_(ana) {} + explicit SparseBlockCtx(arith::Analyzer* ana) : ana_(ana) {} /*! \brief enter new scope */ void EnterScope(const SparseBlockNode* sp_block) { - stack_.emplace_back(sp_block->name, sp_block->sp_iter_vars, tree_, ana_); + stack_.emplace_back(sp_block->name, sp_block->sp_iter_vars, ana_); } /*! \brief exit current scope */ @@ -269,10 +257,10 @@ class SparseBlockCtx { } /*! \brief call GetCoordinate in the top scope. */ - PrimExpr GetCoordinate(const SpIterVarNode* node) { return top()->GetCoordinate(node); } + PrimExpr GetCoordinate(SpIterVar sp_iter_var) { return top()->GetCoordinate(sp_iter_var); } /*! \brief call GetIndicesRange in the top scope. */ - std::tuple GetIndicesRange(const SpIterVarNode* sp_iter_var) { + std::tuple GetIndicesRange(SpIterVar sp_iter_var) { return top()->GetIndicesRange(sp_iter_var); } @@ -281,7 +269,6 @@ class SparseBlockCtx { private: std::vector stack_; - AxisTree tree_; arith::Analyzer* ana_; /*! \brief the top scope in the sparse block stack. */ @@ -303,8 +290,12 @@ class SparseBufferCtx { ana_(std::move(other.ana_)) {} /*! \brief default constructor */ - explicit Scope(String buf_name, Array axes, const SparseBlockCtx* sp_blk_ctx, arith::Analyzer* ana) - : buf_name_(std::move(buf_name)), axes_(std::move(axes)), sp_blk_ctx_(sp_blk_ctx), ana_(ana) { + explicit Scope(String buf_name, Array axes, const SparseBlockCtx* sp_blk_ctx, + arith::Analyzer* ana) + : buf_name_(std::move(buf_name)), + axes_(std::move(axes)), + sp_blk_ctx_(sp_blk_ctx), + ana_(ana) { offsets_.emplace_back(Integer(0)); matches_.emplace_back(true); } @@ -338,9 +329,7 @@ class SparseBufferCtx { } /*! \brief get the axis given dimension index of current buffer. */ - Axis GetAxis(int dim) const { - return axes_[dim]; - } + Axis GetAxis(int dim) const { return axes_[dim]; } /*! \brief whether the index access pattern of current buffer aligns with current block */ const inline bool MatchWithSpBlock() const { return matches_.back(); } @@ -362,10 +351,10 @@ class SparseBufferCtx { }; /*! \brief default constructor */ - explicit SparseBufferCtx(AxisTree tree, arith::Analyzer* ana) : tree_(std::move(tree)), ana_(ana) {} + explicit SparseBufferCtx(arith::Analyzer* ana) : ana_(ana) {} /*! \brief enter new scope */ - void EnterScope(const SparseBuffer& sp_buf, const SparseBlockCtx* sp_blk_ctx) { + void EnterScope(SparseBuffer sp_buf, const SparseBlockCtx* sp_blk_ctx) { stack_.emplace_back(sp_buf->name, sp_buf->axes, sp_blk_ctx, ana_); } @@ -373,9 +362,7 @@ class SparseBufferCtx { void ExitScope() { stack_.pop_back(); } /*! \brief call GetAxis in top scope. */ - Axis GetAxis(int dim) const { - return top()->GetAxis(dim); - } + Axis GetAxis(int dim) const { return top()->GetAxis(dim); } /*! \brief call MatchWithSpBlock in top scope. */ const inline bool MatchWithSpBlock() const { return top()->MatchWithSpBlock(); } @@ -389,7 +376,6 @@ class SparseBufferCtx { } private: - AxisTree tree_; std::vector stack_; arith::Analyzer* ana_; @@ -403,8 +389,7 @@ class SparseBufferCtx { */ class IndexTransformer : public StmtExprMutator { public: - explicit IndexTransformer(const AxisTree& axis_tree) - : sp_blk_ctx_(axis_tree, &ana_), sp_buf_ctx_(axis_tree, &ana_), axis_tree_(axis_tree) {} + explicit IndexTransformer() : sp_blk_ctx_(&ana_), sp_buf_ctx_(&ana_) {} private: // Sparse block context stack; @@ -456,7 +441,7 @@ class IndexTransformer : public StmtExprMutator { * \param sp_buffer The sparse buffer to access. * \param indices The array of indices. */ - PrimExpr ComputeOffset(SparseBuffer sp_buffer, const Array& indices) { + PrimExpr ComputeOffset(SparseBuffer sp_buffer, Array indices) { int num_lowered_indices = static_cast(indices.size()); ICHECK_LE(num_lowered_indices, sp_buffer->ndim()); @@ -477,7 +462,7 @@ class IndexTransformer : public StmtExprMutator { PrimExpr VisitExpr_(const VarNode* v) final { auto it = sp_blk_ctx_.GetSparseIterVar(v); if (it.defined()) { - return sp_blk_ctx_.GetCoordinate(it.value().get()); + return sp_blk_ctx_.GetCoordinate(it.value()); } else { return GetRef(v); } @@ -537,8 +522,7 @@ class IndexTransformer : public StmtExprMutator { } // Step 4. Collet block iters and iter bindings. - std::set in_stack; - in_stack.insert("root"); + std::set in_stack; /* A stack that stores block itervars in each block. */ std::stack> block_iters_st; /* A stack that stores itervar bindings in each block. */ @@ -555,33 +539,40 @@ class IndexTransformer : public StmtExprMutator { /* Itervar bindings of current block. */ Array iter_bindings; /* Axis names of current block. */ - Array axis_names; + Array blk_axes; /* Generated loop vars of current block. */ Array loop_vars; /* An indicator that records whether there is reduction axis in current block. */ bool has_reduction_var = false; for (int i = 0; i < n_iter; ++i) { SpIterVar sp_it_var = sp_block->sp_iter_vars[i]; - String axis_name = sp_it_var->axis->name; - String parent_axis_name; - if (sp_it_var->axis->is_derived_axis) { - // derived axis doesn't appear in the axis tree. - parent_axis_name = "root"; - } else { - auto&& parent_axis = axis_tree_->parent.Get(axis_name); - CHECK(parent_axis.defined()) << "Sparse IterVar not defined in Axis Tree."; - parent_axis_name = parent_axis.value(); - } - bool is_fixed_axis = (sp_it_var->axis->kind() == AxisKind::kDenseFixed || sp_it_var->axis->kind() == AxisKind::kSparseFixed); + Axis axis = sp_it_var->axis; + /* Add itervar to current block when * - it's not used yet (not in stack) and * - it's parent axis was used in outer blocks or * - it's an iterator to a fixed axis. */ - if ((is_fixed_axis || in_stack.find(parent_axis_name) != in_stack.end()) && - in_stack.find(axis_name) == in_stack.end()) { + auto parent = axis->GetParentAxis(); + bool emit_iter_var = true; + if (in_stack.find(axis.get()) != + in_stack.end()) { // the iter var has already been emitted. + emit_iter_var = false; + } else { + if (parent.defined()) { // has parent + if (in_stack.find(parent.value().get()) == in_stack.end()) { // parent not emitted yet + if (axis->kind() == AxisKind::kDenseVariable || + axis->kind() == AxisKind::kSparseVariable) { // is not fixed axis. + emit_iter_var = false; + } + } + } + } + // LOG(INFO) << axis->name << " " << (parent.defined() ? parent.value()->name : "no-parent") + // << " " << emit_iter_var; + if (emit_iter_var) { loop_vars.push_back(all_loop_vars[i]); - axis_names.push_back(std::move(axis_name)); + blk_axes.push_back(axis); block_iters.push_back(SpIterVarToIterVar(sp_it_var, var_map)); iter_bindings.push_back(all_loop_vars[i]); has_reduction_var |= sp_it_var->is_reduction; @@ -589,8 +580,8 @@ class IndexTransformer : public StmtExprMutator { } /* Tag axes in current block as "in-stack". */ - for (const String&& axis_name : axis_names) { - in_stack.insert(std::move(axis_name)); + for (const Axis&& axis : blk_axes) { + in_stack.insert(axis.get()); } /* Update stack. */ @@ -644,7 +635,8 @@ class IndexTransformer : public StmtExprMutator { /*span=*/sp_block->span); BlockRealize block_realize(std::move(iter_bindings), const_true(), std::move(block)); // Generate outer loop and the block binding. - Stmt loop = GenerateLoops(std::move(block_realize), block_iters, loop_vars); + Stmt loop = + GenerateLoops(std::move(block_realize), std::move(block_iters), std::move(loop_vars)); body = loop; blk_counter += 1; } @@ -661,7 +653,7 @@ class IndexTransformer : public StmtExprMutator { * \param var_map The mapping from sparse iterable variable to corresponding ordinary iterable * variable. */ - IterVar SpIterVarToIterVar(const SpIterVar& sp_iter_var, + IterVar SpIterVarToIterVar(SpIterVar sp_iter_var, const std::unordered_map& var_map) { PrimExpr extent{nullptr}; AxisKind kind = sp_iter_var->axis->kind(); @@ -676,11 +668,9 @@ class IndexTransformer : public StmtExprMutator { break; } case AxisKind::kDenseVariable: - // TODO(zihao): need discussion. - break; case AxisKind::kSparseVariable: { PrimExpr l, r; - std::tie(l, r) = sp_blk_ctx_.GetIndicesRange(sp_iter_var.get()); + std::tie(l, r) = sp_blk_ctx_.GetIndicesRange(sp_iter_var); extent = sub(r, l); break; } @@ -724,7 +714,7 @@ class IndexTransformer : public StmtExprMutator { * \param loop_vars The loop variables binded with block iterators. * \return The outermost loop. */ - Stmt GenerateLoops(Stmt body, const Array& block_iters, const Array& loop_vars) { + Stmt GenerateLoops(Stmt body, Array block_iters, Array loop_vars) { int n_iter = static_cast(block_iters.size()); for (int i = n_iter - 1; i >= 0; --i) { const Range& dom = block_iters[i]->dom; @@ -733,7 +723,6 @@ class IndexTransformer : public StmtExprMutator { return body; } - AxisTree axis_tree_; arith::Analyzer ana_; std::unordered_set buffer_read_; std::unordered_set buffer_write_; @@ -752,18 +741,17 @@ Stmt WrapWithRootBlock(Stmt body) { /*! * \brief Rewrite the given primitive function. - * \param axis_tree The axis dependency tree. * \param f The Sparse-TIR primitive function to lower. * \return lowered primitive function in TIR. */ -PrimFunc LowerSparseTIR(AxisTree axis_tree, PrimFunc f) { +PrimFunc LowerSparseTIR(PrimFunc f) { // Only apply this pass to TIR that is not from TE schedules if (!IsFromLegacyTESchedule(f)) { PrimFuncNode* fptr = f.CopyOnWrite(); // Step 1. Update the PrimFunc's buffer map. fptr->buffer_map = UpdateBufferMap(f); // Step 2. Lower indices. - fptr->body = IndexTransformer(axis_tree)(std::move(f->body)); + fptr->body = IndexTransformer()(std::move(f->body)); // Step 3. Wrap the function body with a root block. fptr->body = WrapWithRootBlock(std::move(fptr->body)); return f; @@ -776,11 +764,10 @@ namespace transform { /*! * \brief The lowering pass from TIR to Sparse TIR. - * \param axis_tree The axis dependency tree. */ -Pass LowerSparseTIR(AxisTree axis_tree) { +Pass LowerSparseTIR() { auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) { - return LowerSparseTIR(std::move(axis_tree), std::move(f)); + return LowerSparseTIR(std::move(f)); }; return CreatePrimFuncPass(pass_func, 0, "tir.LowerSparseTIR", {}); } diff --git a/tests/python/sparsetir/test_tir_sparse_buffer.py b/tests/python/sparsetir/test_tir_sparse_buffer.py index 4bc7423e31d6..a3a099ff1f44 100644 --- a/tests/python/sparsetir/test_tir_sparse_buffer.py +++ b/tests/python/sparsetir/test_tir_sparse_buffer.py @@ -17,18 +17,12 @@ import tvm import tvm.tir as tir -def test_format_tree_creation(): +def test_axis_creation(): i = tir.sparse.DenseFixedAxis('i', 128) j = tir.sparse.DenseFixedAxis('j', 128) k = tir.sparse.DenseFixedAxis('k', 128) - tree = tir.sparse.AxisTree({ - 'i': None, - 'j': None, - 'k': None - }) - print(tree) print(i, j, k) if __name__ == "__main__": - test_format_tree_creation() + test_axis_creation() diff --git a/tests/python/sparsetir/test_tir_sparse_lower.py b/tests/python/sparsetir/test_tir_sparse_lower.py index da58f0608a35..1388e9044dfc 100644 --- a/tests/python/sparsetir/test_tir_sparse_lower.py +++ b/tests/python/sparsetir/test_tir_sparse_lower.py @@ -20,7 +20,6 @@ import scipy.sparse as sp import numpy as np from tvm.script import tir as T -from tvm.tir.sparse import AxisTree @T.prim_func @@ -36,11 +35,11 @@ def csrmm( nnz: T.int32, ) -> None: I = T.dense_fixed(n) - J = T.sparse_variable((m, n + 1, nnz), (indptr, indices), "int32") + J = T.sparse_variable(I, (m, nnz), (indptr, indices), "int32") K = T.dense_fixed(k) - A = T.match_sparse_buffer(a, (I, J), nnz, "float32") - B = T.match_sparse_buffer(b, (T.dense(J), K), m * k, "float32") - C = T.match_sparse_buffer(c, (I, K), n * k, "float32") + A = T.match_sparse_buffer(a, (I, J), "float32") + B = T.match_sparse_buffer(b, (T.dense(J), K), "float32") + C = T.match_sparse_buffer(c, (I, K), "float32") with T.iter([I, J, K], "SRS", "csrmm") as [vi, vj, vk]: with T.init(): C[vi, vk] = 0.0 @@ -60,17 +59,35 @@ def csrmm_dense_iter( nnz: T.int32, ) -> None: I = T.dense_fixed(n) - J = T.sparse_variable((m, n + 1, nnz), (indptr, indices), "int32") + J = T.sparse_variable(I, (m, nnz), (indptr, indices), "int32") K = T.dense_fixed(k) - A = T.match_sparse_buffer(a, (I, J), nnz, "float32") - B = T.match_sparse_buffer(b, (T.dense(J), K), m * k, "float32") - C = T.match_sparse_buffer(c, (I, K), n * k, "float32") + A = T.match_sparse_buffer(a, (I, J), "float32") + B = T.match_sparse_buffer(b, (T.dense(J), K), "float32") + C = T.match_sparse_buffer(c, (I, K), "float32") with T.iter([I, T.dense(J), K], "SRS", "csrmm") as [vi, vj, vk]: with T.init(): C[vi, vk] = 0.0 C[vi, vk] = C[vi, vk] + A[vi, vj] * B[vj, vk] +@T.prim_func +def segment_reduce( + a: T.handle, + b: T.handle, + indptr: T.handle, + n: T.int32, + nnz: T.int32, +) -> None: + I = T.dense_fixed(n) + J = T.dense_variable(I, (100, nnz), indptr, "int32") + A = T.match_sparse_buffer(a, (I, J), "float32") + B = T.match_sparse_buffer(b, (I,), "float32") + with T.iter([I, J], "SR", "segment_reduce") as [vi, vj]: + with T.init(): + B[vi] = 0. + B[vi] = B[vi] + A[vi, vj] + + @T.prim_func def lowered_csrmm(a: T.handle, b: T.handle, c: T.handle, indptr: T.handle, indices: T.handle, n: T.int32, m: T.int32, k: T.int32, nnz: T.int32) -> None: A_data = T.match_buffer(a, [nnz], dtype="float32") @@ -109,9 +126,9 @@ def csr_reduce( nnz: T.int32, ) -> None: I = T.dense_fixed(n) - J = T.sparse_variable((m, n + 1, nnz), (indptr, indices), "int32") - A = T.match_sparse_buffer(a, (I, J), nnz, "float32") - B = T.match_sparse_buffer(b, (I,), n, "float32") + J = T.sparse_variable(I, (m, nnz), (indptr, indices), "int32") + A = T.match_sparse_buffer(a, (I, J), "float32") + B = T.match_sparse_buffer(b, (I,), "float32") with T.iter([I, J], "SR", "csr_reduce") as [vi, vj]: with T.init(): B[vi] = 0.0 @@ -155,13 +172,13 @@ def bsrmm( feat_size: T.int32, ) -> None: I = T.dense_fixed(nb) - J = T.sparse_variable((mb, nb + 1, nnzb), (indptr, indices), "int32") + J = T.sparse_variable(I, (mb, nnzb), (indptr, indices), "int32") BI = T.dense_fixed(blk) BJ = T.dense_fixed(blk) F = T.dense_fixed(feat_size) - A = T.match_sparse_buffer(a, (I, J, BI, BJ), nnzb * blk * blk, "float32") - B = T.match_sparse_buffer(b, (T.dense(J), BJ, F), mb * blk * feat_size, "float32") - C = T.match_sparse_buffer(c, (I, BI, F), nb * blk * feat_size, "float32") + A = T.match_sparse_buffer(a, (I, J, BI, BJ), "float32") + B = T.match_sparse_buffer(b, (T.dense(J), BJ, F), "float32") + C = T.match_sparse_buffer(c, (I, BI, F), "float32") with T.iter([I, J, BI, BJ, F], "SRSRS", "bsrmm") as [ vi, @@ -211,18 +228,17 @@ def ellpack_mm( nb: T.int32, mb: T.int32, feat_size: T.int32, - nnz: T.int32, col: T.int32, blk: T.int32, ) -> None: I = T.dense_fixed(nb) - J = T.sparse_fixed((mb, nnz, col), indices, "int32") + J = T.sparse_fixed(I, (mb, col), indices, "int32") F = T.dense_fixed(feat_size) BI = T.dense_fixed(blk) BJ = T.dense_fixed(blk) - A = T.match_sparse_buffer(a, (I, J, BI, BJ), nnz * blk * blk, "float32") - B = T.match_sparse_buffer(b, (T.dense(J), BJ, F), mb * blk * feat_size, "float32") - C = T.match_sparse_buffer(c, (I, BI, F), nb * blk * feat_size, "float32") + A = T.match_sparse_buffer(a, (I, J, BI, BJ), "float32") + B = T.match_sparse_buffer(b, (T.dense(J), BJ, F), "float32") + C = T.match_sparse_buffer(c, (I, BI, F), "float32") with T.iter([I, J, BI, BJ, F], "SRSRS", "ellmm") as [ vi, @@ -237,22 +253,22 @@ def ellpack_mm( @T.prim_func -def lowered_ellpack_mm(a: T.handle, b: T.handle, c: T.handle, indices: T.handle, nb: T.int32, mb: T.int32, feat_size: T.int32, nnz: T.int32, col: T.int32, blk: T.int32) -> None: - A_data = T.match_buffer(a, [nnz * blk * blk], dtype="float32") +def lowered_ellpack_mm(a: T.handle, b: T.handle, c: T.handle, indices: T.handle, nb: T.int32, mb: T.int32, feat_size: T.int32, col: T.int32, blk: T.int32) -> None: + A_data = T.match_buffer(a, [nb * col * blk * blk], dtype="float32") B_data = T.match_buffer(b, [mb * blk * feat_size], dtype="float32") C_data = T.match_buffer(c, [nb * blk * feat_size], dtype="float32") - J_indices = T.match_buffer(indices, [nnz], dtype="int32") + J_indices = T.match_buffer(indices, [nb * col], dtype="int32") + # body + # with T.block("root") for v_vi, v_vj, v_vbi, v_vbj, v_vf in T.grid(nb, col, blk, blk, feat_size): with T.block("ellmm"): vi, vj, vbi, vbj, vf = T.axis.remap("SRSRS", [v_vi, v_vj, v_vbi, v_vbj, v_vf]) - T.reads([J_indices[0: nnz], A_data[0: nnz * blk * blk], - B_data[0: mb * blk * feat_size], C_data[0: nb * blk * feat_size]]) - T.writes([C_data[0: nb * blk * feat_size]]) - T.block_attr({"sparse": True}) + T.reads([J_indices[0 : nb * col], A_data[0 : nb * col * blk * blk], B_data[0 : mb * blk * feat_size], C_data[0 : nb * blk * feat_size]]) + T.writes([C_data[0 : nb * blk * feat_size]]) + T.block_attr({"sparse":True}) with T.init(): C_data[(vi * blk + vbi) * feat_size + vf] = T.float32(0) - C_data[(vi * blk + vbi) * feat_size + vf] = C_data[(vi * blk + vbi) * feat_size + vf] + A_data[((vi * - col + vj) * blk + vbi) * blk + vbj] * B_data[(J_indices[vi * col + vj] * blk + vbj) * feat_size + vf] + C_data[(vi * blk + vbi) * feat_size + vf] = C_data[(vi * blk + vbi) * feat_size + vf] + A_data[((vi * col + vj) * blk + vbi) * blk + vbj] * B_data[(J_indices[vi * col + vj] * blk + vbj) * feat_size + vf] @T.prim_func @@ -266,9 +282,9 @@ def csr_element_wise( nnz: T.int32, ) -> None: I = T.dense_fixed(m) - J = T.sparse_variable((n, m + 1, nnz), (indptr, indices), "int32") - A = T.match_sparse_buffer(a, (I, J), nnz, "float32") - B = T.match_sparse_buffer(b, (I, J), nnz, "float32") + J = T.sparse_variable(I, (n, nnz), (indptr, indices), "int32") + A = T.match_sparse_buffer(a, (I, J), "float32") + B = T.match_sparse_buffer(b, (I, J), "float32") with T.iter([I, J], "SS", "csr_element_wise") as [vi, vj]: B[vi, vj] = A[vi, vj] * 2.5 @@ -297,12 +313,8 @@ def lowered_csr_element_wise(a: T.handle, b: T.handle, indptr: T.handle, indices def test_csrmm(): mod = tvm.IRModule.from_expr(csrmm) - t = AxisTree({ - "J": "I", - "I": None, - "K": None - }) - mod = tvm.tir.transform.LowerSparseTIR(t)(mod) + mod = tvm.tir.transform.LowerSparseTIR()(mod) + print(mod["main"].script()) tvm.ir.assert_structural_equal(mod["main"], lowered_csrmm, True) A = sp.random(512, 512, dtype="float32", density=0.0125, format="csr") @@ -325,23 +337,20 @@ def test_csrmm(): def test_csrmm_dense_iter(): mod = tvm.IRModule.from_expr(csrmm_dense_iter) - t = AxisTree({ - "J": "I", - "I": None, - "K": None - }) - mod = tvm.tir.transform.LowerSparseTIR(t)(mod) + mod = tvm.tir.transform.LowerSparseTIR()(mod) print(mod["main"].script()) # tvm.ir.assert_structural_equal(mod["main"], lowered_csrmm, True) +def test_segment_reduce(): + mod = tvm.IRModule.from_expr(segment_reduce) + mod = tvm.tir.transform.LowerSparseTIR()(mod) + print(mod["main"].script()) + + def test_csr_reduce(): mod = tvm.IRModule.from_expr(csr_reduce) - t = AxisTree({ - "J": "I", - "I": None - }) - mod = tvm.tir.transform.LowerSparseTIR(t)(mod) + mod = tvm.tir.transform.LowerSparseTIR()(mod) tvm.ir.assert_structural_equal(mod["main"], lowered_csr_reduce, True) A = sp.random(128, 128, dtype="float32", density=0.0125, format="csr") @@ -362,14 +371,7 @@ def test_csr_reduce(): def test_bsrmm(): mod = tvm.IRModule.from_expr(bsrmm) - t = AxisTree({ - "J": "I", - "I": None, - "BJ": None, - "BI": None, - "F": None - }) - mod = tvm.tir.transform.LowerSparseTIR(t)(mod) + mod = tvm.tir.transform.LowerSparseTIR()(mod) tvm.ir.assert_structural_equal(mod["main"], lowered_bsrmm, True) block_size = 16 @@ -409,14 +411,8 @@ def test_bsrmm(): def test_ellpack_mm(): mod = tvm.IRModule.from_expr(ellpack_mm) - t = AxisTree({ - "J": "I", - "I": None, - "F": None, - "BI": None, - "BJ": None - }) - mod = tvm.tir.transform.LowerSparseTIR(t)(mod) + mod = tvm.tir.transform.LowerSparseTIR()(mod) + print(mod["main"].script()) tvm.ir.assert_structural_equal(mod["main"], lowered_ellpack_mm, True) nnz_cols = 4 @@ -439,14 +435,13 @@ def test_ellpack_mm(): y_ground_truth = A * x y = np.zeros((n * feat_size,)).astype("float32") - v_nb, v_mb, v_feat_size, v_nnz, v_col, v_blk = ellpack_mm.params[-6:] + v_nb, v_mb, v_feat_size, v_col, v_blk = ellpack_mm.params[-5:] f = tvm.build( mod["main"].specialize( { v_nb: nb, v_mb: mb, v_feat_size: feat_size, - v_nnz: nnz, v_col: nnz_cols, v_blk: block_size, } @@ -465,11 +460,7 @@ def test_ellpack_mm(): def test_csr_element_wise(): mod = tvm.IRModule.from_expr(csr_element_wise) - t = AxisTree({ - "J": "I", - "I": None - }) - mod = tvm.tir.transform.LowerSparseTIR(t)(mod) + mod = tvm.tir.transform.LowerSparseTIR()(mod) tvm.ir.assert_structural_equal(mod["main"], lowered_csr_element_wise, True) A = sp.random(128, 128, dtype="float32", density=0.0125, format="csr") @@ -491,6 +482,7 @@ def test_csr_element_wise(): if __name__ == "__main__": test_csrmm() test_csrmm_dense_iter() + test_segment_reduce() test_csr_reduce() test_bsrmm() test_ellpack_mm() diff --git a/tests/python/sparsetir/test_tir_sparse_nnz_inference.py b/tests/python/sparsetir/test_tir_sparse_nnz_inference.py index 6ec3e819f934..f521503775ca 100644 --- a/tests/python/sparsetir/test_tir_sparse_nnz_inference.py +++ b/tests/python/sparsetir/test_tir_sparse_nnz_inference.py @@ -19,7 +19,6 @@ import scipy.sparse as sp import numpy as np from tvm.script import tir as T -from tvm.tir.sparse import AxisTree @T.prim_func def csr2bsr_cnt_nnz( @@ -27,9 +26,9 @@ def csr2bsr_cnt_nnz( new_cord: T.handle, glb_counter: T.handle, n: T.int32, m: T.int32, nnz: T.int32) -> None: I = T.dense_fixed(n) - J = T.sparse_variable((m, n + 1, nnz), (indptr, indices), "int32") + J = T.sparse_variable(I, (m, nnz), (indptr, indices), "int32") K = T.dense_fixed(2) - New_cord = T.match_sparse_buffer(new_cord, (I, J, K), nnz * 2, "int32") + New_cord = T.match_sparse_buffer(new_cord, (I, J, K), "int32") with T.iter([I, J], "SS", "csr2bsr_cnt_nnz") as [vi, vj]: New_cord[vi, vj, 0] = 0 New_cord[vi, vj, 1] = 1 @@ -42,13 +41,13 @@ def csr2bsr(indptr_1: T.handle, indices_1: T.handle, indptr_2: T.handle, indices n: T.int32, m: T.int32, nnz: T.int32, nb: T.int32, mb: T.int32, nnzb: T.int32) -> None: I = T.dense_fixed(n) - J = T.sparse_variable((m, n + 1, nnz), (indptr_1, indices_1), "int32") + J = T.sparse_variable(I, (m, nnz), (indptr_1, indices_1), "int32") Ibo = T.dense_fixed(nb) - Jbo = T.sparse_variable((mb, nb + 1, nnzb), (indptr_2, indices_2), "int32") + Jbo = T.sparse_variable(Ibo, (mb, nnzb), (indptr_2, indices_2), "int32") Ibi = T.dense_fixed(block_size) Jbi = T.dense_fixed(block_size) - A_csr = T.match_sparse_buffer(a_csr, (I, J), nnz, "float32") - A_bsr = T.match_sparse_buffer(a_bsr, (Ibo, Jbo, Ibi, Jbi), nnzb * block_size * block_size, "float32") + A_csr = T.match_sparse_buffer(a_csr, (I, J), "float32") + A_bsr = T.match_sparse_buffer(a_bsr, (Ibo, Jbo, Ibi, Jbi), "float32") with T.iter([I, J], "SS", "csr2bsrm") as [vi, vj]: A_bsr[T.floordiv(vi, block_size), T.floordiv(vj, block_size), T.floormod(vi, block_size), T.floormod(vj, block_size)] =\ A_csr[vi, vj] @@ -56,27 +55,13 @@ def csr2bsr(indptr_1: T.handle, indices_1: T.handle, indptr_2: T.handle, indices def test_cnt_nnz(): mod = tvm.IRModule.from_expr(csr2bsr_cnt_nnz) - t = AxisTree({ - "J": "I", - "I": None, - "K": None - }) - mod = tvm.tir.transform.LowerSparseTIR(t)(mod) + mod = tvm.tir.transform.LowerSparseTIR()(mod) print(mod['main'].script()) def test_csr2bsr(): mod = tvm.IRModule.from_expr(csr2bsr) - t = AxisTree({ - "J": "I", - "I": None, - "K": None, - "Ibo": None, - "Jbo": "Ibo", - "Ibi": None, - "Ibo": None, - }) - mod = tvm.tir.transform.LowerSparseTIR(t)(mod) + mod = tvm.tir.transform.LowerSparseTIR()(mod) print(mod['main'].script()) diff --git a/tests/python/sparsetir/test_tir_sparse_script_roundtrip.py b/tests/python/sparsetir/test_tir_sparse_script_roundtrip.py index 5ea544470526..aedc2c22ec3c 100644 --- a/tests/python/sparsetir/test_tir_sparse_script_roundtrip.py +++ b/tests/python/sparsetir/test_tir_sparse_script_roundtrip.py @@ -27,11 +27,11 @@ def csrmm(a: T.handle, b: T.handle, c: T.handle, indptr: T.handle, indices: T.ha k = T.var("int32") nnz = T.var("int32") I = T.dense_fixed(n) - J = T.sparse_variable((m, n + 1, nnz), (indptr, indices), "int32") + J = T.sparse_variable(I, (m, nnz), (indptr, indices), "int32") K = T.dense_fixed(k) - A = T.match_sparse_buffer(a, (I, J), nnz, "float32") - B = T.match_sparse_buffer(b, (T.dense(J), K), m * k, "float32") - C = T.match_sparse_buffer(c, (I, K), n * k, "float32") + A = T.match_sparse_buffer(a, (I, J), "float32") + B = T.match_sparse_buffer(b, (T.dense(J), K), "float32") + C = T.match_sparse_buffer(c, (I, K), "float32") with T.iter([I, J, K], "SRS", "csrmm") as [vi, vj, vk]: with T.init(): C[vi, vk] = 0.0 @@ -45,11 +45,11 @@ def csrmm_dense_iter(a: T.handle, b: T.handle, c: T.handle, indptr: T.handle, in k = T.var("int32") nnz = T.var("int32") I = T.dense_fixed(n) - J = T.sparse_variable((m, n + 1, nnz), (indptr, indices), "int32") + J = T.sparse_variable(I, (m, nnz), (indptr, indices), "int32") K = T.dense_fixed(k) - A = T.match_sparse_buffer(a, (I, J), nnz, "float32") - B = T.match_sparse_buffer(b, (T.dense(J), K), m * k, "float32") - C = T.match_sparse_buffer(c, (I, K), n * k, "float32") + A = T.match_sparse_buffer(a, (I, J), "float32") + B = T.match_sparse_buffer(b, (T.dense(J), K), "float32") + C = T.match_sparse_buffer(c, (I, K), "float32") with T.iter([I, T.dense(J), K], "SRS", "csrmm") as [vi, vj, vk]: with T.init(): C[vi, vk] = 0.0 @@ -62,9 +62,9 @@ def csr_reduce(a: T.handle, b: T.handle, indptr: T.handle, indices: T.handle) -> m = T.var("int32") nnz = T.var("int32") I = T.dense_fixed(n) - J = T.sparse_variable((m, n + 1, nnz), (indptr, indices), "int32") - A = T.match_sparse_buffer(a, (I, J), nnz, "float32") - B = T.match_sparse_buffer(b, (I,), n, "float32") + J = T.sparse_variable(I, (m, nnz), (indptr, indices), "int32") + A = T.match_sparse_buffer(a, (I, J), "float32") + B = T.match_sparse_buffer(b, (I,), "float32") with T.iter([I, J], "SR", "csr_reduce") as [vi, vj]: with T.init(): B[vi] = 0.0 @@ -79,13 +79,13 @@ def bsrmm(a: T.handle, b: T.handle, c: T.handle, indptr: T.handle, indices: T.ha blk = T.var("int32") feat_size = T.var("int32") I = T.dense_fixed(nb) - J = T.sparse_variable((mb, nb + 1, nnzb), (indptr, indices), "int32") + J = T.sparse_variable(I, (mb, nnzb), (indptr, indices), "int32") BI = T.dense_fixed(blk) BJ = T.dense_fixed(blk) F = T.dense_fixed(feat_size) - A = T.match_sparse_buffer(a, (I, J, BI, BJ), nnzb * blk * blk, "float32") - B = T.match_sparse_buffer(b, (T.dense(J), BJ, F), mb * blk * feat_size, "float32") - C = T.match_sparse_buffer(c, (I, BI, F), nb * blk * feat_size, "float32") + A = T.match_sparse_buffer(a, (I, J, BI, BJ), "float32") + B = T.match_sparse_buffer(b, (T.dense(J), BJ, F), "float32") + C = T.match_sparse_buffer(c, (I, BI, F), "float32") with T.iter([I, J, BI, BJ, F], "SRSSS", "bsrmm") as [ vi, @@ -104,17 +104,16 @@ def ellpack_mm(a: T.handle, b: T.handle, c: T.handle, indices: T.handle) -> None nb = T.var("int32") mb = T.var("int32") feat_size = T.var("int32") - nnz = T.var("int32") col = T.var("int32") blk = T.var("int32") I = T.dense_fixed(nb) - J = T.sparse_fixed((mb, nnz, col), indices, "int32") + J = T.sparse_fixed(I, (mb, col), indices, "int32") F = T.dense_fixed(feat_size) BI = T.dense_fixed(blk) BJ = T.dense_fixed(blk) - A = T.match_sparse_buffer(a, (I, J, BI, BJ), nnz * blk * blk, "float32") - B = T.match_sparse_buffer(b, (T.dense(J), BJ, F), mb * blk * feat_size, "float32") - C = T.match_sparse_buffer(c, (I, BI, F), nb * blk * feat_size, "float32") + A = T.match_sparse_buffer(a, (I, J, BI, BJ), "float32") + B = T.match_sparse_buffer(b, (T.dense(J), BJ, F), "float32") + C = T.match_sparse_buffer(c, (I, BI, F), "float32") with T.iter([I, J, BI, BJ, F], "SRSSS", "bsrmm") as [ vi, @@ -134,9 +133,9 @@ def csr_element_wise(a: T.handle, b: T.handle, indptr: T.handle, indices: T.hand n = T.var("int32") nnz = T.var("int32") I = T.dense_fixed(m) - J = T.sparse_variable((n, m + 1, nnz), (indptr, indices), "int32") - A = T.match_sparse_buffer(a, (I, J), nnz, "float32") - B = T.match_sparse_buffer(b, (I, J), nnz, "float32") + J = T.sparse_variable(I, (n, nnz), (indptr, indices), "int32") + A = T.match_sparse_buffer(a, (I, J), "float32") + B = T.match_sparse_buffer(b, (I, J), "float32") with T.iter([I, J], "SS", "csr_element_wise") as [vi, vj]: B[vi, vj] = A[vi, vj] * 2.0 + 1.0