diff --git a/Dockerfile b/Dockerfile index 82d16a561886..7722555357b2 100644 --- a/Dockerfile +++ b/Dockerfile @@ -72,6 +72,11 @@ WORKDIR /tmp/nemo COPY requirements . RUN for f in $(ls requirements*.txt); do pip3 install --disable-pip-version-check --no-cache-dir -r $f; done +# install flash attention dependencies +RUN pip install flash-attn +# pinned triton version for flash-attention https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/flash_attn_triton.py#L3 +RUN pip install triton==2.0.0.dev20221202 + # install k2, skip if installation fails COPY scripts /tmp/nemo/scripts/ RUN INSTALL_MSG=$(/bin/bash /tmp/nemo/scripts/speech_recognition/k2/setup.sh); INSTALL_CODE=$?; \ diff --git a/Jenkinsfile b/Jenkinsfile index d16379cabb8a..d335378173f0 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -65,6 +65,14 @@ pipeline { pip install -e .' } } + + stage('Flash Attention installation') { + steps { + // pinned triton version for flash-attention https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/flash_attn_triton.py#L3 + sh 'pip install flash-attn && \ + pip install triton==2.0.0.dev20221202' + } + } stage('PyTorch Lightning version') { steps { @@ -3144,6 +3152,88 @@ assert_frame_equal(training_curve, gt_curve, rtol=1e-3, atol=1e-3)"''' } } stage('L2: Megatron GPT Pretraining and Resume Training TP=2') { + when { + anyOf { + branch 'main' + changeRequest target: 'main' + } + } + failFast true + steps { + sh "python examples/nlp/language_modeling/megatron_gpt_pretraining.py \ + trainer.devices=2 \ + trainer.accelerator=gpu \ + trainer.log_every_n_steps=1 \ + trainer.val_check_interval=2 \ + trainer.limit_val_batches=2 \ + trainer.accumulate_grad_batches=1 \ + trainer.max_steps=3 \ + trainer.precision=16 \ + trainer.gradient_clip_val=1.0 \ + exp_manager.exp_dir=examples/nlp/language_modeling/gpt_pretrain_results \ + model.tensor_model_parallel_size=2 \ + model.optim.name=fused_adam \ + model.optim.lr=2e-4 \ + model.optim.sched.warmup_steps=1 \ + model.optim.sched.constant_steps=1 \ + model.optim.sched.min_lr=8e-5 \ + model.max_position_embeddings=128 \ + model.encoder_seq_length=128 \ + model.data.seq_length=128 \ + model.normalization=rmsnorm \ + model.bias=False \ + model.bias_activation_fusion=False \ + model.bias_dropout_add_fusion=False \ + model.tokenizer.vocab_file=/home/TestData/nlp/megatron_gpt/data/gpt/vocab.json \ + model.tokenizer.merge_file=/home/TestData/nlp/megatron_gpt/data/gpt/merges.txt \ + model.num_layers=8 \ + model.hidden_size=256 \ + model.num_attention_heads=8 \ + model.activations_checkpoint_method='block' \ + model.activations_checkpoint_granularity='full' \ + model.activations_checkpoint_num_layers=1 \ + model.data.data_prefix=[.5,/home/TestData/nlp/megatron_gpt/data/gpt/simple_wiki_gpt_preproc_text_document,.5,/home/TestData/nlp/megatron_gpt/data/gpt/simple_wiki_gpt_preproc_text_document] \ + model.data.index_mapping_dir=examples/nlp/language_modeling/gpt_index_mappings" + sh "python examples/nlp/language_modeling/megatron_gpt_pretraining.py \ + trainer.devices=2 \ + trainer.accelerator=gpu \ + trainer.log_every_n_steps=1 \ + trainer.val_check_interval=2 \ + trainer.limit_val_batches=1 \ + trainer.accumulate_grad_batches=1 \ + trainer.max_steps=6 \ + trainer.precision=16 \ + trainer.gradient_clip_val=1.0 \ + exp_manager.exp_dir=examples/nlp/language_modeling/gpt_pretrain_results \ + exp_manager.resume_if_exists=True \ + model.tensor_model_parallel_size=2 \ + model.optim.name=fused_adam \ + model.optim.lr=2e-4 \ + model.optim.sched.warmup_steps=2 \ + model.optim.sched.constant_steps=2 \ + model.optim.sched.min_lr=8e-5 \ + model.max_position_embeddings=128 \ + model.encoder_seq_length=128 \ + model.data.seq_length=128 \ + model.normalization=rmsnorm \ + model.bias=False \ + model.bias_activation_fusion=False \ + model.bias_dropout_add_fusion=False \ + model.tokenizer.vocab_file=/home/TestData/nlp/megatron_gpt/data/gpt/vocab.json \ + model.tokenizer.merge_file=/home/TestData/nlp/megatron_gpt/data/gpt/merges.txt \ + model.num_layers=8 \ + model.hidden_size=256 \ + model.num_attention_heads=8 \ + model.activations_checkpoint_method='block' \ + model.activations_checkpoint_granularity='full' \ + model.activations_checkpoint_num_layers=1 \ + model.data.data_prefix=[.5,/home/TestData/nlp/megatron_gpt/data/gpt/simple_wiki_gpt_preproc_text_document,.5,/home/TestData/nlp/megatron_gpt/data/gpt/simple_wiki_gpt_preproc_text_document] \ + model.data.index_mapping_dir=examples/nlp/language_modeling/gpt_index_mappings" + sh "rm -rf examples/nlp/language_modeling/gpt_pretrain_results" + sh "rm -rf examples/nlp/language_modeling/gpt_index_mappings" + } + } + stage('L2: Megatron GPT with Rope Pretraining and Resume Training TP=2') { when { anyOf { branch 'main' @@ -3229,6 +3319,262 @@ assert_frame_equal(training_curve, gt_curve, rtol=1e-3, atol=1e-3)"''' sh "rm -rf examples/nlp/language_modeling/gpt_index_mappings" } } + stage('L2: Megatron GPT with Rope Pretraining using Flash Attention and Resume Training TP=2') { + when { + anyOf { + branch 'main' + changeRequest target: 'main' + } + } + failFast true + steps { + sh "python examples/nlp/language_modeling/megatron_gpt_pretraining.py \ + trainer.devices=2 \ + trainer.accelerator=gpu \ + trainer.log_every_n_steps=1 \ + trainer.val_check_interval=2 \ + trainer.limit_val_batches=2 \ + trainer.accumulate_grad_batches=1 \ + trainer.max_steps=3 \ + trainer.precision=16 \ + trainer.gradient_clip_val=1.0 \ + exp_manager.exp_dir=examples/nlp/language_modeling/gpt_pretrain_results \ + model.tensor_model_parallel_size=2 \ + model.optim.name=fused_adam \ + model.optim.lr=2e-4 \ + model.optim.sched.warmup_steps=1 \ + model.optim.sched.constant_steps=1 \ + model.optim.sched.min_lr=8e-5 \ + model.max_position_embeddings=128 \ + model.encoder_seq_length=128 \ + model.data.seq_length=128 \ + model.position_embedding_type=rope \ + model.rotary_percentage=0.5 \ + model.normalization=rmsnorm \ + model.bias=False \ + model.bias_activation_fusion=False \ + model.bias_dropout_add_fusion=False \ + model.tokenizer.vocab_file=/home/TestData/nlp/megatron_gpt/data/gpt/vocab.json \ + model.tokenizer.merge_file=/home/TestData/nlp/megatron_gpt/data/gpt/merges.txt \ + model.num_layers=8 \ + model.hidden_size=256 \ + model.num_attention_heads=8 \ + model.activations_checkpoint_method='block' \ + model.activations_checkpoint_granularity='full' \ + model.activations_checkpoint_num_layers=1 \ + model.data.data_prefix=[.5,/home/TestData/nlp/megatron_gpt/data/gpt/simple_wiki_gpt_preproc_text_document,.5,/home/TestData/nlp/megatron_gpt/data/gpt/simple_wiki_gpt_preproc_text_document] \ + model.data.index_mapping_dir=examples/nlp/language_modeling/gpt_index_mappings \ + model.use_flash_attention=True" + sh "python examples/nlp/language_modeling/megatron_gpt_pretraining.py \ + trainer.devices=2 \ + trainer.accelerator=gpu \ + trainer.log_every_n_steps=1 \ + trainer.val_check_interval=2 \ + trainer.limit_val_batches=1 \ + trainer.accumulate_grad_batches=1 \ + trainer.max_steps=6 \ + trainer.precision=16 \ + trainer.gradient_clip_val=1.0 \ + exp_manager.exp_dir=examples/nlp/language_modeling/gpt_pretrain_results \ + exp_manager.resume_if_exists=True \ + model.tensor_model_parallel_size=2 \ + model.optim.name=fused_adam \ + model.optim.lr=2e-4 \ + model.optim.sched.warmup_steps=2 \ + model.optim.sched.constant_steps=2 \ + model.optim.sched.min_lr=8e-5 \ + model.max_position_embeddings=128 \ + model.encoder_seq_length=128 \ + model.data.seq_length=128 \ + model.position_embedding_type=rope \ + model.rotary_percentage=0.5 \ + model.normalization=rmsnorm \ + model.bias=False \ + model.bias_activation_fusion=False \ + model.bias_dropout_add_fusion=False \ + model.tokenizer.vocab_file=/home/TestData/nlp/megatron_gpt/data/gpt/vocab.json \ + model.tokenizer.merge_file=/home/TestData/nlp/megatron_gpt/data/gpt/merges.txt \ + model.num_layers=8 \ + model.hidden_size=256 \ + model.num_attention_heads=8 \ + model.activations_checkpoint_method='block' \ + model.activations_checkpoint_granularity='full' \ + model.activations_checkpoint_num_layers=1 \ + model.data.data_prefix=[.5,/home/TestData/nlp/megatron_gpt/data/gpt/simple_wiki_gpt_preproc_text_document,.5,/home/TestData/nlp/megatron_gpt/data/gpt/simple_wiki_gpt_preproc_text_document] \ + model.data.index_mapping_dir=examples/nlp/language_modeling/gpt_index_mappings \ + model.use_flash_attention=True" + sh "rm -rf examples/nlp/language_modeling/gpt_pretrain_results" + sh "rm -rf examples/nlp/language_modeling/gpt_index_mappings" + } + } + stage('L2: Megatron GPT with ALiBi Pretraining and Resume Training TP=2') { + when { + anyOf { + branch 'main' + changeRequest target: 'main' + } + } + failFast true + steps { + sh "python examples/nlp/language_modeling/megatron_gpt_pretraining.py \ + trainer.devices=2 \ + trainer.accelerator=gpu \ + trainer.log_every_n_steps=1 \ + trainer.val_check_interval=2 \ + trainer.limit_val_batches=2 \ + trainer.accumulate_grad_batches=1 \ + trainer.max_steps=3 \ + trainer.precision=16 \ + trainer.gradient_clip_val=1.0 \ + exp_manager.exp_dir=examples/nlp/language_modeling/gpt_pretrain_results \ + model.tensor_model_parallel_size=2 \ + model.optim.name=fused_adam \ + model.optim.lr=2e-4 \ + model.optim.sched.warmup_steps=1 \ + model.optim.sched.constant_steps=1 \ + model.optim.sched.min_lr=8e-5 \ + model.max_position_embeddings=128 \ + model.encoder_seq_length=128 \ + model.data.seq_length=128 \ + model.position_embedding_type=alibi \ + model.normalization=rmsnorm \ + model.bias=False \ + model.bias_activation_fusion=False \ + model.bias_dropout_add_fusion=False \ + model.tokenizer.vocab_file=/home/TestData/nlp/megatron_gpt/data/gpt/vocab.json \ + model.tokenizer.merge_file=/home/TestData/nlp/megatron_gpt/data/gpt/merges.txt \ + model.num_layers=8 \ + model.hidden_size=256 \ + model.num_attention_heads=8 \ + model.activations_checkpoint_method='block' \ + model.activations_checkpoint_granularity='full' \ + model.activations_checkpoint_num_layers=1 \ + model.data.data_prefix=[.5,/home/TestData/nlp/megatron_gpt/data/gpt/simple_wiki_gpt_preproc_text_document,.5,/home/TestData/nlp/megatron_gpt/data/gpt/simple_wiki_gpt_preproc_text_document] \ + model.data.index_mapping_dir=examples/nlp/language_modeling/gpt_index_mappings" + sh "python examples/nlp/language_modeling/megatron_gpt_pretraining.py \ + trainer.devices=2 \ + trainer.accelerator=gpu \ + trainer.log_every_n_steps=1 \ + trainer.val_check_interval=2 \ + trainer.limit_val_batches=1 \ + trainer.accumulate_grad_batches=1 \ + trainer.max_steps=6 \ + trainer.precision=16 \ + trainer.gradient_clip_val=1.0 \ + exp_manager.exp_dir=examples/nlp/language_modeling/gpt_pretrain_results \ + exp_manager.resume_if_exists=True \ + model.tensor_model_parallel_size=2 \ + model.optim.name=fused_adam \ + model.optim.lr=2e-4 \ + model.optim.sched.warmup_steps=2 \ + model.optim.sched.constant_steps=2 \ + model.optim.sched.min_lr=8e-5 \ + model.max_position_embeddings=128 \ + model.encoder_seq_length=128 \ + model.data.seq_length=128 \ + model.position_embedding_type=alibi \ + model.normalization=rmsnorm \ + model.bias=False \ + model.bias_activation_fusion=False \ + model.bias_dropout_add_fusion=False \ + model.tokenizer.vocab_file=/home/TestData/nlp/megatron_gpt/data/gpt/vocab.json \ + model.tokenizer.merge_file=/home/TestData/nlp/megatron_gpt/data/gpt/merges.txt \ + model.num_layers=8 \ + model.hidden_size=256 \ + model.num_attention_heads=8 \ + model.activations_checkpoint_method='block' \ + model.activations_checkpoint_granularity='full' \ + model.activations_checkpoint_num_layers=1 \ + model.data.data_prefix=[.5,/home/TestData/nlp/megatron_gpt/data/gpt/simple_wiki_gpt_preproc_text_document,.5,/home/TestData/nlp/megatron_gpt/data/gpt/simple_wiki_gpt_preproc_text_document] \ + model.data.index_mapping_dir=examples/nlp/language_modeling/gpt_index_mappings" + sh "rm -rf examples/nlp/language_modeling/gpt_pretrain_results" + sh "rm -rf examples/nlp/language_modeling/gpt_index_mappings" + } + } + stage('L2: Megatron GPT with KERPLE Pretraining and Resume Training TP=2') { + when { + anyOf { + branch 'main' + changeRequest target: 'main' + } + } + failFast true + steps { + sh "python examples/nlp/language_modeling/megatron_gpt_pretraining.py \ + trainer.devices=2 \ + trainer.accelerator=gpu \ + trainer.log_every_n_steps=1 \ + trainer.val_check_interval=2 \ + trainer.limit_val_batches=2 \ + trainer.accumulate_grad_batches=1 \ + trainer.max_steps=3 \ + trainer.precision=16 \ + trainer.gradient_clip_val=1.0 \ + exp_manager.exp_dir=examples/nlp/language_modeling/gpt_pretrain_results \ + model.tensor_model_parallel_size=2 \ + model.optim.name=fused_adam \ + model.optim.lr=2e-4 \ + model.optim.sched.warmup_steps=1 \ + model.optim.sched.constant_steps=1 \ + model.optim.sched.min_lr=8e-5 \ + model.max_position_embeddings=128 \ + model.encoder_seq_length=128 \ + model.data.seq_length=128 \ + model.position_embedding_type=kerple \ + model.normalization=rmsnorm \ + model.bias=False \ + model.bias_activation_fusion=False \ + model.bias_dropout_add_fusion=False \ + model.tokenizer.vocab_file=/home/TestData/nlp/megatron_gpt/data/gpt/vocab.json \ + model.tokenizer.merge_file=/home/TestData/nlp/megatron_gpt/data/gpt/merges.txt \ + model.num_layers=8 \ + model.hidden_size=256 \ + model.num_attention_heads=8 \ + model.activations_checkpoint_method='block' \ + model.activations_checkpoint_granularity='full' \ + model.activations_checkpoint_num_layers=1 \ + model.data.data_prefix=[.5,/home/TestData/nlp/megatron_gpt/data/gpt/simple_wiki_gpt_preproc_text_document,.5,/home/TestData/nlp/megatron_gpt/data/gpt/simple_wiki_gpt_preproc_text_document] \ + model.data.index_mapping_dir=examples/nlp/language_modeling/gpt_index_mappings" + sh "python examples/nlp/language_modeling/megatron_gpt_pretraining.py \ + trainer.devices=2 \ + trainer.accelerator=gpu \ + trainer.log_every_n_steps=1 \ + trainer.val_check_interval=2 \ + trainer.limit_val_batches=1 \ + trainer.accumulate_grad_batches=1 \ + trainer.max_steps=6 \ + trainer.precision=16 \ + trainer.gradient_clip_val=1.0 \ + exp_manager.exp_dir=examples/nlp/language_modeling/gpt_pretrain_results \ + exp_manager.resume_if_exists=True \ + model.tensor_model_parallel_size=2 \ + model.optim.name=fused_adam \ + model.optim.lr=2e-4 \ + model.optim.sched.warmup_steps=2 \ + model.optim.sched.constant_steps=2 \ + model.optim.sched.min_lr=8e-5 \ + model.max_position_embeddings=128 \ + model.encoder_seq_length=128 \ + model.data.seq_length=128 \ + model.position_embedding_type=kerple \ + model.normalization=rmsnorm \ + model.bias=False \ + model.bias_activation_fusion=False \ + model.bias_dropout_add_fusion=False \ + model.tokenizer.vocab_file=/home/TestData/nlp/megatron_gpt/data/gpt/vocab.json \ + model.tokenizer.merge_file=/home/TestData/nlp/megatron_gpt/data/gpt/merges.txt \ + model.num_layers=8 \ + model.hidden_size=256 \ + model.num_attention_heads=8 \ + model.activations_checkpoint_method='block' \ + model.activations_checkpoint_granularity='full' \ + model.activations_checkpoint_num_layers=1 \ + model.data.data_prefix=[.5,/home/TestData/nlp/megatron_gpt/data/gpt/simple_wiki_gpt_preproc_text_document,.5,/home/TestData/nlp/megatron_gpt/data/gpt/simple_wiki_gpt_preproc_text_document] \ + model.data.index_mapping_dir=examples/nlp/language_modeling/gpt_index_mappings" + sh "rm -rf examples/nlp/language_modeling/gpt_pretrain_results" + sh "rm -rf examples/nlp/language_modeling/gpt_index_mappings" + } + } stage('L2: Megatron GPT Pretraining and Resume Training PP=2') { when { anyOf { diff --git a/README.rst b/README.rst index b9ba7fce30f3..863b279b2be8 100644 --- a/README.rst +++ b/README.rst @@ -280,6 +280,16 @@ It is highly recommended to use the NVIDIA PyTorch or NeMo container if having i Transformer Engine requires PyTorch to be built with CUDA 11.8. + +Flash Attention +~~~~~~~~~~~~~~~~~~~~ +Transformer Engine already supports Flash Attention for GPT models. If you want to use Flash Attention for non-causal models or use with attention bias (introduced from position encoding, e.g. Alibi), please install `flash-attn `_. + +.. code-block:: bash + + pip install flash-attn + pip install triton==2.0.0.dev20221202 + NeMo Text Processing ~~~~~~~~~~~~~~~~~~~~ NeMo Text Processing, specifically (Inverse) Text Normalization, is now a separate repository `https://github.com/NVIDIA/NeMo-text-processing `_. diff --git a/examples/nlp/language_modeling/conf/megatron_gpt_config.yaml b/examples/nlp/language_modeling/conf/megatron_gpt_config.yaml index d502f255bd8e..d1132a32349a 100755 --- a/examples/nlp/language_modeling/conf/megatron_gpt_config.yaml +++ b/examples/nlp/language_modeling/conf/megatron_gpt_config.yaml @@ -77,7 +77,7 @@ model: transformer_block_type: 'pre_ln' # Options ['pre_ln', 'post_ln', 'normformer'] openai_gelu: False # Use OpenAI's GELU instead of the default GeLU normalize_attention_scores: True # Whether to scale the output Q * K^T by 1 / sqrt(hidden_size_per_head). This arg is provided as a configuration option mostly for compatibility with models that have been weight-converted from HF. You almost always want to se this to True. - position_embedding_type: 'learned_absolute' # Position embedding type. Options ['learned_absolute', 'rope'] + position_embedding_type: 'learned_absolute' # Position embedding type. Options ['learned_absolute', 'rope', 'alibi', 'kerple' , 'xpos', 'sandwich'] xpos and sandwich are experimental. rotary_percentage: 1.0 # If using position_embedding_type=rope, then the per head dim is multiplied by this. attention_type: 'multihead' # Attention type. Options ['multihead'] share_embeddings_and_output_weights: True # Share embedding and output layer weights. @@ -167,6 +167,9 @@ model: reduce_amax: True # Perform reduction to sync amax tensors across GPUs after every iteration use_emha: False # Use fused multi-head attention for large sequence-length. Note this is not yet supported. Please set to False. + ## Flash Attention + use_flash_attention: False # Use flash attention in self-attention module, this config does nothing when transformer_engine=True + data: # Path to data must be specified by the user. # Supports List, String and Dictionary diff --git a/examples/nlp/language_modeling/conf/megatron_model_base_config.yaml b/examples/nlp/language_modeling/conf/megatron_model_base_config.yaml index d3feb97ea9b4..e98ebae6da63 100644 --- a/examples/nlp/language_modeling/conf/megatron_model_base_config.yaml +++ b/examples/nlp/language_modeling/conf/megatron_model_base_config.yaml @@ -36,4 +36,5 @@ megatron_legacy: False # Whether to use the legacy Megatron model. This affects normalize_attention_scores: True # Whether to scale the output Q * K^T by 1 / sqrt(hidden_size_per_head). This arg is provided as a configuration option mostly for compatibility with models that have been weight-converted from HF. You almost always want to se this to True. num_moe_experts: 1 # When >1, FFNs are changed to MoE layers moe_frequency: 1 # every Nth ffn layer will be made MoE -moe_dropout: 0.0 # Dropout value for MoE layers \ No newline at end of file +moe_dropout: 0.0 # Dropout value for MoE layers +use_flash_attention: false # Use flash attention in self-attention module \ No newline at end of file diff --git a/examples/nlp/language_modeling/tuning/conf/megatron_gpt_peft_eval_config.yaml b/examples/nlp/language_modeling/tuning/conf/megatron_gpt_peft_eval_config.yaml index 69dc17f244f5..8c21117969ab 100755 --- a/examples/nlp/language_modeling/tuning/conf/megatron_gpt_peft_eval_config.yaml +++ b/examples/nlp/language_modeling/tuning/conf/megatron_gpt_peft_eval_config.yaml @@ -129,4 +129,5 @@ inference: repetition_penalty: 1.2 # The parameter for repetition penalty. 1.0 means no penalty. min_tokens_to_generate: 0 # The minimum length of the sequence to be generated. compute_logprob: False # a flag used to compute logprob of all the input text, a very special case of running inference, default False - outfile_path: output.txt \ No newline at end of file + outfile_path: output.txt + compute_attention_mask: True \ No newline at end of file diff --git a/nemo/collections/nlp/models/language_modeling/megatron/gpt_model.py b/nemo/collections/nlp/models/language_modeling/megatron/gpt_model.py index e890e6ae4807..b43dc98f2fe7 100755 --- a/nemo/collections/nlp/models/language_modeling/megatron/gpt_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron/gpt_model.py @@ -151,6 +151,7 @@ def __init__( gradient_accumulation_fusion=False, persist_layer_norm=False, openai_gelu=False, + megatron_legacy=False, onnx_safe=False, sequence_parallel=False, transformer_engine=False, @@ -163,6 +164,7 @@ def __init__( fp8_amax_compute_algo='most_recent', reduce_amax=True, use_emha=False, + use_flash_attention=False, ): super(GPTModel, self).__init__(share_token_embeddings=share_embeddings_and_output_weights) @@ -232,6 +234,7 @@ def __init__( persist_layer_norm=persist_layer_norm, openai_gelu=openai_gelu, onnx_safe=onnx_safe, + megatron_legacy=megatron_legacy, sequence_parallel=sequence_parallel, transformer_engine=transformer_engine, fp8=fp8, @@ -243,6 +246,7 @@ def __init__( fp8_amax_compute_algo=fp8_amax_compute_algo, reduce_amax=reduce_amax, use_emha=use_emha, + use_flash_attention=use_flash_attention, ) if self.share_embeddings_and_output_weights: diff --git a/nemo/collections/nlp/models/language_modeling/megatron_base_model.py b/nemo/collections/nlp/models/language_modeling/megatron_base_model.py index 2568a14f8dbf..7be679376175 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_base_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_base_model.py @@ -25,6 +25,7 @@ from pytorch_lightning.trainer.trainer import Trainer from nemo.collections.nlp.models.nlp_model import NLPModel +from nemo.collections.nlp.modules.common.megatron.attention import HAVE_FLASH_ATTENTION from nemo.collections.nlp.modules.common.megatron.clip_grads import ( clip_grad_norm_distributed_optimizer, clip_grad_norm_fp32, @@ -84,6 +85,12 @@ def __init__(self, cfg: DictConfig, trainer: Trainer, no_lm_init=True): if trainer is None: raise ValueError(f"Trainer cannot be None for Megatron-based models. Please provide a PTL trainer object.") + if cfg.get('use_flash_attention', False) and not HAVE_FLASH_ATTENTION: + raise ImportError( + "flash_attn was not found. Please see the installation instructions: https://github.com/HazyResearch/flash-attention." + "If you use flash_attn with triton. Please install triton==2.0.0.dev20221202." + ) + # this prevents base constructor from initializing tokenizer self.tokenizer = None @@ -205,9 +212,10 @@ def _build_tokenizer(self): self.tokenizer = get_nmt_tokenizer( library=self._cfg.tokenizer.library, model_name=self._cfg.tokenizer.type, - tokenizer_model=self.register_artifact("tokenizer.model", self._cfg.tokenizer.model), - vocab_file=self.register_artifact("tokenizer.vocab_file", self._cfg.tokenizer.vocab_file), - merges_file=self.register_artifact("tokenizer.merge_file", self._cfg.tokenizer.merge_file), + tokenizer_model=self.register_artifact("tokenizer.model", self._cfg.tokenizer.get('model', None)), + vocab_file=self.register_artifact("tokenizer.vocab_file", self._cfg.tokenizer.get('vocab_file', None)), + merges_file=self.register_artifact("tokenizer.merge_file", self._cfg.tokenizer.get('merge_file', None)), + use_fast=self.cfg.tokenizer.get('use_fast', False), delimiter=self.cfg.tokenizer.get('delimiter', None), legacy=legacy, ) diff --git a/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py b/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py index 8eff896cf9d8..853c637eb3b3 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py @@ -300,7 +300,7 @@ def get_inference_config(self): def model_provider_func(self, pre_process, post_process): """Model depends on pipeline paralellism.""" model = GPTModel( - vocab_size=self.padded_vocab_size, + vocab_size=self.cfg.get('override_vocab_size', self.padded_vocab_size), hidden_size=self.cfg.hidden_size, max_position_embeddings=self.cfg.max_position_embeddings, num_layers=self.cfg.num_layers, @@ -357,6 +357,8 @@ def model_provider_func(self, pre_process, post_process): fp8_amax_compute_algo=self.cfg.get('fp8_amax_compute_algo', 'most_recent'), reduce_amax=self.cfg.get('reduce_amax', True), use_emha=self.cfg.get('use_emha', False), + use_flash_attention=self.cfg.get('use_flash_attention', False), + megatron_legacy=self.cfg.get('megatron_legacy', False), ) return model @@ -765,7 +767,6 @@ def fwd_output_and_loss_func(dataloader_iter, model, checkpoint_activations_all_ if self.get_attention_mask_from_fusion: required_keys.remove('attention_mask') batch = {key: val.cuda(non_blocking=True) if key in required_keys else None for key, val in batch.items()} - # Model forward pass output_tensor = model( batch['tokens'], @@ -822,9 +823,10 @@ def fwd_output_only_func(dataloader_iter, model): inference_max_sequence_len, ) = batch tokens = tokens.cuda() - attention_mask = attention_mask.cuda() position_ids = position_ids.cuda() - attention_mask = attention_mask[0:1] + if attention_mask is not None: + attention_mask = attention_mask.cuda() + attention_mask = attention_mask[0:1] extra_arg['set_inference_key_value_memory'] = set_inference_key_value_memory[0].item() extra_arg['inference_max_sequence_len'] = inference_max_sequence_len[0].item() output_tensor = model(tokens, position_ids, attention_mask, **extra_arg) diff --git a/nemo/collections/nlp/models/language_modeling/megatron_gpt_prompt_learning_model.py b/nemo/collections/nlp/models/language_modeling/megatron_gpt_prompt_learning_model.py index 95448e67bd11..81ca1c283ad0 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_gpt_prompt_learning_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_gpt_prompt_learning_model.py @@ -753,6 +753,7 @@ def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: Optional[int] "add_BOS": inference_config["add_BOS"], "all_probs": inference_config["all_probs"], "compute_logprob": inference_config["compute_logprob"], + "compute_attention_mask": inference_config.get("compute_attention_mask", True), } task_ids, processed_inputs = batch diff --git a/nemo/collections/nlp/modules/common/megatron/attention.py b/nemo/collections/nlp/modules/common/megatron/attention.py index 9c954b5e6313..b0d98e0c2fb1 100644 --- a/nemo/collections/nlp/modules/common/megatron/attention.py +++ b/nemo/collections/nlp/modules/common/megatron/attention.py @@ -27,8 +27,15 @@ ) from nemo.collections.nlp.modules.common.megatron.fused_softmax import MatchedScaleMaskSoftmax from nemo.collections.nlp.modules.common.megatron.module import MegatronModule -from nemo.collections.nlp.modules.common.megatron.rotary_pos_embedding import apply_rotary_pos_emb -from nemo.collections.nlp.modules.common.megatron.utils import ApexGuardDefaults, attention_mask_func +from nemo.collections.nlp.modules.common.megatron.position_embedding import XPOSPositionEmbedding +from nemo.collections.nlp.modules.common.megatron.position_embedding.rotary_position_embedding import ( + apply_rotary_pos_emb, +) +from nemo.collections.nlp.modules.common.megatron.utils import ( + ApexGuardDefaults, + _cast_if_autocast_enabled, + attention_mask_func, +) from nemo.collections.nlp.parts import utils_funcs from nemo.core import adapter_mixins @@ -55,6 +62,20 @@ HAVE_MEGATRON_CORE = False +try: + from flash_attn.bert_padding import pad_input, unpad_input + from flash_attn.flash_attn_interface import flash_attn_unpadded_func + from flash_attn.flash_attn_triton import flash_attn_func + + HAVE_FLASH_ATTENTION = True + +except (ImportError, ModuleNotFoundError): + + HAVE_FLASH_ATTENTION = False + + flash_attn_unpadded_func, flash_attn_func = None, None + unpad_input, pad_input = None, None + """ We use the following notation throughout this file: h: hidden size n: number of attention heads @@ -104,9 +125,9 @@ def __init__( sequence_parallel=False, gradient_accumulation_fusion=False, normalize_attention_scores=True, + use_flash_attention=False, ): super(ParallelAttention, self).__init__() - self.layer_number = max(1, layer_number) self.attention_type = attention_type self.attn_mask_type = attn_mask_type @@ -201,6 +222,8 @@ def __init__( multi_query_attention=multi_query_attention, sequence_parallel=sequence_parallel, normalize_attention_scores=normalize_attention_scores, + position_embedding_type=position_embedding_type, + use_flash_attention=use_flash_attention, ) # Output. @@ -292,14 +315,14 @@ def custom_forward(*inputs): return hidden_states - def _allocate_memory(self, inference_max_sequence_len, batch_size, dtype): + def _allocate_memory(self, inference_max_sequence_len, batch_size, dtype, device): return torch.empty( inference_max_sequence_len, batch_size, self.num_attention_heads_per_partition, self.hidden_size_per_attention_head, dtype=dtype, - device=torch.cuda.current_device(), + device=device, ) def _transpose_last_dim(self, mixed_layer, num_splits, num_splits_first): @@ -357,10 +380,10 @@ def forward( if set_inference_key_value_memory: assert inference_max_sequence_len and inference_max_sequence_len > 0 self.inference_key_memory = self._allocate_memory( - inference_max_sequence_len, hidden_states.size(1), hidden_states.dtype + inference_max_sequence_len, hidden_states.size(1), hidden_states.dtype, hidden_states.device ) self.inference_value_memory = self._allocate_memory( - inference_max_sequence_len, hidden_states.size(1), hidden_states.dtype + inference_max_sequence_len, hidden_states.size(1), hidden_states.dtype, hidden_states.device ) self.inference_current_sequence_len = 0 @@ -469,7 +492,8 @@ def forward( key_layer = self.inference_key_memory[:end, ...] value_layer = self.inference_value_memory[:end, ...] # Adjust attention mask - attention_mask = attention_mask[..., start:end, :end] + if attention_mask is not None: + attention_mask = attention_mask[..., start:end, :end] # adjust the key rotary positional embedding if rotary_pos_emb is not None: q_pos_emb, k_pos_emb = rotary_pos_emb @@ -711,6 +735,8 @@ def __init__( sequence_parallel=False, normalize_attention_scores=True, multi_query_attention=False, + position_embedding_type='learned_absolute', + use_flash_attention=False, ): super(CoreAttention, self).__init__() @@ -723,6 +749,7 @@ def __init__( elif int(precision) == 16: self.fp16 = True self.multi_query_attention = multi_query_attention + self.position_embedding_type = position_embedding_type self.apply_query_key_layer_scaling = apply_query_key_layer_scaling self.attention_softmax_in_fp32 = False @@ -772,8 +799,17 @@ def __init__( # Dropout. Note that for a single iteration, this layer will generate # different outputs on different number of parallel partitions but # on average it should not be partition dependent. + self.attention_dropout_p = attention_dropout self.attention_dropout = torch.nn.Dropout(attention_dropout) + if use_flash_attention: + self.attn_fn = self.flash_attention + else: + self.attn_fn = self.torch_attention + + if position_embedding_type.lower() == 'xpos': + self.xpos = XPOSPositionEmbedding(kv_channels) + def forward( self, query_layer, @@ -786,19 +822,43 @@ def forward( relative_position_bias=None, headscale_tensor=None, ): + b, np, sq, sk, hn = ( + query_layer.size(1), + query_layer.size(2), + query_layer.size(0), + key_layer.size(0), + query_layer.size(3), + ) - # =================================== - # Raw attention scores. [b, np, s, s] - # =================================== + # ================================================== + # Update attention mask for inference. [b, np, sq, sk] + # ================================================== + if get_key_value: + with torch.no_grad(): + if layer_past is not None: + attention_mask = attention_mask[..., sq - 1, :sk].unsqueeze(2) + else: + attention_mask = attention_mask[..., :sq, :sk] - # [b, np, sq, sk] - output_size = (query_layer.size(1), query_layer.size(2), query_layer.size(0), key_layer.size(0)) + # ================================================== + # Update attention bias. [b, np, sq, sk] + # ================================================== + if relative_position_bias is not None: + relative_position_bias = relative_position_bias[ + :, + self.num_attention_heads_partition_offset : self.num_attention_heads_partition_offset + + self.num_attention_heads_per_partition, + -sq:, + -sk:, + ] + # ================================================== + # Update query_layer, key_layer, value_layer + # ================================================== # TODO: figure out how to do this # apply relative positional encoding (rotary embedding) if rotary_pos_emb is not None: q_pos_emb, k_pos_emb = rotary_pos_emb - query_layer = apply_rotary_pos_emb(query_layer, q_pos_emb) key_layer = apply_rotary_pos_emb(key_layer, k_pos_emb) # TODO, can apply positional embedding to value_layer so it has @@ -806,86 +866,67 @@ def forward( # otherwise, only relative positional embedding takes effect # value_layer = apply_rotary_pos_emb(value_layer, k_pos_emb) - if self.multi_query_attention: - # [sq, b, np, hn] -> [b, np * sq, hn] - query_layer = query_layer.permute([1, 2, 0, 3]).reshape( - output_size[0], output_size[1] * output_size[2], -1 - ) + if self.position_embedding_type.lower() == 'xpos': + query_layer = self.xpos(query_layer, offset=key_layer.shape[-2] - query_layer.shape[-2], downscale=False) + key_layer = self.xpos(key_layer, offset=0, downscale=True) - # [sk, b, 1, hn] -> [b, hn, sk] - key_layer = key_layer.squeeze(2).permute(1, 2, 0) + # ================================================== + # query_layer [sq, b, np, hn] + # key_layer [sk, b, np, hn] + # value_layer [sk, b, np, hn] + # attention_mask [b, 1, sq, sk] or [b, s] + # relative_position_bias [b, np, sq, sk] + # context_layer [b, np, sq, hn] + # ================================================== + context_layer = self.attn_fn(query_layer, key_layer, value_layer, attention_mask, relative_position_bias) - # preallocting input tensor: [b * np, sq, sk] - matmul_input_buffer = torch.empty( - output_size[0] * output_size[1], - output_size[2], - output_size[3], - dtype=query_layer.dtype, - device=torch.cuda.current_device(), - ) + if headscale_tensor is not None: + context_layer = context_layer * headscale_tensor - # Raw attention scores. [b * np, sq, sk] - matmul_result = torch.baddbmm( - matmul_input_buffer, - query_layer, # [b * np, sq, hn] - key_layer, # [b * np, hn, sk] - beta=0.0, - alpha=(1.0 / self.norm_factor), - ) - else: - # [sq, b, np, hn] -> [sq, b * np, hn] - query_layer = query_layer.view(output_size[2], output_size[0] * output_size[1], -1) - # [sk, b, np, hn] -> [sk, b * np, hn] - key_layer = key_layer.view(output_size[3], output_size[0] * output_size[1], -1) - - # preallocting input tensor: [b * np, sq, sk] - matmul_input_buffer = torch.empty( - output_size[0] * output_size[1], - output_size[2], - output_size[3], - dtype=query_layer.dtype, - device=torch.cuda.current_device(), - ) + # [b, np, sq, hn] --> [sq, b, np, hn] + context_layer = context_layer.permute(2, 0, 1, 3).contiguous() - # Raw attention scores. [b * np, sq, sk] - matmul_result = torch.baddbmm( - matmul_input_buffer, - query_layer.transpose(0, 1), # [b * np, sq, hn] - key_layer.transpose(0, 1).transpose(1, 2), # [b * np, hn, sk] - beta=0.0, - alpha=(1.0 / self.norm_factor) if self.normalize_attention_scores else 1.0, - ) + # [sq, b, np, hn] --> [sq, b, hp] + new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,) + context_layer = context_layer.view(*new_context_layer_shape) - # change view to [b, np, sq, sk] - attention_scores = matmul_result.view(*output_size) + return context_layer - if relative_position_bias is not None: - attention_scores += relative_position_bias[ - :, - self.num_attention_heads_partition_offset : self.num_attention_heads_partition_offset - + self.num_attention_heads_per_partition, - : attention_scores.size(2), - : attention_scores.size(3), - ] + def torch_attention(self, query_layer, key_layer, value_layer, attention_mask, attention_bias): + sq, b, np, hn = query_layer.shape + sk = key_layer.shape[0] - # ================================================== - # Update attention mask for inference. [b, np, sq, sk] - # ================================================== + if self.multi_query_attention: + query_layer = rearrange(query_layer, 'sq b np hn -> b (np sq) hn') + key_layer = rearrange(key_layer, 'sk b 1 hn -> b hn sk') + value_layer = rearrange(value_layer, 'sv b np hn -> (b np) sv hn') + else: + query_layer = rearrange(query_layer, 'sq b np hn -> (b np) sq hn') + key_layer = rearrange(key_layer, 'sk b np hn -> (b np) hn sk') + value_layer = rearrange(value_layer, 'sv b np hn -> (b np) sv hn') + + matmul_input_buffer = torch.empty( + query_layer.shape[0], + query_layer.shape[1], + key_layer.shape[2], + dtype=query_layer.dtype, + device=query_layer.device, + ) - if get_key_value: - with torch.no_grad(): - if layer_past is not None: - attention_mask = attention_mask[ - ..., attention_scores.size(3) - 1, : attention_scores.size(3) - ].unsqueeze(2) - else: - attention_mask = attention_mask[..., : attention_scores.size(3), : attention_scores.size(3)] + matmul_result = torch.baddbmm( + matmul_input_buffer, + query_layer, + key_layer, + beta=0.0, + alpha=(1.0 / self.norm_factor) if self.normalize_attention_scores else 1.0, + ) + + # change view to [b, np, sq, sk] + attention_scores = matmul_result.view(b, np, sq, sk) - # =========================== - # Attention probs and dropout - # =========================== + if attention_bias is not None: + attention_scores += attention_bias - # attention scores and attention mask [b, np, sq, sk] attention_probs = self.scale_mask_softmax(attention_scores, attention_mask) # This is actually dropping out entire tokens to attend to, which might @@ -897,36 +938,111 @@ def forward( else: attention_probs = self.attention_dropout(attention_probs) - # ========================= - # Context layer. [sq, b, hp] - # ========================= + # change view [b * np, sq, sk] + attention_probs = rearrange(attention_probs, 'b np sq sk -> (b np) sq sk') - # value_layer -> context layer. - # [sk, b, np, hn] --> [b, np, sq, hn] + # matmul: [b * np, sq, hn] + context_layer = torch.bmm(attention_probs, value_layer) - # context layer shape: [b, np, sq, hn] - output_size = (value_layer.size(1), value_layer.size(2), query_layer.size(0), value_layer.size(3)) + # change view [b, np, sq, hn] + context_layer = rearrange(context_layer, '(b np) sq hn -> b np sq hn', np=np) - # change view [sk, b * np, hn] - value_layer = value_layer.view(value_layer.size(0), output_size[0] * output_size[1], -1) + return context_layer - # change view [b * np, sq, sk] - attention_probs = attention_probs.view(output_size[0] * output_size[1], output_size[2], -1) + def flash_attention(self, query_layer, key_layer, value_layer, attention_mask, attention_bias): + query_layer = rearrange(query_layer, 'sq b np hn -> b sq np hn') + key_layer = rearrange(key_layer, 'sk b np hn -> b sk np hn') + value_layer = rearrange(value_layer, 'sv b np hn -> b sv np hn') - # matmul: [b * np, sq, hn] - context_layer = torch.bmm(attention_probs, value_layer.transpose(0, 1)) + # Use to ensure dtype cast to fp16 or bf16 + query_layer = _cast_if_autocast_enabled(query_layer) + key_layer = _cast_if_autocast_enabled(key_layer) + value_layer = _cast_if_autocast_enabled(value_layer) + attention_mask = _cast_if_autocast_enabled(attention_mask) + attention_bias = _cast_if_autocast_enabled(attention_bias) - # change view [b, np, sq, hn] - context_layer = context_layer.view(*output_size) + if attention_bias is not None: + return self.flash_attention_triton(query_layer, key_layer, value_layer, attention_mask, attention_bias,) + else: + return self.flash_attention_cuda(query_layer, key_layer, value_layer, attention_mask,) + + def reset_is_causal(self, query_length, key_length, causal): + if query_length != key_length: + if query_length == 1: + return False + raise NotImplementedError( + "Flash attention does not support query and key with different number of tokens, unless number of query tokens is 1." + ) + return causal + + def flash_attention_cuda(self, query_layer, key_layer, value_layer, attention_mask): + batch_size, seqlen, nheads, _ = query_layer.shape + + # True: attend / False: not attend + if attention_mask is None: + attention_mask_q = torch.ones(batch_size, query_layer.shape[1], device=query_layer.device).bool() + attention_mask_kv = torch.ones(batch_size, key_layer.shape[1], device=query_layer.device).bool() + elif len(attention_mask.shape) == 4: + # [b, 1, sq, sk] -> [b, sq] / [b, sk] + attention_mask_q = torch.any(torch.eq(attention_mask, False), dim=3).squeeze(1) + attention_mask_kv = torch.any(torch.eq(attention_mask, False), dim=2).squeeze(1) + else: + assert len(attention_mask.shape) == 2 + attention_mask_q = attention_mask + attention_mask_kv = attention_mask + + q, indices_q, cu_seqlens_q, max_seqlen_q = unpad_input(query_layer, attention_mask_q) + k, _, cu_seqlens_k, max_seqlen_k = unpad_input(key_layer, attention_mask_kv) + v, _, _, _ = unpad_input(value_layer, attention_mask_kv) + causal = self.reset_is_causal( + query_layer.shape[1], key_layer.shape[1], self.attn_mask_type == AttnMaskType.causal + ) + context_layer = flash_attn_unpadded_func( + q, + k, + v, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + dropout_p=self.attention_dropout_p if self.training else 0.0, + causal=causal, + ) - if headscale_tensor is not None: - context_layer = context_layer * headscale_tensor + # [b, sq, np, hn] + context_layer = pad_input(context_layer, indices_q, batch_size, seqlen) - # [b, np, sq, hn] --> [sq, b, np, hn] - context_layer = context_layer.permute(2, 0, 1, 3).contiguous() + # [b, sq, np, hn] -> [b, np, sq, hn] + context_layer = context_layer.permute(0, 2, 1, 3) + return context_layer - # [sq, b, np, hn] --> [sq, b, hp] - new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,) - context_layer = context_layer.view(*new_context_layer_shape) + def flash_attention_triton(self, query_layer, key_layer, value_layer, attention_mask, attention_bias): + if self.attention_dropout_p > 0.0: + raise NotImplementedError(f'attention_dropout not implemented for flash_attention with attention bias') + + if attention_mask is not None: + if len(attention_mask.shape) == 4: + # [b, 1, sq, sk] -> [b, 1, sq, 1] / [b, 1, 1, sk] + attention_mask_q = torch.any(torch.eq(attention_mask, False), dim=3).unsqueeze(3) + attention_mask_kv = torch.any(torch.eq(attention_mask, False), dim=2).unsqueeze(2) + else: + # [b, s] -> [b, 1, s, 1] / [b, 1, 1, s] + assert len(attention_mask.shape) == 2 + attention_mask_q = attention_mask.unsqueeze(1).unsqueeze(3) + attention_mask_kv = attention_mask.unsqueeze(1).unsqueeze(2) + + attention_bias = attention_bias.masked_fill(~attention_mask_q, torch.finfo(query_layer.dtype).min) + attention_bias = attention_bias.masked_fill(~attention_mask_kv, torch.finfo(query_layer.dtype).min) + + causal = self.reset_is_causal( + query_layer.shape[1], key_layer.shape[1], self.attn_mask_type == AttnMaskType.causal + ) + context_layer = flash_attn_func(query_layer, key_layer, value_layer, attention_bias, causal) + + # [b, sq, np, hn] -> [b, np, sq, hn] + context_layer = context_layer.permute(0, 2, 1, 3) + + if attention_mask is not None: + context_layer = context_layer * attention_mask_q return context_layer diff --git a/nemo/collections/nlp/modules/common/megatron/language_model.py b/nemo/collections/nlp/modules/common/megatron/language_model.py index b8b12cf0caec..2d10576dc7d0 100755 --- a/nemo/collections/nlp/modules/common/megatron/language_model.py +++ b/nemo/collections/nlp/modules/common/megatron/language_model.py @@ -21,7 +21,12 @@ ) from nemo.collections.nlp.modules.common.megatron.layer_type import LayerType from nemo.collections.nlp.modules.common.megatron.module import MegatronModule -from nemo.collections.nlp.modules.common.megatron.rotary_pos_embedding import RotaryEmbedding +from nemo.collections.nlp.modules.common.megatron.position_embedding import ( + ALiBiRelativePositionEmbedding, + KERPLERelativePositionEmbedding, + RotaryEmbedding, + SandwichRelativePositionEmbedding, +) from nemo.collections.nlp.modules.common.megatron.transformer import ParallelTransformer from nemo.collections.nlp.modules.common.megatron.utils import ( ApexGuardDefaults, @@ -116,6 +121,7 @@ def get_language_model( fp8_amax_compute_algo='most_recent', reduce_amax=True, use_emha=False, + use_flash_attention=False, ): """Build language model and return along with the key to save.""" @@ -191,6 +197,7 @@ def get_language_model( fp8_amax_compute_algo=fp8_amax_compute_algo, reduce_amax=reduce_amax, use_emha=use_emha, + use_flash_attention=use_flash_attention, ) # key used for checkpoints. language_model_key = 'language_model' @@ -497,6 +504,7 @@ def __init__( fp8_amax_compute_algo='most_recent', reduce_amax=True, use_emha=False, + use_flash_attention=False, ): super(TransformerLanguageModel, self).__init__(share_token_embeddings=share_embeddings_and_output_weights) @@ -518,7 +526,6 @@ def __init__( self.share_embeddings_and_output_weights = share_embeddings_and_output_weights self.sequence_parallel = sequence_parallel self.dtype = utils_funcs.dtype_from_precision(precision, megatron_amp_O2) - if kv_channels is None: assert ( @@ -551,6 +558,40 @@ def __init__( rotary_dim = int(rotary_dim * rotary_percentage) self.rotary_pos_emb = RotaryEmbedding(rotary_dim) + elif position_embedding_type == 'alibi': + # TODO: If this is used for encoder-decodemax_position_embeddingsr model, implement proper logic and following + # addition for decoder. Currently it is only used for decoder model only. + # Encoder-decoder model, such as T5 is implemented in token_level_encoder_decoder.py + self.encoder_relative_position_embedding = ALiBiRelativePositionEmbedding( + bidirectional=encoder_attn_mask_type != AttnMaskType.causal, + num_attention_heads=num_attention_heads, + layer_type=LayerType.encoder, + num_attention_heads_alibi=None, + max_seq_len=max_position_embeddings, + ) + + elif position_embedding_type == 'kerple': + # TODO: If this is used for encoder-decodemax_position_embeddingsr model, implement proper logic and following + # addition for decoder. Currently it is only used for decoder model only. + # Encoder-decoder model, such as T5 is implemented in token_level_encoder_decoder.py + self.encoder_relative_position_embedding = KERPLERelativePositionEmbedding( + bidirectional=encoder_attn_mask_type != AttnMaskType.causal, + num_attention_heads=num_attention_heads, + layer_type=LayerType.encoder, + num_attention_heads_kerple=None, + max_seq_len=max_position_embeddings, + ) + assert use_flash_attention == False # flash-attention not supported with kerple at this point + + elif position_embedding_type == 'sandwich': + self.encoder_relative_position_embedding = SandwichRelativePositionEmbedding( + bidirectional=encoder_attn_mask_type != AttnMaskType.causal, + num_attention_heads=num_attention_heads, + layer_type=LayerType.encoder, + hidden_size=self.hidden_size // num_attention_heads if kv_channels is None else kv_channels, + max_seq_len=max_position_embeddings, + ) + # Transformer. self.encoder = ParallelTransformer( init_method=self.init_method, @@ -602,6 +643,8 @@ def __init__( fp8_amax_compute_algo=fp8_amax_compute_algo, reduce_amax=reduce_amax, use_emha=use_emha, + position_embedding_type=position_embedding_type, + use_flash_attention=use_flash_attention, ) self._encoder_key = 'encoder' @@ -642,6 +685,8 @@ def __init__( activations_checkpoint_granularity=activations_checkpoint_granularity, activations_checkpoint_layers_per_pipeline=activations_checkpoint_layers_per_pipeline, transformer_engine=transformer_engine, + position_embedding_type=position_embedding_type, + use_flash_attention=use_flash_attention, ) self._decoder_key = 'decoder' @@ -713,26 +758,35 @@ def forward( pass # enc_attn_mask: [1, 1, s, s] - - if self.position_embedding_type == 'rope': - if inference_max_sequence_len is not None: - rotary_pos_emb = self.rotary_pos_emb(inference_max_sequence_len) - elif self.encoder.input_tensor is not None: - if self.sequence_parallel: - rotary_pos_emb = self.rotary_pos_emb( - self.encoder.input_tensor.size(0) * parallel_state.get_tensor_model_parallel_world_size() - ) - else: - rotary_pos_emb = self.rotary_pos_emb(self.encoder.input_tensor.size(0)) + if inference_max_sequence_len is not None: + enc_seq_length = inference_max_sequence_len + elif self.encoder.input_tensor is not None: + if self.sequence_parallel: + enc_seq_length = ( + self.encoder.input_tensor.size(0) * parallel_state.get_tensor_model_parallel_world_size() + ) else: - if self.sequence_parallel: - rotary_pos_emb = self.rotary_pos_emb( - encoder_input.size(0) * parallel_state.get_tensor_model_parallel_world_size() - ) - else: - rotary_pos_emb = self.rotary_pos_emb(encoder_input.size(0)) + enc_seq_length = self.encoder.input_tensor.size(0) else: - rotary_pos_emb = None + if self.sequence_parallel: + enc_seq_length = encoder_input.size(0) * parallel_state.get_tensor_model_parallel_world_size() + else: + enc_seq_length = encoder_input.size(0) + + rotary_pos_emb = None + encoder_self_attention_relative_position_bias = None + if self.position_embedding_type == 'rope': + rotary_pos_emb = self.rotary_pos_emb(enc_seq_length) + elif ( + self.position_embedding_type == 'alibi' + or self.position_embedding_type == 'sandwich' + or self.position_embedding_type == 'kerple' + ): + encoder_self_attention_relative_position_bias = self.encoder_relative_position_embedding( + query_seq_length=enc_seq_length, key_seq_length=enc_seq_length, + ) + # causal attention bias: [1, head, 1, k] + # non-causal attention bias: [1, head, q, k] # encoder. if enc_hidden_states is None: @@ -747,6 +801,7 @@ def forward( rotary_pos_emb=(rotary_pos_emb, None, None) if rotary_pos_emb is not None else None, # This assumes that this being used as a GPT/BERT model only (no cross-attention) + self_attention_relative_position_bias=encoder_self_attention_relative_position_bias, ) else: encoder_output = enc_hidden_states.to(encoder_input.dtype) diff --git a/nemo/collections/nlp/modules/common/megatron/layer_norm_1p.py b/nemo/collections/nlp/modules/common/megatron/layer_norm_1p.py index ca59bcc8850a..4a94b37aae7b 100644 --- a/nemo/collections/nlp/modules/common/megatron/layer_norm_1p.py +++ b/nemo/collections/nlp/modules/common/megatron/layer_norm_1p.py @@ -13,7 +13,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -from torch import nn +import torch +from nemo.collections.nlp.modules.common.megatron.utils import _cast_if_autocast_enabled try: from apex.contrib.layer_norm.layer_norm import FastLayerNorm as OrigFastLayerNorm @@ -35,8 +36,8 @@ def __init__(self, *args, **kwargs): ), 'LayerNorm1P implemented only as an apex.contrib.layer_norm.FastLayerNorm extension' def reset_parameters(self): - nn.init.zeros_(self.weight) - nn.init.zeros_(self.bias) + torch.nn.init.zeros_(self.weight) + torch.nn.init.zeros_(self.bias) def forward(self, x): return _fast_layer_norm(x, self.weight + 1, self.bias, self.epsilon) @@ -44,6 +45,27 @@ def forward(self, x): else: - class LayerNorm1P(nn.Module): + class LayerNorm1P(torch.nn.Module): def __init__(self, *args, **kwargs): raise NotImplementedError('LayerNorm1P available only with apex installed') + + +class LPLayerNorm(torch.nn.LayerNorm): + def __init__(self, normalized_shape, eps=1e-05, elementwise_affine=True, device=None, dtype=None): + super().__init__( + normalized_shape=normalized_shape, + eps=eps, + elementwise_affine=elementwise_affine, + device=device, + dtype=dtype, + ) + + def forward(self, x): + module_device = x.device + downcast_x = _cast_if_autocast_enabled(x) + downcast_weight = _cast_if_autocast_enabled(self.weight) if self.weight is not None else self.weight + downcast_bias = _cast_if_autocast_enabled(self.bias) if self.bias is not None else self.bias + with torch.autocast(enabled=False, device_type=module_device.type): + return torch.nn.functional.layer_norm( + downcast_x, self.normalized_shape, downcast_weight, downcast_bias, self.eps + ) diff --git a/nemo/collections/nlp/modules/common/megatron/megatron_decoders.py b/nemo/collections/nlp/modules/common/megatron/megatron_decoders.py index 28eb39e630fc..ca2000842fe4 100644 --- a/nemo/collections/nlp/modules/common/megatron/megatron_decoders.py +++ b/nemo/collections/nlp/modules/common/megatron/megatron_decoders.py @@ -88,6 +88,8 @@ def get_decoder_model( moe_dropout=0.0, turn_off_rop=False, # turn off the RoP positional embedding version=1, + position_embedding_type='learned_absolute', + use_flash_attention=False, ): """Build language model and return along with the key to save.""" @@ -145,6 +147,8 @@ def get_decoder_model( num_moe_experts=num_moe_experts, moe_frequency=moe_frequency, moe_dropout=moe_dropout, + position_embedding_type=position_embedding_type, + use_flash_attention=use_flash_attention, ) elif arch == "retro": decoder = MegatronRetrievalTransformerDecoderModule( diff --git a/nemo/collections/nlp/modules/common/megatron/megatron_encoders.py b/nemo/collections/nlp/modules/common/megatron/megatron_encoders.py index 4005ffbd879e..9f5d917e2077 100644 --- a/nemo/collections/nlp/modules/common/megatron/megatron_encoders.py +++ b/nemo/collections/nlp/modules/common/megatron/megatron_encoders.py @@ -90,6 +90,8 @@ def get_encoder_model( moe_dropout=0.0, turn_off_rop=False, # turn off the RoP positional embedding version=1, # model version + position_embedding_type='learned_absolute', + use_flash_attention=False, ): """Build language model and return along with the key to save.""" @@ -147,6 +149,8 @@ def get_encoder_model( num_moe_experts=num_moe_experts, moe_frequency=moe_frequency, moe_dropout=moe_dropout, + position_embedding_type=position_embedding_type, + use_flash_attention=use_flash_attention, ) elif arch == "retro": encoder = MegatronRetrievalTransformerEncoderModule( diff --git a/nemo/collections/nlp/modules/common/megatron/megatron_transformer_decoder.py b/nemo/collections/nlp/modules/common/megatron/megatron_transformer_decoder.py index c3cb1fd05c3b..f2c42597eb83 100644 --- a/nemo/collections/nlp/modules/common/megatron/megatron_transformer_decoder.py +++ b/nemo/collections/nlp/modules/common/megatron/megatron_transformer_decoder.py @@ -85,6 +85,8 @@ def __init__( num_moe_experts=1, moe_frequency=1, moe_dropout=0.0, + position_embedding_type='learned_absolute', + use_flash_attention=False, ): super(MegatronTransformerDecoderModule, self).__init__() @@ -149,6 +151,8 @@ def __init__( num_moe_experts=num_moe_experts, moe_frequency=moe_frequency, moe_dropout=moe_dropout, + position_embedding_type=position_embedding_type, + use_flash_attention=use_flash_attention, ) self._model_key = 'model' diff --git a/nemo/collections/nlp/modules/common/megatron/megatron_transformer_encoder.py b/nemo/collections/nlp/modules/common/megatron/megatron_transformer_encoder.py index 2eacf8aad672..60c347338105 100644 --- a/nemo/collections/nlp/modules/common/megatron/megatron_transformer_encoder.py +++ b/nemo/collections/nlp/modules/common/megatron/megatron_transformer_encoder.py @@ -82,6 +82,8 @@ def __init__( num_moe_experts=1, moe_frequency=1, moe_dropout=0.0, + position_embedding_type='learned_absolute', + use_flash_attention=False, ): super(MegatronTransformerEncoderModule, self).__init__() @@ -96,6 +98,7 @@ def __init__( self.parent_model_type = parent_model_type self.normalization = normalization self.transformer_block_type = transformer_block_type + self.use_flash_attention = use_flash_attention if kv_channels is None: @@ -147,6 +150,8 @@ def __init__( num_moe_experts=num_moe_experts, moe_frequency=moe_frequency, moe_dropout=moe_dropout, + position_embedding_type=position_embedding_type, + use_flash_attention=use_flash_attention, ) self._model_key = 'model' @@ -163,9 +168,12 @@ def forward( enc_self_attention_relative_position_bias=None, ): # convert to Megatron mask - enc_attn_mask_3d = build_attention_mask_3d( - source_mask=enc_attn_mask, target_mask=enc_attn_mask, attn_mask_type=self.model_attn_mask_type, - ) + if self.use_flash_attention: + enc_attn_mask_3d = enc_attn_mask < 0.5 + else: + enc_attn_mask_3d = build_attention_mask_3d( + source_mask=enc_attn_mask, target_mask=enc_attn_mask, attn_mask_type=self.model_attn_mask_type, + ) # transformer encoder enc_output = self.model( diff --git a/nemo/collections/nlp/modules/common/megatron/position_embedding/__init__.py b/nemo/collections/nlp/modules/common/megatron/position_embedding/__init__.py new file mode 100644 index 000000000000..fdbbed86cb2c --- /dev/null +++ b/nemo/collections/nlp/modules/common/megatron/position_embedding/__init__.py @@ -0,0 +1,31 @@ +# coding=utf-8 +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from nemo.collections.nlp.modules.common.megatron.position_embedding.alibi_relative_position_embedding import ( + ALiBiRelativePositionEmbedding, +) +from nemo.collections.nlp.modules.common.megatron.position_embedding.kerple_relative_position_embedding import ( + KERPLERelativePositionEmbedding, +) +from nemo.collections.nlp.modules.common.megatron.position_embedding.rotary_position_embedding import RotaryEmbedding +from nemo.collections.nlp.modules.common.megatron.position_embedding.sandwich_relative_position_embedding import ( + SandwichRelativePositionEmbedding, +) +from nemo.collections.nlp.modules.common.megatron.position_embedding.t5_relative_position_embedding import ( + T5RelativePositionEmbedding, +) +from nemo.collections.nlp.modules.common.megatron.position_embedding.xpos_position_embedding import ( + XPOSPositionEmbedding, +) diff --git a/nemo/collections/nlp/modules/common/megatron/alibi_relative_position_embedding.py b/nemo/collections/nlp/modules/common/megatron/position_embedding/alibi_relative_position_embedding.py similarity index 73% rename from nemo/collections/nlp/modules/common/megatron/alibi_relative_position_embedding.py rename to nemo/collections/nlp/modules/common/megatron/position_embedding/alibi_relative_position_embedding.py index 4f5abd96743b..6425e288f277 100644 --- a/nemo/collections/nlp/modules/common/megatron/alibi_relative_position_embedding.py +++ b/nemo/collections/nlp/modules/common/megatron/position_embedding/alibi_relative_position_embedding.py @@ -42,20 +42,31 @@ def build_slopes(num_attention_heads, num_attention_heads_alibi): """ Builds a slopes tensor. """ - slopes = torch.Tensor( - get_slopes(num_attention_heads_alibi) + [0] * (num_attention_heads - num_attention_heads_alibi) - ).cuda() - return slopes.unsqueeze(-1).unsqueeze(-1) + slopes = ( + torch.Tensor(get_slopes(num_attention_heads_alibi) + [0] * (num_attention_heads - num_attention_heads_alibi)) + .unsqueeze(-1) + .unsqueeze(-1) + ) + if torch.cuda.is_available(): + slopes = slopes.to(torch.cuda.current_device()) -def build_relative_position(query_length, key_length, num_attention_heads): - context_position = torch.arange(query_length)[:, None].cuda() - memory_position = torch.arange(key_length)[None, :].cuda() - # shape (query_length, key_length, num_heads) - relative_position = memory_position - context_position + return slopes + + +def build_relative_position(max_seq_len, full=True): + """ + full=True: shape (max_seq_len, max_seq_len) + full=False: shape (max_seq_len) + """ + relative_position = torch.arange(1 - max_seq_len, 1)[None, :].mul(-1) # (1, max_seq_len) + + if full: + memory_position = torch.arange(1 - max_seq_len, 1)[:, None].mul(-1) + relative_position = torch.abs(memory_position - relative_position) # (max_seq_len, max_seq_len) - # shape (num_attention_heads, max_seq_len, max_seq_len) - relative_position = torch.abs(relative_position).unsqueeze(0).expand(num_attention_heads, -1, -1) + if torch.cuda.is_available(): + relative_position = relative_position.to(torch.cuda.current_device()) return relative_position @@ -68,7 +79,7 @@ class ALiBiRelativePositionEmbedding(torch.nn.Module): """ def __init__( - self, bidirectional, num_attention_heads, layer_type, num_attention_heads_alibi=None, max_seq_len=512 + self, bidirectional, num_attention_heads, layer_type, num_attention_heads_alibi=None, max_seq_len=512, ): """ Args: @@ -101,20 +112,25 @@ def __init__( # cache the slopes self.slopes = build_slopes(num_attention_heads, num_attention_heads_alibi) # cache the relative position bias. shape (num_attention_heads, max_seq_len, max_seq_len) - self.relative_position = build_relative_position(max_seq_len, max_seq_len, num_attention_heads) + # if we use causal attention (not bidrectional), we can use singleton relative position + self.relative_position = ( + build_relative_position(max_seq_len, full=bidirectional).unsqueeze(0).expand(num_attention_heads, -1, -1) + ) def forward(self, query_seq_length, key_seq_length): # used cached relative position if possible max_seq_len = max(query_seq_length, key_seq_length) if max_seq_len > self.max_seq_len: - relative_position = build_relative_position(max_seq_len, max_seq_len, self.num_attention_heads) + relative_position = ( + build_relative_position(max_seq_len, full=self.bidirectional) + .unsqueeze(0) + .expand(self.num_attention_heads, -1, -1) + ) else: relative_position = self.relative_position # shape (num_attention_heads, query_seq_length, key_seq_length) - relative_position = relative_position[:, :query_seq_length, :key_seq_length] + relative_position = relative_position[:, -query_seq_length:, -key_seq_length:] # if not bidirectional, mask out the future positions - if not self.bidirectional: - relative_position = torch.tril(relative_position) # shape (1, num_heads, query_length, key_length) return -relative_position.unsqueeze(0) * self.slopes diff --git a/nemo/collections/nlp/modules/common/megatron/kerple_relative_position_embedding.py b/nemo/collections/nlp/modules/common/megatron/position_embedding/kerple_relative_position_embedding.py similarity index 81% rename from nemo/collections/nlp/modules/common/megatron/kerple_relative_position_embedding.py rename to nemo/collections/nlp/modules/common/megatron/position_embedding/kerple_relative_position_embedding.py index 54276d6fa21e..fc0c837da556 100644 --- a/nemo/collections/nlp/modules/common/megatron/kerple_relative_position_embedding.py +++ b/nemo/collections/nlp/modules/common/megatron/position_embedding/kerple_relative_position_embedding.py @@ -17,7 +17,7 @@ import torch -from nemo.collections.nlp.modules.common.megatron.alibi_relative_position_embedding import ( +from nemo.collections.nlp.modules.common.megatron.position_embedding.alibi_relative_position_embedding import ( build_relative_position, build_slopes, ) @@ -33,7 +33,7 @@ class KERPLERelativePositionEmbedding(torch.nn.Module): """ def __init__( - self, bidirectional, num_attention_heads, layer_type, num_attention_heads_kerple=None, max_seq_len=512 + self, bidirectional, num_attention_heads, layer_type, num_attention_heads_kerple=None, max_seq_len=512, ): """ Args: @@ -65,21 +65,26 @@ def __init__( # initialize the slopes self.kerple_b = torch.nn.Parameter(build_slopes(num_attention_heads, num_attention_heads_kerple)) - self.kerple_a = torch.zeros_like(self.kerple_b) - self.kerple_p = torch.ones_like(self.kerple_b) + self.kerple_a = torch.nn.Parameter(torch.ones_like(self.kerple_b)) + self.kerple_p = torch.nn.Parameter(torch.ones_like(self.kerple_b)) # cache the relative position bias. shape (num_attention_heads, max_seq_len, max_seq_len) - self.relative_position = build_relative_position(max_seq_len, max_seq_len, num_attention_heads) + # if we use causal attention (not bidrectional), we can use singleton relative position + self.relative_position = ( + build_relative_position(max_seq_len, full=True).unsqueeze(0).expand(num_attention_heads, -1, -1) + ) def forward(self, query_seq_length, key_seq_length): # used cached relative position if possible max_seq_len = max(query_seq_length, key_seq_length) if max_seq_len > self.max_seq_len: - relative_position = build_relative_position(max_seq_len, max_seq_len, self.num_attention_heads) + relative_position = ( + build_relative_position(max_seq_len, full=True).unsqueeze(0).expand(self.num_attention_heads, -1, -1) + ) else: relative_position = self.relative_position # shape (num_attention_heads, query_seq_length, key_seq_length) - relative_position = relative_position[:, :query_seq_length, :key_seq_length] + relative_position = relative_position[:, -query_seq_length:, -key_seq_length:] # if not bidirectional, mask out the future positions if not self.bidirectional: relative_position = torch.tril(relative_position) diff --git a/nemo/collections/nlp/modules/common/megatron/rotary_pos_embedding.py b/nemo/collections/nlp/modules/common/megatron/position_embedding/rotary_position_embedding.py similarity index 96% rename from nemo/collections/nlp/modules/common/megatron/rotary_pos_embedding.py rename to nemo/collections/nlp/modules/common/megatron/position_embedding/rotary_position_embedding.py index 191601054ef8..5a8d6d7dd333 100644 --- a/nemo/collections/nlp/modules/common/megatron/rotary_pos_embedding.py +++ b/nemo/collections/nlp/modules/common/megatron/position_embedding/rotary_position_embedding.py @@ -38,7 +38,8 @@ def forward(self, max_seq_len, offset=0): def _rotate_half(x): """ - change sign so the last dimension becomes [-odd, +even] + change sign so the last dimension + [A, B, C, D] -> [-C, -D, A, B] """ x = rearrange(x, '... (j d) -> ... j d', j=2) x1, x2 = x.unbind(dim=-2) diff --git a/nemo/collections/nlp/modules/common/megatron/position_embedding/sandwich_relative_position_embedding.py b/nemo/collections/nlp/modules/common/megatron/position_embedding/sandwich_relative_position_embedding.py new file mode 100644 index 000000000000..0e2dfd7d2ef6 --- /dev/null +++ b/nemo/collections/nlp/modules/common/megatron/position_embedding/sandwich_relative_position_embedding.py @@ -0,0 +1,75 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch + +from nemo.collections.nlp.modules.common.megatron.position_embedding.alibi_relative_position_embedding import ( + build_relative_position, +) +from nemo.utils.decorators import experimental + +__all__ = ['SandwichRelativePositionEmbedding'] + + +@experimental +class SandwichRelativePositionEmbedding(torch.nn.Module): + """ + Dissecting Transformer Length Extrapolation via the Lens of Receptive Field Analysis + Based on https://arxiv.org/abs/2212.10356 + """ + + def __init__( + self, bidirectional, num_attention_heads, layer_type, hidden_size, max_seq_len=512, + ): + """ + Args: + num_attention_heads: Number of attention heads + hidden_size: Hidden size per attention head + """ + super().__init__() + self.bidirectional = bidirectional + self.layer_type = layer_type + self.num_attention_heads = num_attention_heads + self.hidden_size = hidden_size + self.max_seq_len = max_seq_len + self.relative_position = build_relative_position(max_seq_len, full=True) + + def forward(self, query_seq_length, key_seq_length): + # used cached relative position if possible + max_seq_len = max(query_seq_length, key_seq_length) + if max_seq_len > self.max_seq_len: + relative_position = build_relative_position(max_seq_len, full=True) + else: + relative_position = self.relative_position + + # shape (query_seq_length, key_seq_length) + relative_position = relative_position[-query_seq_length:, -key_seq_length:] + # if not bidirectional, mask out the future positions + if not self.bidirectional: + relative_position = torch.tril(relative_position) + + inv_freq = 1.0 / ( + 10000 + ** (2 * torch.arange(1, self.hidden_size / 2 + 1, device=relative_position.device) / self.hidden_size) + ) + + _bias = torch.sum((relative_position[:, :, None].repeat(1, 1, len(inv_freq)) * inv_freq).cos(), axis=2) + bias = _bias.repeat(self.num_attention_heads, 1, 1) + + _bias_scales = torch.arange(1, self.num_attention_heads + 1, 1, device=relative_position.device) + bias_scales = _bias_scales[:, None, None] + + scaled_bias = (bias - self.hidden_size / 2) / (bias_scales * 8 / self.num_attention_heads).unsqueeze(0) + + return scaled_bias diff --git a/nemo/collections/nlp/modules/common/megatron/t5_relative_position_embedding.py b/nemo/collections/nlp/modules/common/megatron/position_embedding/t5_relative_position_embedding.py similarity index 95% rename from nemo/collections/nlp/modules/common/megatron/t5_relative_position_embedding.py rename to nemo/collections/nlp/modules/common/megatron/position_embedding/t5_relative_position_embedding.py index c2a0c8661acf..4566d9aa7876 100644 --- a/nemo/collections/nlp/modules/common/megatron/t5_relative_position_embedding.py +++ b/nemo/collections/nlp/modules/common/megatron/position_embedding/t5_relative_position_embedding.py @@ -43,9 +43,7 @@ def __init__( # Relative position Embedding # Relative Position embedding (all attention layers). - self.relative_position_embedding = torch.nn.Embedding( - self.relative_position_num_buckets, num_attention_heads - ).to(torch.cuda.current_device()) + self.relative_position_embedding = torch.nn.Embedding(self.relative_position_num_buckets, num_attention_heads) self._relative_position_embedding_key = 'relative_position_embedding' init_method(self.relative_position_embedding.weight) @@ -104,8 +102,9 @@ def _compute_relative_position_bucket(self, query_length, key_length): """ """Compute binned relative position bias""" - context_position = torch.arange(query_length, dtype=torch.long, device=torch.cuda.current_device())[:, None] - memory_position = torch.arange(key_length, dtype=torch.long, device=torch.cuda.current_device())[None, :] + device = self.relative_position_embedding.weight.device + context_position = torch.arange(query_length, dtype=torch.long, device=device)[:, None] + memory_position = torch.arange(key_length, dtype=torch.long, device=device)[None, :] relative_position = memory_position - context_position # shape (query_length, key_length) relative_position_bucket_tensor = self._relative_position_bucket( diff --git a/nemo/collections/nlp/modules/common/megatron/position_embedding/xpos_position_embedding.py b/nemo/collections/nlp/modules/common/megatron/position_embedding/xpos_position_embedding.py new file mode 100644 index 000000000000..ef59234790c5 --- /dev/null +++ b/nemo/collections/nlp/modules/common/megatron/position_embedding/xpos_position_embedding.py @@ -0,0 +1,78 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch.nn as nn +from einops import rearrange +from nemo.utils.decorators import experimental + + +def fixed_pos_embedding(x): + seq_len, dim = x.shape + inv_freq = 1.0 / (10000 ** (torch.arange(0, dim) / dim)) + sinusoid_inp = torch.einsum("i , j -> i j", torch.arange(0, seq_len, dtype=torch.float), inv_freq).to(x) + return torch.sin(sinusoid_inp), torch.cos(sinusoid_inp) + + +def rotate_every_two(x): + x1 = x[:, :, ::2] + x2 = x[:, :, 1::2] + x = torch.stack((-x2, x1), dim=-1) + return x.flatten(-2) # in einsum notation: rearrange(x, '... d j -> ... (d j)')\ + + +def duplicate_interleave(m): + """ + A simple version of `torch.repeat_interleave` for duplicating a matrix while interleaving the copy. + """ + dim0 = m.shape[0] + m = m.view(-1, 1) # flatten the matrix + m = m.repeat(1, 2) # repeat all elements into the 2nd dimension + m = m.view(dim0, -1) # reshape into a matrix, interleaving the copy + return m + + +def apply_rotary_pos_emb(x, sin, cos, scale=1): + sin, cos = map(lambda t: duplicate_interleave(t * scale), (sin, cos)) + # einsum notation for lambda t: repeat(t[offset:x.shape[1]+offset,:], "n d -> () n () (d j)", j=2) + return (x * cos) + (rotate_every_two(x) * sin) + + +@experimental +class XPOSPositionEmbedding(nn.Module): + def __init__(self, head_dim, scale_base=2048): + super().__init__() + self.head_dim = head_dim + self.scale_base = scale_base + self.register_buffer("scale", (torch.arange(0, head_dim, 2) + 0.4 * head_dim) / (1.4 * head_dim)) + + def forward(self, x, offset=0, downscale=False): + length, b = x.shape[0], x.shape[1] + x = rearrange(x, 's b np hn -> (b np) s hn') + min_pos = -(length + offset) // 2 + max_pos = length + offset + min_pos + scale = self.scale ** torch.arange(min_pos, max_pos, 1).to(self.scale).div(self.scale_base)[:, None] + sin, cos = fixed_pos_embedding(scale) + + if scale.shape[0] > length: + scale = scale[-length:] + sin = sin[-length:] + cos = cos[-length:] + + if downscale: + scale = 1 / scale + + x = apply_rotary_pos_emb(x, sin, cos, scale) + x = rearrange(x, '(b np) s hn -> s b np hn', b=b) + return x diff --git a/nemo/collections/nlp/modules/common/megatron/retrieval_transformer.py b/nemo/collections/nlp/modules/common/megatron/retrieval_transformer.py index 73c41cee6c6f..83dea362c3e1 100644 --- a/nemo/collections/nlp/modules/common/megatron/retrieval_transformer.py +++ b/nemo/collections/nlp/modules/common/megatron/retrieval_transformer.py @@ -19,7 +19,7 @@ from einops import rearrange, repeat from nemo.collections.nlp.modules.common.megatron.module import MegatronModule -from nemo.collections.nlp.modules.common.megatron.rotary_pos_embedding import RotaryEmbedding +from nemo.collections.nlp.modules.common.megatron.position_embedding import RotaryEmbedding from nemo.collections.nlp.modules.common.megatron.transformer import ParallelTransformer from nemo.collections.nlp.modules.common.megatron.utils import ApexGuardDefaults, build_attention_mask_3d diff --git a/nemo/collections/nlp/modules/common/megatron/token_level_encoder_decoder.py b/nemo/collections/nlp/modules/common/megatron/token_level_encoder_decoder.py index 229a9af48048..fc16295020fb 100644 --- a/nemo/collections/nlp/modules/common/megatron/token_level_encoder_decoder.py +++ b/nemo/collections/nlp/modules/common/megatron/token_level_encoder_decoder.py @@ -15,12 +15,6 @@ import torch from omegaconf import DictConfig -from nemo.collections.nlp.modules.common.megatron.alibi_relative_position_embedding import ( - ALiBiRelativePositionEmbedding, -) -from nemo.collections.nlp.modules.common.megatron.kerple_relative_position_embedding import ( - KERPLERelativePositionEmbedding, -) from nemo.collections.nlp.modules.common.megatron.language_model import Embedding from nemo.collections.nlp.modules.common.megatron.layer_type import LayerType from nemo.collections.nlp.modules.common.megatron.megatron_decoders import get_decoder_model @@ -29,7 +23,11 @@ ) from nemo.collections.nlp.modules.common.megatron.megatron_encoders import get_encoder_model from nemo.collections.nlp.modules.common.megatron.module import MegatronModule -from nemo.collections.nlp.modules.common.megatron.t5_relative_position_embedding import T5RelativePositionEmbedding +from nemo.collections.nlp.modules.common.megatron.position_embedding import ( + ALiBiRelativePositionEmbedding, + KERPLERelativePositionEmbedding, + T5RelativePositionEmbedding, +) from nemo.collections.nlp.modules.common.megatron.utils import ( ApexGuardDefaults, build_position_ids, @@ -197,6 +195,11 @@ def __init__( else: self.encoder_relative_position_embedding = None + if encoder_cfg.get('use_flash_attention', False) and encoder_cfg.get( + 'position_embedding_type', 'learned_absolute' + ) in ['relative', 'kerple']: + raise ValueError('flash-attention not supported with relative or kerple at this point') + encoder = get_encoder_model( arch=encoder_cfg.arch, hidden_size=encoder_cfg.hidden_size, @@ -243,6 +246,8 @@ def __init__( num_moe_experts=encoder_cfg.get('num_moe_experts', 1), moe_frequency=encoder_cfg.get('moe_frequency', 1), moe_dropout=encoder_cfg.get('moe_dropout', 0.0), + position_embedding_type=encoder_cfg.get('position_embedding_type', 'learned_absolute'), + use_flash_attention=encoder_cfg.get('use_flash_attention', False), ) if add_decoder: @@ -307,6 +312,7 @@ def __init__( ): self.decoder_cross_attention_relative_position_embeddings_weight().data.fill_(0) self.decoder_cross_attention_relative_position_embeddings_weight().shared = True + elif self.decoder_cfg.get('position_embedding_type', 'learned_absolute') == 'alibi': self.decoder_relative_position_embedding = ALiBiRelativePositionEmbedding( bidirectional=False, @@ -328,6 +334,11 @@ def __init__( else: self.decoder_relative_position_embedding = None + if decoder_cfg.get('use_flash_attention', False) and decoder_cfg.get( + 'position_embedding_type', 'learned_absolute' + ) in ['relative', 'kerple']: + raise ValueError('flash-attention not supported with relative or kerple at this point') + decoder = get_decoder_model( arch=decoder_cfg.arch, hidden_size=decoder_cfg.hidden_size, @@ -373,6 +384,8 @@ def __init__( num_moe_experts=decoder_cfg.get('num_moe_experts', 1), moe_frequency=decoder_cfg.get('moe_frequency', 1), moe_dropout=decoder_cfg.get('moe_dropout', 0.0), + position_embedding_type=decoder_cfg.get('position_embedding_type', 'learned_absolute'), + use_flash_attention=decoder_cfg.get('use_flash_attention', False), ) self.enc_dec_model = MegatronTransformerEncoderDecoderModule( diff --git a/nemo/collections/nlp/modules/common/megatron/transformer.py b/nemo/collections/nlp/modules/common/megatron/transformer.py index f5dfbcabcd0e..8a0b22b4d289 100644 --- a/nemo/collections/nlp/modules/common/megatron/transformer.py +++ b/nemo/collections/nlp/modules/common/megatron/transformer.py @@ -18,6 +18,7 @@ from typing import Any, Callable, Optional import torch +import torch.nn as nn from einops import rearrange from nemo.collections.common.parts.adapter_modules import LinearAdapterConfig @@ -33,7 +34,7 @@ dropout_add, ) from nemo.collections.nlp.modules.common.megatron.fused_layer_norm import get_layer_norm -from nemo.collections.nlp.modules.common.megatron.layer_norm_1p import LayerNorm1P +from nemo.collections.nlp.modules.common.megatron.layer_norm_1p import LayerNorm1P, LPLayerNorm from nemo.collections.nlp.modules.common.megatron.layer_type import LayerType from nemo.collections.nlp.modules.common.megatron.mlp import ParallelMLP, SwitchMLP from nemo.collections.nlp.modules.common.megatron.module import MegatronModule @@ -115,6 +116,12 @@ def _dropout_add(x, bias, residual, prob): return _dropout_add +def remove_bias_from_layernorm(layer): + for module in layer.modules(): + if hasattr(module, 'bias') and isinstance(module.bias, nn.Parameter): + module.register_parameter('bias', None) + + class ParallelTransformerLayer_(MegatronModule, adapter_mixins.AdapterModuleMixin): """A single transformer layer. @@ -164,6 +171,7 @@ def __init__( num_moe_experts=1, moe_frequency=1, moe_dropout=0.0, + use_flash_attention=False, ): super(ParallelTransformerLayer_, self).__init__() @@ -187,7 +195,9 @@ def __init__( 'bias_dropout_add_fusion=True requires bias=True, found bias=False. Either set both to True or both to False.' ) - if normalization not in ['layernorm', 'layernorm1p', 'rmsnorm']: + # the low_precision_layernorm does not require a bias term, whereas layernorm1p from apex + # does require a bias, so it cannot be used for bias-less low precision LN such as in MPT-7B + if normalization not in ['layernorm', 'layernorm1p', 'rmsnorm', 'low_precision_layernorm']: raise ValueError(f'normalization must be "layernorm", "layernorm1p" or "rmsnorm", found {normalization}') if transformer_block_type not in ['pre_ln', 'post_ln', 'normformer']: @@ -212,8 +222,16 @@ def __init__( self.input_layernorm = LayerNorm1P( hidden_size, layernorm_epsilon, sequence_parallel_enabled=sequence_parallel ) + elif normalization == 'low_precision_layernorm': + self.input_layernorm = LPLayerNorm(hidden_size, layernorm_epsilon) else: self.input_layernorm = MixedFusedRMSNorm(hidden_size, layernorm_epsilon) + # for architectures such as MPT, there is no bias term even on the layernorms + # this code allows us to remove the bias terms from the layernorm module + # so that we can support MPT. However, certain apex-based LNs don't support + # removing bias, so we also have to check for that + if not bias and normalization not in ['layernorm', 'layernorm1p']: + remove_bias_from_layernorm(self.input_layernorm) self.self_attention = ParallelAttention( init_method=init_method, @@ -240,6 +258,7 @@ def __init__( sequence_parallel=sequence_parallel, gradient_accumulation_fusion=gradient_accumulation_fusion, normalize_attention_scores=normalize_attention_scores, + use_flash_attention=use_flash_attention, ) if transformer_block_type == 'normformer': @@ -261,8 +280,12 @@ def __init__( self.post_attention_layernorm = LayerNorm1P( hidden_size, layernorm_epsilon, sequence_parallel_enabled=sequence_parallel ) + elif normalization == 'low_precision_layernorm': + self.post_attention_layernorm = LPLayerNorm(hidden_size, layernorm_epsilon) else: self.post_attention_layernorm = MixedFusedRMSNorm(hidden_size, layernorm_epsilon) + if not bias and normalization not in ['layernorm', 'layernorm1p']: + remove_bias_from_layernorm(self.post_attention_layernorm) if self.layer_type == LayerType.decoder_pre_mlp: # skip MLP and cross attention @@ -280,8 +303,12 @@ def __init__( self.post_attention_layernorm = LayerNorm1P( hidden_size, layernorm_epsilon, sequence_parallel_enabled=sequence_parallel ) + elif normalization == 'low_precision_layernorm': + self.post_attention_layernorm = LPLayerNorm(hidden_size, layernorm_epsilon) else: self.post_attention_layernorm = MixedFusedRMSNorm(hidden_size, layernorm_epsilon) + if not bias and normalization not in ['layernorm', 'layernorm1p']: + remove_bias_from_layernorm(self.post_attention_layernorm) if self.layer_type == LayerType.decoder or self.layer_type == LayerType.retrieval_encoder: self.inter_attention = ParallelAttention( @@ -669,6 +696,7 @@ def __init__( num_moe_experts=1, moe_frequency=1, moe_dropout=0.0, + use_flash_attention=False, ): super(ParallelTransformerLayer, self).__init__( init_method=init_method, @@ -711,6 +739,7 @@ def __init__( num_moe_experts=num_moe_experts, moe_frequency=moe_frequency, moe_dropout=moe_dropout, + use_flash_attention=use_flash_attention, ) # Dtype for forward pass - ignore amp O2 @@ -924,6 +953,7 @@ def __init__( num_moe_experts=1, moe_frequency=1, moe_dropout=0.0, + use_flash_attention=False, ): super(ParallelTransformer, self).__init__() @@ -1104,6 +1134,7 @@ def build_layer(layer_number): num_moe_experts=num_moe_experts, moe_frequency=moe_frequency, moe_dropout=moe_dropout, + use_flash_attention=use_flash_attention, ) if parallel_state.get_virtual_pipeline_model_parallel_world_size() is not None: @@ -1154,8 +1185,16 @@ def build_layer(layer_number): self.final_layernorm = LayerNorm1P( hidden_size, layernorm_epsilon, sequence_parallel_enabled=sequence_parallel ) + elif normalization == 'low_precision_layernorm': + self.final_layernorm = LPLayerNorm(hidden_size, layernorm_epsilon) else: self.final_layernorm = MixedFusedRMSNorm(hidden_size, layernorm_epsilon) + # for architectures such as MPT, there is no bias term even on the layernorms + # this code allows us to remove the bias terms from the layernorm module + # so that we can support MPT. However, certain apex-based LNs don't support + # removing bias, so we also have to check for that + if not bias and normalization not in ['layernorm', 'layernorm1p']: + remove_bias_from_layernorm(self.final_layernorm) def _get_layer(self, layer_number): return self.layers[layer_number] diff --git a/nemo/collections/nlp/modules/common/megatron/utils.py b/nemo/collections/nlp/modules/common/megatron/utils.py index 8ef46c10d49b..7c7a428fa43f 100644 --- a/nemo/collections/nlp/modules/common/megatron/utils.py +++ b/nemo/collections/nlp/modules/common/megatron/utils.py @@ -179,7 +179,9 @@ def average_losses_across_data_parallel_group(losses): return averaged_losses -def get_ltor_masks_and_position_ids(data, eod_token, reset_position_ids, reset_attention_mask, eod_mask_loss): +def get_ltor_masks_and_position_ids( + data, eod_token, reset_position_ids, reset_attention_mask, eod_mask_loss, compute_attention_mask=True +): """Build masks and position id for left to right model.""" # Extract batch size and sequence length. @@ -190,9 +192,12 @@ def get_ltor_masks_and_position_ids(data, eod_token, reset_position_ids, reset_a att_mask_batch = micro_batch_size else: att_mask_batch = 1 - attention_mask = torch.tril(torch.ones((att_mask_batch, seq_length, seq_length), device=data.device)).view( - att_mask_batch, 1, seq_length, seq_length - ) + + attention_mask = None + if compute_attention_mask: + attention_mask = torch.tril(torch.ones((att_mask_batch, seq_length, seq_length), device=data.device)).view( + att_mask_batch, 1, seq_length, seq_length + ) # Loss mask. loss_mask = torch.ones(data.size(), dtype=torch.float, device=data.device) @@ -228,8 +233,9 @@ def get_ltor_masks_and_position_ids(data, eod_token, reset_position_ids, reset_a position_ids[b, (i + 1) :] -= i + 1 - prev_index prev_index = i + 1 - # Convert attention mask to binary: - attention_mask = attention_mask < 0.5 + if compute_attention_mask: + # Convert attention mask to binary: + attention_mask = attention_mask < 0.5 return attention_mask, loss_mask, position_ids @@ -381,3 +387,16 @@ def get_iterator_k_split(batch: List[torch.Tensor], num_microbatches: int) -> It microbatches = [[elem[i] for elem in split_batch] for i in range(num_microbatches)] return itertools.chain(microbatches) + + +def _cast_if_autocast_enabled(tensor): + if torch.is_autocast_enabled(): + if isinstance(tensor, torch.Tensor): + if tensor.device.type == 'cuda': + dtype = torch.get_autocast_gpu_dtype() + elif tensor.device.type == 'cpu': + dtype = torch.get_autocast_cpu_dtype() + else: + raise NotImplementedError() + return tensor.to(dtype=dtype) + return tensor diff --git a/nemo/collections/nlp/modules/common/text_generation_strategy.py b/nemo/collections/nlp/modules/common/text_generation_strategy.py index 310065fc3523..8608c0c9a680 100644 --- a/nemo/collections/nlp/modules/common/text_generation_strategy.py +++ b/nemo/collections/nlp/modules/common/text_generation_strategy.py @@ -97,10 +97,11 @@ def clip_max_len(self, maxlen: int) -> int: pass @abc.abstractclassmethod - def init_batch(self, context_tokens: torch.Tensor, context_length: int): + def init_batch(self, context_tokens: torch.Tensor, context_length: int, compute_attention_mask: bool): """initialize the batch data before the inference steps. It will save the intermediate results as object attributes context_length (int): the context token length + compute_attention_mask: bool: set to True to compute attention mask (not needed for FA) Args: context_tokens (torch.Tensor): The padded context tokens including the space for tokens to be generated """ @@ -187,7 +188,7 @@ def clip_max_len(self, maxlen: int) -> int: maxlen = self.model.cfg.encoder_seq_length + 1 return maxlen - def init_batch(self, context_tokens: torch.Tensor, context_length: int): + def init_batch(self, context_tokens: torch.Tensor, context_length: int, compute_attention_mask: bool): """initialize the batch data before the inference steps.""" # Move to GPU. tokenizer = self.model.tokenizer @@ -199,10 +200,17 @@ def init_batch(self, context_tokens: torch.Tensor, context_length: int): self.model.cfg.get('reset_position_ids', False), self.model.cfg.get('reset_attention_mask', False), self.model.cfg.get('eod_mask_loss', False), + compute_attention_mask=compute_attention_mask, ) def prepare_batch_at_step( - self, tokens: torch.Tensor, maxlen: int, micro_batch_size: int, step: int, context_length: int + self, + tokens: torch.Tensor, + maxlen: int, + micro_batch_size: int, + step: int, + context_length: int, + compute_attention_mask: bool = True, ) -> Tuple[List[torch.Tensor], List[int]]: """ generate the batch used in inference for each of the steps @@ -226,7 +234,10 @@ def prepare_batch_at_step( # types2use = type_ids[:, context_length - 1].view(batch_size, -1) """Prepare batch for each of the inference steps""" - attention_mask_repeat = torch.concat([self.attention_mask for _ in range(micro_batch_size)]) + attention_mask_repeat = None + if compute_attention_mask: + attention_mask_repeat = torch.concat([self.attention_mask for _ in range(micro_batch_size)]) + setkey_value_array = torch.tensor( [set_inference_key_value_memory] * micro_batch_size, device=torch.cuda.current_device() ) @@ -243,7 +254,7 @@ def __init__(self, model, task_ids): self.task_ids = task_ids self.forward_model = self.model - def init_batch(self, context_tokens: torch.Tensor, context_length: int): + def init_batch(self, context_tokens: torch.Tensor, context_length: int, compute_attention_mask: bool): """initialize the batch data before the inference steps.""" # Move to GPU. tokenizer = self.model.tokenizer @@ -255,6 +266,7 @@ def init_batch(self, context_tokens: torch.Tensor, context_length: int): self.model.cfg.get('reset_position_ids', False), self.model.cfg.get('reset_attention_mask', False), self.model.cfg.get('eod_mask_loss', False), + compute_attention_mask=compute_attention_mask, ) def clip_max_len(self, maxlen: int) -> int: @@ -264,7 +276,13 @@ def clip_max_len(self, maxlen: int) -> int: return maxlen def prepare_batch_at_step( - self, tokens: torch.Tensor, maxlen: int, micro_batch_size: int, step: int, context_length: int + self, + tokens: torch.Tensor, + maxlen: int, + micro_batch_size: int, + step: int, + context_length: int, + compute_attention_mask: bool, ) -> Tuple[List[torch.Tensor], List[int]]: # types2use = None if step == 0: @@ -285,7 +303,9 @@ def prepare_batch_at_step( # types2use = type_ids[:, context_length - 1].view(batch_size, -1) """Prepare batch for each of the inference steps""" - attention_mask_repeat = torch.concat([self.attention_mask for _ in range(micro_batch_size)]) + attention_mask_repeat = None + if compute_attention_mask: + attention_mask_repeat = torch.concat([self.attention_mask for _ in range(micro_batch_size)]) setkey_value_array = torch.tensor( [set_inference_key_value_memory] * micro_batch_size, device=torch.cuda.current_device() ) diff --git a/nemo/collections/nlp/modules/common/text_generation_utils.py b/nemo/collections/nlp/modules/common/text_generation_utils.py index a56304970bdc..6417f887c0cd 100644 --- a/nemo/collections/nlp/modules/common/text_generation_utils.py +++ b/nemo/collections/nlp/modules/common/text_generation_utils.py @@ -105,7 +105,7 @@ def megatron_gpt_generate(model, inputs, tokenizer, length_params, sampling_para greedy=sampling_params['use_greedy'], repetition_penalty=sampling_params['repetition_penalty'], min_tokens_to_generate=length_params['min_length'], - **strategy_args, + compute_attention_mask=sampling_params.get("compute_attention_mask", True) ** strategy_args, ) compute_prob_response = get_computeprob_response(tokenizer, response, inputs) return compute_prob_response @@ -376,6 +376,7 @@ def synced_generate( top_k=0, top_p=0.0, greedy=False, + compute_attention_mask=True, compute_logprob=False, repetition_penalty=1.2, min_tokens_to_generate=0, @@ -401,6 +402,7 @@ def synced_generate( context_length_tensor, tokens_to_generate, all_probs, + compute_attention_mask=compute_attention_mask, compute_logprob=compute_logprob, temperature=temperature, end_strings=end_strings, @@ -469,6 +471,7 @@ def generate( top_k=0, top_p=0.0, greedy=False, + compute_attention_mask=True, compute_logprob=False, repetition_penalty=1.0, min_tokens_to_generate=0, @@ -550,6 +553,7 @@ def generate( tokens_to_generate, all_probs, temperature, + compute_attention_mask=compute_attention_mask, compute_logprob=compute_logprob, top_k=top_k, top_p=top_p, @@ -635,6 +639,7 @@ def sample_sequence_batch( context_lengths, tokens_to_generate, all_probs=False, + compute_attention_mask=True, compute_logprob=False, type_ids=None, temperature=None, @@ -666,7 +671,7 @@ def sample_sequence_batch( # initialize the batch with torch.no_grad(): context_length = context_lengths.min().item() - inference_strategy.init_batch(context_tokens, context_length) + inference_strategy.init_batch(context_tokens, context_length, compute_attention_mask) # added eos_id to support the function generate_samples_eval that passes # eos_id as an argument and needs termination when that id id found. eod_id = tokenizer.eos_id @@ -685,7 +690,7 @@ def sample_sequence_batch( lengths = torch.ones([batch_size]).long().cuda() * maxlen while context_length < maxlen: batch, tensor_shape = inference_strategy.prepare_batch_at_step( - tokens, maxlen, micro_batch_size, counter, context_length + tokens, maxlen, micro_batch_size, counter, context_length, compute_attention_mask ) output = inference_strategy.forward_step(batch, tensor_shape) diff --git a/scripts/asr_language_modeling/ngram_lm/create_lexicon_from_arpa.py b/scripts/asr_language_modeling/ngram_lm/create_lexicon_from_arpa.py index 6e992f5348ae..22c657b25613 100644 --- a/scripts/asr_language_modeling/ngram_lm/create_lexicon_from_arpa.py +++ b/scripts/asr_language_modeling/ngram_lm/create_lexicon_from_arpa.py @@ -1,76 +1,77 @@ -# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -# Use this file to create a lexicon file for Flashlight decoding from an existing KenLM arpa file -# A lexicon file is required for Flashlight decoding in most cases, as it acts as a map from the words -# in you arpa file to the representation used by your ASR AM. -# For more details, see: https://github.com/flashlight/flashlight/tree/main/flashlight/app/asr#data-preparation -# -# Usage: python create_lexicon_from_arpa.py --arpa /path/to/english.arpa --model /path/to/model.nemo --lower -# -# - - -import argparse -import os -import re - -if __name__ == "__main__": - parser = argparse.ArgumentParser(description="Utility script for generating lexicon file from a KenLM arpa file") - parser.add_argument("--arpa", required=True, help="path to your arpa file") - parser.add_argument("--dst", help="directory to store generated lexicon", default=None) - parser.add_argument("--lower", action='store_true', help="Whether to lowercase the arpa vocab") - parser.add_argument("--model", default=None, help="path to Nemo model for its tokeniser") - - args = parser.parse_args() - - if not os.path.exists(args.arpa): - print("ARPA file not detected on disk, aborting!", flush=True) - exit(255) - - if args.dst is not None: - save_path = args.dst - else: - save_path = os.path.dirname(args.arpa) - os.makedirs(save_path, exist_ok=True) - - tokenizer = None - if args.model is not None: - from nemo.collections.asr.models import ASRModel - - model = ASRModel.restore_from(restore_path=args.model, map_location='cpu') - if hasattr(model, 'tokenizer'): - tokenizer = model.tokenizer - else: - print('WARNING: supplied Nemo model does not contain a tokenizer', flush=True) - - lex_file = os.path.join(save_path, os.path.splitext(os.path.basename(args.arpa))[0] + '.lexicon') - print(f"Writing Lexicon file - {lex_file}...", flush=True) - with open(lex_file, "w", encoding='utf_8', newline='\n') as f: - with open(args.arpa, "r", encoding='utf_8') as arpa: - for line in arpa: - # verify if the line corresponds to unigram - if not re.match(r"[-]*[0-9\.]+\t\S+\t*[-]*[0-9\.]*$", line): - continue - word = line.split("\t")[1] - word = word.strip().lower() if args.lower else word.strip() - if word == "" or word == "" or word == "" or word == "": - continue - - if tokenizer is None: - f.write("{w}\t{s}\n".format(w=word, s=" ".join(word))) - else: - f.write("{w}\t{s}\n".format(w=word, s=" ".join(tokenizer.text_to_tokens(word)))) - - print("Done!", flush=True) +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Use this file to create a lexicon file for Flashlight decoding from an existing KenLM arpa file +# A lexicon file is required for Flashlight decoding in most cases, as it acts as a map from the words +# in you arpa file to the representation used by your ASR AM. +# For more details, see: https://github.com/flashlight/flashlight/tree/main/flashlight/app/asr#data-preparation +# +# Usage: python create_lexicon_from_arpa.py --arpa /path/to/english.arpa --model /path/to/model.nemo --lower +# +# + + +import argparse +import os +import re + +from nemo.utils import logging + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Utility script for generating lexicon file from a KenLM arpa file") + parser.add_argument("--arpa", required=True, help="path to your arpa file") + parser.add_argument("--dst", help="directory to store generated lexicon", default=None) + parser.add_argument("--lower", action='store_true', help="Whether to lowercase the arpa vocab") + parser.add_argument("--model", default=None, help="path to Nemo model for its tokeniser") + + args = parser.parse_args() + + if not os.path.exists(args.arpa): + logging.critical(f"ARPA file [ {args.arpa} ] not detected on disk, aborting!") + exit(255) + + if args.dst is not None: + save_path = args.dst + else: + save_path = os.path.dirname(args.arpa) + os.makedirs(save_path, exist_ok=True) + + tokenizer = None + if args.model is not None: + from nemo.collections.asr.models import ASRModel + + model = ASRModel.restore_from(restore_path=args.model, map_location='cpu') + if hasattr(model, 'tokenizer'): + tokenizer = model.tokenizer + else: + logging.warning('Supplied Nemo model does not contain a tokenizer') + + lex_file = os.path.join(save_path, os.path.splitext(os.path.basename(args.arpa))[0] + '.lexicon') + + logging.info(f"Writing Lexicon file to: {lex_file}...") + with open(lex_file, "w", encoding='utf_8', newline='\n') as f: + with open(args.arpa, "r", encoding='utf_8') as arpa: + for line in arpa: + # verify if the line corresponds to unigram + if not re.match(r"[-]*[0-9\.]+\t\S+\t*[-]*[0-9\.]*$", line): + continue + word = line.split("\t")[1] + word = word.strip().lower() if args.lower else word.strip() + if word == "" or word == "" or word == "" or word == "": + continue + + if tokenizer is None: + f.write("{w}\t{s}\n".format(w=word, s=" ".join(word))) + else: + f.write("{w}\t{s}\n".format(w=word, s=" ".join(tokenizer.text_to_tokens(word)))) diff --git a/scripts/nlp_language_modeling/convert_mpt_7b_hf_to_nemo.py b/scripts/nlp_language_modeling/convert_mpt_7b_hf_to_nemo.py new file mode 100644 index 000000000000..14d7b6ae54ea --- /dev/null +++ b/scripts/nlp_language_modeling/convert_mpt_7b_hf_to_nemo.py @@ -0,0 +1,212 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +""" +A script to convert the Mosaic MPT-7B checkpoint on HuggingFace to Megatron GPTModel +This script is hardcoded specifically for the MPT-7B pretrained model only, and is not +generalisable to any other models. + +This script will load and convert the model entirely on CPU for OOM safety, but there +is an option to put the model onto GPU before the save down, which sets the map_location +to cuda for the restore_from call. You can do this by adding --cuda to this script call. + +This script requires that you have downloaded the 2 .bin weight files for MPT-7B from +HuggingFace located here: https://huggingface.co/mosaicml/mpt-7b/tree/main +These files MUST have the following file names and be saved somewhere where this script +can read them: + pytorch_model-00001-of-00002.bin + pytorch_model-00002-of-00002.bin + +This script will generate a Megatron model with TP=1 and PP=1. If you need different TP/PP +values, then after running this script, please use the script located below to set whatever +TP/PP values you want: + NeMo/examples/nlp/language_modeling/megatron_change_num_partitions.py + + +Here is an example usage command: + +```python +python scripts/nlp_language_modeling/convert_mpt_7b_hf_to_nemo.py -i /path/to/mpt_7b -o /path/to/save +``` + +""" + + +import argparse +import os + +import pytorch_lightning as pl +import torch +from omegaconf import OmegaConf + +from nemo.collections.nlp.models.language_modeling.megatron import GPTModel +from nemo.collections.nlp.models.language_modeling.megatron_gpt_model import MegatronGPTModel +from nemo.utils import logging + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument( + '-i', '--input', required=True, type=str, help='path to the two MPT-7B .bin weight files from HuggingFace' + ) + parser.add_argument( + '-o', '--output', required=False, default=None, type=str, help='path to dir where to store output .nemo file' + ) + parser.add_argument('--cuda', action='store_true', help='put Nemo model onto GPU prior to savedown') + + args = parser.parse_args() + + if not os.path.exists(args.input): + logging.critical(f'Input directory [ {args.input} ] does not exist or cannot be found. Aborting.') + exit(255) + + model_dict = { + 'micro_batch_size': 4, + 'global_batch_size': 8, + 'rampup_batch_size': None, + 'tensor_model_parallel_size': 1, + 'pipeline_model_parallel_size': 1, + 'virtual_pipeline_model_parallel_size': None, + 'megatron_amp_O2': True, + 'transformer_engine': False, + 'use_cpu_initialization': True, + 'hidden_size': 4096, + 'max_position_embeddings': 2048, + 'num_layers': 32, + 'num_attention_heads': 32, + 'ffn_hidden_size': 4 * 4096, + 'precision': 'bf16', + 'pre_process': True, + 'post_process': True, + 'num_tokentypes': 0, + 'apply_query_key_layer_scaling': False, + 'parallel_output': False, + 'bias': False, + 'bias_dropout_add_fusion': False, + 'bias_activation_fusion': False, + 'transformer_block_type': 'pre_ln', + 'normalization': 'low_precision_layernorm', + 'fp32_residual_connection': False, + 'hidden_dropout': 0, + 'attention_dropout': 0, + 'ffn_dropout': 0, + 'megatron_legacy': True, + 'share_embeddings_and_output_weights': True, + 'sequence_parallel': False, + 'position_embedding_type': 'alibi', + 'normalize_attention_scores': True, + 'use_flash_attention': False, + 'override_vocab_size': 50432, + } + tokeniser_dict = { + 'library': 'huggingface', + 'type': 'EleutherAI/gpt-neox-20b', + 'use_fast': True, + } + optim_dict = { + 'name': 'fused_adam', + 'lr': 2e-4, + 'weight_decay': 0.01, + } + trainer_dict = { + 'devices': 1, + 'num_nodes': 1, + 'accelerator': 'gpu' if args.cuda else 'cpu', + 'precision': 'bf16', + 'logger': False, # logger provided by exp_manager + 'enable_checkpointing': False, + 'replace_sampler_ddp': False, + 'max_epochs': -1, # PTL default. In practice, max_steps will be reached first. + 'max_steps': 100000, # consumed_samples = global_step * micro_batch_size * data_parallel_size * accumulate_grad_batches + 'log_every_n_steps': 10, + 'val_check_interval': 100, + 'limit_val_batches': 50, + 'limit_test_batches': 500, + 'accumulate_grad_batches': 1, + 'gradient_clip_val': 1.0, + 'benchmark': False, + 'enable_model_summary': False, + } + + model_dict['tokenizer'] = tokeniser_dict + model_dict['optim'] = optim_dict + + omega_cfg = OmegaConf.create(model_dict) + + trainer = pl.Trainer(**trainer_dict) + + model = MegatronGPTModel(omega_cfg, trainer) + + model_keys = list(model.state_dict().keys()) + model_dtypes = list(set([model.state_dict()[x].dtype for x in model_keys])) + + if not (len(model_dtypes) == 1 and model_dtypes[0] is torch.bfloat16): + model = model.bfloat16() + + if args.cuda: + model = model.cuda() + + mpt_1 = torch.load(os.path.join(args.input, 'pytorch_model-00001-of-00002.bin'), map_location="cpu") + mpt_2 = torch.load(os.path.join(args.input, 'pytorch_model-00002-of-00002.bin'), map_location="cpu") + mpt_dict = {**mpt_1, **mpt_2} + del mpt_1, mpt_2 + + def convert_state_dict(state_dict, amp=False): + def get_new_key(old_key): + if old_key == 'transformer.wte.weight': + return 'language_model.embedding.word_embeddings.weight' + elif old_key == 'transformer.norm_f.weight': + return 'language_model.encoder.final_layernorm.weight' + else: + p1 = old_key.replace('transformer.blocks.', 'language_model.encoder.layers.') + p2 = p1.replace('norm_1.weight', 'input_layernorm.weight') + p3 = p2.replace('attn.Wqkv.weight', 'self_attention.query_key_value.weight') + p4 = p3.replace('attn.out_proj.weight', 'self_attention.dense.weight') + p5 = p4.replace('norm_2.weight', 'post_attention_layernorm.weight') + p6 = p5.replace('ffn.up_proj.weight', 'mlp.dense_h_to_4h.weight') + p7 = p6.replace('ffn.down_proj.weight', 'mlp.dense_4h_to_h.weight') + + return p7 + + new_dict = {} + + for old_key, val in state_dict.items(): + new_key = get_new_key(old_key) + if amp: + new_key = 'module.' + new_key + + new_dict[new_key] = val + + return new_dict + + convert_dict = convert_state_dict(mpt_dict, amp=model_dict['megatron_amp_O2']) + + if model_dict['megatron_amp_O2']: + missing_keys, unexpected_keys = model.model.load_state_dict(convert_dict, strict=True) + else: + missing_keys, unexpected_keys = super(GPTModel, model.model).load_state_dict(convert_dict, strict=True) + + if len(missing_keys) > 0: + logging.critical('Missing keys were detected during the load, something has gone wrong. Aborting.') + logging.critical(f'Missing keys: \n{missing_keys}') + exit(255) + + if len(unexpected_keys) > 0: + logging.warning('Unexpected keys were detected which should not happen. Please investigate.') + logging.warning(f'Unexpected keys: \n{unexpected_keys}') + + if args.output is None: + args.output = os.path.dirname(os.path.abspath(__file__)) + + model.save_to(os.path.join(args.output, 'megatron_mpt_7b_base_tp1_pp1.nemo')) diff --git a/tests/collections/nlp/test_flash_attention.py b/tests/collections/nlp/test_flash_attention.py new file mode 100644 index 000000000000..cead91ff312a --- /dev/null +++ b/tests/collections/nlp/test_flash_attention.py @@ -0,0 +1,247 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import random + +import pytest +import torch +from pytorch_lightning.trainer.trainer import Trainer + +from nemo.collections.nlp.modules.common.megatron.attention import CoreAttention +from nemo.collections.nlp.modules.common.megatron.megatron_init import initialize_model_parallel_for_nemo +from nemo.collections.nlp.modules.common.megatron.utils import build_attention_mask_3d +from nemo.collections.nlp.parts.nlp_overrides import NLPDDPStrategy + +try: + from apex.transformer.enums import AttnMaskType + + HAVE_APEX = True +except (ImportError, ModuleNotFoundError): + HAVE_APEX = False + +try: + import flash_attn + + HAVE_FA = True +except (ImportError, ModuleNotFoundError): + HAVE_FA = False + +try: + import triton + + HAVE_TRITON = True +except (ImportError, ModuleNotFoundError): + HAVE_TRITON = False + +import pynvml + + +def HAVE_AMPERE_GPU(): + pynvml.nvmlInit() + handle = pynvml.nvmlDeviceGetHandleByIndex(0) + device_arch = pynvml.nvmlDeviceGetArchitecture(handle) + pynvml.nvmlShutdown() + return device_arch == pynvml.NVML_DEVICE_ARCH_AMPERE + + +@pytest.mark.run_only_on('GPU') +@pytest.mark.skipif(not HAVE_APEX, reason="apex is not installed") +class TestFlashAttention: + @classmethod + def setup_class(cls): + if not torch.cuda.is_available(): + return + + GPUS = 1 + TP_SIZE = GPUS + PP_SIZE = 1 + MB_SIZE = 4 + GB_SIZE = 8 + SEED = 1234 + trainer = Trainer(strategy=NLPDDPStrategy(), devices=GPUS, accelerator='gpu', num_nodes=1, logger=None,) + + initialize_model_parallel_for_nemo( + world_size=trainer.world_size, + global_rank=trainer.global_rank, + local_rank=trainer.local_rank, + tensor_model_parallel_size=TP_SIZE, + pipeline_model_parallel_size=PP_SIZE, + micro_batch_size=MB_SIZE, + global_batch_size=GB_SIZE, + seed=SEED, + apex_transformer_log_level=30, + ) + + @pytest.fixture() + def cfg(self): + cfg = { + 'bz': random.randint(1, 7), + 'sl': random.randint(1, 7), + 'head': random.randint(1, 7), + 'device': torch.cuda.current_device(), + } + # flash attention requires head dimensions are multiples of 8 + head_dim = random.randint(1, 7) * 8 + cfg['hidden'] = cfg['head'] * head_dim + + return cfg + + @pytest.mark.skipif(not HAVE_FA, reason="flash-attention is not installed") + @pytest.mark.unit + def test_flash_attention(self, cfg): + device = cfg['device'] + bz, sl, np, h = cfg['bz'], cfg['sl'], cfg['head'], cfg['hidden'] + hn = h // np + + q = torch.rand(sl, bz, np, hn, device=device).half() + k = torch.rand(sl, bz, np, hn, device=device).half() + v = torch.rand(sl, bz, np, hn, device=device).half() + + attention_mask_2d = torch.arange(sl, device=device).unsqueeze(0) < torch.randint( + 1, sl, (bz,), device=device + ).unsqueeze(1) + + attention_mask_padding_3d = build_attention_mask_3d( + source_mask=attention_mask_2d, target_mask=attention_mask_2d, attn_mask_type=AttnMaskType.padding + ).unsqueeze(1) + + attention_mask_causal_3d = build_attention_mask_3d( + source_mask=attention_mask_2d, target_mask=attention_mask_2d, attn_mask_type=AttnMaskType.causal + ).unsqueeze(1) + + # Non-causal + attention = CoreAttention( + layer_number=1, + num_attention_heads=np, + hidden_size=h, + attn_mask_type=AttnMaskType.padding, + attention_dropout=0.0, + ) + + attention_fa = CoreAttention( + layer_number=1, + num_attention_heads=np, + hidden_size=h, + attn_mask_type=AttnMaskType.padding, + attention_dropout=0.0, + use_flash_attention=True, + ) + + out = attention(q, k, v, attention_mask_padding_3d) + out_fa = attention_fa(q, k, v, attention_mask_padding_3d) + assert torch.allclose(out, out_fa, rtol=1e-3, atol=1e-3) + out_fa = attention_fa(q, k, v, attention_mask_2d) + assert torch.allclose(out, out_fa, rtol=1e-3, atol=1e-3) + + # Causal + attention = CoreAttention( + layer_number=1, + num_attention_heads=np, + hidden_size=h, + attn_mask_type=AttnMaskType.causal, + attention_dropout=0.0, + ) + + attention_fa = CoreAttention( + layer_number=1, + num_attention_heads=np, + hidden_size=h, + attn_mask_type=AttnMaskType.causal, + attention_dropout=0.0, + use_flash_attention=True, + ) + + out = attention(q, k, v, attention_mask_causal_3d) + out_fa = attention_fa(q, k, v, attention_mask_causal_3d) + assert torch.allclose(out, out_fa, rtol=1e-3, atol=1e-3) + out_fa = attention_fa(q, k, v, attention_mask_2d) + assert torch.allclose(out, out_fa, rtol=1e-3, atol=1e-3) + + @pytest.mark.skipif(not HAVE_FA, reason="flash-attention is not installed") + @pytest.mark.skipif(not HAVE_TRITON, reason="triton is not installed") + @pytest.mark.skipif( + not HAVE_AMPERE_GPU(), + reason="should only run on AMPERE GPU. Please see https://github.com/HazyResearch/flash-attention/issues/245", + ) + @pytest.mark.unit + def test_flash_attention_triton(self, cfg): + device = cfg['device'] + bz, sl, np, h = cfg['bz'], cfg['sl'], cfg['head'], cfg['hidden'] + hn = h // np + + q = torch.rand(sl, bz, np, hn, device=device).half() + k = torch.rand(sl, bz, np, hn, device=device).half() + v = torch.rand(sl, bz, np, hn, device=device).half() + + attention_mask_2d = torch.arange(sl, device=device).unsqueeze(0) < torch.randint( + 1, sl, (bz,), device=device + ).unsqueeze(1) + + attention_mask_padding_3d = build_attention_mask_3d( + source_mask=attention_mask_2d, target_mask=attention_mask_2d, attn_mask_type=AttnMaskType.padding + ).unsqueeze(1) + + attention_mask_causal_3d = build_attention_mask_3d( + source_mask=attention_mask_2d, target_mask=attention_mask_2d, attn_mask_type=AttnMaskType.causal + ).unsqueeze(1) + + attention_bias = torch.rand(bz, np, sl, sl, device=device) + + # Non-causal + attention = CoreAttention( + layer_number=1, + num_attention_heads=np, + hidden_size=h, + attn_mask_type=AttnMaskType.padding, + attention_dropout=0.0, + ) + + attention_fa = CoreAttention( + layer_number=1, + num_attention_heads=np, + hidden_size=h, + attn_mask_type=AttnMaskType.padding, + attention_dropout=0.0, + use_flash_attention=True, + ) + + out = attention(q, k, v, attention_mask_padding_3d, relative_position_bias=attention_bias) + out_fa = attention_fa(q, k, v, attention_mask_padding_3d, relative_position_bias=attention_bias) + assert torch.allclose(out, out_fa, rtol=1e-3, atol=1e-3) + out_fa = attention_fa(q, k, v, attention_mask_2d, relative_position_bias=attention_bias) + assert torch.allclose(out, out_fa, rtol=1e-3, atol=1e-3) + + # Causal + attention = CoreAttention( + layer_number=1, + num_attention_heads=np, + hidden_size=h, + attn_mask_type=AttnMaskType.causal, + attention_dropout=0.0, + ) + + attention_fa = CoreAttention( + layer_number=1, + num_attention_heads=np, + hidden_size=h, + attn_mask_type=AttnMaskType.causal, + attention_dropout=0.0, + use_flash_attention=True, + ) + + out = attention(q, k, v, attention_mask_causal_3d, relative_position_bias=attention_bias) + out_fa = attention_fa(q, k, v, attention_mask_causal_3d, relative_position_bias=attention_bias) + assert torch.allclose(out, out_fa, rtol=1e-3, atol=1e-3) + out_fa = attention_fa(q, k, v, attention_mask_2d, relative_position_bias=attention_bias) + assert torch.allclose(out, out_fa, rtol=1e-3, atol=1e-3) diff --git a/tests/collections/nlp/test_position_embedding.py b/tests/collections/nlp/test_position_embedding.py new file mode 100644 index 000000000000..263ca8669d81 --- /dev/null +++ b/tests/collections/nlp/test_position_embedding.py @@ -0,0 +1,211 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import random + +import pytest +import torch + +from nemo.collections.nlp.modules.common.megatron.layer_type import LayerType +from nemo.collections.nlp.modules.common.megatron.position_embedding import ( + ALiBiRelativePositionEmbedding, + KERPLERelativePositionEmbedding, + RotaryEmbedding, + SandwichRelativePositionEmbedding, + T5RelativePositionEmbedding, + XPOSPositionEmbedding, +) +from nemo.collections.nlp.modules.common.megatron.position_embedding.rotary_position_embedding import ( + apply_rotary_pos_emb, +) +from nemo.collections.nlp.modules.common.megatron.utils import init_method_normal + + +@pytest.fixture() +def cfg(): + cfg = { + 'max_seq_len': 8, + 'num_attention_heads': 2, + 'layer_type': LayerType.encoder, + 'hidden_size': 4, + 'rpe_init_method_std': 0.02, + 'rpe_num_buckets': 6, + 'rpe_max_distance': 16, + } + return cfg + + +@pytest.mark.unit +def test_alibi(cfg): + # non-causal + PE_nc = ALiBiRelativePositionEmbedding( + bidirectional=True, + num_attention_heads=cfg['num_attention_heads'], + layer_type=cfg['layer_type'], + max_seq_len=cfg['max_seq_len'], + ) + + # causal + PE_c = ALiBiRelativePositionEmbedding( + bidirectional=False, + num_attention_heads=cfg['num_attention_heads'], + layer_type=cfg['layer_type'], + max_seq_len=cfg['max_seq_len'], + ) + + q_len = k_len = random.randint(1, cfg['max_seq_len'] * 2) + + bias_nc = PE_nc(q_len, k_len) + assert bias_nc.shape == (1, cfg['num_attention_heads'], q_len, k_len) + assert torch.equal(bias_nc, bias_nc.transpose(2, 3)) + + bias_c = PE_c(q_len, k_len) + assert bias_c.shape == (1, cfg['num_attention_heads'], 1, k_len) + assert torch.equal(bias_c, bias_nc[:, :, -1:, :]) + + +@pytest.mark.unit +def test_sandwich(cfg): + # non-causal + PE_nc = SandwichRelativePositionEmbedding( + bidirectional=True, + num_attention_heads=cfg['num_attention_heads'], + layer_type=cfg['layer_type'], + max_seq_len=cfg['max_seq_len'], + hidden_size=cfg['hidden_size'], + ) + + # causal + PE_c = SandwichRelativePositionEmbedding( + bidirectional=False, + num_attention_heads=cfg['num_attention_heads'], + layer_type=cfg['layer_type'], + max_seq_len=cfg['max_seq_len'], + hidden_size=cfg['hidden_size'], + ) + + q_len = k_len = random.randint(1, cfg['max_seq_len'] * 2) + + bias_nc = PE_nc(q_len, k_len) + assert bias_nc.shape == (1, cfg['num_attention_heads'], q_len, k_len) + assert torch.equal(bias_nc, bias_nc.transpose(2, 3)) + + bias_c = PE_c(q_len, k_len) + assert bias_c.shape == (1, cfg['num_attention_heads'], q_len, k_len) + assert torch.all(torch.triu(bias_c, diagonal=0) == 0) + + +@pytest.mark.unit +def test_kerple(cfg): + # non-causal + PE_nc = KERPLERelativePositionEmbedding( + bidirectional=True, + num_attention_heads=cfg['num_attention_heads'], + layer_type=cfg['layer_type'], + max_seq_len=cfg['max_seq_len'], + ) + + # causal + PE_c = KERPLERelativePositionEmbedding( + bidirectional=False, + num_attention_heads=cfg['num_attention_heads'], + layer_type=cfg['layer_type'], + max_seq_len=cfg['max_seq_len'], + ) + + q_len = k_len = random.randint(1, cfg['max_seq_len'] * 2) + + bias_nc = PE_nc(q_len, k_len) + assert bias_nc.shape == (1, cfg['num_attention_heads'], q_len, k_len) + assert torch.equal(bias_nc, bias_nc.transpose(2, 3)) + + bias_c = PE_c(q_len, k_len) + assert bias_c.shape == (1, cfg['num_attention_heads'], q_len, k_len) + assert torch.all(torch.triu(bias_c, diagonal=0) == 0) + + +@pytest.mark.unit +def test_t5relative(cfg): + # non-causal + PE_nc = T5RelativePositionEmbedding( + bidirectional=True, + num_attention_heads=cfg['num_attention_heads'], + layer_type=cfg['layer_type'], + init_method=init_method_normal(cfg['rpe_init_method_std']), + relative_position_num_buckets=cfg['rpe_num_buckets'], + relative_position_max_distance=cfg['rpe_max_distance'], + ) + + # causal + PE_c = T5RelativePositionEmbedding( + bidirectional=False, + num_attention_heads=cfg['num_attention_heads'], + layer_type=cfg['layer_type'], + init_method=init_method_normal(cfg['rpe_init_method_std']), + relative_position_num_buckets=cfg['rpe_num_buckets'], + relative_position_max_distance=cfg['rpe_max_distance'], + ) + + q_len = k_len = random.randint(1, cfg['max_seq_len'] * 2) + + bias_nc = PE_nc(q_len, k_len) + assert bias_nc.shape == (1, cfg['num_attention_heads'], q_len, k_len) + + bias_c = PE_c(q_len, k_len) + assert bias_c.shape == (1, cfg['num_attention_heads'], q_len, k_len) + assert ( + len(torch.triu(bias_c, diagonal=0).unique()) == cfg['num_attention_heads'] + 1 + if q_len > 1 + else cfg['num_attention_heads'] + ) + + +@pytest.mark.unit +def test_rotary(cfg): + PE = RotaryEmbedding(dim=cfg['hidden_size']) + rotary_embedding = PE(cfg['max_seq_len']) + + x = torch.rand(cfg['max_seq_len'], 1, cfg['num_attention_heads'], cfg['hidden_size']) + x_rotary = apply_rotary_pos_emb(x, rotary_embedding) + assert x_rotary.shape == x.shape + + hd = cfg['hidden_size'] // 2 + x_rotary_test = torch.cat( + ( + x[..., :hd] * rotary_embedding[..., :hd].cos() + x[..., hd:] * rotary_embedding[..., hd:].sin() * -1, + x[..., :hd] * rotary_embedding[..., :hd].sin() + x[..., hd:] * rotary_embedding[..., hd:].cos(), + ), + dim=-1, + ) + assert torch.equal(x_rotary, x_rotary_test) + + offset = random.choice(range(1, cfg['max_seq_len'])) + rotary_embedding_offset = PE(cfg['max_seq_len'], offset=offset) + x_rotary = apply_rotary_pos_emb(x[: offset + 1], rotary_embedding[: offset + 1]) + x_rotary_offset = apply_rotary_pos_emb(x[offset : offset + 1], rotary_embedding_offset[:1]) + assert torch.equal(x_rotary[-1], x_rotary_offset[0]) + + +@pytest.mark.unit +def test_xpos(cfg): + PE = XPOSPositionEmbedding(head_dim=cfg['hidden_size']) + x = torch.rand(cfg['max_seq_len'], 1, cfg['num_attention_heads'], cfg['hidden_size']) + + x_rotary = PE(x) + assert x_rotary.shape == x.shape + + offset = random.choice(range(1, cfg['max_seq_len'])) + x_rotary = PE(x[: offset + 1]) + x_rotary_offset = PE(x[offset : offset + 1], offset=offset) + assert torch.equal(x_rotary[-1], x_rotary_offset[0]) diff --git a/tests/collections/nlp/test_retrieval_module.py b/tests/collections/nlp/test_retrieval_module.py index 3a2d46f4fed2..08425964e566 100644 --- a/tests/collections/nlp/test_retrieval_module.py +++ b/tests/collections/nlp/test_retrieval_module.py @@ -21,6 +21,7 @@ from nemo.collections.nlp.modules.common.megatron.attention import ParallelChunkedCrossAttention from nemo.collections.nlp.modules.common.megatron.layer_type import LayerType from nemo.collections.nlp.modules.common.megatron.megatron_init import initialize_model_parallel_for_nemo +from nemo.collections.nlp.modules.common.megatron.position_embedding import RotaryEmbedding from nemo.collections.nlp.modules.common.megatron.retrieval_token_level_encoder_decoder import ( MegatronRetrievalTokenLevelEncoderDecoderModule, ) @@ -28,7 +29,6 @@ MegatronRetrievalTransformerDecoderModule, MegatronRetrievalTransformerEncoderModule, ) -from nemo.collections.nlp.modules.common.megatron.rotary_pos_embedding import RotaryEmbedding from nemo.collections.nlp.modules.common.megatron.utils import ( build_attention_mask_3d, init_method_normal, diff --git a/tests/collections/nlp/test_retrieval_module_inference.py b/tests/collections/nlp/test_retrieval_module_inference.py index 16e7e556bd10..a9aa002815b2 100644 --- a/tests/collections/nlp/test_retrieval_module_inference.py +++ b/tests/collections/nlp/test_retrieval_module_inference.py @@ -22,6 +22,7 @@ from nemo.collections.nlp.modules.common.megatron.attention import ParallelChunkedCrossAttention from nemo.collections.nlp.modules.common.megatron.layer_type import LayerType from nemo.collections.nlp.modules.common.megatron.megatron_init import initialize_model_parallel_for_nemo +from nemo.collections.nlp.modules.common.megatron.position_embedding import RotaryEmbedding from nemo.collections.nlp.modules.common.megatron.retrieval_token_level_encoder_decoder import ( MegatronRetrievalTokenLevelEncoderDecoderModule, ) @@ -29,7 +30,6 @@ MegatronRetrievalTransformerDecoderModule, MegatronRetrievalTransformerEncoderModule, ) -from nemo.collections.nlp.modules.common.megatron.rotary_pos_embedding import RotaryEmbedding from nemo.collections.nlp.modules.common.megatron.utils import ( build_attention_mask_3d, init_method_normal,