From ca9ec51d4915deb3ae5fde95fc3be8a569cfd2eb Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Fri, 6 Oct 2023 03:01:46 -0700 Subject: [PATCH] Ignore incorrect type annotations related to jax dtypes PiperOrigin-RevId: 571283879 --- rax/_src/losses.py | 24 ++++++++++++------------ rax/_src/metrics.py | 14 +++++++------- 2 files changed, 19 insertions(+), 19 deletions(-) diff --git a/rax/_src/losses.py b/rax/_src/losses.py index 06bfb17..b2ad49e 100644 --- a/rax/_src/losses.py +++ b/rax/_src/losses.py @@ -65,7 +65,7 @@ from rax._src.types import ReduceFn -def softmax_loss( +def softmax_loss( # pytype: disable=annotation-type-mismatch # jnp-type scores: Array, labels: Array, *, @@ -143,7 +143,7 @@ def softmax_loss( return utils.safe_reduce(loss, where=where, reduce_fn=reduce_fn) -def poly1_softmax_loss( +def poly1_softmax_loss( # pytype: disable=annotation-type-mismatch # jnp-type scores: Array, labels: Array, *, @@ -241,7 +241,7 @@ def poly1_softmax_loss( return utils.safe_reduce(loss, where=where, reduce_fn=reduce_fn) -def unique_softmax_loss( +def unique_softmax_loss( # pytype: disable=annotation-type-mismatch # jnp-type scores: Array, labels: Array, *, @@ -340,7 +340,7 @@ def unique_softmax_loss( return utils.safe_reduce(loss, where=where, reduce_fn=reduce_fn) -def listmle_loss( +def listmle_loss( # pytype: disable=annotation-type-mismatch # jnp-type scores: Array, labels: Array, *, @@ -422,7 +422,7 @@ def listmle_loss( return utils.safe_reduce(loss, where=where, reduce_fn=reduce_fn) -def pairwise_loss( +def pairwise_loss( # pytype: disable=annotation-type-mismatch # jnp-type scores: Array, labels: Array, *, @@ -489,7 +489,7 @@ def pairwise_loss( return utils.safe_reduce(pair_losses, where=valid_pairs, reduce_fn=reduce_fn) -def pairwise_hinge_loss( +def pairwise_hinge_loss( # pytype: disable=annotation-type-mismatch # jnp-type scores: Array, labels: Array, *, @@ -545,7 +545,7 @@ def _hinge_loss( ) -def pairwise_logistic_loss( +def pairwise_logistic_loss( # pytype: disable=annotation-type-mismatch # jnp-type scores: Array, labels: Array, *, @@ -604,7 +604,7 @@ def _logistic_loss( ) -def pairwise_soft_zero_one_loss( +def pairwise_soft_zero_one_loss( # pytype: disable=annotation-type-mismatch # jnp-type scores: Array, labels: Array, *, @@ -667,7 +667,7 @@ def _soft_zero_one_loss( ) -def pointwise_sigmoid_loss( +def pointwise_sigmoid_loss( # pytype: disable=annotation-type-mismatch # jnp-type scores: Array, labels: Array, *, @@ -729,7 +729,7 @@ def pointwise_sigmoid_loss( return utils.safe_reduce(loss, where=where, reduce_fn=reduce_fn) -def pointwise_mse_loss( +def pointwise_mse_loss( # pytype: disable=annotation-type-mismatch # jnp-type scores: Array, labels: Array, *, @@ -778,7 +778,7 @@ def pointwise_mse_loss( return utils.safe_reduce(loss, where=where, reduce_fn=reduce_fn) -def pairwise_mse_loss( +def pairwise_mse_loss( # pytype: disable=annotation-type-mismatch # jnp-type scores: Array, labels: Array, *, @@ -835,7 +835,7 @@ def _mse_loss(scores_diff: Array, labels_diff: Array) -> Tuple[Array, Array]: ) -def pairwise_qr_loss( +def pairwise_qr_loss( # pytype: disable=annotation-type-mismatch # jnp-type scores: Array, labels: Array, *, diff --git a/rax/_src/metrics.py b/rax/_src/metrics.py index 4a46b47..8403c8a 100644 --- a/rax/_src/metrics.py +++ b/rax/_src/metrics.py @@ -142,7 +142,7 @@ def default_discount_fn(rank: Array) -> Array: return 1.0 / jnp.log2(rank + 1) -def mrr_metric( +def mrr_metric( # pytype: disable=annotation-type-mismatch # jnp-type scores: Array, labels: Array, *, @@ -242,7 +242,7 @@ def mrr_metric( return utils.safe_reduce(values, where=where, reduce_fn=reduce_fn) -def recall_metric( +def recall_metric( # pytype: disable=annotation-type-mismatch # jnp-type scores: Array, labels: Array, *, @@ -342,7 +342,7 @@ def recall_metric( return utils.safe_reduce(values, where=where, reduce_fn=reduce_fn) -def precision_metric( +def precision_metric( # pytype: disable=annotation-type-mismatch # jnp-type scores: Array, labels: Array, *, @@ -442,7 +442,7 @@ def precision_metric( return utils.safe_reduce(values, where=where, reduce_fn=reduce_fn) -def ap_metric( +def ap_metric( # pytype: disable=annotation-type-mismatch # jnp-type scores: Array, labels: Array, *, @@ -553,7 +553,7 @@ def ap_metric( return utils.safe_reduce(values, where=where, reduce_fn=reduce_fn) -def opa_metric( +def opa_metric( # pytype: disable=annotation-type-mismatch # jnp-type scores: Array, labels: Array, *, @@ -619,7 +619,7 @@ def opa_metric( return utils.safe_reduce(per_list_opa, where=where, reduce_fn=reduce_fn) -def dcg_metric( +def dcg_metric( # pytype: disable=annotation-type-mismatch # jnp-type scores: Array, labels: Array, *, @@ -715,7 +715,7 @@ def dcg_metric( return utils.safe_reduce(values, where=where, reduce_fn=reduce_fn) -def ndcg_metric( +def ndcg_metric( # pytype: disable=annotation-type-mismatch # jnp-type scores: Array, labels: Array, *,