Skip to content

Commit

Permalink
Remove use of initial argument to jax.nn.softmax and `jax.nn.log_…
Browse files Browse the repository at this point in the history
…softmax`

This has been deprecated and has had no effect since JAX v0.4.27, and will result in a TypeError in JAX v0.4.36.

PiperOrigin-RevId: 690740882
  • Loading branch information
Jake VanderPlas authored and Rax Developers committed Oct 28, 2024
1 parent 4a713d5 commit 8daea3c
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 5 deletions.
1 change: 0 additions & 1 deletion rax/_src/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,7 +311,6 @@ def unique_softmax_loss(
scores_repeated,
axis=-1,
where=identity_mask | labels_lt,
initial=jnp.min(scores),
)
log_softmax = jnp.diagonal(log_softmax, axis1=-2, axis2=-1)

Expand Down
8 changes: 4 additions & 4 deletions rax/_src/segment_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,8 +109,8 @@ def test_computes_log_softmax_with_mask(self):
segments = jnp.array([0, 0, 0, 1, 1])
mask = jnp.array([1, 1, 0, 1, 1])

expected_1 = jax.nn.log_softmax(a[0:3], where=mask[0:3], initial=jnp.min(a))
expected_2 = jax.nn.log_softmax(a[3:5], where=mask[3:5], initial=jnp.min(a))
expected_1 = jax.nn.log_softmax(a[0:3], where=mask[0:3])
expected_2 = jax.nn.log_softmax(a[3:5], where=mask[3:5])
actual = segment_utils.segment_log_softmax(a, segments, where=mask)

np.testing.assert_allclose(actual[0:2], expected_1[0:2])
Expand Down Expand Up @@ -146,8 +146,8 @@ def test_computes_softmax_with_mask(self):
segments = jnp.array([0, 0, 0, 1, 1])
mask = jnp.array([1, 1, 0, 1, 1])

expected_1 = jax.nn.softmax(a[0:3], where=mask[0:3], initial=jnp.min(a))
expected_2 = jax.nn.softmax(a[3:5], where=mask[3:5], initial=jnp.min(a))
expected_1 = jax.nn.softmax(a[0:3], where=mask[0:3])
expected_2 = jax.nn.softmax(a[3:5], where=mask[3:5])
actual = segment_utils.segment_softmax(a, segments, where=mask)

np.testing.assert_allclose(actual[0:2], expected_1[0:2])
Expand Down

0 comments on commit 8daea3c

Please sign in to comment.