Skip to content

Commit

Permalink
[LSC] Ignore incorrect type annotations related to jax.numpy APIs
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 568300084
  • Loading branch information
Jake VanderPlas authored and Rax Developers committed Sep 25, 2023
1 parent 37132c2 commit e44b62a
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions rax/_src/segment_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def segment_log_softmax(
a: Array, segments: Array, where: Optional[Array] = None
) -> Array:
"""Returns segment log-softmax."""
a_max = segment_max(a, segments, where=where, initial=jnp.min(a))
a_max = segment_max(a, segments, where=where, initial=jnp.min(a)) # pytype: disable=wrong-arg-types # jnp-type
shifted = a - jax.lax.stop_gradient(a_max)
shifted_logsumexp = jnp.log(
segment_sum(jnp.exp(shifted), segments, where=where)
Expand All @@ -79,7 +79,7 @@ def segment_softmax(
a: Array, segments: Array, where: Optional[Array] = None
) -> Array:
"""Returns segment softmax."""
a_max = segment_max(a, segments, where=where, initial=jnp.min(a))
a_max = segment_max(a, segments, where=where, initial=jnp.min(a)) # pytype: disable=wrong-arg-types # jnp-type
unnormalized = jnp.exp(a - jax.lax.stop_gradient(a_max))
return unnormalized / segment_sum(unnormalized, segments, where=where)

Expand Down

0 comments on commit e44b62a

Please sign in to comment.