Skip to content

Commit

Permalink
GH-3524: Refactor fill_mean_token_embeddings for performance optimiza…
Browse files Browse the repository at this point in the history
…tion on GPU
  • Loading branch information
sheldon-roberts committed Aug 9, 2024
1 parent e17ab12 commit dd73547
Showing 1 changed file with 24 additions and 5 deletions.
29 changes: 24 additions & 5 deletions flair/embeddings/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,11 +171,30 @@ def fill_mean_token_embeddings(
word_ids: torch.Tensor,
token_lengths: torch.Tensor,
):
for i in torch.arange(all_token_embeddings.shape[0]):
for _id in torch.arange(token_lengths[i]): # type: ignore[call-overload]
all_token_embeddings[i, _id, :] = torch.nan_to_num(
sentence_hidden_states[i][word_ids[i] == _id].mean(dim=0)
)
batch_size, max_tokens, embedding_dim = all_token_embeddings.shape
mask = word_ids >= 0

# sum embeddings for each token
all_token_embeddings.scatter_add_(
1,
word_ids.clamp(min=0).unsqueeze(-1).expand(-1, -1, embedding_dim),
sentence_hidden_states * mask.unsqueeze(-1).float(),
)

# calculate the mean of subtokens
subtoken_counts = torch.zeros_like(all_token_embeddings[:, :, 0])
subtoken_counts.scatter_add_(1, word_ids.clamp(min=0), mask.float())
all_token_embeddings = torch.where(
subtoken_counts.unsqueeze(-1) > 0,
all_token_embeddings / subtoken_counts.unsqueeze(-1),
torch.zeros_like(all_token_embeddings),
)

# Create a mask for valid tokens based on token_lengths
token_mask = torch.arange(max_tokens, device=token_lengths.device)[None, :] < token_lengths[:, None]
all_token_embeddings = all_token_embeddings * token_mask.unsqueeze(-1)
all_token_embeddings = torch.nan_to_num(all_token_embeddings)

return all_token_embeddings


Expand Down

0 comments on commit dd73547

Please sign in to comment.