Skip to content

Commit

Permalink
Relax numerical precision tolerance in unit tests and doctests.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 723192022
  • Loading branch information
rjagerman authored and Rax Developers committed Feb 4, 2025
1 parent a0a3b33 commit 88291d4
Show file tree
Hide file tree
Showing 6 changed files with 48 additions and 31 deletions.
12 changes: 8 additions & 4 deletions examples/t5x/models_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,10 +190,14 @@ def test_loss_fn(self):
self.assertEqual(args[3].shape, (16 * 4, 1)) # decoder_target_tokens

# Check the loss and metric values.
np.testing.assert_allclose(loss, 20.415768)
np.testing.assert_allclose(metrics["loss"].compute(), 20.415768)
np.testing.assert_allclose(metrics["metrics/ndcg"].compute(), 0.41030282)
np.testing.assert_allclose(metrics["metrics/mrr"].compute(), 0.30208334)
np.testing.assert_allclose(loss, 20.415768, rtol=1e-5)
np.testing.assert_allclose(metrics["loss"].compute(), 20.415768, rtol=1e-5)
np.testing.assert_allclose(
metrics["metrics/ndcg"].compute(), 0.41030282, rtol=1e-5
)
np.testing.assert_allclose(
metrics["metrics/mrr"].compute(), 0.30208334, rtol=1e-5
)


if __name__ == "__main__":
Expand Down
4 changes: 2 additions & 2 deletions rax/_src/lambdaweights.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@
>>> labels = jnp.array([1.0, 2.0, 0.0])
>>> loss = rax.pairwise_logistic_loss(
... scores, labels, lambdaweight_fn=rax.labeldiff_lambdaweight)
>>> print(loss)
1.8923712
>>> print(f"{loss:.5f}")
1.89237
"""

import operator
Expand Down
13 changes: 8 additions & 5 deletions rax/_src/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,17 +29,20 @@
>>> scores = jnp.array([2., 1., 3.])
>>> labels = jnp.array([1., 0., 0.])
>>> print(rax.softmax_loss(scores, labels))
1.4076059
>>> loss = rax.softmax_loss(scores, labels)
>>> print(f"{loss:.5f}")
1.40761
Usage with a batch of data and a mask to indicate valid items.
>>> scores = jnp.array([[2., 1., 0.], [1., 0.5, 1.5]])
>>> labels = jnp.array([[1., 0., 0.], [0., 0., 1.]])
>>> where = jnp.array([[True, True, False], [True, True, True]])
>>> print(rax.pairwise_hinge_loss(
... scores, labels, where=where, reduce_fn=jnp.mean))
0.16666667
>>> loss = rax.pairwise_hinge_loss(
... scores, labels, where=where, reduce_fn=jnp.mean
... )
>>> print(f"{loss:.5f}")
0.16667
To compute gradients of each loss function, please use standard JAX
transformations such as :func:`jax.grad` or :func:`jax.value_and_grad`:
Expand Down
4 changes: 2 additions & 2 deletions rax/_src/losses_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ def test_computes_loss_value(self, loss_fn, expected_value):

loss = loss_fn(scores, labels, reduce_fn=jnp.sum)

np.testing.assert_allclose(jnp.asarray(expected_value), loss)
np.testing.assert_allclose(jnp.asarray(expected_value), loss, rtol=1e-5)

@parameterized.parameters([
{
Expand Down Expand Up @@ -523,7 +523,7 @@ def test_computes_loss_value_with_vmap(self, loss_fn, expected_value):
vmap_loss_fn = jax.vmap(loss_fn, in_axes=(0, 0), out_axes=0)
loss = vmap_loss_fn(scores, labels)

np.testing.assert_allclose(jnp.asarray(expected_value), loss)
np.testing.assert_allclose(jnp.asarray(expected_value), loss, rtol=1e-5)

@parameterized.parameters([
{
Expand Down
10 changes: 6 additions & 4 deletions rax/_src/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,16 +31,18 @@
>>> import rax
>>> scores = jnp.array([2., 1., 3.])
>>> labels = jnp.array([2., 0., 1.])
>>> print(rax.ndcg_metric(scores, labels))
0.79670763
>>> loss = rax.ndcg_metric(scores, labels)
>>> print(f"{loss:.5f}")
0.79671
Usage with a batch of data and a mask to indicate valid items:
>>> scores = jnp.array([[2., 1., 3.], [1., 0.5, 1.5]])
>>> labels = jnp.array([[2., 0., 1.], [0., 0., 1.]])
>>> where = jnp.array([[True, True, False], [True, True, True]])
>>> print(rax.ndcg_metric(scores, labels))
0.8983538
>>> loss = rax.ndcg_metric(scores, labels)
>>> print(f"{loss:.5f}")
0.89835
Usage with :func:`jax.vmap` batching and a mask to indicate valid items:
Expand Down
36 changes: 22 additions & 14 deletions rax/_src/t12n.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,9 @@
>>> scores = jnp.asarray([0., 1., 3., 2.])
>>> labels = jnp.asarray([0., 0., 1., 2.])
>>> approx_ndcg_loss_fn = rax.approx_t12n(rax.ndcg_metric)
>>> print(approx_ndcg_loss_fn(scores, labels))
-0.71789175
>>> loss = approx_ndcg_loss_fn(scores, labels)
>>> print(f"{loss:.5f}")
-0.71789
"""

