Skip to content
This repository has been archived by the owner on Nov 29, 2022. It is now read-only.

Commit

Permalink
Added class_weight parameter (#11)
Browse files Browse the repository at this point in the history
* Added class_weight parameter to sparse_categorical_focal_loss

* Added class_weight to SparseCategoricalFocalLoss
  • Loading branch information
artemmavrin authored Nov 1, 2020
1 parent 9e023de commit 5b2ca68
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 2 deletions.
28 changes: 26 additions & 2 deletions src/focal_loss/_categorical_focal_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,15 @@
# |_| \___/ \___| \__,_| |_| |_| \___/ |___/ |___/

import itertools
from typing import Any, Optional

import tensorflow as tf

_EPSILON = tf.keras.backend.epsilon()


def sparse_categorical_focal_loss(y_true, y_pred, gamma, *,
class_weight: Optional[Any] = None,
from_logits: bool = False, axis: int = -1
) -> tf.Tensor:
r"""Focal loss function for multiclass classification with integer labels.
Expand Down Expand Up @@ -65,6 +67,10 @@ def sparse_categorical_focal_loss(y_true, y_pred, gamma, *,
one-dimensional tensor, in which case it specifies a focusing parameter
for each class.
class_weight: tensor-like of shape (K,)
Weighting factor for each of the :math:`k` classes. If not specified,
then all classes are weighted equally.
from_logits : bool, optional
Whether `y_pred` contains logits or probabilities.
Expand Down Expand Up @@ -116,6 +122,11 @@ def sparse_categorical_focal_loss(y_true, y_pred, gamma, *,
gamma_rank = gamma.shape.rank
scalar_gamma = gamma_rank == 0

# Process class weight
if class_weight is not None:
class_weight = tf.convert_to_tensor(class_weight,
dtype=tf.dtypes.float32)

# Process prediction tensor
y_pred = tf.convert_to_tensor(y_pred)
y_pred_rank = y_pred.shape.rank
Expand Down Expand Up @@ -165,6 +176,11 @@ def sparse_categorical_focal_loss(y_true, y_pred, gamma, *,
focal_modulation = (1 - probs) ** gamma
loss = focal_modulation * xent_loss

if class_weight is not None:
class_weight = tf.gather(class_weight, y_true, axis=0,
batch_dims=y_true_rank)
loss *= class_weight

if reshape_needed:
loss = tf.reshape(loss, y_pred_shape[:-1])

Expand Down Expand Up @@ -193,6 +209,10 @@ class SparseCategoricalFocalLoss(tf.keras.losses.Loss):
one-dimensional tensor, in which case it specifies a focusing parameter
for each class.
class_weight: tensor-like of shape (K,)
Weighting factor for each of the :math:`k` classes. If not specified,
then all classes are weighted equally.
from_logits : bool, optional
Whether model prediction will be logits or probabilities.
Expand Down Expand Up @@ -238,9 +258,11 @@ class SparseCategoricalFocalLoss(tf.keras.losses.Loss):
tensor and a prediction tensor and outputting a loss.
"""

def __init__(self, gamma, from_logits: bool = False, **kwargs):
def __init__(self, gamma, class_weight: Optional[Any] = None,
from_logits: bool = False, **kwargs):
super().__init__(**kwargs)
self.gamma = gamma
self.class_weight = class_weight
self.from_logits = from_logits

def get_config(self):
Expand All @@ -256,7 +278,8 @@ def get_config(self):
This layer's config.
"""
config = super().get_config()
config.update(gamma=self.gamma, from_logits=self.from_logits)
config.update(gamma=self.gamma, class_weight=self.class_weight,
from_logits=self.from_logits)
return config

def call(self, y_true, y_pred):
Expand All @@ -283,5 +306,6 @@ def call(self, y_true, y_pred):
:meth:`~focal_loss.SparseCateogiricalFocalLoss.__call__` method.
"""
return sparse_categorical_focal_loss(y_true=y_true, y_pred=y_pred,
class_weight=self.class_weight,
gamma=self.gamma,
from_logits=self.from_logits)
23 changes: 23 additions & 0 deletions src/focal_loss/tests/test_sparse_categorical_focal_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,4 +355,27 @@ def test_with_dynamic_ranks(self, gamma, from_logits):

self.assertAllClose(loss, loss_numpy)

@named_parameters_with_testcase_names(y_true=Y_TRUE, y_pred=Y_PRED_PROB,
gamma=[0, 1, 2])
def test_class_weight(self, y_true, y_pred, gamma):
rng = np.random.default_rng(0)
for _ in range(10):
class_weight = rng.uniform(size=np.shape(y_pred)[-1])

loss_without_weight = sparse_categorical_focal_loss(
y_true=y_true,
y_pred=y_pred,
gamma=gamma,
)
loss_with_weight = sparse_categorical_focal_loss(
y_true=y_true,
y_pred=y_pred,
gamma=gamma,
class_weight=class_weight,
)

# Apply class weights to loss computed without class_weight
loss_without_weight = loss_without_weight.numpy()
loss_without_weight *= np.take(class_weight, y_true)

self.assertAllClose(loss_with_weight, loss_without_weight)

0 comments on commit 5b2ca68

Please sign in to comment.