Skip to content

Commit

Permalink
Relax numerical precision tolerance in unit tests and doctests.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 723192022
  • Loading branch information
rjagerman authored and Rax Developers committed Feb 4, 2025
1 parent a0a3b33 commit 9b66b8c
Show file tree
Hide file tree
Showing 11 changed files with 125 additions and 69 deletions.
32 changes: 22 additions & 10 deletions examples/approx_metrics/main_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,15 @@
from unittest import mock

from absl.testing import absltest
import jax
import numpy as np

from examples.approx_metrics import main
import tensorflow as tf
import tensorflow_datasets as tfds

# Opt-in to the partitionable PRNG implementation.
jax.config.update("jax_threefry_partitionable", True)


def letor_dataset(self, *args, **kwargs):
del args, kwargs # Unused but needed for `tfds.testing.mock_data()`.
Expand Down Expand Up @@ -65,19 +68,28 @@ def test_end_to_end(self):
output = json.loads(mock_stdout.getvalue())

# ApproxAP works best on AP.
np.testing.assert_allclose(output["ApproxAP"]["AP"], 0.88122, rtol=1e-3)
np.testing.assert_allclose(output["ApproxNDCG"]["AP"], 0.70764, rtol=1e-3)
np.testing.assert_allclose(output["ApproxR@50"]["AP"], 0.45507, rtol=1e-3)
with self.subTest(method="ApproxAP", metric="AP"):
self.assertAlmostEqual(output["ApproxAP"]["AP"], 0.794514, places=3)
with self.subTest(method="ApproxNDCG", metric="AP"):
self.assertAlmostEqual(output["ApproxNDCG"]["AP"], 0.64406, places=3)
with self.subTest(method="ApproxR@50", metric="AP"):
self.assertAlmostEqual(output["ApproxR@50"]["AP"], 0.48270, places=3)

# ApproxNDCG works best on NDCG
np.testing.assert_allclose(output["ApproxAP"]["NDCG"], 0.79133, rtol=1e-3)
np.testing.assert_allclose(output["ApproxNDCG"]["NDCG"], 0.90464, rtol=1e-3)
np.testing.assert_allclose(output["ApproxR@50"]["NDCG"], 0.67759, rtol=1e-3)
with self.subTest(method="ApproxAP", metric="NDCG"):
self.assertAlmostEqual(output["ApproxAP"]["NDCG"], 0.76585, places=3)
with self.subTest(method="ApproxNDCG", metric="NDCG"):
self.assertAlmostEqual(output["ApproxNDCG"]["NDCG"], 0.80603, places=3)
with self.subTest(method="ApproxR@50", metric="NDCG"):
self.assertAlmostEqual(output["ApproxR@50"]["NDCG"], 0.64390, places=3)

# ApproxR@50 is not best on R@50 due to difficulty of that metric.
np.testing.assert_allclose(output["ApproxNDCG"]["R@50"], 0.92857, rtol=1e-3)
np.testing.assert_allclose(output["ApproxAP"]["R@50"], 0.92460, rtol=1e-3)
np.testing.assert_allclose(output["ApproxR@50"]["R@50"], 0.85317, rtol=1e-3)
with self.subTest(method="ApproxAP", metric="R@50"):
self.assertAlmostEqual(output["ApproxNDCG"]["R@50"], 0.88095, places=3)
with self.subTest(method="ApproxNDCG", metric="R@50"):
self.assertAlmostEqual(output["ApproxAP"]["R@50"], 0.92857, places=3)
with self.subTest(method="ApproxR@50", metric="R@50"):
self.assertAlmostEqual(output["ApproxR@50"]["R@50"], 0.82937, places=3)


if __name__ == "__main__":
Expand Down
32 changes: 19 additions & 13 deletions examples/flax_integration/main_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,13 @@
from unittest import mock

from absl.testing import absltest
import numpy as np

import jax
from examples.flax_integration import main
import tensorflow_datasets as tfds

# Opt-in to the partitionable PRNG implementation.
jax.config.update("jax_threefry_partitionable", True)


class Web30kTest(absltest.TestCase):

Expand All @@ -41,23 +43,27 @@ def test_end_to_end(self):
output = json.loads(mock_stdout.getvalue())

# Epochs should increase.
self.assertEqual(output[0]["epoch"], 1)
self.assertEqual(output[1]["epoch"], 2)
self.assertEqual(output[2]["epoch"], 3)
with self.subTest(name="Epochs increase"):
self.assertEqual(output[0]["epoch"], 1)
self.assertEqual(output[1]["epoch"], 2)
self.assertEqual(output[2]["epoch"], 3)

