Skip to content

Commit

Permalink
[Frontend][Tensorflow] Sparse dense matmul adjoint option added
Browse files Browse the repository at this point in the history
  • Loading branch information
ANSHUMAN TRIPATHY authored and ANSHUMAN TRIPATHY committed Jan 13, 2021
1 parent d949d15 commit 2741533
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 19 deletions.
19 changes: 3 additions & 16 deletions python/tvm/relay/frontend/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -941,8 +941,10 @@ def _impl(inputs, attr, params, mod):
(values_tensor, (rows, cols)), shape=tuple(dense_shape_tensor.tolist())
)

if sparse_lhs:
if sparse_lhs or attr.get("adjoint_b"):
data = _op.transpose(data)
elif attr.get("adjoint_a"):
weight_sp = weight_sp
else:
weight_sp = csr_matrix(weight_sp.transpose())

Expand All @@ -955,21 +957,6 @@ def _impl(inputs, attr, params, mod):
if not sparse_lhs:
ret = _op.transpose(ret)

# Case 1. If both are true means first input was dense and second was sparse
# Case 2. If both are false means first input was sparse and second was dense
# TODO(ANSHUMAN87): Support other adjoint option too
if not (
(attr.get("adjoint_a") and attr.get("adjoint_b"))
or ((not attr.get("adjoint_a")) and (not attr.get("adjoint_b")))
):
raise tvm.error.OpAttributeUnImplemented(
"Only tf.sparse.sparse_dense_matmul() with adjoint_a=True and adjoint_b=True"
"or with adjoint_a=False and adjoint_b=False"
" is supported, but adjoint_a={} and adjoint_b={} was supplied.".format(
attr.get("adjoint_a"), attr.get("adjoint_b")
)
)

return ret

return _impl
Expand Down
8 changes: 5 additions & 3 deletions tests/python/frontend/tensorflow/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -1761,9 +1761,11 @@ def test_forward_batch_matmul():
def _test_sparse_dense_matmul(indices, values, A_shape, B_shape, dtype, flip=False):
""" One iteration of sparse_dense_matmul """

# TODO(ANSHUMAN87): Support adjoint options too
for adjoint_a in [False]:
for adjoint_b in [False]:
for adjoint_a in [False, True]:
for adjoint_b in [False, True]:
A_shape = A_shape[::-1] if adjoint_a else A_shape
B_shape = B_shape[::-1] if adjoint_b else B_shape

with tf.Graph().as_default():
A_sp = tf.sparse.SparseTensor(indices=indices, values=values, dense_shape=A_shape)
B = tf.placeholder(shape=B_shape, dtype=dtype, name="B")
Expand Down

0 comments on commit 2741533

Please sign in to comment.