Skip to content

Commit

Permalink
support for transformers 4.41.0
Browse files Browse the repository at this point in the history
  • Loading branch information
vaibhavad committed May 22, 2024
1 parent 01a0ed7 commit 4aad5ff
Showing 1 changed file with 2 additions and 1 deletion.
3 changes: 2 additions & 1 deletion llm2vec/models/bidirectional_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def __init__(self, config: LlamaConfig):
# Initialize weights and apply final processing
self.post_init()

def _update_causal_mask(self, attention_mask, input_tensor, cache_position, past_seen_tokens=None):
def _update_causal_mask(self, attention_mask, input_tensor, cache_position, past_seen_tokens=None, output_attentions=False):
if self.config._attn_implementation == "flash_attention_2":
if attention_mask is not None and 0.0 in attention_mask:
return attention_mask
Expand Down Expand Up @@ -179,6 +179,7 @@ def _update_causal_mask(self, attention_mask, input_tensor, cache_position, past
self.config._attn_implementation == "sdpa"
and attention_mask is not None
and attention_mask.device.type == "cuda"
and not output_attentions
):
causal_mask = AttentionMaskConverter._unmask_unattended(
causal_mask, min_dtype
Expand Down

0 comments on commit 4aad5ff

Please sign in to comment.