From 88291d40d6d17c19403bf7b051893abc23635325 Mon Sep 17 00:00:00 2001 From: Rolf Jagerman Date: Tue, 4 Feb 2025 12:44:19 -0800 Subject: [PATCH] Relax numerical precision tolerance in unit tests and doctests. PiperOrigin-RevId: 723192022 --- examples/t5x/models_test.py | 12 ++++++++---- rax/_src/lambdaweights.py | 4 ++-- rax/_src/losses.py | 13 ++++++++----- rax/_src/losses_test.py | 4 ++-- rax/_src/metrics.py | 10 ++++++---- rax/_src/t12n.py | 36 ++++++++++++++++++++++-------------- 6 files changed, 48 insertions(+), 31 deletions(-) diff --git a/examples/t5x/models_test.py b/examples/t5x/models_test.py index d2dd7e2..aa1d214 100644 --- a/examples/t5x/models_test.py +++ b/examples/t5x/models_test.py @@ -190,10 +190,14 @@ def test_loss_fn(self): self.assertEqual(args[3].shape, (16 * 4, 1)) # decoder_target_tokens # Check the loss and metric values. - np.testing.assert_allclose(loss, 20.415768) - np.testing.assert_allclose(metrics["loss"].compute(), 20.415768) - np.testing.assert_allclose(metrics["metrics/ndcg"].compute(), 0.41030282) - np.testing.assert_allclose(metrics["metrics/mrr"].compute(), 0.30208334) + np.testing.assert_allclose(loss, 20.415768, rtol=1e-5) + np.testing.assert_allclose(metrics["loss"].compute(), 20.415768, rtol=1e-5) + np.testing.assert_allclose( + metrics["metrics/ndcg"].compute(), 0.41030282, rtol=1e-5 + ) + np.testing.assert_allclose( + metrics["metrics/mrr"].compute(), 0.30208334, rtol=1e-5 + ) if __name__ == "__main__": diff --git a/rax/_src/lambdaweights.py b/rax/_src/lambdaweights.py index 629b6a3..70847dc 100644 --- a/rax/_src/lambdaweights.py +++ b/rax/_src/lambdaweights.py @@ -25,8 +25,8 @@ >>> labels = jnp.array([1.0, 2.0, 0.0]) >>> loss = rax.pairwise_logistic_loss( ... scores, labels, lambdaweight_fn=rax.labeldiff_lambdaweight) ->>> print(loss) -1.8923712 +>>> print(f"{loss:.5f}") +1.89237 """ import operator diff --git a/rax/_src/losses.py b/rax/_src/losses.py index e7914f5..7362ab2 100644 --- a/rax/_src/losses.py +++ b/rax/_src/losses.py @@ -29,17 +29,20 @@ >>> scores = jnp.array([2., 1., 3.]) >>> labels = jnp.array([1., 0., 0.]) ->>> print(rax.softmax_loss(scores, labels)) -1.4076059 +>>> loss = rax.softmax_loss(scores, labels) +>>> print(f"{loss:.5f}") +1.40761 Usage with a batch of data and a mask to indicate valid items. >>> scores = jnp.array([[2., 1., 0.], [1., 0.5, 1.5]]) >>> labels = jnp.array([[1., 0., 0.], [0., 0., 1.]]) >>> where = jnp.array([[True, True, False], [True, True, True]]) ->>> print(rax.pairwise_hinge_loss( -... scores, labels, where=where, reduce_fn=jnp.mean)) -0.16666667 +>>> loss = rax.pairwise_hinge_loss( +... scores, labels, where=where, reduce_fn=jnp.mean +... ) +>>> print(f"{loss:.5f}") +0.16667 To compute gradients of each loss function, please use standard JAX transformations such as :func:`jax.grad` or :func:`jax.value_and_grad`: diff --git a/rax/_src/losses_test.py b/rax/_src/losses_test.py index d1326de..a31d5f3 100644 --- a/rax/_src/losses_test.py +++ b/rax/_src/losses_test.py @@ -176,7 +176,7 @@ def test_computes_loss_value(self, loss_fn, expected_value): loss = loss_fn(scores, labels, reduce_fn=jnp.sum) - np.testing.assert_allclose(jnp.asarray(expected_value), loss) + np.testing.assert_allclose(jnp.asarray(expected_value), loss, rtol=1e-5) @parameterized.parameters([ { @@ -523,7 +523,7 @@ def test_computes_loss_value_with_vmap(self, loss_fn, expected_value): vmap_loss_fn = jax.vmap(loss_fn, in_axes=(0, 0), out_axes=0) loss = vmap_loss_fn(scores, labels) - np.testing.assert_allclose(jnp.asarray(expected_value), loss) + np.testing.assert_allclose(jnp.asarray(expected_value), loss, rtol=1e-5) @parameterized.parameters([ { diff --git a/rax/_src/metrics.py b/rax/_src/metrics.py index 7cf869f..825e4c1 100644 --- a/rax/_src/metrics.py +++ b/rax/_src/metrics.py @@ -31,16 +31,18 @@ >>> import rax >>> scores = jnp.array([2., 1., 3.]) >>> labels = jnp.array([2., 0., 1.]) ->>> print(rax.ndcg_metric(scores, labels)) -0.79670763 +>>> loss = rax.ndcg_metric(scores, labels) +>>> print(f"{loss:.5f}") +0.79671 Usage with a batch of data and a mask to indicate valid items: >>> scores = jnp.array([[2., 1., 3.], [1., 0.5, 1.5]]) >>> labels = jnp.array([[2., 0., 1.], [0., 0., 1.]]) >>> where = jnp.array([[True, True, False], [True, True, True]]) ->>> print(rax.ndcg_metric(scores, labels)) -0.8983538 +>>> loss = rax.ndcg_metric(scores, labels) +>>> print(f"{loss:.5f}") +0.89835 Usage with :func:`jax.vmap` batching and a mask to indicate valid items: diff --git a/rax/_src/t12n.py b/rax/_src/t12n.py index bf0cbb4..2a9c7e3 100644 --- a/rax/_src/t12n.py +++ b/rax/_src/t12n.py @@ -24,8 +24,9 @@ >>> scores = jnp.asarray([0., 1., 3., 2.]) >>> labels = jnp.asarray([0., 0., 1., 2.]) >>> approx_ndcg_loss_fn = rax.approx_t12n(rax.ndcg_metric) ->>> print(approx_ndcg_loss_fn(scores, labels)) --0.71789175 +>>> loss = approx_ndcg_loss_fn(scores, labels) +>>> print(f"{loss:.5f}") +-0.71789 """ import functools @@ -67,16 +68,19 @@ def approx_t12n(metric_fn: MetricFn, temperature: float = 1.0) -> LossFn: >>> approx_mrr = rax.approx_t12n(rax.mrr_metric) >>> scores = jnp.asarray([0., 1., 3., 2.]) >>> labels = jnp.asarray([0., 0., 1., 2.]) - >>> print(approx_mrr(scores, labels)) - -0.6965873 + >>> loss = approx_mrr(scores, labels) + >>> print(f"{loss:.5f}") + -0.69659 Example usage together with :func:`rax.gumbel_t12n`: >>> gumbel_approx_mrr = rax.gumbel_t12n(rax.approx_t12n(rax.mrr_metric)) >>> scores = jnp.asarray([0., 1., 3., 2.]) >>> labels = jnp.asarray([0., 0., 1., 2.]) - >>> print(gumbel_approx_mrr(scores, labels, key=jax.random.PRNGKey(42))) - -0.71880937 + >>> key = jax.random.PRNGKey(42) + >>> loss = gumbel_approx_mrr(scores, labels, key=key) + >>> print(f"{loss:.5f}") + -0.71881 Args: metric_fn: The metric function to convert to an approximate loss. @@ -122,16 +126,18 @@ def bound_t12n(metric_fn: MetricFn): >>> bound_mrr = rax.bound_t12n(rax.mrr_metric) >>> scores = jnp.asarray([0., 1., 3., 2.]) >>> labels = jnp.asarray([0., 1., 0., 1.]) - >>> print(bound_mrr(scores, labels)) - -0.33333334 + >>> loss = bound_mrr(scores, labels) + >>> print(f"{loss:.5f}") + -0.33333 Example usage together with :func:`rax.gumbel_t12n`: >>> gumbel_bound_mrr = rax.gumbel_t12n(rax.bound_t12n(rax.mrr_metric)) >>> scores = jnp.asarray([0., 1., 3., 2.]) >>> labels = jnp.asarray([0., 1., 0., 1.]) - >>> print(gumbel_bound_mrr(scores, labels, key=jax.random.PRNGKey(42))) - -0.31619418 + >>> loss = gumbel_bound_mrr(scores, labels, key=jax.random.PRNGKey(42)) + >>> print(f"{loss:.5f}") + -0.31619 Args: metric_fn: The metric function to convert to a lower-bound loss. @@ -184,10 +190,12 @@ def gumbel_t12n( >>> loss_fn = rax.gumbel_t12n(rax.softmax_loss) >>> scores = jnp.asarray([0., 1., 3., 2.]) >>> labels = jnp.asarray([0., 0., 1., 2.]) - >>> print(loss_fn(scores, labels, key=jax.random.PRNGKey(42))) - 6.2066536 - >>> print(loss_fn(scores, labels, key=jax.random.PRNGKey(79))) - 5.0127797 + >>> loss = loss_fn(scores, labels, key=jax.random.PRNGKey(42)) + >>> print(f"{loss:.5f}") + 6.20665 + >>> loss = loss_fn(scores, labels, key=jax.random.PRNGKey(79)) + >>> print(f"{loss:.5f}") + 5.01278 Args: loss_or_metric_fn: A Rax loss or metric function.