diff --git a/tensorflow_probability/python/internal/backend/numpy/numpy_test.py b/tensorflow_probability/python/internal/backend/numpy/numpy_test.py index 4aa17732d7..5ef9fafd98 100644 --- a/tensorflow_probability/python/internal/backend/numpy/numpy_test.py +++ b/tensorflow_probability/python/internal/backend/numpy/numpy_test.py @@ -95,9 +95,13 @@ def _maybe_get_subnormal_kwarg(allow_subnormal=ALLOW_SUBNORMAL): class TestCase(dict): """`dict` object containing test strategies for a single function.""" - def __init__(self, name, strategy_list, **kwargs): + def __init__(self, name, strategy_list, enable_mlir_bridge=True, **kwargs): + self.name = name + if not enable_mlir_bridge: + tf.config.experimental.disable_mlir_bridge() + tensorflow_function = kwargs.pop('tensorflow_function', None) if not tensorflow_function: tensorflow_function = _getattr(tf, name) @@ -1052,6 +1056,8 @@ def _not_implemented(*args, **kwargs): # disabled=NUMPY_MODE and six.PY2 disabled=True), TestCase('linalg.diag_part', [single_arrays(shape=shapes(min_dims=2))]), + # MLIR bridge does not support MatrixDiagPartV2. Disable it to use the + # legacy bridge TestCase( 'raw_ops.MatrixDiagPartV2', [ hps.fixed_dictionaries( @@ -1060,6 +1066,7 @@ def _not_implemented(*args, **kwargs): k=hps.sampled_from([-1, 0, 1]), padding_value=hps.just(0.))).map(Kwargs) ], + enable_mlir_bridge=False, xla_const_args=('k',)), TestCase('identity', [single_arrays()]),