Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

adding support to return logits and generate for Megatron-LM GPT models #819

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
103 changes: 102 additions & 1 deletion docs/source/usage_guides/megatron_lm.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -435,6 +435,103 @@ python checkpoint_utils/megatgron_gpt2/checkpoint_reshaping_and_interoperability
--print-checkpoint-structure
```

## Megatron-LM GPT models support returning logits and `megatron_generate` function for text generation

1. Returning logits require setting `require_logits=True` in MegatronLMPlugin as shown below.
These would be available on the in the last stage of pipeline.
```python
megatron_lm_plugin = MegatronLMPlugin(return_logits=True)
```

2. `megatron_generate` method for Megatron-LM GPT model: This will use Tensor and Pipeline Parallelism to complete
generations for a batch of inputs when using greedy with/without top_k/top_p sampling and for individual prompt inputs when using beam search decoding.
Only a subset of features of transformers generate is supported. This will help in using large models via tensor and pipeline parallelism
for generation (already does key-value caching and uses fused kernels by default).
This requires data parallel size to be 1, sequence parallelism and activation checkpointing to be disabled.
It also requires specifying path to tokenizer's vocab file and merges file.
Below example shows how to configure and use `megatron_generate` method for Megatron-LM GPT model.
```python
# specifying tokenizer's vocab and merges file
vocab_file = os.path.join(args.resume_from_checkpoint, "vocab.json")
merge_file = os.path.join(args.resume_from_checkpoint, "merges.txt")
other_megatron_args = {"vocab_file": vocab_file, "merge_file": merge_file}
megatron_lm_plugin = MegatronLMPlugin(other_megatron_args=other_megatron_args)

# inference using `megatron_generate` functionality
tokenizer.pad_token = tokenizer.eos_token
max_new_tokens = 64
batch_texts = [
"Are you human?",
"The purpose of life is",
"The arsenal was constructed at the request of",
"How are you doing these days?",
]
batch_encodings = tokenizer(batch_texts, return_tensors="pt", padding=True)

# top-p sampling
generated_tokens = model.megatron_generate(
batch_encodings["input_ids"],
batch_encodings["attention_mask"],
max_new_tokens=max_new_tokens,
top_p=0.8,
top_p_decay=0.5,
temperature=0.9,
)
decoded_preds = tokenizer.batch_decode(generated_tokens.cpu().numpy())
accelerator.print(decoded_preds)

# top-k sampling
generated_tokens = model.megatron_generate(
batch_encodings["input_ids"],
batch_encodings["attention_mask"],
max_new_tokens=max_new_tokens,
top_k=50,
temperature=0.9,
)
decoded_preds = tokenizer.batch_decode(generated_tokens.cpu().numpy())
accelerator.print(decoded_preds)

# adding `bos` token at the start
generated_tokens = model.megatron_generate(
batch_encodings["input_ids"], batch_encodings["attention_mask"], max_new_tokens=max_new_tokens, add_BOS=True
)
decoded_preds = tokenizer.batch_decode(generated_tokens.cpu().numpy())
accelerator.print(decoded_preds)

