diff --git a/python/tvm/relay/frontend/tensorflow.py b/python/tvm/relay/frontend/tensorflow.py index ef8288bd65d7c..a78a5e3831690 100644 --- a/python/tvm/relay/frontend/tensorflow.py +++ b/python/tvm/relay/frontend/tensorflow.py @@ -1025,7 +1025,9 @@ def _sparse_tensor_dense_add(): from scipy.sparse import csr_matrix def _impl(inputs, attr, params, mod): - assert len(inputs) == 4, "There should be 4 input tensors" + assert ( + len(inputs) == 4 + ), "There should be 4 input tensors [sparse_indices, sparse_values, sparse_shape, dense]." indices_tensor = _infer_value(inputs[0], params, mod).asnumpy() values_tensor = _infer_value(inputs[1], params, mod).asnumpy() diff --git a/python/tvm/relay/op/nn/_nn.py b/python/tvm/relay/op/nn/_nn.py index e4cceecc1890b..af64873ee9049 100644 --- a/python/tvm/relay/op/nn/_nn.py +++ b/python/tvm/relay/op/nn/_nn.py @@ -144,7 +144,7 @@ def alter_op_layout_sparse_dense(attrs, inputs, tinfos, out_type): # sparse_add reg.register_strategy("nn.sparse_add", strategy.sparse_add_strategy) -reg.register_pattern("nn.sparse_add", reg.OpPattern.ELEMWISE) +reg.register_pattern("nn.sparse_add", reg.OpPattern.OPAQUE) @reg.register_compute("nn.internal.sparse_dense_padded") diff --git a/python/tvm/relay/op/strategy/generic.py b/python/tvm/relay/op/strategy/generic.py index 9e5ba6bcffcd7..f80507cf16563 100644 --- a/python/tvm/relay/op/strategy/generic.py +++ b/python/tvm/relay/op/strategy/generic.py @@ -816,7 +816,7 @@ def sparse_add_strategy(attrs, inputs, out_type, target): strategy = _op.OpStrategy() strategy.add_implementation( wrap_compute_sparse_add(topi.nn.sparse_add), - wrap_topi_schedule(topi.generic.schedule_sparse_add), + wrap_topi_schedule(topi.generic.schedule_extern), name="sparse_add.generic", ) return strategy diff --git a/python/tvm/topi/nn/sparse.py b/python/tvm/topi/nn/sparse.py index c4e687c404450..c17c0f9a6d3ae 100644 --- a/python/tvm/topi/nn/sparse.py +++ b/python/tvm/topi/nn/sparse.py @@ -382,7 +382,7 @@ def sparse_add(dense_data, sparse_data, sparse_indices, sparse_indptr): 2-D with shape [M, N] """ # TODO(ANSHUMAN87): support BSR format too - assert len(sparse_data.shape) == 1 + assert len(sparse_data.shape) == 1, "only CSR format is supported" return _sparse_add_csr(dense_data, sparse_data, sparse_indices, sparse_indptr) @@ -417,6 +417,11 @@ def _csr_add_ir(dense_data, sparse_data, sparse_indices, sparse_indptr, out_data inputs=[dense_data_inp, sparse_data_inp, sparse_indices_inp, sparse_indptr_inp], fcompute=lambda ins, outs: _csr_add_ir(ins[0], ins[1], ins[2], ins[3], outs[0]), tag="sparse_add_csr", - dtype=["float32", "float32", "int32", "int32"], + dtype=[ + dense_data_inp.dtype, + sparse_data_inp.dtype, + sparse_indices_inp.dtype, + sparse_indptr_inp.dtype, + ], name="out", ) diff --git a/src/relay/op/nn/sparse.cc b/src/relay/op/nn/sparse.cc index 08a8394858a1e..94d088d08e619 100644 --- a/src/relay/op/nn/sparse.cc +++ b/src/relay/op/nn/sparse.cc @@ -202,9 +202,10 @@ bool SparseAddRel(const Array& types, int num_inputs, const Attrs& attrs, ICHECK_EQ(types.size(), 5); const auto* dense_data = types[0].as(); const auto* sparse_data = types[1].as(); - ICHECK_EQ(sparse_data->shape.size(), 1); + ICHECK(reporter->Assert(sparse_data->dtype == dense_data->dtype)); + ICHECK(reporter->Assert(sparse_data->shape.size() == 1)); const auto* sparse_indices = types[2].as(); - ICHECK_EQ(sparse_indices->shape.size(), 1); + ICHECK(reporter->Assert(sparse_indices->shape.size() == 1)); reporter->Assign(types[4], TensorType(dense_data->shape, dense_data->dtype)); return true; @@ -228,9 +229,9 @@ RELAY_REGISTER_OP("nn.sparse_add") )code" TVM_ADD_FILELINE) .set_num_inputs(4) .add_argument("dense_data", "2D Tensor", "Dense data matrix.") - .add_argument("sparse_data", "1D Tensor", "Sparse data matrix.") - .add_argument("sparse_indices", "1D Tensor", "Sparse indices matrix.") - .add_argument("sparse_indptr", "1D Tensor", "Sparse index pointer matrix.") + .add_argument("sparse_data", "1D Tensor", "Sparse data vector.") + .add_argument("sparse_indices", "1D Tensor", "Sparse indices vector.") + .add_argument("sparse_indptr", "1D Tensor", "Sparse index pointer vector.") .set_support_level(1) .add_type_rel("SparseAdd", SparseAddRel);