Skip to content

Commit

Permalink
Ignore incorrect type annotations related to jax dtypes
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 571283879
  • Loading branch information
Jake VanderPlas authored and Rax Developers committed Oct 6, 2023
1 parent e44b62a commit ca9ec51
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 19 deletions.
24 changes: 12 additions & 12 deletions rax/_src/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@
from rax._src.types import ReduceFn


def softmax_loss(
def softmax_loss( # pytype: disable=annotation-type-mismatch # jnp-type
scores: Array,
labels: Array,
*,
Expand Down Expand Up @@ -143,7 +143,7 @@ def softmax_loss(
return utils.safe_reduce(loss, where=where, reduce_fn=reduce_fn)


def poly1_softmax_loss(
def poly1_softmax_loss( # pytype: disable=annotation-type-mismatch # jnp-type
scores: Array,
labels: Array,
*,
Expand Down Expand Up @@ -241,7 +241,7 @@ def poly1_softmax_loss(
return utils.safe_reduce(loss, where=where, reduce_fn=reduce_fn)


def unique_softmax_loss(
def unique_softmax_loss( # pytype: disable=annotation-type-mismatch # jnp-type
scores: Array,
labels: Array,
*,
Expand Down Expand Up @@ -340,7 +340,7 @@ def unique_softmax_loss(
return utils.safe_reduce(loss, where=where, reduce_fn=reduce_fn)


def listmle_loss(
def listmle_loss( # pytype: disable=annotation-type-mismatch # jnp-type
scores: Array,
labels: Array,
*,
Expand Down Expand Up @@ -422,7 +422,7 @@ def listmle_loss(
return utils.safe_reduce(loss, where=where, reduce_fn=reduce_fn)


def pairwise_loss(
def pairwise_loss( # pytype: disable=annotation-type-mismatch # jnp-type
scores: Array,
labels: Array,
*,
Expand Down Expand Up @@ -489,7 +489,7 @@ def pairwise_loss(
return utils.safe_reduce(pair_losses, where=valid_pairs, reduce_fn=reduce_fn)


def pairwise_hinge_loss(
def pairwise_hinge_loss( # pytype: disable=annotation-type-mismatch # jnp-type
scores: Array,
labels: Array,
*,
Expand Down Expand Up @@ -545,7 +545,7 @@ def _hinge_loss(
)


def pairwise_logistic_loss(
def pairwise_logistic_loss( # pytype: disable=annotation-type-mismatch # jnp-type
scores: Array,
labels: Array,
*,
Expand Down Expand Up @@ -604,7 +604,7 @@ def _logistic_loss(
)


def pairwise_soft_zero_one_loss(
def pairwise_soft_zero_one_loss( # pytype: disable=annotation-type-mismatch # jnp-type
scores: Array,
labels: Array,
*,
Expand Down Expand Up @@ -667,7 +667,7 @@ def _soft_zero_one_loss(
)


def pointwise_sigmoid_loss(
def pointwise_sigmoid_loss( # pytype: disable=annotation-type-mismatch # jnp-type
scores: Array,
labels: Array,
*,
Expand Down Expand Up @@ -729,7 +729,7 @@ def pointwise_sigmoid_loss(
return utils.safe_reduce(loss, where=where, reduce_fn=reduce_fn)


def pointwise_mse_loss(
def pointwise_mse_loss( # pytype: disable=annotation-type-mismatch # jnp-type
scores: Array,
labels: Array,
*,
Expand Down Expand Up @@ -778,7 +778,7 @@ def pointwise_mse_loss(
return utils.safe_reduce(loss, where=where, reduce_fn=reduce_fn)


def pairwise_mse_loss(
def pairwise_mse_loss( # pytype: disable=annotation-type-mismatch # jnp-type
scores: Array,
labels: Array,
*,
Expand Down Expand Up @@ -835,7 +835,7 @@ def _mse_loss(scores_diff: Array, labels_diff: Array) -> Tuple[Array, Array]:
)


def pairwise_qr_loss(
def pairwise_qr_loss( # pytype: disable=annotation-type-mismatch # jnp-type
scores: Array,
labels: Array,
*,
Expand Down
14 changes: 7 additions & 7 deletions rax/_src/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ def default_discount_fn(rank: Array) -> Array:
return 1.0 / jnp.log2(rank + 1)


def mrr_metric(
def mrr_metric( # pytype: disable=annotation-type-mismatch # jnp-type
scores: Array,
labels: Array,
*,
Expand Down Expand Up @@ -242,7 +242,7 @@ def mrr_metric(
return utils.safe_reduce(values, where=where, reduce_fn=reduce_fn)


def recall_metric(
def recall_metric( # pytype: disable=annotation-type-mismatch # jnp-type
scores: Array,
labels: Array,
*,
Expand Down Expand Up @@ -342,7 +342,7 @@ def recall_metric(
return utils.safe_reduce(values, where=where, reduce_fn=reduce_fn)


def precision_metric(
def precision_metric( # pytype: disable=annotation-type-mismatch # jnp-type
scores: Array,
labels: Array,
*,
Expand Down Expand Up @@ -442,7 +442,7 @@ def precision_metric(
return utils.safe_reduce(values, where=where, reduce_fn=reduce_fn)


def ap_metric(
def ap_metric( # pytype: disable=annotation-type-mismatch # jnp-type
scores: Array,
labels: Array,
*,
Expand Down Expand Up @@ -553,7 +553,7 @@ def ap_metric(
return utils.safe_reduce(values, where=where, reduce_fn=reduce_fn)


def opa_metric(
def opa_metric( # pytype: disable=annotation-type-mismatch # jnp-type
scores: Array,
labels: Array,
*,
Expand Down Expand Up @@ -619,7 +619,7 @@ def opa_metric(
return utils.safe_reduce(per_list_opa, where=where, reduce_fn=reduce_fn)


def dcg_metric(
def dcg_metric( # pytype: disable=annotation-type-mismatch # jnp-type
scores: Array,
labels: Array,
*,
Expand Down Expand Up @@ -715,7 +715,7 @@ def dcg_metric(
return utils.safe_reduce(values, where=where, reduce_fn=reduce_fn)


def ndcg_metric(
def ndcg_metric( # pytype: disable=annotation-type-mismatch # jnp-type
scores: Array,
labels: Array,
*,
Expand Down

0 comments on commit ca9ec51

Please sign in to comment.