Skip to content

Commit

Permalink
Disable test_gradient_with_additional_parameters for JAX backend.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 693475540
  • Loading branch information
ColCarroll authored and tensorflower-gardener committed Nov 5, 2024
1 parent b0a692b commit 5047529
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -40,12 +40,23 @@
from tensorflow_probability.python.internal import test_util
from tensorflow_probability.python.random import random_ops

_DIFFERENT_HYPOTHESIS_KWARGS = {}

# This check is done on recent versions of hypothesis, but not all,
# as of November 2024.
if hasattr(hp.HealthCheck, 'differing_executors'):
_DIFFERENT_HYPOTHESIS_KWARGS['suppress_health_check'] = [
hp.HealthCheck.differing_executors
]


@test_util.test_all_tf_execution_regimes
class _BatchBroadcastTest(object):

@hp.given(hps.data())
@tfp_hps.tfp_hp_settings(default_max_examples=5)
@tfp_hps.tfp_hp_settings(
default_max_examples=5,
**_DIFFERENT_HYPOTHESIS_KWARGS)
def test_shapes(self, data):
batch_shape = data.draw(tfp_hps.shapes())
bcast_arg, dist_batch_shp = data.draw(
Expand All @@ -63,7 +74,9 @@ def test_shapes(self, data):
dist.event_shape_tensor())

@hp.given(hps.data())
@tfp_hps.tfp_hp_settings(default_max_examples=5)
@tfp_hps.tfp_hp_settings(
default_max_examples=5,
**_DIFFERENT_HYPOTHESIS_KWARGS)
def test_sample(self, data):
batch_shape = data.draw(tfp_hps.shapes())
bcast_arg, dist_batch_shp = data.draw(
Expand Down Expand Up @@ -109,7 +122,9 @@ def test_sample(self, data):
self.assertAllClose(lp, dist.log_prob(sample2))

@hp.given(hps.data())
@tfp_hps.tfp_hp_settings(default_max_examples=5)
@tfp_hps.tfp_hp_settings(
default_max_examples=5,
**_DIFFERENT_HYPOTHESIS_KWARGS)
def test_log_prob(self, data):
batch_shape = data.draw(tfp_hps.shapes())
bcast_arg, dist_batch_shp = data.draw(
Expand Down Expand Up @@ -235,7 +250,9 @@ def test_docstring_example(self):
self.evaluate(lp)

@hp.given(hps.data())
@tfp_hps.tfp_hp_settings(default_max_examples=5)
@tfp_hps.tfp_hp_settings(
default_max_examples=5,
**_DIFFERENT_HYPOTHESIS_KWARGS)
def test_default_bijector(self, data):
batch_shape = data.draw(tfp_hps.shapes())
bcast_arg, dist_batch_shp = data.draw(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,8 @@ def ildj_fn(y):
self.assertAllClose(ildj, ildj_true, atol=1e-4)
self.assertAllClose(ildj_grad, ildj_grad_true, rtol=1e-4)

@test_util.disable_test_for_backend(
disable_jax=True, reason='Tracer leak from additional parameters.')
@test_util.numpy_disable_gradient_test
@parameterized.named_parameters(
{
Expand Down

0 comments on commit 5047529

Please sign in to comment.