Skip to content

Commit

Permalink
Syntax simplification (apache#34)
Browse files Browse the repository at this point in the history
  • Loading branch information
yzh119 authored and MasterJH5574 committed Dec 22, 2021
1 parent 9a13573 commit 437d88c
Show file tree
Hide file tree
Showing 13 changed files with 279 additions and 373 deletions.
64 changes: 20 additions & 44 deletions include/tvm/tir/sparse.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ enum class AxisKind : int {
kSparseVariable = 3
};

class Axis;

/*!
* \brief Base type for axis in sparse formats.
*/
Expand Down Expand Up @@ -73,6 +75,7 @@ class AxisNode : public Object {
String GetName() const { return name; }
PrimExpr GetLength() const { return length; }
DataType GetIndexType() const { return length->dtype; }
virtual Optional<Axis> GetParentAxis() const = 0;

virtual AxisKind kind() const = 0;
virtual PrimExpr nnz() const = 0;
Expand Down Expand Up @@ -137,6 +140,8 @@ class DenseFixedAxisNode : public DenseAxisNode {

PrimExpr nnz() const final { return length; }

Optional<Axis> GetParentAxis() const final { return NullOpt; }

static constexpr const char* _type_key = "tir.sparse.DenseFixedAxis";
TVM_DECLARE_BASE_OBJECT_INFO(DenseFixedAxisNode, DenseAxisNode);
};
Expand Down Expand Up @@ -238,6 +243,7 @@ class DenseVariableAxisNode : public DenseAxisNode {
public:
Buffer indptr;
PrimExpr nnz_;
Axis parent_;

void VisitAttrs(AttrVisitor* v) {
DenseAxisNode::VisitAttrs(v);
Expand All @@ -257,6 +263,8 @@ class DenseVariableAxisNode : public DenseAxisNode {

PrimExpr nnz() const final { return nnz_; }

Optional<Axis> GetParentAxis() const final { return parent_; }

static constexpr const char* _type_key = "tir.sparse.DenseVariableAxis";
TVM_DECLARE_FINAL_OBJECT_INFO(DenseVariableAxisNode, DenseAxisNode);
};
Expand All @@ -267,7 +275,8 @@ class DenseVariableAxisNode : public DenseAxisNode {
*/
class DenseVariableAxis : public DenseAxis {
public:
TVM_DLL explicit DenseVariableAxis(String name, PrimExpr length, PrimExpr nnz, Buffer indptr);
TVM_DLL explicit DenseVariableAxis(String name, Axis parent, PrimExpr length, PrimExpr nnz,
Buffer indptr);

TVM_DEFINE_OBJECT_REF_METHODS(DenseVariableAxis, DenseAxis, DenseVariableAxisNode);
};
Expand All @@ -280,6 +289,7 @@ class SparseFixedAxisNode : public SparseAxisNode {
Buffer indices;
/* fixed number of non-zero columns of current sparse axis. */
PrimExpr nnz_cols;
Axis parent_;

void VisitAttrs(AttrVisitor* v) {
SparseAxisNode::VisitAttrs(v);
Expand All @@ -302,6 +312,8 @@ class SparseFixedAxisNode : public SparseAxisNode {

AxisKind kind() const final { return AxisKind::kSparseFixed; }

Optional<Axis> GetParentAxis() const final { return parent_; }

static constexpr const char* _type_key = "tir.sparse.SparseFixedAxis";
TVM_DECLARE_FINAL_OBJECT_INFO(SparseFixedAxisNode, SparseAxisNode);
};
Expand All @@ -312,7 +324,8 @@ class SparseFixedAxisNode : public SparseAxisNode {
*/
class SparseFixedAxis : public SparseAxis {
public:
TVM_DLL explicit SparseFixedAxis(String name, PrimExpr length, Buffer indices, PrimExpr nnz_cols);
TVM_DLL explicit SparseFixedAxis(String name, Axis parent, PrimExpr length, Buffer indices,
PrimExpr nnz_cols);

TVM_DEFINE_OBJECT_REF_METHODS(SparseFixedAxis, SparseAxis, SparseFixedAxisNode);
};
Expand All @@ -324,6 +337,7 @@ class SparseVariableAxisNode : public SparseAxisNode {
public:
Buffer indptr;
Buffer indices;
Axis parent_;

void VisitAttrs(AttrVisitor* v) {
SparseAxisNode::VisitAttrs(v);
Expand All @@ -346,6 +360,8 @@ class SparseVariableAxisNode : public SparseAxisNode {

AxisKind kind() const final { return AxisKind::kSparseVariable; }

Optional<Axis> GetParentAxis() const final { return parent_; }

static constexpr const char* _type_key = "tir.sparse.SparseVariableAxis";
TVM_DECLARE_FINAL_OBJECT_INFO(SparseVariableAxisNode, SparseAxisNode);
};
Expand All @@ -356,52 +372,12 @@ class SparseVariableAxisNode : public SparseAxisNode {
*/
class SparseVariableAxis : public SparseAxis {
public:
TVM_DLL explicit SparseVariableAxis(String name, PrimExpr length, Buffer indptr, Buffer indices);
TVM_DLL explicit SparseVariableAxis(String name, Axis parent, PrimExpr length, Buffer indptr,
Buffer indices);

TVM_DEFINE_OBJECT_REF_METHODS(SparseVariableAxis, SparseAxis, SparseVariableAxisNode);
};

/*!
* \brief Axis Dependency Tree.
*/
class AxisTreeNode : public Object {
public:
// unordered map that stores the parent relationship between axes.
Map<String, String> parent;
// unordered map that stores the children relationship between axes.
Map<String, Array<String>> children;

void VisitAttrs(AttrVisitor* v) {
v->Visit("parent", &parent);
v->Visit("children", &children);
}

bool SEqualReduce(const AxisTreeNode* other, SEqualReducer equal) const {
return equal(parent, other->parent) && equal(children, other->children);
}

void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce(parent);
hash_reduce(children);
}

static constexpr const char* _type_key = "tir.sparse.AxisTree";
static constexpr const bool _type_has_method_sequal_reduce = true;
static constexpr const bool _type_has_method_shash_reduce = true;
TVM_DECLARE_FINAL_OBJECT_INFO(AxisTreeNode, Object);
};

/*!
* \brief Managed reference to AxisRefNode.
* \sa AxisTreeNode
*/
class AxisTree : public ObjectRef {
public:
TVM_DLL AxisTree(Array<String> axis_names, Array<Optional<String>> axis_parent_names);

TVM_DEFINE_OBJECT_REF_METHODS(AxisTree, ObjectRef, AxisTreeNode);
};

/*!
* \brief Class of sparse buffer.
*/
Expand Down
2 changes: 1 addition & 1 deletion include/tvm/tir/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -489,7 +489,7 @@ TVM_DLL Pass ConvertForLoopsToSerial();
* \param axis_tree The axis dependency tree.
* \return The pass.
*/
TVM_DLL Pass LowerSparseTIR(AxisTree axis_tree);
TVM_DLL Pass LowerSparseTIR();

} // namespace transform
} // namespace tir
Expand Down
35 changes: 25 additions & 10 deletions python/tvm/script/tir/special_stmt.py
Original file line number Diff line number Diff line change
Expand Up @@ -920,6 +920,7 @@ class DenseVariable(SpecialStmt):

def __init__(self):
def dense_variable(
parent_axis: Axis,
shape: Tuple[PrimExpr, PrimExpr],
indptr_var: tvm.tir.Var,
idtype: str = "int32",
Expand All @@ -931,11 +932,12 @@ def dense_variable(
f"`dense_variable` expected assign to only one var, but got {names}", span
)

length, indptr_len, nnz = shape
length, nnz = shape
indptr_len = parent_axis.length + 1
indptr_buf = tvm.tir.decl_buffer(
(indptr_len,), dtype=idtype, name=names[0] + "_indptr", span=span
)
axis = DenseVariableAxis(names[0], length, nnz, indptr_buf)
axis = DenseVariableAxis(names[0], parent_axis, length, nnz, indptr_buf)
self.context.sp_struct.append(axis)
self.context.sp_struct_params.append([indptr_var])
self.context.update_symbol(names[0], axis, self.node)
Expand All @@ -950,7 +952,8 @@ class SparseFixed(SpecialStmt):

def __init__(self):
def sparse_fixed(
shape: Tuple[PrimExpr, PrimExpr, PrimExpr],
parent_axis: Axis,
shape: Tuple[PrimExpr, PrimExpr],
indices_var: tvm.tir.Var,
idtype: str = "int32",
span: Optional[Span] = None,
Expand All @@ -961,11 +964,12 @@ def sparse_fixed(
f"`sparse_fixed` expected assign to only one var, but got {names}", span
)

length, nnz, nnz_cols = shape
length, nnz_cols = shape
nnz = parent_axis.nnz * nnz_cols
indices_buf = tvm.tir.decl_buffer(
(nnz,), dtype=idtype, name=names[0] + "_indices", span=span
)
axis = SparseFixedAxis(names[0], length, indices_buf, nnz_cols)
axis = SparseFixedAxis(names[0], parent_axis, length, indices_buf, nnz_cols)
self.context.sp_struct.append(axis)
self.context.sp_struct_params.append([indices_var])
self.context.update_symbol(names[0], axis, self.node)
Expand All @@ -980,7 +984,8 @@ class SparseVariable(SpecialStmt):

def __init__(self):
def sparse_variable(
shape: Tuple[PrimExpr, PrimExpr, PrimExpr],
parent_axis: Axis,
shape: Tuple[PrimExpr, PrimExpr],
data: Tuple[tvm.tir.Var, tvm.tir.Var],
idtype: str = "int32",
span: Optional[Span] = None,
Expand All @@ -991,15 +996,16 @@ def sparse_variable(
f"`sparse_variable` expected assign to only one var, but got {names}", span
)

length, indptr_len, nnz = shape
length, nnz = shape
indptr_len = parent_axis.nnz + 1
indptr_var, indices_var = data
indptr_buf = tvm.tir.decl_buffer(
(indptr_len,), dtype=idtype, name=names[0] + "_indptr", span=span
)
indices_buf = tvm.tir.decl_buffer(
(nnz,), dtype=idtype, name=names[0] + "_indices", span=span
)
axis = SparseVariableAxis(names[0], length, indptr_buf, indices_buf)
axis = SparseVariableAxis(names[0], parent_axis, length, indptr_buf, indices_buf)
self.context.sp_struct.append(axis)
self.context.sp_struct_params.append([indptr_var, indices_var])
self.context.update_symbol(names[0], axis, self.node)
Expand All @@ -1017,10 +1023,19 @@ def __init__(self):
def match_sparse_buffer(
param: tvm.tir.Var,
axes: List[Axis],
nnz: PrimExpr,
dtype: str = "float32",
span: Optional[Span] = None,
):
def infer_nnz(axes: List[Axis]) -> PrimExpr:
"""Inference the number of non-zero elements in a sparse buffer."""
ret = axes[0].nnz
for axis in axes[1:]:
if isinstance(axis, DenseFixedAxis):
ret = ret * axis.nnz
else:
ret = axis.nnz
return ret

if not isinstance(self.node, ast.Assign) or not len(self.node.lhs) == 1:
self.context.report_error(
"`match_sparse_buffer` must be assigned to a single sparse buffer, "
Expand All @@ -1035,7 +1050,7 @@ def match_sparse_buffer(
)

if param in self.context.func_params:
data = tvm.tir.decl_buffer(nnz, dtype, buffer_name + "_data", span=span)
data = tvm.tir.decl_buffer(infer_nnz(axes), dtype, buffer_name + "_data", span=span)
buffer = tvm.tir.sparse.SparseBuffer(axes, data, buffer_name)
self.context.sp_struct.append(buffer)
self.context.sp_struct_params.append([param])
Expand Down
48 changes: 22 additions & 26 deletions python/tvm/tir/sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,10 @@ def length(self):
@property
def idtype(self):
return _ffi_api.GetAxisIndexType(self)

@property
def nnz(self):
return _ffi_api.GetNNZ(self)


@tvm._ffi.register_object("tir.sparse.DenseAxis")
Expand Down Expand Up @@ -117,6 +121,9 @@ class DenseVariableAxis(DenseAxis):
----------
name : str
The name of the axis
parent : Axis
The parent axis
length : PrimExpr
The length of the axis
Expand All @@ -126,13 +133,14 @@ class DenseVariableAxis(DenseAxis):
"""

name: str
parent: Axis
length: PrimExpr
nnz: PrimExpr
indptr: Buffer

def __init__(self, name, length, nnz, indptr):
def __init__(self, name, parent, length, nnz, indptr):
self.__init_handle_by_constructor__(
_ffi_api.DenseVariableAxis, name, length, nnz, indptr # type: ignore
_ffi_api.DenseVariableAxis, name, parent, length, nnz, indptr # type: ignore
)


Expand All @@ -145,6 +153,9 @@ class SparseFixedAxis(DenseAxis):
name : str
The name of the axis
parent : Axis
The parent axis
length : PrimExpr
The length of the axis
Expand All @@ -156,13 +167,14 @@ class SparseFixedAxis(DenseAxis):
"""

name: str
parent: Axis
length: PrimExpr
indices: Buffer
nnz_cols: PrimExpr

def __init__(self, name, length, indices, nnz_cols):
def __init__(self, name, parent, length, indices, nnz_cols):
self.__init_handle_by_constructor__(
_ffi_api.SparseFixedAxis, name, length, indices, nnz_cols # type: ignore
_ffi_api.SparseFixedAxis, name, parent, length, indices, nnz_cols # type: ignore
)


Expand All @@ -174,6 +186,9 @@ class SparseVariableAxis(DenseAxis):
----------
name : str
The name of the axis
parent : Axis
The parent axis
length : PrimExpr
The length of the axis
Expand All @@ -186,33 +201,14 @@ class SparseVariableAxis(DenseAxis):
"""

name: str
parent: Axis
length: PrimExpr
indptr: Buffer
indices: Buffer

def __init__(self, name, length, indptr, indices):
self.__init_handle_by_constructor__(
_ffi_api.SparseVariableAxis, name, length, indptr, indices # type: ignore
)


@tvm._ffi.register_object("tir.sparse.AxisTree")
class AxisTree(Object):
"""AxisTree node
Parameters
----------
axis_parent_map: Dict
A dictionary that maps axis name to parent axis name, value is None if there is not parent axis.
"""

axis_parent_map: Dict[str, Optional[str]]

def __init__(self, axis_parent_map) -> None:
keys = list(axis_parent_map.keys())
values = list(axis_parent_map.values())
def __init__(self, name, parent, length, indptr, indices):
self.__init_handle_by_constructor__(
_ffi_api.AxisTree, keys, values # type:ignore
_ffi_api.SparseVariableAxis, name, parent, length, indptr, indices # type: ignore
)


Expand Down
Loading

0 comments on commit 437d88c

Please sign in to comment.