diff --git a/README.md b/README.md index 3f7e8d4804..00d95ff4cd 100644 --- a/README.md +++ b/README.md @@ -333,6 +333,18 @@ Theoretical memory savings vary depending on the combination of the model's para | bf16 param, fp32 grads | 18 | 6 + 12/d | | fp32 param, fp32 grads | 16 | 8 + 8/d | +## FlashAttention + +Usage: `--use-flash-attn`. Support attention head dimensions at most 128. + +[FlashAttention](https://github.com/HazyResearch/flash-attention) is a fast and +memory-efficient algorithm to compute exact attention. It speeds up model +training and reduces memory requirement. + +To install FlashAttention: +```sh +pip install flash-attn +``` ## GPT-3 Example diff --git a/megatron/arguments.py b/megatron/arguments.py index 4bd3e92066..5e6a6e5baa 100644 --- a/megatron/arguments.py +++ b/megatron/arguments.py @@ -612,6 +612,9 @@ def _add_training_args(parser): group.add_argument('--no-bias-dropout-fusion', action='store_false', help='Disable bias and dropout fusion.', dest='bias_dropout_fusion') + group.add_argument('--use-flash-attn', action='store_true', + help='use FlashAttention implementation of attention. ' + 'https://arxiv.org/abs/2205.14135') group.add_argument('--optimizer', type=str, default='adam', choices=['adam', 'sgd'], help='Optimizer function') diff --git a/megatron/model/transformer.py b/megatron/model/transformer.py index 57d992fa26..9d74d5f6c9 100644 --- a/megatron/model/transformer.py +++ b/megatron/model/transformer.py @@ -15,6 +15,16 @@ from megatron.model.fused_bias_gelu import bias_gelu_impl from megatron.model.utils import attention_mask_func, openai_gelu, erf_gelu +try: + from einops import rearrange +except ImportError: + rearrange = None + +try: + from flash_attn.flash_attn_interface import flash_attn_unpadded_func +except ImportError: + flash_attn_unpadded_func = None + """ We use the following notation throughout this file: h: hidden size @@ -306,6 +316,48 @@ def forward(self, query_layer, key_layer, return context_layer +class FlashSelfAttention(torch.nn.Module): + """Implement the scaled dot product attention with softmax. + Arguments + --------- + softmax_scale: The temperature to use for the softmax attention. + (default: 1/sqrt(d_keys) where d_keys is computed at + runtime) + attention_dropout: The dropout rate to apply to the attention + (default: 0.0) + """ + def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0, + device=None, dtype=None): + super().__init__() + assert flash_attn_unpadded_func is not None, ('Please install FlashAttention first, ' + 'e.g., with pip install flash-attn') + assert rearrange is not None, 'Please install einops first, e.g., with pip install einops' + self.causal = causal + self.softmax_scale = softmax_scale + self.dropout_p = attention_dropout + + def forward(self, q, k, v): + """Implements the multihead softmax attention. + Arguments + --------- + q, k, v: The tensor containing the query, key, and value. (B, S, H, D) + """ + assert q.dtype in [torch.float16, torch.bfloat16] + assert q.is_cuda + batch_size, seqlen = q.shape[0], q.shape[1] + q, k, v = [rearrange(x, 'b s ... -> (b s) ...') for x in [q, k, v]] + max_s = seqlen + cu_seqlens = torch.arange(0, (batch_size + 1) * seqlen, step=seqlen, dtype=torch.int32, + device=q.device) + output = flash_attn_unpadded_func( + q, k, v, cu_seqlens, cu_seqlens, max_s, max_s, + self.dropout_p if self.training else 0.0, + softmax_scale=self.softmax_scale, causal=self.causal + ) + output = rearrange(output, '(b s) ... -> b s ...', b=batch_size) + return output + + class ParallelAttention(MegatronModule): """Parallel self-attention layer abstract class. @@ -323,6 +375,19 @@ def __init__(self, init_method, self.attention_type = attention_type self.attn_mask_type = attn_mask_type self.params_dtype = args.params_dtype + self.sequence_parallel = args.sequence_parallel + + self.use_flash_attn = args.use_flash_attn + if self.use_flash_attn: + if flash_attn_unpadded_func is None: + raise ImportError('FlashAttention is not installed, please install with ' + 'pip install flash-attn') + assert attention_type == AttnType.self_attn, ('FlashAttention code path only supports ' + 'self-attention for now') + assert self.attn_mask_type == AttnMaskType.causal, ('FlashAttention code path only ' + 'supports causal mask for now') + if rearrange is None: + raise ImportError('einops is not installed, please install with pip install einops') projection_size = args.kv_channels * args.num_attention_heads @@ -365,6 +430,11 @@ def __init__(self, init_method, self.attn_mask_type) self.checkpoint_core_attention = args.recompute_granularity == 'selective' + if self.use_flash_attn: + self.core_attention_flash = FlashSelfAttention( + causal=True, attention_dropout=args.attention_dropout + ) + # Output. self.dense = tensor_parallel.RowParallelLinear( projection_size, @@ -487,12 +557,22 @@ def forward(self, hidden_states, attention_mask, # core attention computation # ================================== - if self.checkpoint_core_attention: - context_layer = self._checkpointed_attention_forward( - query_layer, key_layer, value_layer, attention_mask) + if not self.use_flash_attn: + if self.checkpoint_core_attention: + context_layer = self._checkpointed_attention_forward( + query_layer, key_layer, value_layer, attention_mask) + else: + context_layer = self.core_attention( + query_layer, key_layer, value_layer, attention_mask) else: - context_layer = self.core_attention( - query_layer, key_layer, value_layer, attention_mask) + q, k, v = [rearrange(x, 's b ... -> b s ...').contiguous() + for x in (query_layer, key_layer, value_layer)] + if not self.sequence_parallel: + with tensor_parallel.get_cuda_rng_tracker().fork(): + context_layer = self.core_attention_flash(q, k, v) + else: + context_layer = self.core_attention_flash(q, k, v) + context_layer = rearrange(context_layer, 'b s h d -> s b (h d)').contiguous() # ================= # Output. [sq, b, h]