Skip to content

Commit

Permalink
[Frontend][Tensorflow] Sparse dense matmul adjoint option added (#7267)
Browse files Browse the repository at this point in the history
* [Frontend][Tensorflow] Sparse dense matmul adjoint option added

* [1] Review comments handled

* [2] Review comments handled

* [3] Review comments handled
  • Loading branch information
ANSHUMAN TRIPATHY authored Jan 28, 2021
1 parent cbc035f commit dda8f5d
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 28 deletions.
69 changes: 46 additions & 23 deletions python/tvm/relay/frontend/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -926,13 +926,6 @@ def _impl(inputs, attr, params, mod):

data = inputs[3]

# By default, in tensorflow the first input ,i.e., data is sparse
sparse_lhs = True

# If both are true means First input was dense and second was sparse
if attr.get("adjoint_a") and attr.get("adjoint_b"):
sparse_lhs = False

rows = [x[0] for x in indices_tensor]
cols = [x[1] for x in indices_tensor]

Expand All @@ -941,9 +934,53 @@ def _impl(inputs, attr, params, mod):
(values_tensor, (rows, cols)), shape=tuple(dense_shape_tensor.tolist())
)

if sparse_lhs:
# As per tensorflow implementation, we have 4 possible input combination
# and the first input(A) is always sparse and second input(B) is always dense.
# Case 1: A , B , adjoint_a=False, adjoint_b=False --> A * B
# Case 2: A , B , adjoint_a=True, adjoint_b=False --> A.T * B
# Case 3: A , B , adjoint_a=False, adjoint_b=True --> A * B.T
# Case 4: A , B , adjoint_a=True, adjoint_b=True --> A.T * B.T
#
# Topi implementation for sparse_dense(matmul) has 2 possible input
# combination where first input(A) is always dense
# and second input(B) is always sparse.
# Case 1: A , B, sparse_lhs = False --> A * B.T
# Case 2: A , B, sparse_lhs = True --> B * A.T
#
# The mapping would be as below:
# TF Case 1: A , B , adjoint_a=False, adjoint_b=False
# --> In TF: A * B --> In Topi: A * B.T.T
# --> sparse_dense(transpose(B), A, sparse_lhs=True)
#
# TF Case 2: A , B , adjoint_a=True, adjoint_b=False
# --> In TF: A.T * B --> In Topi: A.T * B.T.T
# --> sparse_dense(transpose(B), transpose(A), sparse_lhs=True)
#
# TF Case 3: A , B , adjoint_a=False, adjoint_b=True
# --> In TF: A * B.T --> In Topi: A * B
# --> sparse_dense(B, A, sparse_lhs=True)
#
# TF Case 4: A , B , adjoint_a=True, adjoint_b=True
# --> In TF: A.T * B.T --> In Topi: (B * A.T).T
# --> transpose(sparse_dense(B, transpose(A), sparse_lhs=False))

# By default, in tensorflow the first input ,i.e., data is sparse
sparse_lhs = True

# TF Case 1:
if not attr.get("adjoint_a") and not attr.get("adjoint_b"):
data = _op.transpose(data)
# TF Case 2:
elif attr.get("adjoint_a") and not attr.get("adjoint_b"):
data = _op.transpose(data)
weight_sp = csr_matrix(weight_sp.transpose())
# TF Case 3:
elif not attr.get("adjoint_a") and attr.get("adjoint_b"):
pass
# TF Case 4:
# attr.get("adjoint_a") and attr.get("adjoint_b"):
else:
sparse_lhs = False
weight_sp = csr_matrix(weight_sp.transpose())

weight_data = _expr.const(weight_sp.data, weight_sp.data.dtype)
Expand All @@ -953,23 +990,9 @@ def _impl(inputs, attr, params, mod):
ret = _op.nn.sparse_dense(data, [weight_data, weight_indices, weight_indptrs], sparse_lhs)

if not sparse_lhs:
# TF Case 4
ret = _op.transpose(ret)

# Case 1. If both are true means first input was dense and second was sparse
# Case 2. If both are false means first input was sparse and second was dense
# TODO(ANSHUMAN87): Support other adjoint option too
if not (
(attr.get("adjoint_a") and attr.get("adjoint_b"))
or ((not attr.get("adjoint_a")) and (not attr.get("adjoint_b")))
):
raise tvm.error.OpAttributeUnImplemented(
"Only tf.sparse.sparse_dense_matmul() with adjoint_a=True and adjoint_b=True"
"or with adjoint_a=False and adjoint_b=False"
" is supported, but adjoint_a={} and adjoint_b={} was supplied.".format(
attr.get("adjoint_a"), attr.get("adjoint_b")
)
)

return ret

return _impl
Expand Down
12 changes: 7 additions & 5 deletions tests/python/frontend/tensorflow/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -1758,19 +1758,21 @@ def test_forward_batch_matmul():
# ----------------------------------


def _test_sparse_dense_matmul(indices, values, A_shape, B_shape, dtype, flip=False):
def _test_sparse_dense_matmul(indices, values, A_inp_shape, B_inp_shape, dtype, flip=False):
""" One iteration of sparse_dense_matmul """

# TODO(ANSHUMAN87): Support adjoint options too
for adjoint_a in [False]:
for adjoint_b in [False]:
for adjoint_a in [False, True]:
for adjoint_b in [False, True]:
A_shape = A_inp_shape[::-1] if adjoint_a else A_inp_shape
B_shape = B_inp_shape[::-1] if adjoint_b else B_inp_shape

with tf.Graph().as_default():
A_sp = tf.sparse.SparseTensor(indices=indices, values=values, dense_shape=A_shape)
B = tf.placeholder(shape=B_shape, dtype=dtype, name="B")

if flip:
result = tf.sparse.sparse_dense_matmul(
B, A_sp, adjoint_a=adjoint_a, adjoint_b=adjoint_b
B, A_sp, adjoint_a=adjoint_b, adjoint_b=adjoint_a
)
else:
result = tf.sparse.sparse_dense_matmul(
Expand Down

0 comments on commit dda8f5d

Please sign in to comment.