diff --git a/sentence_transformers/losses/CachedGISTEmbedLoss.py b/sentence_transformers/losses/CachedGISTEmbedLoss.py index ccc79f4d7..9e0e1e170 100644 --- a/sentence_transformers/losses/CachedGISTEmbedLoss.py +++ b/sentence_transformers/losses/CachedGISTEmbedLoss.py @@ -217,14 +217,6 @@ def calculate_loss_and_cache_gradients(self, reps: List[List[Tensor]], reps_guid if negative: negative = torch.cat(negative, dim=0) negative_guide = torch.cat(negative_guide, dim=0) - guided_an_sim = self.sim_matrix(anchor_guide, negative_guide) - - # Let's compute the similarity matrices for the combinations of anchor and positive samples. - guided_ap_sim = self.sim_matrix(anchor_guide, positive_guide) - guided_aa_sim = self.sim_matrix(anchor_guide, anchor_guide) - guided_pp_sim = self.sim_matrix(positive_guide, positive_guide) - # Define the anchor threshold - guided_sim = guided_ap_sim.diagonal().view(-1, 1) labels = torch.arange(anchor.size(0)).long().to(anchor.device) batch_size = anchor.shape[0] @@ -238,6 +230,13 @@ def calculate_loss_and_cache_gradients(self, reps: List[List[Tensor]], reps_guid disable=not self.show_progress_bar, ): e = b + self.mini_batch_size + # Let's compute the similarity matrices for the combinations of anchor and positive samples. + guided_ap_sim = self.sim_matrix(anchor_guide[b:e], positive_guide) + guided_aa_sim = self.sim_matrix(anchor_guide[b:e], anchor_guide) + guided_pp_sim = self.sim_matrix(positive_guide[b:e], positive_guide) + # Define the anchor threshold + guided_sim = guided_ap_sim.diagonal().view(-1, 1) + # Compute similarity scores for current mini-batch. # anchor (mbsz,hdim), positive (bsz,hdim) ap_sim = self.sim_matrix(anchor[b:e], positive) # (mbsz,bsz) @@ -248,16 +247,17 @@ def calculate_loss_and_cache_gradients(self, reps: List[List[Tensor]], reps_guid # more similar to the query than the assigned positive as deemed by the guide model. # For these samples, we mask them with -inf to basically ignore their contribution to # the loss. - ap_sim[guided_ap_sim[b:e] > guided_sim[b:e]] = -torch.inf - aa_sim[guided_aa_sim[b:e] > guided_sim[b:e]] = -torch.inf - pp_sim[guided_pp_sim[b:e] > guided_sim[b:e]] = -torch.inf + ap_sim[guided_ap_sim > guided_sim] = -torch.inf + aa_sim[guided_aa_sim > guided_sim] = -torch.inf + pp_sim[guided_pp_sim > guided_sim] = -torch.inf scores = torch.cat([ap_sim, aa_sim, pp_sim], dim=1) # Handle the case where we have a negative sample if negative is not None: + guided_an_sim = self.sim_matrix(anchor_guide[b:e], negative_guide) an_sim = self.sim_matrix(anchor[b:e], negative) - an_sim[guided_an_sim[b:e] > guided_sim[b:e]] = -torch.inf + an_sim[guided_an_sim > guided_sim] = -torch.inf scores = torch.cat([scores, an_sim], dim=1) scores = scores / self.temperature loss_mbatch: torch.Tensor = self.cross_entropy_loss(scores, labels[b:e]) * len(scores) / batch_size