diff --git a/flair/embeddings/transformer.py b/flair/embeddings/transformer.py index fd46321e91..ab4490500c 100644 --- a/flair/embeddings/transformer.py +++ b/flair/embeddings/transformer.py @@ -171,11 +171,30 @@ def fill_mean_token_embeddings( word_ids: torch.Tensor, token_lengths: torch.Tensor, ): - for i in torch.arange(all_token_embeddings.shape[0]): - for _id in torch.arange(token_lengths[i]): # type: ignore[call-overload] - all_token_embeddings[i, _id, :] = torch.nan_to_num( - sentence_hidden_states[i][word_ids[i] == _id].mean(dim=0) - ) + batch_size, max_tokens, embedding_dim = all_token_embeddings.shape + mask = word_ids >= 0 + + # sum embeddings for each token + all_token_embeddings.scatter_add_( + 1, + word_ids.clamp(min=0).unsqueeze(-1).expand(-1, -1, embedding_dim), + sentence_hidden_states * mask.unsqueeze(-1).float(), + ) + + # calculate the mean of subtokens + subtoken_counts = torch.zeros_like(all_token_embeddings[:, :, 0]) + subtoken_counts.scatter_add_(1, word_ids.clamp(min=0), mask.float()) + all_token_embeddings = torch.where( + subtoken_counts.unsqueeze(-1) > 0, + all_token_embeddings / subtoken_counts.unsqueeze(-1), + torch.zeros_like(all_token_embeddings), + ) + + # Create a mask for valid tokens based on token_lengths + token_mask = torch.arange(max_tokens, device=token_lengths.device)[None, :] < token_lengths[:, None] + all_token_embeddings = all_token_embeddings * token_mask.unsqueeze(-1) + all_token_embeddings = torch.nan_to_num(all_token_embeddings) + return all_token_embeddings