Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add flash-attn #41

Merged
merged 6 commits into from
Mar 24, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -333,6 +333,18 @@ Theoretical memory savings vary depending on the combination of the model's para
| bf16 param, fp32 grads | 18 | 6 + 12/d |
| fp32 param, fp32 grads | 16 | 8 + 8/d |

## FlashAttention

Usage: `--use-flash-attn`. Support attention head dimensions at most 128.

[FlashAttention](https://github.com/HazyResearch/flash-attention) is a fast and
memory-efficient algorithm to compute exact attention. It speeds up model
training and reduces memory requirement.

To install FlashAttention:
```sh
pip install flash-attn
```

## GPT-3 Example

Expand Down
3 changes: 3 additions & 0 deletions megatron/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -648,6 +648,9 @@ def _add_training_args(parser):
group.add_argument('--no-bias-dropout-fusion', action='store_false',
help='Disable bias and dropout fusion.',
dest='bias_dropout_fusion')
group.add_argument('--use-flash-attn', action='store_true',
help='use FlashAttention implementation of attention. '
'https://arxiv.org/abs/2205.14135')
group.add_argument('--optimizer', type=str, default='adam',
choices=['adam', 'sgd'],
help='Optimizer function')
Expand Down
103 changes: 98 additions & 5 deletions megatron/model/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

from megatron import get_timers, get_args, get_global_memory_buffer
from megatron import mpu
from megatron.utils import print_rank_0
from .module import MegatronModule
from megatron.model.enums import AttnMaskType, ModelType, LayerType, AttnType, PositionEmbeddingType
from .fused_layer_norm import MixedFusedLayerNorm as LayerNorm
Expand All @@ -38,6 +39,16 @@
torch._C._jit_override_can_fuse_on_cpu(True)
torch._C._jit_override_can_fuse_on_gpu(True)

try:
from einops import rearrange
except ImportError:
rearrange = None

try:
from flash_attn.flash_attn_interface import flash_attn_unpadded_func
except ImportError:
flash_attn_unpadded_func = None


""" We use the following notation throughout this file:
h: hidden size
Expand Down Expand Up @@ -459,6 +470,48 @@ def forward(self, query_layer, key_layer, value_layer, attention_mask, alibi):
return context_layer


class FlashSelfAttention(torch.nn.Module):
"""Implement the scaled dot product attention with softmax.
Arguments
---------
softmax_scale: The temperature to use for the softmax attention.
(default: 1/sqrt(d_keys) where d_keys is computed at
runtime)
attention_dropout: The dropout rate to apply to the attention
(default: 0.0)
"""
def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0,
device=None, dtype=None):
super().__init__()
assert flash_attn_unpadded_func is not None, ('Please install FlashAttention first, '
'e.g., with pip install flash-attn')
assert rearrange is not None, 'Please install einops first, e.g., with pip install einops'
self.causal = causal
self.softmax_scale = softmax_scale
self.dropout_p = attention_dropout

def forward(self, q, k, v):
"""Implements the multihead softmax attention.
Arguments
---------
q, k, v: The tensor containing the query, key, and value. (B, S, H, D)
"""
assert q.dtype in [torch.float16, torch.bfloat16]
assert q.is_cuda
batch_size, seqlen = q.shape[0], q.shape[1]
q, k, v = [rearrange(x, 'b s ... -> (b s) ...') for x in [q, k, v]]
max_s = seqlen
cu_seqlens = torch.arange(0, (batch_size + 1) * seqlen, step=seqlen, dtype=torch.int32,
device=q.device)
output = flash_attn_unpadded_func(
q, k, v, cu_seqlens, cu_seqlens, max_s, max_s,
self.dropout_p if self.training else 0.0,
softmax_scale=self.softmax_scale, causal=self.causal
)
output = rearrange(output, '(b s) ... -> b s ...', b=batch_size)
return output


class ParallelAttention(MegatronModule):
"""Parallel self-attention layer abstract class.

Expand All @@ -477,6 +530,9 @@ def __init__(self, init_method,
self.attn_mask_type = attn_mask_type
self.params_dtype = args.params_dtype
self.attention_head_type = args.attention_head_type
self.sequence_parallel = args.sequence_parallel

self.use_flash_attn = args.use_flash_attn

projection_size = args.kv_channels * args.num_attention_heads

Expand Down Expand Up @@ -533,6 +589,26 @@ def __init__(self, init_method,
else:
self.core_attention = MultiQueryCoreAttention(self.layer_number, self.attn_mask_type)
self.checkpoint_core_attention = args.recompute_granularity == 'selective'

if self.use_flash_attn:
if flash_attn_unpadded_func is None:
raise ImportError('FlashAttention is not installed, please install with '
'pip install flash-attn')
assert attention_type == AttnType.self_attn, ('FlashAttention code path only supports '
'self-attention for now')
assert self.attn_mask_type == AttnMaskType.causal, ('FlashAttention code path only '
'supports causal mask for now')
assert args.position_embedding_type != PositionEmbeddingType.alibi, \
('FlashAttention does not support alibi positional embeddings yet')
if rearrange is None:
raise ImportError('einops is not installed, please install with pip install einops')
RaymondLi0 marked this conversation as resolved.
Show resolved Hide resolved

if self.checkpoint_core_attention:
RaymondLi0 marked this conversation as resolved.
Show resolved Hide resolved
print_rank_0(" Warning, using selective recomputation with flash-attn: this is already handled in the "
"flash-attn library and has no effect.")
self.core_attention_flash = FlashSelfAttention(
causal=True, attention_dropout=args.attention_dropout
)

# Output.
self.dense = mpu.RowParallelLinear(
Expand Down Expand Up @@ -699,13 +775,30 @@ def forward(self, hidden_states, attention_mask,
# ==================================
# core attention computation
# ==================================
if self.use_flash_attn:
if self.attention_head_type == "multiquery":
sq, b, np, hn = query_layer.size()
# Expand kv to be compatible with flash-attn implementation
# [sq, b, 1, hn] -> [sq, b, np, hn]
key_layer = key_layer.expand((sq, b, np, hn))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm wondering if FlashAttention would work with just expand, that doesn't allocate new memory. If it were to work we would get the full benefits of FlashAttention for MQA. (I would expect it to enforce contiguous tensors but it's worth checking)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are you asking whether if it would still work if we remove the call to .contiguous() on the next line?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That would almost certainly not work (transposed tensors are much harder to deal with), but maybe if we do the expand after the transpose or skip the transpose altogether.

value_layer = value_layer.expand((sq, b, np, hn))
q, k, v = [rearrange(x, 's b ... -> b s ...').contiguous()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That looks very bad. Megatron uses the s b format precisely to avoid this kind of reshape. If FlashAttention uses b s we should use that format instead. It should be OK to just comment the two conversions, at least without sequence parallelism (SP would need extra changes but we probably won't use it anyway) https://github.com/bigcode-project/Megatron-LM/blob/multi-query-attention/megatron/model/language_model.py#L240 https://github.com/bigcode-project/Megatron-LM/blob/multi-query-attention/megatron/model/gpt_model.py#L43

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are suggesting to use b s through the whole transformer model?
I think that would require a big chunk of refactoring work, and also testing to make sure we are not breaking anything.
Looking at the nice performance improvements that flash-attn brings, I wouldn't take the risk of breaking everything else just to avoid a transpose here.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually the order only matters for attention (and sequence parallell), so it should just be about bypassing these two lines.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The transposes have a big impact on memory usage and a moderate one on speed (I think) so it's quite important.

for x in (query_layer, key_layer, value_layer)]
if self.sequence_parallel:
context_layer = self.core_attention_flash(q, k, v)
else:
with mpu.get_cuda_rng_tracker().fork():
jlamypoirier marked this conversation as resolved.
Show resolved Hide resolved
context_layer = self.core_attention_flash(q, k, v)
context_layer = rearrange(context_layer, 'b s h d -> s b (h d)').contiguous()

if self.checkpoint_core_attention:
context_layer = self._checkpointed_attention_forward(
query_layer, key_layer, value_layer, attention_mask, alibi)
else:
context_layer = self.core_attention(
query_layer, key_layer, value_layer, attention_mask, alibi)
if self.checkpoint_core_attention:
context_layer = self._checkpointed_attention_forward(
query_layer, key_layer, value_layer, attention_mask, alibi)
else:
context_layer = self.core_attention(
query_layer, key_layer, value_layer, attention_mask, alibi)


# =================
# Output. [sq, b, h]
Expand Down