From 9991c8d948fe0b4fe9f1731c2b38b80c7a0ecf35 Mon Sep 17 00:00:00 2001 From: Lina Khodja <57141057+l-k-11235@users.noreply.github.com> Date: Thu, 27 Jun 2024 19:56:05 +0200 Subject: [PATCH] fixed masked flash attention (#2589) * fixed masked flash attention --- onmt/modules/multi_headed_attn.py | 30 ++++++++++-------------------- 1 file changed, 10 insertions(+), 20 deletions(-) diff --git a/onmt/modules/multi_headed_attn.py b/onmt/modules/multi_headed_attn.py index 2510afef83..a432b42e10 100644 --- a/onmt/modules/multi_headed_attn.py +++ b/onmt/modules/multi_headed_attn.py @@ -439,7 +439,6 @@ def forward( """ # 1) Project key, value, and query. # as a reminder at training layer_cache[0] remains False - key_pad_mask = self.layer_cache[1].get("key_pad_mask", None) if self.layer_cache[0]: # Retrieve keys and values from the KV cache (decoding mode only). if self.attn_type == "self": @@ -484,6 +483,16 @@ def forward( key = key[:, :, 1:, :] value = value[:, :, 1:, :] + if step == 0: + key_pad_mask = self.layer_cache[1].get("key_pad_mask", None) + if key_pad_mask is not None: + x = key_pad_mask.expand( + -1, self.head_count // self.parallel_gpu, -1 + ) + x = x.unsqueeze(3) + x = x.expand(-1, -1, -1, value.size(3)) + value = value.masked_fill(x, 0) + self.layer_cache[1]["keys"] = key self.layer_cache[1]["values"] = value @@ -565,19 +574,6 @@ def forward( self.layer_cache[1]["keys"] = key self.layer_cache[1]["values"] = value - if key_pad_mask is not None: - # Increase the cached key pad mask by concatenation. - # For decoding only. - if step > 0: - y = torch.zeros( - (key_pad_mask.size(0), key_pad_mask.size(1), 1), - dtype=torch.bool, - device=key_pad_mask.device, - ) - self.layer_cache[1]["key_pad_mask"] = torch.cat( - (key_pad_mask, y), 2 - ) - key_pad_mask = self.layer_cache[1]["key_pad_mask"] else: # Retrieve keys and values from linear layers (training mode). key = self.maybe_ckpt(self.linear_keys, key) @@ -706,8 +702,6 @@ def forward( scores = self.alibi(scores) scores = scores.float() - if key_pad_mask is not None and mask is None: - mask = key_pad_mask.unsqueeze(1) if mask is not None: # not 100% necessary but expand to nb of heads @@ -727,10 +721,6 @@ def forward( attn_output.add_(relative_matmul(drop_attn, relations_values, False)) context = unshape(attn_output) - if key_pad_mask is not None: - if key_pad_mask.size(0) > 1 and context.size(1) > 1: - x = key_pad_mask.squeeze(1).unsqueeze(2).expand(-1, -1, context.size(2)) - context = context.masked_fill(x, 0) if self.layer_cache[0]: attn_output = self.final_linear(context)