diff --git a/python/tvm/relay/frontend/tensorflow.py b/python/tvm/relay/frontend/tensorflow.py index d5746a38582cb..b20a654ede12c 100644 --- a/python/tvm/relay/frontend/tensorflow.py +++ b/python/tvm/relay/frontend/tensorflow.py @@ -941,8 +941,10 @@ def _impl(inputs, attr, params, mod): (values_tensor, (rows, cols)), shape=tuple(dense_shape_tensor.tolist()) ) - if sparse_lhs: + if sparse_lhs or attr.get("adjoint_b"): data = _op.transpose(data) + elif attr.get("adjoint_a"): + weight_sp = weight_sp else: weight_sp = csr_matrix(weight_sp.transpose()) @@ -955,21 +957,6 @@ def _impl(inputs, attr, params, mod): if not sparse_lhs: ret = _op.transpose(ret) - # Case 1. If both are true means first input was dense and second was sparse - # Case 2. If both are false means first input was sparse and second was dense - # TODO(ANSHUMAN87): Support other adjoint option too - if not ( - (attr.get("adjoint_a") and attr.get("adjoint_b")) - or ((not attr.get("adjoint_a")) and (not attr.get("adjoint_b"))) - ): - raise tvm.error.OpAttributeUnImplemented( - "Only tf.sparse.sparse_dense_matmul() with adjoint_a=True and adjoint_b=True" - "or with adjoint_a=False and adjoint_b=False" - " is supported, but adjoint_a={} and adjoint_b={} was supplied.".format( - attr.get("adjoint_a"), attr.get("adjoint_b") - ) - ) - return ret return _impl diff --git a/tests/python/frontend/tensorflow/test_forward.py b/tests/python/frontend/tensorflow/test_forward.py index 22ed6c5b2edff..f286bff7eba27 100644 --- a/tests/python/frontend/tensorflow/test_forward.py +++ b/tests/python/frontend/tensorflow/test_forward.py @@ -1761,9 +1761,11 @@ def test_forward_batch_matmul(): def _test_sparse_dense_matmul(indices, values, A_shape, B_shape, dtype, flip=False): """ One iteration of sparse_dense_matmul """ - # TODO(ANSHUMAN87): Support adjoint options too - for adjoint_a in [False]: - for adjoint_b in [False]: + for adjoint_a in [False, True]: + for adjoint_b in [False, True]: + A_shape = A_shape[::-1] if adjoint_a else A_shape + B_shape = B_shape[::-1] if adjoint_b else B_shape + with tf.Graph().as_default(): A_sp = tf.sparse.SparseTensor(indices=indices, values=values, dense_shape=A_shape) B = tf.placeholder(shape=B_shape, dtype=dtype, name="B")