-
Notifications
You must be signed in to change notification settings - Fork 116
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add ignore_class param to sparse CE loss #599
add ignore_class param to sparse CE loss #599
Conversation
@fchollet Apart from that, |
We should make the behavior consistent across all backends -- that's important. And when applicable the behavior should also be standardized on the numpy behavior. |
Thank you. I will make the changes accordingly |
@fchollet I have skipped |
I just realized that even though the tests for jax are passing, I am skeptical about it working when the loss function is complied. Why? Because the masking is dynamic in this case |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the PR. What's the error in torch? And for JAX, surely correctness is being tested so the test passing should be proof that it works?
Thanks for review it. Here is the traceback for torch:
I don't know how the test is passing because I definitely get error when I try to def loss_fn(y_true, y_pred, ignore_class=-1):
mask = ops.equal(y_true, ignore_class)
mask_idx = ops.stack(jnp.nonzero(mask, size=len(mask)), -1)
y_true_masked = ops.where(mask, 0.0, y_true)
y_pred_masked = ops.where(mask, 0.0, y_pred)
loss = ops.sparse_categorical_crossentropy(
target=y_true_masked,
output=y_pred_masked,
from_logits=True,
axis=-1
)
mask = ops.reshape(mask, y_pred.shape[:-1])
loss = ops.where(mask, 0.0, loss)
return loss The problem with this code though that it implements the lax compatible |
I don't know if it's the place to add a comment or try to help and I may be mistaken, but I believe the behaviour of ignore_class could be expected without using ops.where, for example using the following implementation:
It does not use ops.sparse_categorical_crossentropy though, but the masking bit seems to work accross all backends. |
Thanks for the suggestion. If you look at the changed files in this PR, this is exactly what it does except for using |
Update: Now that I have fixed the |
@fchollet good news! The current version is now fully jit compatible. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the update! Looking good
@fchollet thank you for the review and the valuable suggestions. I have made the necessary changes. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks!
Add
ignore_class
param for sparse categorical cross entropy loss