# beam search => only takes single prompt
batch_texts = ["The purpose of life is"]
batch_encodings = tokenizer(batch_texts, return_tensors="pt", padding=True)
generated_tokens = model.megatron_generate(
batch_encodings["input_ids"],
batch_encodings["attention_mask"],
max_new_tokens=max_new_tokens,
num_beams=20,
length_penalty=1.5,
)
decoded_preds = tokenizer.batch_decode(generated_tokens.cpu().numpy())
accelerator.print(decoded_preds)
```

3. An end-to-end example of using `megatron_generate` method for Megatron-LM GPT model is available at
[megatron_gpt2_generation.py](https://github.com/pacman100/accelerate-megatron-test/blob/main/src/inference/megatron_gpt2_generation.py) with
config file [megatron_lm_gpt_generate_config.yaml](https://github.com/pacman100/accelerate-megatron-test/blob/main/src/Configs/megatron_lm_gpt_generate_config.yaml).
The bash script with accelerate launch command is available at [megatron_lm_gpt_generate.sh](https://github.com/pacman100/accelerate-megatron-test/blob/main/megatron_lm_gpt_generate.sh).
The output logs of the script are available at [megatron_lm_gpt_generate.log](https://github.com/pacman100/accelerate-megatron-test/blob/main/output_logs/megatron_lm_gpt_generate.log).

## Support for ROPE and ALiBi Positional embeddings and Multi-Query Attention

1. For ROPE/ALiBi attention, pass `position_embedding_type` with `("absolute" | "rotary" | "alibi")` to `MegatronLMPlugin` as shown below.
```python
other_megatron_args = {"position_embedding_type": "alibi"}
megatron_lm_plugin = MegatronLMPlugin(other_megatron_args=other_megatron_args)
```

2. For Multi-Query Attention, pass `attention_head_type` with `("multihead" | "multiquery")` to `MegatronLMPlugin` as shown below.
```python
other_megatron_args = {"attention_head_type": "multiquery"}
megatron_lm_plugin = MegatronLMPlugin(other_megatron_args=other_megatron_args)
```

## Caveats

1. Supports Transformers GPT2, Megatron-BERT and T5 models.
Expand All @@ -445,8 +542,12 @@ there is quite complex interplay of pipeline, tensor and data parallelsim behind
The `model(**batch_data)` call return loss(es) averaged across the data parallel ranks.
This is fine for most cases wherein pre-training jobs are run using Megatron-LM features and
you can easily compute the `perplexity` using the loss.
For GPT model, returning logits in addition to loss(es) is supported.
These logits aren't gathered across data prallel ranks. Use `accelerator.utils.gather_across_data_parallel_groups`
to gather logits across data parallel ranks. These logits along with labels can be used for computing various
performance metrics.

3. The main process is the last rank as the losses are available in the last stage of pipeline.
3. The main process is the last rank as the losses/logits are available in the last stage of pipeline.
`accelerator.is_main_process` and `accelerator.is_local_main_process` return `True` for last rank when using
Megatron-LM integration.

Expand Down
1 change: 1 addition & 0 deletions src/accelerate/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@
MegatronLMSchedulerWrapper,
T5TrainStep,
avg_losses_across_data_parallel_group,
gather_across_data_parallel_groups,
)
from .megatron_lm import initialize as megatron_lm_initialize
from .megatron_lm import prepare_data_loader as megatron_lm_prepare_data_loader
Expand Down
35 changes: 31 additions & 4 deletions src/accelerate/utils/dataclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -771,6 +771,18 @@ class MegatronLMPlugin:
default=False,
metadata={"help": "Whether to set all logging options."},
)
eval_iters: int = field(
default=100, metadata={"help": "Number of iterations to run for evaluation validation/test for."}
)
eval_interval: int = field(
default=1000, metadata={"help": "Interval between running evaluation on validation set."}
)
return_logits: bool = field(
default=False,
metadata={"help": "Whether to return logits from the model."},
)

# custom train step args
custom_train_step_class: Optional[Any] = field(
default=None,
metadata={"help": "Custom train step class."},
Expand All @@ -779,11 +791,22 @@ class MegatronLMPlugin:
default=None,
metadata={"help": "Custom train step kwargs."},
)
eval_iters: int = field(
default=100, metadata={"help": "Number of iterations to run for evaluation validation/test for."}

# custom model args
custom_model_provider_function: Optional[Callable] = field(
default=None,
metadata={"help": "Custom model provider function."},
)
eval_interval: int = field(
default=1000, metadata={"help": "Interval between running evaluation on validation set."}
custom_prepare_model_function: Optional[Callable] = field(
default=None,
metadata={"help": "Custom prepare model function."},
)

# remaining args such as enabling Alibi/ROPE positional embeddings,
# wandb logging, Multi-Query Attention, etc.
other_megatron_args: Optional[Dict[str, Any]] = field(
default=None,
metadata={"help": "Other Megatron-LM arguments. Please refer Megatron-LM"},
)

def __post_init__(self):
Expand Down Expand Up @@ -840,6 +863,8 @@ def __post_init__(self):
self.megatron_lm_default_args["tensorboard_dir"] = self.tensorboard_dir
if self.set_all_logging_options:
self.set_tensorboard_logging_options()
if self.other_megatron_args is not None:
self.megatron_lm_default_args.update(self.other_megatron_args)

def set_network_size_args(self, model, batch_data=None):
# Check if the model is either BERT, GPT or T5 else raise error
Expand Down Expand Up @@ -884,6 +909,8 @@ def set_network_size_args(self, model, batch_data=None):
else:
self.seq_length = max_position_embeddings
self.megatron_lm_default_args["seq_length"] = self.seq_length
self.megatron_lm_default_args["return_logits"] = self.return_logits
self.megatron_lm_default_args["tokenizer_type"] = "GPT2BPETokenizer"
elif "t5" in model.config.model_type.lower():
model_type_name = "t5"
num_layers = model.config.num_layers
Expand Down
11 changes: 7 additions & 4 deletions src/accelerate/utils/imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@
# limitations under the License.

import importlib
import os
import sys
from distutils.util import strtobool
from functools import lru_cache

import torch
Expand Down Expand Up @@ -90,10 +92,11 @@ def is_bf16_available(ignore_tpu=False):


def is_megatron_lm_available():
package_exists = importlib.util.find_spec("megatron") is not None
if package_exists:
megatron_version = parse(importlib_metadata.version("megatron-lm"))
return compare_versions(megatron_version, ">=", "2.2.0")
if strtobool(os.environ.get("USE_MEGATRON_LM", "False")) == 1:
package_exists = importlib.util.find_spec("megatron") is not None
if package_exists:
megatron_version = parse(importlib_metadata.version("megatron-lm"))
return compare_versions(megatron_version, ">=", "2.2.0")
return False


Expand Down
Loading