Skip to content

Commit

Permalink
Disable MLIR bridge for the tests that MLIR bridge silently fails
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 676939123
  • Loading branch information
Googler authored and tensorflower-gardener committed Sep 20, 2024
1 parent 6aed4cc commit c6c86e7
Showing 1 changed file with 8 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -95,9 +95,13 @@ def _maybe_get_subnormal_kwarg(allow_subnormal=ALLOW_SUBNORMAL):
class TestCase(dict):
"""`dict` object containing test strategies for a single function."""

def __init__(self, name, strategy_list, **kwargs):
def __init__(self, name, strategy_list, enable_mlir_bridge=True, **kwargs):

self.name = name

if not enable_mlir_bridge:
tf.config.experimental.disable_mlir_bridge()

tensorflow_function = kwargs.pop('tensorflow_function', None)
if not tensorflow_function:
tensorflow_function = _getattr(tf, name)
Expand Down Expand Up @@ -1052,6 +1056,8 @@ def _not_implemented(*args, **kwargs):
# disabled=NUMPY_MODE and six.PY2
disabled=True),
TestCase('linalg.diag_part', [single_arrays(shape=shapes(min_dims=2))]),
# MLIR bridge does not support MatrixDiagPartV2. Disable it to use the
# legacy bridge
TestCase(
'raw_ops.MatrixDiagPartV2', [
hps.fixed_dictionaries(
Expand All @@ -1060,6 +1066,7 @@ def _not_implemented(*args, **kwargs):
k=hps.sampled_from([-1, 0, 1]),
padding_value=hps.just(0.))).map(Kwargs)
],
enable_mlir_bridge=False,
xla_const_args=('k',)),
TestCase('identity', [single_arrays()]),

Expand Down

0 comments on commit c6c86e7

Please sign in to comment.