Skip to content

Commit

Permalink
[1] Review comment handled
Browse files Browse the repository at this point in the history
  • Loading branch information
ANSHUMAN TRIPATHY authored and ANSHUMAN TRIPATHY committed Feb 18, 2021
1 parent b9abbb6 commit 054ecfe
Show file tree
Hide file tree
Showing 5 changed files with 18 additions and 10 deletions.
4 changes: 3 additions & 1 deletion python/tvm/relay/frontend/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -1025,7 +1025,9 @@ def _sparse_tensor_dense_add():
from scipy.sparse import csr_matrix

def _impl(inputs, attr, params, mod):
assert len(inputs) == 4, "There should be 4 input tensors"
assert (
len(inputs) == 4
), "There should be 4 input tensors [sparse_indices, sparse_values, sparse_shape, dense]."

indices_tensor = _infer_value(inputs[0], params, mod).asnumpy()
values_tensor = _infer_value(inputs[1], params, mod).asnumpy()
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/relay/op/nn/_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ def alter_op_layout_sparse_dense(attrs, inputs, tinfos, out_type):

# sparse_add
reg.register_strategy("nn.sparse_add", strategy.sparse_add_strategy)
reg.register_pattern("nn.sparse_add", reg.OpPattern.ELEMWISE)
reg.register_pattern("nn.sparse_add", reg.OpPattern.OPAQUE)


@reg.register_compute("nn.internal.sparse_dense_padded")
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/relay/op/strategy/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -816,7 +816,7 @@ def sparse_add_strategy(attrs, inputs, out_type, target):
strategy = _op.OpStrategy()
strategy.add_implementation(
wrap_compute_sparse_add(topi.nn.sparse_add),
wrap_topi_schedule(topi.generic.schedule_sparse_add),
wrap_topi_schedule(topi.generic.schedule_extern),
name="sparse_add.generic",
)
return strategy
Expand Down
9 changes: 7 additions & 2 deletions python/tvm/topi/nn/sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,7 +382,7 @@ def sparse_add(dense_data, sparse_data, sparse_indices, sparse_indptr):
2-D with shape [M, N]
"""
# TODO(ANSHUMAN87): support BSR format too
assert len(sparse_data.shape) == 1
assert len(sparse_data.shape) == 1, "only CSR format is supported"
return _sparse_add_csr(dense_data, sparse_data, sparse_indices, sparse_indptr)


Expand Down Expand Up @@ -417,6 +417,11 @@ def _csr_add_ir(dense_data, sparse_data, sparse_indices, sparse_indptr, out_data
inputs=[dense_data_inp, sparse_data_inp, sparse_indices_inp, sparse_indptr_inp],
fcompute=lambda ins, outs: _csr_add_ir(ins[0], ins[1], ins[2], ins[3], outs[0]),
tag="sparse_add_csr",
dtype=["float32", "float32", "int32", "int32"],
dtype=[
dense_data_inp.dtype,
sparse_data_inp.dtype,
sparse_indices_inp.dtype,
sparse_indptr_inp.dtype,
],
name="out",
)
11 changes: 6 additions & 5 deletions src/relay/op/nn/sparse.cc
Original file line number Diff line number Diff line change
Expand Up @@ -202,9 +202,10 @@ bool SparseAddRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
ICHECK_EQ(types.size(), 5);
const auto* dense_data = types[0].as<TensorTypeNode>();
const auto* sparse_data = types[1].as<TensorTypeNode>();
ICHECK_EQ(sparse_data->shape.size(), 1);
ICHECK(reporter->Assert(sparse_data->dtype == dense_data->dtype));
ICHECK(reporter->Assert(sparse_data->shape.size() == 1));
const auto* sparse_indices = types[2].as<TensorTypeNode>();
ICHECK_EQ(sparse_indices->shape.size(), 1);
ICHECK(reporter->Assert(sparse_indices->shape.size() == 1));

reporter->Assign(types[4], TensorType(dense_data->shape, dense_data->dtype));
return true;
Expand All @@ -228,9 +229,9 @@ RELAY_REGISTER_OP("nn.sparse_add")
)code" TVM_ADD_FILELINE)
.set_num_inputs(4)
.add_argument("dense_data", "2D Tensor", "Dense data matrix.")
.add_argument("sparse_data", "1D Tensor", "Sparse data matrix.")
.add_argument("sparse_indices", "1D Tensor", "Sparse indices matrix.")
.add_argument("sparse_indptr", "1D Tensor", "Sparse index pointer matrix.")
.add_argument("sparse_data", "1D Tensor", "Sparse data vector.")
.add_argument("sparse_indices", "1D Tensor", "Sparse indices vector.")
.add_argument("sparse_indptr", "1D Tensor", "Sparse index pointer vector.")
.set_support_level(1)
.add_type_rel("SparseAdd", SparseAddRel);

Expand Down

0 comments on commit 054ecfe

Please sign in to comment.