Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(losses): add Tversky loss implementation #19511

Merged
merged 2 commits into from
Apr 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions keras/losses/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from keras.losses.losses import Poisson
from keras.losses.losses import SparseCategoricalCrossentropy
from keras.losses.losses import SquaredHinge
from keras.losses.losses import Tversky
from keras.losses.losses import binary_crossentropy
from keras.losses.losses import binary_focal_crossentropy
from keras.losses.losses import categorical_crossentropy
Expand All @@ -40,6 +41,7 @@
from keras.losses.losses import poisson
from keras.losses.losses import sparse_categorical_crossentropy
from keras.losses.losses import squared_hinge
from keras.losses.losses import tversky
from keras.saving import serialization_lib

ALL_OBJECTS = {
Expand Down Expand Up @@ -68,6 +70,7 @@
CategoricalHinge,
# Image segmentation
Dice,
Tversky,
# Probabilistic
kl_divergence,
poisson,
Expand All @@ -90,6 +93,7 @@
categorical_hinge,
# Image segmentation
dice,
tversky,
}

ALL_OBJECTS_DICT = {cls.__name__: cls for cls in ALL_OBJECTS}
Expand Down
90 changes: 90 additions & 0 deletions keras/losses/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -2000,3 +2000,93 @@ def dice(y_true, y_pred):
)

return 1 - dice


@keras_export("keras.losses.Tversky")
class Tversky(LossFunctionWrapper):
"""Computes the Tversky loss value between `y_true` and `y_pred`.

This loss function is weighted by the alpha and beta coefficients
that penalize false positives and false negatives.

With `alpha=0.5` and `beta=0.5`, the loss value becomes equivalent to
Dice Loss.

Args:
y_true: tensor of true targets.
y_pred: tensor of predicted targets.
alpha: coefficient controlling incidence of false positives.
beta: coefficient controlling incidence of false negatives.

Returns:
lpizzinidev marked this conversation as resolved.
Show resolved Hide resolved
Tversky loss value.

Reference:

- [Salehi et al., 2017](https://arxiv.org/abs/1706.05721)
"""

def __init__(
self,
alpha=0.5,
beta=0.5,
reduction="sum_over_batch_size",
name="tversky",
):
super().__init__(
tversky,
alpha=alpha,
beta=beta,
name=name,
reduction=reduction,
)
self.alpha = alpha
self.beta = beta

def get_config(self):
return {
"name": self.name,
"alpha": self.alpha,
"beta": self.beta,
"reduction": self.reduction,
}


@keras_export("keras.losses.tversky")
def tversky(y_true, y_pred, alpha=0.5, beta=0.5):
"""Computes the Tversky loss value between `y_true` and `y_pred`.

This loss function is weighted by the alpha and beta coefficients
that penalize false positives and false negatives.

With `alpha=0.5` and `beta=0.5`, the loss value becomes equivalent to
Dice Loss.

Args:
y_true: tensor of true targets.
y_pred: tensor of predicted targets.
alpha: coefficient controlling incidence of false positives.
beta: coefficient controlling incidence of false negatives.

Returns:
Tversky loss value.

Reference:

- [Salehi et al., 2017](https://arxiv.org/abs/1706.05721)
"""
y_pred = ops.convert_to_tensor(y_pred)
y_true = ops.cast(y_true, y_pred.dtype)

inputs = ops.reshape(y_true, [-1])
targets = ops.reshape(y_pred, [-1])

intersection = ops.sum(inputs * targets)
fp = ops.sum((1 - targets) * inputs)
fn = ops.sum(targets * (1 - inputs))
tversky = ops.divide(
intersection,
intersection + fp * alpha + fn * beta + backend.epsilon(),
)

return 1 - tversky
37 changes: 37 additions & 0 deletions keras/losses/losses_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1409,3 +1409,40 @@ def test_binary_segmentation(self):
)
output = losses.Dice()(y_true, y_pred)
self.assertAllClose(output, 0.77777773)


class TverskyTest(testing.TestCase):
def test_config(self):
self.run_class_serialization_test(losses.Tversky(name="mytversky"))

def test_correctness(self):
y_true = np.array(([[1, 2], [1, 2]]))
y_pred = np.array(([[4, 1], [6, 1]]))
output = losses.Tversky()(y_true, y_pred)
self.assertAllClose(output, -0.55555546)

def test_correctness_custom_coefficients(self):
y_true = np.array(([[1, 2], [1, 2]]))
y_pred = np.array(([[4, 1], [6, 1]]))
output = losses.Tversky(alpha=0.2, beta=0.8)(y_true, y_pred)
self.assertAllClose(output, -0.29629636)

def test_binary_segmentation(self):
y_true = np.array(
([[1, 0, 1, 0], [0, 1, 0, 1], [1, 0, 1, 0], [0, 1, 0, 1]])
)
y_pred = np.array(
([[0, 1, 0, 1], [1, 0, 1, 1], [0, 1, 0, 1], [1, 0, 1, 1]])
)
output = losses.Tversky()(y_true, y_pred)
self.assertAllClose(output, 0.77777773)

def test_binary_segmentation_custom_coefficients(self):
y_true = np.array(
([[1, 0, 1, 0], [0, 1, 0, 1], [1, 0, 1, 0], [0, 1, 0, 1]])
)
y_pred = np.array(
([[0, 1, 0, 1], [1, 0, 1, 1], [0, 1, 0, 1], [1, 0, 1, 1]])
)
output = losses.Tversky(alpha=0.2, beta=0.8)(y_true, y_pred)
self.assertAllClose(output, 0.7916667)