Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add ignore_class param to sparse CE loss #599

Merged
merged 15 commits into from
Aug 27, 2023
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 31 additions & 2 deletions keras_core/losses/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -947,6 +947,7 @@ class SparseCategoricalCrossentropy(LossFunctionWrapper):
def __init__(
self,
from_logits=False,
ignore_class=None,
reduction="sum_over_batch_size",
name="sparse_categorical_crossentropy",
):
Expand All @@ -955,14 +956,17 @@ def __init__(
name=name,
reduction=reduction,
from_logits=from_logits,
ignore_class=ignore_class,
)
self.from_logits = from_logits
self.ignore_class = ignore_class

def get_config(self):
return {
"name": self.name,
"reduction": self.reduction,
"from_logits": self.from_logits,
"ignore_class": self.ignore_class,
}


Expand Down Expand Up @@ -1659,14 +1663,21 @@ def categorical_focal_crossentropy(
"keras_core.losses.sparse_categorical_crossentropy",
]
)
def sparse_categorical_crossentropy(y_true, y_pred, from_logits=False, axis=-1):
def sparse_categorical_crossentropy(
y_true, y_pred, from_logits=False, ignore_class=None, axis=-1
):
"""Computes the sparse categorical crossentropy loss.

Args:
y_true: Ground truth values.
y_pred: The predicted values.
from_logits: Whether `y_pred` is expected to be a logits tensor. By
default, we assume that `y_pred` encodes a probability distribution.
ignore_class: Optional integer. The ID of a class to be ignored during
loss computation. This is useful, for example, in segmentation
problems featuring a "void" class (commonly -1 or 255) in
segmentation maps. By default (ignore_class=None), all classes are
AakashKumarNain marked this conversation as resolved.
Show resolved Hide resolved
considered.
axis: Defaults to -1. The dimension along which the entropy is
computed.

Expand All @@ -1682,13 +1693,31 @@ def sparse_categorical_crossentropy(y_true, y_pred, from_logits=False, axis=-1):
>>> loss
array([0.0513, 2.303], dtype=float32)
"""
return ops.sparse_categorical_crossentropy(

if ignore_class is not None:
res_shape = ops.shape(y_pred)[:-1]
valid_mask = ops.not_equal(y_true, ops.cast(ignore_class, y_pred.dtype))
y_true = y_true * ops.cast(valid_mask, y_true.dtype)
y_pred = y_pred * ops.cast(
ops.expand_dims(valid_mask, -1), y_pred.dtype
)

res = ops.sparse_categorical_crossentropy(
y_true,
y_pred,
from_logits=from_logits,
axis=axis,
)

if ignore_class is not None:
valid_mask = ops.reshape(valid_mask, res_shape)
res = ops.where(valid_mask, res, 0.0)

if backend.backend() != "numpy":
AakashKumarNain marked this conversation as resolved.
Show resolved Hide resolved
res._keras_mask = valid_mask

return res


@keras_core_export(
[
Expand Down
9 changes: 9 additions & 0 deletions keras_core/losses/losses_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1157,6 +1157,15 @@ def test_no_reduction(self):
loss = cce_obj(y_true, logits)
self.assertAllClose((0.001822, 0.000459, 0.169846), loss, 3)

def test_ignore_class(self):
y_true = np.array([[-1, 2]])
logits = np.array([[[0.854, 0.698, 0.598], [0.088, 0.86, 0.018]]])
cce_obj = losses.SparseCategoricalCrossentropy(
from_logits=True, ignore_class=-1, reduction=None
)
loss = cce_obj(y_true, logits)
self.assertAllClose([[0.0, 1.48012]], loss, 3)


class BinaryFocalCrossentropyTest(testing.TestCase):
def test_config(self):
Expand Down