Skip to content

Commit

Permalink
Update guided similarity computation in mini-batch
Browse files Browse the repository at this point in the history
  • Loading branch information
JacksonCakes committed Apr 15, 2024
1 parent 5c054da commit 3208e61
Showing 1 changed file with 12 additions and 12 deletions.
24 changes: 12 additions & 12 deletions sentence_transformers/losses/CachedGISTEmbedLoss.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,14 +217,6 @@ def calculate_loss_and_cache_gradients(self, reps: List[List[Tensor]], reps_guid
if negative:
negative = torch.cat(negative, dim=0)
negative_guide = torch.cat(negative_guide, dim=0)
guided_an_sim = self.sim_matrix(anchor_guide, negative_guide)

# Let's compute the similarity matrices for the combinations of anchor and positive samples.
guided_ap_sim = self.sim_matrix(anchor_guide, positive_guide)
guided_aa_sim = self.sim_matrix(anchor_guide, anchor_guide)
guided_pp_sim = self.sim_matrix(positive_guide, positive_guide)
# Define the anchor threshold
guided_sim = guided_ap_sim.diagonal().view(-1, 1)

labels = torch.arange(anchor.size(0)).long().to(anchor.device)
batch_size = anchor.shape[0]
Expand All @@ -238,6 +230,13 @@ def calculate_loss_and_cache_gradients(self, reps: List[List[Tensor]], reps_guid
disable=not self.show_progress_bar,
):
e = b + self.mini_batch_size
# Let's compute the similarity matrices for the combinations of anchor and positive samples.
guided_ap_sim = self.sim_matrix(anchor_guide[b:e], positive_guide)
guided_aa_sim = self.sim_matrix(anchor_guide[b:e], anchor_guide)
guided_pp_sim = self.sim_matrix(positive_guide[b:e], positive_guide)
# Define the anchor threshold
guided_sim = guided_ap_sim.diagonal().view(-1, 1)

# Compute similarity scores for current mini-batch.
# anchor (mbsz,hdim), positive (bsz,hdim)
ap_sim = self.sim_matrix(anchor[b:e], positive) # (mbsz,bsz)
Expand All @@ -248,16 +247,17 @@ def calculate_loss_and_cache_gradients(self, reps: List[List[Tensor]], reps_guid
# more similar to the query than the assigned positive as deemed by the guide model.
# For these samples, we mask them with -inf to basically ignore their contribution to
# the loss.
ap_sim[guided_ap_sim[b:e] > guided_sim[b:e]] = -torch.inf
aa_sim[guided_aa_sim[b:e] > guided_sim[b:e]] = -torch.inf
pp_sim[guided_pp_sim[b:e] > guided_sim[b:e]] = -torch.inf
ap_sim[guided_ap_sim > guided_sim] = -torch.inf
aa_sim[guided_aa_sim > guided_sim] = -torch.inf
pp_sim[guided_pp_sim > guided_sim] = -torch.inf

scores = torch.cat([ap_sim, aa_sim, pp_sim], dim=1)

# Handle the case where we have a negative sample
if negative is not None:
guided_an_sim = self.sim_matrix(anchor_guide[b:e], negative_guide)
an_sim = self.sim_matrix(anchor[b:e], negative)
an_sim[guided_an_sim[b:e] > guided_sim[b:e]] = -torch.inf
an_sim[guided_an_sim > guided_sim] = -torch.inf
scores = torch.cat([scores, an_sim], dim=1)
scores = scores / self.temperature
loss_mbatch: torch.Tensor = self.cross_entropy_loss(scores, labels[b:e]) * len(scores) / batch_size
Expand Down

0 comments on commit 3208e61

Please sign in to comment.