# Loss should decrease consistently.
self.assertGreater(output[0]["loss"], output[1]["loss"])
self.assertGreater(output[1]["loss"], output[2]["loss"])
with self.subTest(name="Loss decreases consistently"):
self.assertGreater(output[0]["loss"], output[1]["loss"])
self.assertGreater(output[1]["loss"], output[2]["loss"])

# Metrics should increase consistently.
self.assertLess(output[0]["metric/ndcg"], output[1]["metric/ndcg"])
self.assertLess(output[1]["metric/ndcg"], output[2]["metric/ndcg"])
self.assertLess(output[0]["metric/ndcg@10"], output[1]["metric/ndcg@10"])
self.assertLess(output[1]["metric/ndcg@10"], output[2]["metric/ndcg@10"])
with self.subTest(name="Metrics increase consistently"):
self.assertLess(output[0]["metric/ndcg"], output[1]["metric/ndcg"])
self.assertLess(output[1]["metric/ndcg"], output[2]["metric/ndcg"])
self.assertLess(output[0]["metric/ndcg@10"], output[1]["metric/ndcg@10"])
self.assertLess(output[1]["metric/ndcg@10"], output[2]["metric/ndcg@10"])

# Evaluate metric values after training.
np.testing.assert_allclose(output[2]["metric/ndcg"], 0.829134, atol=0.03)
np.testing.assert_allclose(output[2]["metric/ndcg@10"], 0.650672, atol=0.03)
with self.subTest(name="Metric values after training"):
self.assertAlmostEqual(output[2]["metric/ndcg"], 0.81434, places=3)
self.assertAlmostEqual(output[2]["metric/ndcg@10"], 0.576916, places=3)


if __name__ == "__main__":
Expand Down
24 changes: 16 additions & 8 deletions examples/segmentation/main_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,15 @@
from unittest import mock

from absl.testing import absltest
import jax
import numpy as np

from examples.segmentation import main
import tensorflow_datasets as tfds

# Opt-in to the partitionable PRNG implementation.
jax.config.update("jax_threefry_partitionable", True)


class SegmentationTest(absltest.TestCase):

Expand All @@ -41,20 +45,24 @@ def test_end_to_end(self):
output = json.loads(mock_stdout.getvalue())

# Epochs should increase.
self.assertEqual(output[0]["step"], 20)
self.assertEqual(output[1]["step"], 40)
self.assertEqual(output[2]["step"], 60)
with self.subTest(name="Steps increase"):
self.assertEqual(output[0]["step"], 20)
self.assertEqual(output[1]["step"], 40)
self.assertEqual(output[2]["step"], 60)

# Loss should decrease consistently.
self.assertGreater(output[0]["loss"], output[1]["loss"])
self.assertGreater(output[1]["loss"], output[2]["loss"])
with self.subTest("Loss decreases consistently"):
self.assertGreater(output[0]["loss"], output[1]["loss"])
self.assertGreater(output[1]["loss"], output[2]["loss"])

# NDCG@10 metric should increase consistently.
self.assertLess(output[0]["ndcg@10"], output[1]["ndcg@10"])
self.assertLess(output[1]["ndcg@10"], output[2]["ndcg@10"])
with self.subTest(name="NDCG@10 increases consistently"):
self.assertLess(output[0]["ndcg@10"], output[1]["ndcg@10"])
self.assertLess(output[1]["ndcg@10"], output[2]["ndcg@10"])

# Evaluate exact NDCG@10 metric value after training.
np.testing.assert_allclose(output[2]["ndcg@10"], 0.732969, atol=0.03)
with self.subTest(name="Exact NDCG@10 value after training"):
np.testing.assert_allclose(output[2]["ndcg@10"], 0.732969, atol=0.03)


if __name__ == "__main__":
Expand Down
27 changes: 19 additions & 8 deletions examples/t5x/models_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@
from examples.t5x import models
import tensorflow as tf

# Opt-in to the partitionable PRNG implementation.
jax.config.update("jax_threefry_partitionable", True)


class RankingEncDecFeatureConverterTest(absltest.TestCase):

Expand Down Expand Up @@ -184,16 +187,24 @@ def test_loss_fn(self):
# args where the leading (batch_size, list_size, ...) dimensions are
# flattened to (batch_size * list_size, ...).
args = mocked_module.apply.call_args.args
self.assertEqual(args[0], {"params": params})
self.assertEqual(args[1].shape, (16 * 4, 5)) # encoder_input_tokens
self.assertEqual(args[2].shape, (16 * 4, 1)) # decoder_input_tokens
self.assertEqual(args[3].shape, (16 * 4, 1)) # decoder_target_tokens
with self.subTest(name="Check parameters and shapes"):
self.assertEqual(args[0], {"params": params})
self.assertEqual(args[1].shape, (16 * 4, 5)) # encoder_input_tokens
self.assertEqual(args[2].shape, (16 * 4, 1)) # decoder_input_tokens
self.assertEqual(args[3].shape, (16 * 4, 1)) # decoder_target_tokens

