Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a couple of options to learn-to-rank losses and metrics #125

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion rax/_src/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def softmax_loss(
weights: An optional ``[..., list_size]``-:class:`~jax.Array`, indicating
the weight for each item.
label_fn: A label function that maps labels to probabilities. Default keeps
labels as-is.
labels as-is. See rax.utils.normalize_probabilities for an example.
reduce_fn: An optional function that reduces the loss values. Can be
:func:`jax.numpy.sum` or :func:`jax.numpy.mean`. If ``None``, no reduction
is performed.
Expand Down
2 changes: 2 additions & 0 deletions rax/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from rax._src.utils import approx_ranks
from rax._src.utils import compute_pairs
from rax._src.utils import cutoff
from rax._src.utils import normalize_probabilities
from rax._src.utils import ranks
from rax._src.utils import safe_reduce

Expand All @@ -27,6 +28,7 @@
"approx_ranks",
"compute_pairs",
"cutoff",
"normalize_probabilities",
"pairwise_loss",
"ranks",
"safe_reduce",
Expand Down
Loading