From 83b5d2ea055cd0fe5d3c7b1a1c5a74e0bfb42265 Mon Sep 17 00:00:00 2001 From: ANSHUMAN TRIPATHY Date: Sat, 13 Feb 2021 10:33:36 +0530 Subject: [PATCH] lint resolved --- python/tvm/relay/op/strategy/generic.py | 2 ++ python/tvm/topi/generic/nn.py | 2 ++ python/tvm/topi/nn/sparse.py | 10 +++++----- src/relay/op/nn/sparse.cc | 3 +-- tests/python/topi/python/test_topi_sparse.py | 2 ++ 5 files changed, 12 insertions(+), 7 deletions(-) diff --git a/python/tvm/relay/op/strategy/generic.py b/python/tvm/relay/op/strategy/generic.py index 7a90817eb7471..9e5ba6bcffcd7 100644 --- a/python/tvm/relay/op/strategy/generic.py +++ b/python/tvm/relay/op/strategy/generic.py @@ -821,12 +821,14 @@ def sparse_add_strategy(attrs, inputs, out_type, target): ) return strategy + @generic_func def schedule_sparse_add(attrs, outs, target): """schedule sparse_add""" with target: return topi.generic.schedule_sparse_add(outs) + # sparse_transpose @generic_func def schedule_sparse_transpose(attrs, outs, target): diff --git a/python/tvm/topi/generic/nn.py b/python/tvm/topi/generic/nn.py index cbb6a94a28194..49281e0356a64 100644 --- a/python/tvm/topi/generic/nn.py +++ b/python/tvm/topi/generic/nn.py @@ -729,6 +729,7 @@ def schedule_sparse_transpose(outs): """ return _default_schedule(outs, False) + def schedule_sparse_add(outs): """Schedule for sparse_add @@ -745,6 +746,7 @@ def schedule_sparse_add(outs): """ return _default_schedule(outs, False) + def schedule_batch_matmul(outs): """Schedule for batch_matmul diff --git a/python/tvm/topi/nn/sparse.py b/python/tvm/topi/nn/sparse.py index dc8693b41e6d1..f19e5a512cb9b 100644 --- a/python/tvm/topi/nn/sparse.py +++ b/python/tvm/topi/nn/sparse.py @@ -357,6 +357,7 @@ def sparse_dense_alter_layout(_attrs, _inputs, _tinfos, _out_type): """ return None + def sparse_add(dense_data, sparse_data, sparse_indices, sparse_indptr): """ Computes sparse-dense addition @@ -380,10 +381,11 @@ def sparse_add(dense_data, sparse_data, sparse_indices, sparse_indptr): output : tvm.te.Tensor 2-D with shape [M, N] """ - #TODO(ANSHUMAN87): support BSR format too + # TODO(ANSHUMAN87): support BSR format too assert len(sparse_data.shape) == 1 return _sparse_add_csr(dense_data, sparse_data, sparse_indices, sparse_indptr) + def _sparse_add_csr(dense_data_inp, sparse_data_inp, sparse_indices_inp, sparse_indptr_inp): oshape = get_const_tuple(dense_data_inp.shape) @@ -413,10 +415,8 @@ def _csr_add_ir(dense_data, sparse_data, sparse_indices, sparse_indptr, out_data return te.extern( shape=oshape, inputs=[dense_data_inp, sparse_data_inp, sparse_indices_inp, sparse_indptr_inp], - fcompute=lambda ins, outs: _csr_add_ir( - ins[0], ins[1], ins[2], ins[3], outs[0] - ), + fcompute=lambda ins, outs: _csr_add_ir(ins[0], ins[1], ins[2], ins[3], outs[0]), tag="sparse_add_csr", dtype=["float32", "float32", "int32", "int32"], name="out", - ) \ No newline at end of file + ) diff --git a/src/relay/op/nn/sparse.cc b/src/relay/op/nn/sparse.cc index 7b3aaa285b030..f026aa949bd9a 100644 --- a/src/relay/op/nn/sparse.cc +++ b/src/relay/op/nn/sparse.cc @@ -198,7 +198,7 @@ RELAY_REGISTER_OP("nn.sparse_transpose") // relay.nn.sparse_add bool SparseAddRel(const Array& types, int num_inputs, const Attrs& attrs, - const TypeReporter& reporter) { + const TypeReporter& reporter) { ICHECK_EQ(types.size(), 5); const auto* dense_data = types[0].as(); const auto* sparse_data = types[1].as(); @@ -236,6 +236,5 @@ RELAY_REGISTER_OP("nn.sparse_add") .set_support_level(1) .add_type_rel("SparseAdd", SparseAddRel); - } // namespace relay } // namespace tvm diff --git a/tests/python/topi/python/test_topi_sparse.py b/tests/python/topi/python/test_topi_sparse.py index 1b3bcc7f2732a..5d92694cfaa4e 100644 --- a/tests/python/topi/python/test_topi_sparse.py +++ b/tests/python/topi/python/test_topi_sparse.py @@ -525,6 +525,7 @@ def test_sparse_dense_padded_alter_op(): with tvm.transform.PassContext(opt_level=3, required_pass="AlterOpLayout"): x = relay.build(tvm.IRModule.from_expr(f), target=tvm.target.Target("cuda")) + def test_sparse_add_csr(): M, K, density = 3, 49, 0.2 X_np = np.random.randn(M, K).astype("float32") @@ -549,6 +550,7 @@ def test_sparse_add_csr(): ) tvm.testing.assert_allclose(Z_tvm.asnumpy(), Z_np, atol=1e-4, rtol=1e-4) + if __name__ == "__main__": test_csrmv() test_csrmm()