# Check the loss and metric values.
np.testing.assert_allclose(loss, 20.415768)
np.testing.assert_allclose(metrics["loss"].compute(), 20.415768)
np.testing.assert_allclose(metrics["metrics/ndcg"].compute(), 0.41030282)
np.testing.assert_allclose(metrics["metrics/mrr"].compute(), 0.30208334)
with self.subTest(name="Loss value"):
self.assertAlmostEqual(loss, 19.37958, places=3)
self.assertAlmostEqual(metrics["loss"].compute(), 19.37958, places=3)

with self.subTest(name="Metric values"):
self.assertAlmostEqual(
metrics["metrics/ndcg"].compute(), 0.42098486, places=3
)
self.assertAlmostEqual(
metrics["metrics/mrr"].compute(), 0.40104166, places=3
)


if __name__ == "__main__":
Expand Down
4 changes: 2 additions & 2 deletions rax/_src/lambdaweights.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@
>>> labels = jnp.array([1.0, 2.0, 0.0])
>>> loss = rax.pairwise_logistic_loss(
... scores, labels, lambdaweight_fn=rax.labeldiff_lambdaweight)
>>> print(loss)
1.8923712
>>> print(f"{loss:.5f}")
1.89237
"""

import operator
Expand Down
13 changes: 8 additions & 5 deletions rax/_src/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,17 +29,20 @@
>>> scores = jnp.array([2., 1., 3.])
>>> labels = jnp.array([1., 0., 0.])
>>> print(rax.softmax_loss(scores, labels))
1.4076059
>>> loss = rax.softmax_loss(scores, labels)
>>> print(f"{loss:.5f}")
1.40761
Usage with a batch of data and a mask to indicate valid items.
>>> scores = jnp.array([[2., 1., 0.], [1., 0.5, 1.5]])
>>> labels = jnp.array([[1., 0., 0.], [0., 0., 1.]])
>>> where = jnp.array([[True, True, False], [True, True, True]])
>>> print(rax.pairwise_hinge_loss(
... scores, labels, where=where, reduce_fn=jnp.mean))
0.16666667
>>> loss = rax.pairwise_hinge_loss(
... scores, labels, where=where, reduce_fn=jnp.mean
... )
>>> print(f"{loss:.5f}")
0.16667
To compute gradients of each loss function, please use standard JAX
transformations such as :func:`jax.grad` or :func:`jax.value_and_grad`:
Expand Down
4 changes: 2 additions & 2 deletions rax/_src/losses_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ def test_computes_loss_value(self, loss_fn, expected_value):

loss = loss_fn(scores, labels, reduce_fn=jnp.sum)

np.testing.assert_allclose(jnp.asarray(expected_value), loss)
np.testing.assert_allclose(jnp.asarray(expected_value), loss, rtol=1e-5)

@parameterized.parameters([
{
Expand Down Expand Up @@ -523,7 +523,7 @@ def test_computes_loss_value_with_vmap(self, loss_fn, expected_value):
vmap_loss_fn = jax.vmap(loss_fn, in_axes=(0, 0), out_axes=0)
loss = vmap_loss_fn(scores, labels)

np.testing.assert_allclose(jnp.asarray(expected_value), loss)
np.testing.assert_allclose(jnp.asarray(expected_value), loss, rtol=1e-5)

@parameterized.parameters([
{
Expand Down
10 changes: 6 additions & 4 deletions rax/_src/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,16 +31,18 @@
>>> import rax
>>> scores = jnp.array([2., 1., 3.])
>>> labels = jnp.array([2., 0., 1.])
>>> print(rax.ndcg_metric(scores, labels))
0.79670763
>>> loss = rax.ndcg_metric(scores, labels)
>>> print(f"{loss:.5f}")
0.79671
Usage with a batch of data and a mask to indicate valid items:
>>> scores = jnp.array([[2., 1., 3.], [1., 0.5, 1.5]])
>>> labels = jnp.array([[2., 0., 1.], [0., 0., 1.]])
>>> where = jnp.array([[True, True, False], [True, True, True]])
>>> print(rax.ndcg_metric(scores, labels))
0.8983538
>>> loss = rax.ndcg_metric(scores, labels)
>>> print(f"{loss:.5f}")
0.89835
Usage with :func:`jax.vmap` batching and a mask to indicate valid items:
Expand Down
36 changes: 22 additions & 14 deletions rax/_src/t12n.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,9 @@
>>> scores = jnp.asarray([0., 1., 3., 2.])
>>> labels = jnp.asarray([0., 0., 1., 2.])
>>> approx_ndcg_loss_fn = rax.approx_t12n(rax.ndcg_metric)
>>> print(approx_ndcg_loss_fn(scores, labels))
-0.71789175
>>> loss = approx_ndcg_loss_fn(scores, labels)
>>> print(f"{loss:.5f}")
-0.71789
"""

