From 1ec2e033adea0448c83d86789b3ffcf3e9f1731b Mon Sep 17 00:00:00 2001 From: ilanaliouchouche Date: Sun, 29 Dec 2024 02:03:01 +0100 Subject: [PATCH 1/4] todo: final formula + doc + add scales(temp) --- .../DebiasedMultipleNegativesRankingLoss.py | 52 +++++++++++++++++++ 1 file changed, 52 insertions(+) create mode 100644 sentence_transformers/losses/DebiasedMultipleNegativesRankingLoss.py diff --git a/sentence_transformers/losses/DebiasedMultipleNegativesRankingLoss.py b/sentence_transformers/losses/DebiasedMultipleNegativesRankingLoss.py new file mode 100644 index 000000000..59964c0c1 --- /dev/null +++ b/sentence_transformers/losses/DebiasedMultipleNegativesRankingLoss.py @@ -0,0 +1,52 @@ +from __future__ import annotations + +from collections.abc import Iterable +from typing import Any + +import torch +from torch import Tensor, nn + +from sentence_transformers import util +from sentence_transformers.SentenceTransformer import SentenceTransformer + + +class DebiasedMultipleNegativesRankingLoss(nn.Module): + def __init__(self, model: SentenceTransformer, scale: float = 20.0, similarity_fct=util.cos_sim, tau_plus: float = 0.01) -> None: + super().__init__() + self.model = model + self.scale = scale + self.similarity_fct = similarity_fct + self.tau_plus = tau_plus + self.cross_entropy_loss = nn.CrossEntropyLoss() + + def forward(self, sentence_features: Iterable[dict[str, Tensor]], labels: Tensor) -> Tensor: + # Compute the embeddings and distribute them to anchor and candidates (positive and optionally negatives) + embeddings = [self.model(sentence_feature)["sentence_embedding"] for sentence_feature in sentence_features] + anchors = embeddings[0] # (batch_size, embedding_dim) + candidates = torch.cat(embeddings[1:]) # (batch_size * (1 + num_negatives), embedding_dim) + + # For every anchor, we compute the similarity to all other candidates (positives and negatives), + # also from other anchors. This gives us a lot of in-batch negatives. + scores: Tensor = self.similarity_fct(anchors, candidates) * self.scale + # (batch_size, batch_size * (1 + num_negatives)) + + mask = torch.ones_like(scores, dtype=torch.bool) + for i in range(scores.size(0)): + mask[i, i] = False + + neg_exp = torch.exp(scores.masked_fill(mask, float("-inf"))).sum(dim=-1) + + pos_exp = torch.exp(torch.gather(scores, -1, + torch.arange(scores.size(0), + device=scores.device).unsqueeze(1)).squeeze()) + + N = scores.size(1) - 1 + + def get_config_dict(self) -> dict[str, Any]: + return {"scale": self.scale, "similarity_fct": self.similarity_fct.__name__} + + @property + def citation(self) -> str: + return """ +TODO: Add citation +""" From 32c41dbf48785eba26600c1812fa9d5bbd90d35b Mon Sep 17 00:00:00 2001 From: ilanaliouchouche Date: Sun, 29 Dec 2024 16:51:31 +0100 Subject: [PATCH 2/4] Loss Class & Doc Done. TODO: Prepare PR --- .../DebiasedMultipleNegativesRankingLoss.py | 115 ++++++++++++++++-- 1 file changed, 105 insertions(+), 10 deletions(-) diff --git a/sentence_transformers/losses/DebiasedMultipleNegativesRankingLoss.py b/sentence_transformers/losses/DebiasedMultipleNegativesRankingLoss.py index 59964c0c1..f3931012e 100644 --- a/sentence_transformers/losses/DebiasedMultipleNegativesRankingLoss.py +++ b/sentence_transformers/losses/DebiasedMultipleNegativesRankingLoss.py @@ -12,6 +12,84 @@ class DebiasedMultipleNegativesRankingLoss(nn.Module): def __init__(self, model: SentenceTransformer, scale: float = 20.0, similarity_fct=util.cos_sim, tau_plus: float = 0.01) -> None: + """ + This loss is a debiased version of the `MultipleNegativesRankingLoss` loss that addresses the inherent sampling bias in the negative examples. + + In standard contrastive loss, negative samples are drawn randomly from the dataset, leading to potential false negatives. + + This debiased loss adjusts for this sampling bias by reweighting the contributions of positive and negative terms in the denominator. + + For each ``a_i``, it uses all other ``p_j`` as negative samples, i.e., for ``a_i``, we have 1 positive example + (``p_i``) and ``n-1`` negative examples (``p_j``). Unlike the standard implementation, this loss applies a correction + term to account for the sampling bias introduced by in-batch negatives. Specifically, it adjusts the influence of + negatives based on a prior probability ``tau_plus``. + + It then minimizes the negative log-likelihood for softmax-normalized scores while reweighting the contributions of + positive and negative terms in the denominator. + + This loss function works great to train embeddings for retrieval setups where you have positive pairs + (e.g., (query, relevant_doc)) as it will sample in each batch ``n-1`` negative docs randomly and incorporate a bias + correction for improved robustness. + + The performance usually increases with increasing batch sizes. + + You can also provide one or multiple hard negatives per anchor-positive pair by structuring the data like this: + ``(a_1, p_1, n_1), (a_2, p_2, n_2)``. Then, ``n_1`` is a hard negative for ``(a_1, p_1)``. The loss will use for + the pair ``(a_i, p_i)`` all ``p_j`` for ``j != i``, all ``n_j`` as negatives, and apply the bias correction. + + Args: + model: SentenceTransformer model + scale: Output of similarity function is multiplied by scale + value + similarity_fct: similarity function between sentence + embeddings. By default, cos_sim. Can also be set to dot + product (and then set scale to 1) + tau_plus: Prior probability. + + References: + - Chuang et al. (2020). Debiased Contrastive Learning. NeurIPS 2020. https://arxiv.org/pdf/2007.00224.pdf + + Requirements: + 1. The input batch should consist of (anchor, positive) pairs or (anchor, positive, negative) triplets. + + Inputs: + +-------------------------------------------------+--------+ + | Texts | Labels | + +=================================================+========+ + | (anchor, positive) pairs | none | + +-------------------------------------------------+--------+ + | (anchor, positive, negative) triplets | none | + +-------------------------------------------------+--------+ + | (anchor, positive, negative_1, ..., negative_n) | none | + +-------------------------------------------------+--------+ + + Recommendations: + - Use ``BatchSamplers.NO_DUPLICATES`` (:class:`docs `) to + ensure that no in-batch negatives are duplicates of the anchor or positive samples. + + Relations: + - Extends :class:`MultipleNegativesRankingLoss` by incorporating a bias correction term. + + Example: + :: + + from sentence_transformers import SentenceTransformer, SentenceTransformerTrainer, losses + from datasets import Dataset + + model = SentenceTransformer("microsoft/mpnet-base") + train_dataset = Dataset.from_dict({ + "anchor": ["It's nice weather outside today.", "He drove to work."], + "positive": ["It's so sunny.", "He took the car to the office."], + }) + loss = losses.DebiasedMultipleNegativesRankingLoss(model, tau_plus=0.02) + + trainer = SentenceTransformerTrainer( + model=model, + train_dataset=train_dataset, + loss=loss, + ) + trainer.train() + """ super().__init__() self.model = model self.scale = scale @@ -30,17 +108,28 @@ def forward(self, sentence_features: Iterable[dict[str, Tensor]], labels: Tensor scores: Tensor = self.similarity_fct(anchors, candidates) * self.scale # (batch_size, batch_size * (1 + num_negatives)) - mask = torch.ones_like(scores, dtype=torch.bool) - for i in range(scores.size(0)): - mask[i, i] = False + # Compute the mask to remove the similarity of the anchor to the positive candidate. + batch_size = scores.size(0) + mask = torch.ones_like(scores, dtype=torch.bool) # (batch_size, batch_size * (1 + num_negatives)) + positive_indices = torch.arange(0, batch_size, device=scores.device) + mask[positive_indices, positive_indices] = False + + # Get the similarity of the anchor to the negative candidates. + neg_exp = torch.exp(scores.masked_fill(mask, float("-inf"))).sum(dim=-1) # (batch_size,) + # Get the similarity of the anchor to the positive candidate. + pos_exp = torch.exp(torch.gather(scores, -1, positive_indices.unsqueeze(1)).squeeze()) + # (batch_size,) + + # Compute the g estimator with the exponential of the similarities. + N_neg = scores.size(1) - 1 # Number of negatives + g = torch.clamp((1 / (1 - self.tau_plus)) * ((neg_exp / N_neg) - (self.tau_plus * pos_exp)), + min=torch.exp(-torch.tensor(self.scale))) + # (batch_size,) - neg_exp = torch.exp(scores.masked_fill(mask, float("-inf"))).sum(dim=-1) + # Compute the final debiased loss. + loss = - torch.log(pos_exp / (pos_exp + N_neg * g)).mean() - pos_exp = torch.exp(torch.gather(scores, -1, - torch.arange(scores.size(0), - device=scores.device).unsqueeze(1)).squeeze()) - - N = scores.size(1) - 1 + return loss def get_config_dict(self) -> dict[str, Any]: return {"scale": self.scale, "similarity_fct": self.similarity_fct.__name__} @@ -48,5 +137,11 @@ def get_config_dict(self) -> dict[str, Any]: @property def citation(self) -> str: return """ -TODO: Add citation +@inproceedings{chuang2020debiased, + title={Debiased Contrastive Learning}, + author={Ching-Yao Chuang and Joshua Robinson and Lin Yen-Chen and Antonio Torralba and Stefanie Jegelka}, + booktitle={Advances in Neural Information Processing Systems}, + year={2020}, + url={https://arxiv.org/pdf/2007.00224} +} """ From 2d03076aad8dd5e510ea17d89cdad910053412e3 Mon Sep 17 00:00:00 2001 From: ilanaliouchouche Date: Wed, 8 Jan 2025 01:17:02 +0100 Subject: [PATCH 3/4] Corrected errors in DebiasedMultipleNegativesRankingLoss.py --- .../DebiasedMultipleNegativesRankingLoss.py | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/sentence_transformers/losses/DebiasedMultipleNegativesRankingLoss.py b/sentence_transformers/losses/DebiasedMultipleNegativesRankingLoss.py index f3931012e..e43f22bba 100644 --- a/sentence_transformers/losses/DebiasedMultipleNegativesRankingLoss.py +++ b/sentence_transformers/losses/DebiasedMultipleNegativesRankingLoss.py @@ -6,12 +6,15 @@ import torch from torch import Tensor, nn +import numpy as np + from sentence_transformers import util from sentence_transformers.SentenceTransformer import SentenceTransformer +from collections import defaultdict class DebiasedMultipleNegativesRankingLoss(nn.Module): - def __init__(self, model: SentenceTransformer, scale: float = 20.0, similarity_fct=util.cos_sim, tau_plus: float = 0.01) -> None: + def __init__(self, model: SentenceTransformer, scale: float = 1.0, similarity_fct=util.cos_sim, tau_plus: float = 0.1) -> None: """ This loss is a debiased version of the `MultipleNegativesRankingLoss` loss that addresses the inherent sampling bias in the negative examples. @@ -110,23 +113,19 @@ def forward(self, sentence_features: Iterable[dict[str, Tensor]], labels: Tensor # Compute the mask to remove the similarity of the anchor to the positive candidate. batch_size = scores.size(0) - mask = torch.ones_like(scores, dtype=torch.bool) # (batch_size, batch_size * (1 + num_negatives)) + mask = torch.zeros_like(scores, dtype=torch.bool) # (batch_size, batch_size * (1 + num_negatives)) positive_indices = torch.arange(0, batch_size, device=scores.device) - mask[positive_indices, positive_indices] = False + mask[positive_indices, positive_indices] = True # Get the similarity of the anchor to the negative candidates. neg_exp = torch.exp(scores.masked_fill(mask, float("-inf"))).sum(dim=-1) # (batch_size,) # Get the similarity of the anchor to the positive candidate. pos_exp = torch.exp(torch.gather(scores, -1, positive_indices.unsqueeze(1)).squeeze()) - # (batch_size,) # Compute the g estimator with the exponential of the similarities. N_neg = scores.size(1) - 1 # Number of negatives g = torch.clamp((1 / (1 - self.tau_plus)) * ((neg_exp / N_neg) - (self.tau_plus * pos_exp)), - min=torch.exp(-torch.tensor(self.scale))) - # (batch_size,) - - # Compute the final debiased loss. + min=np.exp(-self.scale)) loss = - torch.log(pos_exp / (pos_exp + N_neg * g)).mean() return loss From 370bf473e60b57f7d01a6e084b5acaabdac38a2c Mon Sep 17 00:00:00 2001 From: ilanaliouchouche Date: Wed, 8 Jan 2025 01:23:39 +0100 Subject: [PATCH 4/4] enhancement & correction of DebiasedMNRL --- .../losses/DebiasedMultipleNegativesRankingLoss.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/sentence_transformers/losses/DebiasedMultipleNegativesRankingLoss.py b/sentence_transformers/losses/DebiasedMultipleNegativesRankingLoss.py index e43f22bba..9520f5725 100644 --- a/sentence_transformers/losses/DebiasedMultipleNegativesRankingLoss.py +++ b/sentence_transformers/losses/DebiasedMultipleNegativesRankingLoss.py @@ -11,8 +11,6 @@ from sentence_transformers import util from sentence_transformers.SentenceTransformer import SentenceTransformer -from collections import defaultdict - class DebiasedMultipleNegativesRankingLoss(nn.Module): def __init__(self, model: SentenceTransformer, scale: float = 1.0, similarity_fct=util.cos_sim, tau_plus: float = 0.1) -> None: """