Skip to content

Commit

Permalink
add ignore_class param to sparse CE loss (#599)
Browse files Browse the repository at this point in the history
* add ignore_class param to sparse CE loss

* fix reshape args

* fix where condition

* exclude torch backend for now

* format code

* ignore torch backend for now

* fix tests

* non-jittable version

* add jitable version

* replace condition with try/except block

* correct indentation for try/except block
  • Loading branch information
AakashKumarNain authored Aug 27, 2023
1 parent 115ab8b commit 65ef46a
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 2 deletions.
35 changes: 33 additions & 2 deletions keras_core/losses/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -947,6 +947,7 @@ class SparseCategoricalCrossentropy(LossFunctionWrapper):
def __init__(
self,
from_logits=False,
ignore_class=None,
reduction="sum_over_batch_size",
name="sparse_categorical_crossentropy",
):
Expand All @@ -955,14 +956,17 @@ def __init__(
name=name,
reduction=reduction,
from_logits=from_logits,
ignore_class=ignore_class,
)
self.from_logits = from_logits
self.ignore_class = ignore_class

def get_config(self):
return {
"name": self.name,
"reduction": self.reduction,
"from_logits": self.from_logits,
"ignore_class": self.ignore_class,
}


Expand Down Expand Up @@ -1659,14 +1663,21 @@ def categorical_focal_crossentropy(
"keras_core.losses.sparse_categorical_crossentropy",
]
)
def sparse_categorical_crossentropy(y_true, y_pred, from_logits=False, axis=-1):
def sparse_categorical_crossentropy(
y_true, y_pred, from_logits=False, ignore_class=None, axis=-1
):
"""Computes the sparse categorical crossentropy loss.
Args:
y_true: Ground truth values.
y_pred: The predicted values.
from_logits: Whether `y_pred` is expected to be a logits tensor. By
default, we assume that `y_pred` encodes a probability distribution.
ignore_class: Optional integer. The ID of a class to be ignored during
loss computation. This is useful, for example, in segmentation
problems featuring a "void" class (commonly -1 or 255) in
segmentation maps. By default (`ignore_class=None`), all classes are
considered.
axis: Defaults to -1. The dimension along which the entropy is
computed.
Expand All @@ -1682,13 +1693,33 @@ def sparse_categorical_crossentropy(y_true, y_pred, from_logits=False, axis=-1):
>>> loss
array([0.0513, 2.303], dtype=float32)
"""
return ops.sparse_categorical_crossentropy(

if ignore_class is not None:
res_shape = ops.shape(y_pred)[:-1]
valid_mask = ops.not_equal(y_true, ops.cast(ignore_class, y_pred.dtype))
y_true = y_true * ops.cast(valid_mask, y_true.dtype)
y_pred = y_pred * ops.cast(
ops.expand_dims(valid_mask, -1), y_pred.dtype
)

res = ops.sparse_categorical_crossentropy(
y_true,
y_pred,
from_logits=from_logits,
axis=axis,
)

if ignore_class is not None:
valid_mask = ops.reshape(valid_mask, res_shape)
res = ops.where(valid_mask, res, 0.0)

try:
res._keras_mask = valid_mask
except AttributeError:
pass

return res


@keras_core_export(
[
Expand Down
9 changes: 9 additions & 0 deletions keras_core/losses/losses_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1157,6 +1157,15 @@ def test_no_reduction(self):
loss = cce_obj(y_true, logits)
self.assertAllClose((0.001822, 0.000459, 0.169846), loss, 3)

def test_ignore_class(self):
y_true = np.array([[-1, 2]])
logits = np.array([[[0.854, 0.698, 0.598], [0.088, 0.86, 0.018]]])
cce_obj = losses.SparseCategoricalCrossentropy(
from_logits=True, ignore_class=-1, reduction=None
)
loss = cce_obj(y_true, logits)
self.assertAllClose([[0.0, 1.48012]], loss, 3)


class BinaryFocalCrossentropyTest(testing.TestCase):
def test_config(self):
Expand Down

0 comments on commit 65ef46a

Please sign in to comment.