import functools
Expand Down Expand Up @@ -67,16 +68,19 @@ def approx_t12n(metric_fn: MetricFn, temperature: float = 1.0) -> LossFn:
>>> approx_mrr = rax.approx_t12n(rax.mrr_metric)
>>> scores = jnp.asarray([0., 1., 3., 2.])
>>> labels = jnp.asarray([0., 0., 1., 2.])
>>> print(approx_mrr(scores, labels))
-0.6965873
>>> loss = approx_mrr(scores, labels)
>>> print(f"{loss:.5f}")
-0.69659
Example usage together with :func:`rax.gumbel_t12n`:
>>> gumbel_approx_mrr = rax.gumbel_t12n(rax.approx_t12n(rax.mrr_metric))
>>> scores = jnp.asarray([0., 1., 3., 2.])
>>> labels = jnp.asarray([0., 0., 1., 2.])
>>> print(gumbel_approx_mrr(scores, labels, key=jax.random.PRNGKey(42)))
-0.71880937
>>> key = jax.random.PRNGKey(42)
>>> loss = gumbel_approx_mrr(scores, labels, key=key)
>>> print(f"{loss:.5f}")
-0.75971
Args:
metric_fn: The metric function to convert to an approximate loss.
Expand Down Expand Up @@ -122,16 +126,18 @@ def bound_t12n(metric_fn: MetricFn):
>>> bound_mrr = rax.bound_t12n(rax.mrr_metric)
>>> scores = jnp.asarray([0., 1., 3., 2.])
>>> labels = jnp.asarray([0., 1., 0., 1.])
>>> print(bound_mrr(scores, labels))
-0.33333334
>>> loss = bound_mrr(scores, labels)
>>> print(f"{loss:.5f}")
-0.33333
Example usage together with :func:`rax.gumbel_t12n`:
>>> gumbel_bound_mrr = rax.gumbel_t12n(rax.bound_t12n(rax.mrr_metric))
>>> scores = jnp.asarray([0., 1., 3., 2.])
>>> labels = jnp.asarray([0., 1., 0., 1.])
>>> print(gumbel_bound_mrr(scores, labels, key=jax.random.PRNGKey(42)))
-0.31619418
>>> loss = gumbel_bound_mrr(scores, labels, key=jax.random.PRNGKey(42))
>>> print(f"{loss:.5f}")
-0.40368
Args:
metric_fn: The metric function to convert to a lower-bound loss.
Expand Down Expand Up @@ -184,10 +190,12 @@ def gumbel_t12n(
>>> loss_fn = rax.gumbel_t12n(rax.softmax_loss)
>>> scores = jnp.asarray([0., 1., 3., 2.])
>>> labels = jnp.asarray([0., 0., 1., 2.])
>>> print(loss_fn(scores, labels, key=jax.random.PRNGKey(42)))
6.2066536
>>> print(loss_fn(scores, labels, key=jax.random.PRNGKey(79)))
5.0127797
>>> loss = loss_fn(scores, labels, key=jax.random.PRNGKey(42))
>>> print(f"{loss:.5f}")
3.45703
>>> loss = loss_fn(scores, labels, key=jax.random.PRNGKey(79))
>>> print(f"{loss:.5f}")
4.12491
Args:
loss_or_metric_fn: A Rax loss or metric function.
Expand Down
7 changes: 5 additions & 2 deletions rax/_src/t12n_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,9 @@
from rax._src import metrics
from rax._src import t12n

# Opt-in to the partitionable PRNG implementation.
jax.config.update("jax_threefry_partitionable", True)


class ApproxT12nTest(parameterized.TestCase):

Expand Down Expand Up @@ -244,7 +247,7 @@ def test_samples_scores_using_key(self):

loss = new_loss_fn(scores, labels, key=jax.random.PRNGKey(42))
np.testing.assert_allclose(
loss, jnp.asarray([[0.589013, 0.166654, 0.962401]]), rtol=1e-5
loss, jnp.asarray([[0.334093, 1.952019, 2.725531]]), rtol=1e-5
)

def test_repeats_inputs_n_times(self):
Expand Down Expand Up @@ -282,7 +285,7 @@ def test_handles_extreme_scores(self):

loss = new_loss_fn(scores, labels, key=jax.random.PRNGKey(42))
np.testing.assert_allclose(
loss, jnp.asarray([[-3e18, 1.666543e-01, 2e22]]), rtol=1e-5
loss, jnp.asarray([[-3e18, 1.952019, 2e22]]), rtol=1e-5
)

def test_raises_an_error_if_no_key_is_provided(self):
Expand Down
Loading

0 comments on commit 9b66b8c

Please sign in to comment.