import functools
Expand Down Expand Up @@ -67,16 +68,19 @@ def approx_t12n(metric_fn: MetricFn, temperature: float = 1.0) -> LossFn:
>>> approx_mrr = rax.approx_t12n(rax.mrr_metric)
>>> scores = jnp.asarray([0., 1., 3., 2.])
>>> labels = jnp.asarray([0., 0., 1., 2.])
>>> print(approx_mrr(scores, labels))
-0.6965873
>>> loss = approx_mrr(scores, labels)
>>> print(f"{loss:.5f}")
-0.69659
Example usage together with :func:`rax.gumbel_t12n`:
>>> gumbel_approx_mrr = rax.gumbel_t12n(rax.approx_t12n(rax.mrr_metric))
>>> scores = jnp.asarray([0., 1., 3., 2.])
>>> labels = jnp.asarray([0., 0., 1., 2.])
>>> print(gumbel_approx_mrr(scores, labels, key=jax.random.PRNGKey(42)))
-0.71880937
>>> key = jax.random.PRNGKey(42)
>>> loss = gumbel_approx_mrr(scores, labels, key=key)
>>> print(f"{loss:.5f}")
-0.71881
Args:
metric_fn: The metric function to convert to an approximate loss.
Expand Down Expand Up @@ -122,16 +126,18 @@ def bound_t12n(metric_fn: MetricFn):
>>> bound_mrr = rax.bound_t12n(rax.mrr_metric)
>>> scores = jnp.asarray([0., 1., 3., 2.])
>>> labels = jnp.asarray([0., 1., 0., 1.])
>>> print(bound_mrr(scores, labels))
-0.33333334
>>> loss = bound_mrr(scores, labels)
>>> print(f"{loss:.5f}")
-0.33333
Example usage together with :func:`rax.gumbel_t12n`:
>>> gumbel_bound_mrr = rax.gumbel_t12n(rax.bound_t12n(rax.mrr_metric))
>>> scores = jnp.asarray([0., 1., 3., 2.])
>>> labels = jnp.asarray([0., 1., 0., 1.])
>>> print(gumbel_bound_mrr(scores, labels, key=jax.random.PRNGKey(42)))
-0.31619418
>>> loss = gumbel_bound_mrr(scores, labels, key=jax.random.PRNGKey(42))
>>> print(f"{loss:.5f}")
-0.31619
Args:
metric_fn: The metric function to convert to a lower-bound loss.
Expand Down Expand Up @@ -184,10 +190,12 @@ def gumbel_t12n(
>>> loss_fn = rax.gumbel_t12n(rax.softmax_loss)
>>> scores = jnp.asarray([0., 1., 3., 2.])
>>> labels = jnp.asarray([0., 0., 1., 2.])
>>> print(loss_fn(scores, labels, key=jax.random.PRNGKey(42)))
6.2066536
>>> print(loss_fn(scores, labels, key=jax.random.PRNGKey(79)))
5.0127797
>>> loss = loss_fn(scores, labels, key=jax.random.PRNGKey(42))
>>> print(f"{loss:.5f}")
6.20665
>>> loss = loss_fn(scores, labels, key=jax.random.PRNGKey(79))
>>> print(f"{loss:.5f}")
5.01278
Args:
loss_or_metric_fn: A Rax loss or metric function.
Expand Down

0 comments on commit 88291d4

Please sign in to comment.