diff --git a/docs/generate_docs_netlify.sh b/docs/generate_docs_netlify.sh index 06acff69..9545b9d6 100755 --- a/docs/generate_docs_netlify.sh +++ b/docs/generate_docs_netlify.sh @@ -13,7 +13,7 @@ poetry build -f wheel pip install dist/$(ls -1 dist | grep .whl) pip install pytorch-metric-learning==1.3.2 -pip install sphinx>=5.0.1 +pip install sphinx==6.1.3 pip install "git+https://github.com/qdrant/qdrant_sphinx_theme.git@master#egg=qdrant-sphinx-theme" sphinx-apidoc --force --separate --no-toc -o docs/source quaterion diff --git a/docs/source/api/index.rst b/docs/source/api/index.rst index f1e83afb..af2d0a15 100644 --- a/docs/source/api/index.rst +++ b/docs/source/api/index.rst @@ -128,8 +128,9 @@ Implementations ~softmax_loss.SoftmaxLoss ~triplet_loss.TripletLoss ~circle_loss.CircleLoss - ~fastap_loss.FastAPLoss + ~fast_ap_loss.FastAPLoss ~cos_face_loss.CosFaceLoss + ~center_loss.CenterLoss Extras ++++++ diff --git a/docs/source/quaterion.loss.center_loss.rst b/docs/source/quaterion.loss.center_loss.rst new file mode 100644 index 00000000..7d45b3ac --- /dev/null +++ b/docs/source/quaterion.loss.center_loss.rst @@ -0,0 +1,7 @@ +quaterion.loss.center\_loss module +================================== + +.. automodule:: quaterion.loss.center_loss + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/source/quaterion.loss.circle_loss.rst b/docs/source/quaterion.loss.circle_loss.rst new file mode 100644 index 00000000..99b09b58 --- /dev/null +++ b/docs/source/quaterion.loss.circle_loss.rst @@ -0,0 +1,7 @@ +quaterion.loss.circle\_loss module +================================== + +.. automodule:: quaterion.loss.circle_loss + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/source/quaterion.loss.cos_face_loss.rst b/docs/source/quaterion.loss.cos_face_loss.rst new file mode 100644 index 00000000..54589bc0 --- /dev/null +++ b/docs/source/quaterion.loss.cos_face_loss.rst @@ -0,0 +1,7 @@ +quaterion.loss.cos\_face\_loss module +===================================== + +.. automodule:: quaterion.loss.cos_face_loss + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/source/quaterion.loss.fast_ap_loss.rst b/docs/source/quaterion.loss.fast_ap_loss.rst new file mode 100644 index 00000000..5a96dc89 --- /dev/null +++ b/docs/source/quaterion.loss.fast_ap_loss.rst @@ -0,0 +1,7 @@ +quaterion.loss.fast\_ap\_loss module +==================================== + +.. automodule:: quaterion.loss.fast_ap_loss + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/source/quaterion.loss.rst b/docs/source/quaterion.loss.rst index 9e4a254d..f90ef41d 100644 --- a/docs/source/quaterion.loss.rst +++ b/docs/source/quaterion.loss.rst @@ -16,7 +16,11 @@ Submodules :maxdepth: 4 quaterion.loss.arcface_loss + quaterion.loss.center_loss + quaterion.loss.circle_loss quaterion.loss.contrastive_loss + quaterion.loss.cos_face_loss + quaterion.loss.fast_ap_loss quaterion.loss.group_loss quaterion.loss.multiple_negatives_ranking_loss quaterion.loss.online_contrastive_loss diff --git a/docs/source/tutorials/triplet_loss_trick.rst b/docs/source/tutorials/triplet_loss_trick.rst index ec57fe0e..867f02ad 100644 --- a/docs/source/tutorials/triplet_loss_trick.rst +++ b/docs/source/tutorials/triplet_loss_trick.rst @@ -1,5 +1,5 @@ Triplet Loss: Vector Collapse Prevention -============================ +======================================== Triplet Loss is one of the most widely known loss functions in similarity learning. If you want to deep-dive into the details of its implementations and advantages, diff --git a/quaterion/loss/__init__.py b/quaterion/loss/__init__.py index 5cb1280d..c0a4d52a 100644 --- a/quaterion/loss/__init__.py +++ b/quaterion/loss/__init__.py @@ -1,4 +1,5 @@ from quaterion.loss.arcface_loss import ArcFaceLoss +from quaterion.loss.center_loss import CenterLoss from quaterion.loss.circle_loss import CircleLoss from quaterion.loss.contrastive_loss import ContrastiveLoss from quaterion.loss.cos_face_loss import CosFaceLoss diff --git a/quaterion/loss/center_loss.py b/quaterion/loss/center_loss.py new file mode 100644 index 00000000..d5f93a97 --- /dev/null +++ b/quaterion/loss/center_loss.py @@ -0,0 +1,57 @@ +from typing import Optional + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch import LongTensor, Tensor + +from quaterion.loss.group_loss import GroupLoss +from quaterion.utils import l2_norm + + +class CenterLoss(GroupLoss): + """ + Center Loss as defined in the paper "A Discriminative Feature Learning Approach + for Deep Face Recognition" (http://ydwen.github.io/papers/WenECCV16.pdf) + It aims to minimize the intra-class variations while keeping the features of + different classes separable. + + Args: + embedding_size: Output dimension of the encoder. + num_groups: Number of groups (classes) in the dataset. + lambda_c: A regularization parameter that controls the contribution of the center loss. + """ + + def __init__( + self, embedding_size: int, num_groups: int, lambda_c: Optional[float] = 0.5 + ): + super(GroupLoss, self).__init__() + self.num_groups = num_groups + self.centers = nn.Parameter(torch.randn(num_groups, embedding_size)) + self.lambda_c = lambda_c + + nn.init.xavier_uniform_(self.centers) + + def forward(self, embeddings: Tensor, groups: LongTensor) -> Tensor: + """ + Compute the Center Loss value. + + Args: + embeddings: shape (batch_size, vector_length) - Output embeddings from the encoder. + groups: shape (batch_size,) - Group (class) ids associated with embeddings. + + Returns: + Tensor: loss value. + """ + embeddings = l2_norm(embeddings, 1) + + # Gather the center for each embedding's corresponding group + centers_batch = self.centers.index_select(0, groups) + + # Calculate the distance between embeddings and their respective class centers + loss = F.mse_loss(embeddings, centers_batch) + + # Scale the loss by the regularization parameter + loss *= self.lambda_c + + return loss diff --git a/tests/eval/losses/test_center_loss.py b/tests/eval/losses/test_center_loss.py new file mode 100644 index 00000000..cd5a1035 --- /dev/null +++ b/tests/eval/losses/test_center_loss.py @@ -0,0 +1,32 @@ +import torch + +from quaterion.loss import CenterLoss + + +class TestCenterLoss: + embeddings = torch.Tensor( + [ + [0.0, -1.0, 0.5], + [0.1, 2.0, 0.5], + [0.0, 0.3, 0.2], + [1.0, 0.0, 0.9], + [1.2, -1.2, 0.01], + [-0.7, 0.0, 1.5], + ] + ) + groups = torch.LongTensor([1, 2, 0, 0, 2, 1]) + + def test_batch_all(self): + # Initialize the CenterLoss + loss = CenterLoss(embedding_size=self.embeddings.size()[1], num_groups=3) + + # Calculate the loss + loss_res = loss.forward(embeddings=self.embeddings, groups=self.groups) + + # Assertions to check the output shape and type + assert isinstance( + loss_res, torch.Tensor + ), "Loss result should be a torch.Tensor" + assert loss_res.shape == torch.Size( + [] + ), "Loss result should be a scalar (0-dimension tensor)"