Skip to content

Commit

Permalink
fix nnz cols
Browse files Browse the repository at this point in the history
  • Loading branch information
yzh119 committed Nov 16, 2021
1 parent d578964 commit 332df3e
Show file tree
Hide file tree
Showing 6 changed files with 19 additions and 19 deletions.
12 changes: 6 additions & 6 deletions include/tvm/tir/sparse.h
Original file line number Diff line number Diff line change
Expand Up @@ -198,26 +198,26 @@ class DenseVariableAxis : public DenseAxis {
class SparseFixedAxisNode : public SparseAxisNode {
public:
Buffer indices;
/* fixed number of columns of current sparse axis. */
PrimExpr num_cols;
/* fixed number of non-zero columns of current sparse axis. */
PrimExpr nnz_cols;

void VisitAttrs(AttrVisitor* v) {
v->Visit("name", &name);
v->Visit("length", &length);
v->Visit("indptr", &indices);
v->Visit("num_cols", &num_cols);
v->Visit("nnz_cols", &nnz_cols);
}

bool SEqualReduce(const SparseFixedAxisNode* other, SEqualReducer equal) const {
return equal(name, other->name) && equal(length, other->length) &&
equal(indices, other->indices) && equal(num_cols, other->num_cols);
equal(indices, other->indices) && equal(nnz_cols, other->nnz_cols);
}

void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce(name);
hash_reduce(length);
hash_reduce(indices);
hash_reduce(num_cols);
hash_reduce(nnz_cols);
}

static constexpr const char* _type_key = "tir.sparse.SparseFixedAxis";
Expand All @@ -230,7 +230,7 @@ class SparseFixedAxisNode : public SparseAxisNode {
*/
class SparseFixedAxis : public SparseAxis {
public:
TVM_DLL explicit SparseFixedAxis(String name, PrimExpr length, Buffer indices, PrimExpr num_cols);
TVM_DLL explicit SparseFixedAxis(String name, PrimExpr length, Buffer indices, PrimExpr nnz_cols);

TVM_DEFINE_OBJECT_REF_METHODS(SparseFixedAxis, SparseAxis, SparseFixedAxisNode);
};
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/script/tir/intrin.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,6 @@ def pos(axis: Axis, span: Optional[Span] = None):
elif isinstance(axis, DenseVariableAxis):
return SpIterVar(var_temp, axis.length, SpIterVar.DenseVariable, False, axis)
elif isinstance(axis, SparseFixedAxis):
return SpIterVar(var_temp, axis.num_cols, SpIterVar.SparseFixed, False, axis)
return SpIterVar(var_temp, axis.nnz_cols, SpIterVar.SparseFixed, False, axis)
else:
return SpIterVar(var_temp, axis.length, SpIterVar.SparseVariable, False, axis)
10 changes: 5 additions & 5 deletions python/tvm/tir/sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,18 +118,18 @@ class SparseFixedAxis(DenseAxis):
indices : Buffer
The indices buffer of the axis
num_cols : PrimExpr
The number of non-zero elements along the axis
nnz_cols : PrimExpr
The fixed number of non-zero elements along the axis
"""

name: str
length: PrimExpr
indices: Buffer
num_cols: PrimExpr
nnz_cols: PrimExpr

def __init__(self, name, length, indices, num_cols):
def __init__(self, name, length, indices, nnz_cols):
self.__init_handle_by_constructor__(
_ffi_api.SparseFixedAxis, name, length, indices, num_cols # type: ignore
_ffi_api.SparseFixedAxis, name, length, indices, nnz_cols # type: ignore
)


Expand Down
2 changes: 1 addition & 1 deletion src/printer/tvmscript_printer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1314,7 +1314,7 @@ Doc TVMScriptPrinter::PrintSparseStructDefinitions(const SparseBlockNode* sp_blo
} else if (const auto* sf_axis = obj.as<SparseFixedAxisNode>()) {
ICHECK_EQ(params.size(), 1);
doc << "sparse_fixed((" << Print(sf_axis->length) << ", " << Print(sf_axis->indices->shape[0])
<< ", " << Print(sf_axis->num_cols) << "), " << Print(params[0]) << ", "
<< ", " << Print(sf_axis->nnz_cols) << "), " << Print(params[0]) << ", "
<< PrintDType(sf_axis->indices->dtype) << ")";
} else if (const auto* sv_axis = obj.as<SparseVariableAxisNode>()) {
ICHECK_EQ(params.size(), 2);
Expand Down
10 changes: 5 additions & 5 deletions src/tir/ir/sparse.cc
Original file line number Diff line number Diff line change
Expand Up @@ -91,26 +91,26 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
});

// SparseFixedAxis
SparseFixedAxis::SparseFixedAxis(String name, PrimExpr length, Buffer indices, PrimExpr num_cols) {
SparseFixedAxis::SparseFixedAxis(String name, PrimExpr length, Buffer indices, PrimExpr nnz_cols) {
ObjectPtr<SparseFixedAxisNode> node = make_object<SparseFixedAxisNode>();
node->name = std::move(name);
node->length = std::move(length);
node->indices = std::move(indices);
node->num_cols = std::move(num_cols);
node->nnz_cols = std::move(nnz_cols);
data_ = std::move(node);
}

TVM_REGISTER_NODE_TYPE(SparseFixedAxisNode);

TVM_REGISTER_GLOBAL("tir.sparse.SparseFixedAxis")
.set_body_typed([](String name, PrimExpr length, Buffer indices, PrimExpr num_cols) {
return SparseFixedAxis(name, length, indices, num_cols);
.set_body_typed([](String name, PrimExpr length, Buffer indices, PrimExpr nnz_cols) {
return SparseFixedAxis(name, length, indices, nnz_cols);
});

TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<SparseFixedAxisNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const SparseFixedAxisNode*>(node.get());
p->stream << "sparse_fixed(" << op->name << ", " << op->length << ", " << op->num_cols << ", "
p->stream << "sparse_fixed(" << op->name << ", " << op->length << ", " << op->nnz_cols << ", "
<< op->indices->name << ")";
});

Expand Down
2 changes: 1 addition & 1 deletion src/tir/transforms/lower_sparse_tir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -288,7 +288,7 @@ class IndexTransformer : public StmtExprMutator {
if (axis->IsInstance<DenseFixedAxisNode>()) {
return ana_.Simplify(std::move(prev_lowered_index) * axis->length + std::move(index));
} else if (const auto* sf_axis = axis.as<SparseFixedAxisNode>()) {
return ana_.Simplify(std::move(prev_lowered_index) * sf_axis->num_cols + std::move(index));
return ana_.Simplify(std::move(prev_lowered_index) * sf_axis->nnz_cols + std::move(index));
} else if (const auto* dv_axis = axis.as<DenseVariableAxisNode>()) {
return ana_.Simplify(
add(BufferLoad(dv_axis->indptr, {std::move(prev_lowered_index)}), std::move(index)));
Expand Down

0 comments on commit 332df3e

Please sign in to comment.