Skip to content

Commit

Permalink
Merge branch 'hn-inference-cudagraphs-mcore' into 'main'
Browse files Browse the repository at this point in the history
Inference CUDA graphs (MCore version)

See merge request ADLR/megatron-lm!2429
  • Loading branch information
jaredcasper committed Jan 9, 2025
2 parents 93cb1c1 + fa93a05 commit 8fba594
Show file tree
Hide file tree
Showing 12 changed files with 184 additions and 36 deletions.
4 changes: 4 additions & 0 deletions megatron/core/extensions/transformer_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -1135,6 +1135,10 @@ class TECudaRNGStatesTracker(te.pytorch.distributed.CudaRNGStatesTracker):
"""Wraps TransformerEngine's CudaRNGStatesTracker so that it is
interchangeable with Megatron's RNG tracker"""

def __init__(self):
super().__init__()
self.reset()

def is_initialized(self):
"""Checks if the internal RNG state has been set wirth set_states()."""
return self._is_initialized
Expand Down
42 changes: 41 additions & 1 deletion megatron/core/inference_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ class InferenceParams:
def __init__(self, max_batch_size, max_sequence_length):
self.max_sequence_length = max_sequence_length
self.max_batch_size = max_batch_size
self.current_batch_size = max_batch_size # Required for bookkeeping variable-sized batches
self.sequence_len_offset = 0
self.batch_size_offset = 0
self.key_value_memory_dict = {}
Expand All @@ -28,4 +29,43 @@ def swap_key_value_dict(self, batch_idx):
)

def __str__(self):
return f"InferenceParams(max_seq_len = {self.max_sequence_length}, max_batch_size = {self.max_batch_size}, sequence_len_offset = {self.sequence_len_offset}, batch_size_offset = {self.batch_size_offset}, key_value_memory_dict = {self.key_value_memory_dict.keys()})"
return (
f"InferenceParams(max_seq_len = {self.max_sequence_length}, "
f"max_batch_size = {self.max_batch_size}, "
f"current_batch_size = {self.current_batch_size}, "
f"sequence_len_offset = {self.sequence_len_offset}, "
f"batch_size_offset = {self.batch_size_offset}, "
f"key_value_memory_dict = {self.key_value_memory_dict.keys()})"
)

def __eq__(self, other):

if not isinstance(other, InferenceParams):
return False

# Check all attributes match
basic_attrs = [
'max_sequence_length',
'max_batch_size',
'current_batch_size',
'sequence_len_offset',
'batch_size_offset',
]

if not all(hasattr(other, attr) for attr in basic_attrs):
return False

# Check dictionary keys match; i.e. the same number of layers are cached
if self.key_value_memory_dict.keys() != other.key_value_memory_dict.keys():
return False

# Check each tensor tuple in the dictionary
for key in self.key_value_memory_dict:
self_tensors = self.key_value_memory_dict[key]
other_tensors = other.key_value_memory_dict[key]

# Compare each key, value tensor in the tuple
for self_tensor, other_tensor in zip(self_tensors, other_tensors):
if not self_tensor.shape == other_tensor.shape:
return False
return True
20 changes: 17 additions & 3 deletions megatron/core/models/gpt/gpt_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from collections import OrderedDict
from typing import Dict, Literal, Optional

import torch
from torch import Tensor

from megatron.core import InferenceParams, tensor_parallel
Expand Down Expand Up @@ -121,6 +122,9 @@ def __init__(
use_cpu_initialization=self.config.use_cpu_initialization,
)

# Cache for RoPE tensors which do not change between iterations.
self.rotary_pos_emb_cache = {}

# Transformer.
self.decoder = TransformerBlock(
config=self.config,
Expand Down Expand Up @@ -224,10 +228,11 @@ def forward(
rotary_pos_cos = None
rotary_pos_sin = None
if self.position_embedding_type == 'rope' and not self.config.multi_latent_attention:
if not self.training and self.config.flash_decode:
if not self.training and self.config.flash_decode and inference_params:
# Flash decoding uses precomputed cos and sin for RoPE
rotary_pos_cos, rotary_pos_sin = self.rotary_pos_emb.get_cos_sin(
inference_params.max_sequence_length
rotary_pos_cos, rotary_pos_sin = self.rotary_pos_emb_cache.setdefault(
inference_params.max_sequence_length,
self.rotary_pos_emb.get_cos_sin(inference_params.max_sequence_length),
)
else:
rotary_seq_len = self.rotary_pos_emb.get_rotary_seq_len(
Expand All @@ -238,6 +243,14 @@ def forward(
packed_seq=packed_seq_params is not None
and packed_seq_params.qkv_format == 'thd',
)
if (self.config.enable_cuda_graph or self.config.flash_decode) and inference_params:
sequence_len_offset = torch.tensor(
[inference_params.sequence_len_offset] * inference_params.current_batch_size,
dtype=torch.int32,
device=rotary_pos_cos.device, # Co-locate this with the rotary tensors
)
else:
sequence_len_offset = None

# Run decoder.
hidden_states = self.decoder(
Expand All @@ -248,6 +261,7 @@ def forward(
rotary_pos_cos=rotary_pos_cos,
rotary_pos_sin=rotary_pos_sin,
packed_seq_params=packed_seq_params,
sequence_len_offset=sequence_len_offset,
**(extra_block_kwargs or {}),
)

Expand Down
42 changes: 34 additions & 8 deletions megatron/core/tensor_parallel/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,35 +164,59 @@ def fork(self, name=_MODEL_PARALLEL_RNG_TRACKER_NAME):
_CUDA_RNG_STATE_TRACKER_INITIALIZED = False


def initialize_rng_tracker(use_te_rng_tracker: bool = False):
def initialize_rng_tracker(use_te_rng_tracker: bool = False, inference_rng_tracker: bool = False):
"""Create the RNG tracker. 'use_te_rng_tracker' determines whether to use
Megatron or TransformerEngine's implementation.
In particular, TransformerEngine's implementation is cudagraphable and supports FP8.
"""

global _CUDA_RNG_STATE_TRACKER
global _CUDA_RNG_STATE_TRACKER_INITIALIZED
if _CUDA_RNG_STATE_TRACKER_INITIALIZED:
return

# Get the base tracker class
base_tracker = None
if use_te_rng_tracker:
if not is_te_min_version("1.5.0"):
raise RuntimeError("use_te_rng_tracker requires TransformerEngine version >= 1.5")
from megatron.core.extensions.transformer_engine import TECudaRNGStatesTracker

_CUDA_RNG_STATE_TRACKER = TECudaRNGStatesTracker()
base_tracker = TECudaRNGStatesTracker
else:
_CUDA_RNG_STATE_TRACKER = CudaRNGStatesTracker()
base_tracker = CudaRNGStatesTracker

if inference_rng_tracker:

class InferenceCudaRNGStatesTracker(base_tracker):
"""RNG tracker for inference."""

def add(self, name, seed):
"""Mirrors the interface from the training RNG tracker."""
pass

def set_states(self, states):
"""Mirrors the interface from the training RNG tracker."""
pass

def fork(self, name=_MODEL_PARALLEL_RNG_TRACKER_NAME):
"""Mirrors the interface from the training RNG tracker."""
return contextlib.nullcontext()

tracker_class = InferenceCudaRNGStatesTracker
else:
tracker_class = base_tracker

_CUDA_RNG_STATE_TRACKER = tracker_class()
_CUDA_RNG_STATE_TRACKER_INITIALIZED = True


def get_cuda_rng_tracker(use_te_rng_tracker=False):
def get_cuda_rng_tracker(use_te_rng_tracker=False, inference_rng_tracker=False):
"""Get cuda rng tracker."""
initialize_rng_tracker(use_te_rng_tracker)
initialize_rng_tracker(use_te_rng_tracker, inference_rng_tracker)
return _CUDA_RNG_STATE_TRACKER


def model_parallel_cuda_manual_seed(seed):
def model_parallel_cuda_manual_seed(seed, te_rng_tracker=False, inference_rng_tracker=False):
"""Initialize model parallel cuda seed.
This function should be called after the model parallel is
Expand All @@ -216,7 +240,7 @@ def model_parallel_cuda_manual_seed(seed):
# Data parallel gets the original seed.
data_parallel_seed = seed

initialize_rng_tracker()
initialize_rng_tracker(te_rng_tracker, inference_rng_tracker)
_CUDA_RNG_STATE_TRACKER.reset()
# Set the default state.
torch.cuda.manual_seed(data_parallel_seed)
Expand All @@ -241,6 +265,7 @@ class CheckpointFunction(torch.autograd.Function):

@staticmethod
def forward(ctx, run_function, distribute_saved_activations, *args):
"""Forward pass."""
ctx.run_function = run_function
ctx.distribute_saved_activations = distribute_saved_activations

Expand All @@ -267,6 +292,7 @@ def forward(ctx, run_function, distribute_saved_activations, *args):

@staticmethod
def backward(ctx, *args):
"""Backward pass."""
if not torch.autograd._is_checkpoint_valid():
raise RuntimeError(
"Checkpointing is not compatible with .grad(), "
Expand Down
23 changes: 16 additions & 7 deletions megatron/core/transformer/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,14 +169,14 @@ def custom_forward(*inputs):

return hidden_states

def _allocate_memory(self, inference_max_sequence_length, batch_size, dim, dtype):
def _allocate_memory(self, inference_max_sequence_length, batch_size, dtype):
"""Allocate memory to store kv cache during inference."""

return torch.empty(
inference_max_sequence_length,
batch_size,
self.num_query_groups_per_partition,
dim,
self.hidden_size_per_attention_head,
dtype=dtype,
device=torch.cuda.current_device(),
)
Expand All @@ -190,6 +190,7 @@ def _adjust_key_value_for_inference(
rotary_pos_emb: Tensor,
rotary_pos_cos: Tensor = None,
rotary_pos_sin: Tensor = None,
sequence_len_offset=None,
) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor]:
"""
Saves the generated key and value tensors to the end of the buffers in inference_params.
Expand All @@ -210,10 +211,10 @@ def _adjust_key_value_for_inference(
inf_max_seq_length = inference_params.max_sequence_length
inf_max_batch_size = inference_params.max_batch_size
inference_key_memory = self._allocate_memory(
inf_max_seq_length, inf_max_batch_size, key.shape[-1], key.dtype
inf_max_seq_length, inf_max_batch_size, key.dtype
)
inference_value_memory = self._allocate_memory(
inf_max_seq_length, inf_max_batch_size, value.shape[-1], value.dtype
inf_max_seq_length, inf_max_batch_size, value.dtype
)
inference_params.key_value_memory_dict[self.layer_number] = (
inference_key_memory,
Expand Down Expand Up @@ -246,7 +247,7 @@ def _adjust_key_value_for_inference(
rotary_pos_sin_q = rotary_pos_sin[sequence_end - 1 : sequence_end]
rotary_pos_cos_k = rotary_pos_cos[sequence_end - 1 : sequence_end]
rotary_pos_sin_k = rotary_pos_sin[sequence_end - 1 : sequence_end]
else:
else: # Prefill
rotary_pos_cos_q = rotary_pos_cos[:sequence_end]
rotary_pos_sin_q = rotary_pos_sin[:sequence_end]
rotary_pos_cos_k = rotary_pos_cos[:sequence_end]
Expand Down Expand Up @@ -341,6 +342,7 @@ def forward(
rotary_pos_sin=None,
attention_bias=None,
packed_seq_params=None,
sequence_len_offset=None,
):
"""
Perform a forward pass through the attention module.
Expand Down Expand Up @@ -380,7 +382,7 @@ def forward(
self.layer_number
]
output = self.flash_decoding(
sequence_len_offset=inference_params.sequence_len_offset,
sequence_len_offset=sequence_len_offset,
query_layer=query,
key_layer=key,
value_layer=value,
Expand All @@ -395,7 +397,14 @@ def forward(
return output, bias

query, key, value, rotary_pos_emb, attn_mask_type = self._adjust_key_value_for_inference(
inference_params, query, key, value, rotary_pos_emb, rotary_pos_cos, rotary_pos_sin
inference_params,
query,
key,
value,
rotary_pos_emb,
rotary_pos_cos,
rotary_pos_sin,
sequence_len_offset,
)

if packed_seq_params is not None:
Expand Down
18 changes: 13 additions & 5 deletions megatron/core/transformer/cuda_graphs.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@

_IS_GRAPH_CAPTURING = False

logger = logging.getLogger(__name__)


def is_graph_capturing():
"""Query if currently capturing."""
Expand All @@ -53,7 +55,9 @@ def _set_capture_end():
def _check_supported_type(arg):
"""Check if arg is a supported type for cudagraph input/outputs."""

_SUPPORTED_TYPES = {torch.Tensor, type(None), bool, int, str, float}
from megatron.core import InferenceParams # guard against circular import

_SUPPORTED_TYPES = {torch.Tensor, type(None), bool, int, str, float, InferenceParams}
assert type(arg) in _SUPPORTED_TYPES or is_dataclass(
arg
), f"Cudagraphs recieved an arg of type {type(arg)} which is not supported."
Expand Down Expand Up @@ -153,7 +157,6 @@ def create_cudagraphs(cls):
vpp_size = parallel_state.get_virtual_pipeline_model_parallel_world_size()
if vpp_size is None or vpp_size == 1:
bwd_mempool = fwd_mempool

if optimize_transformer_layer_graph_buffers:
if graph_type == 'fwd':
args, kwargs = g[3:]
Expand Down Expand Up @@ -550,6 +553,7 @@ def record_graph_capture(self, args, kwargs):
'create_cudagraphs()` is called. Subsequent fwd passes will replay the cudagraph.
"""
if not self.fwd_graph_recorded:
logger.debug(f"Recording forward graph creation...")
_CudagraphGlobalRecord.record_fwd_graph(self, args, kwargs)
self.fwd_graph_recorded = True

Expand Down Expand Up @@ -598,10 +602,11 @@ def forward(self, is_first_microbatch, args, kwargs):
return out

def matches_graph_inputs(self, args, kwargs):
"""Check the the passed args, kwargs match with the arg, kwargs
"""Check that the passed args, kwargs match with the arg, kwargs
the graph was created with."""

def check(val, ref):

_check_supported_type(val)
_check_supported_type(ref)

Expand Down Expand Up @@ -758,9 +763,12 @@ def __call__(self, megatron_module, args, kwargs):
if self.training and torch.is_grad_enabled():
runner = _CudaGraphRunner(megatron_module, len(self.cudagraph_runners))
self.cudagraph_runners.append(runner)
elif 'inference_params' in kwargs.keys() and kwargs['inference_params']:
# Instantiate the cudagraphed version of the module in inference mode
runner = _CudaGraphRunner(megatron_module, len(self.cudagraph_runners))
runner.eval()
self.cudagraph_runners.append(runner)
else:
# No cudagraphs were found in inference mode, so fallback to eager since
# tensor.requires_grad is needed to correctly trace the backward graph.
return super(MegatronModule, megatron_module).__call__(*args, **kwargs)

# Trigger Mcore DDP pre-forward hooks
Expand Down
6 changes: 6 additions & 0 deletions megatron/core/transformer/transformer_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -404,6 +404,7 @@ def forward(
attention_bias: Tensor = None,
inference_params: InferenceParams = None,
packed_seq_params: PackedSeqParams = None,
sequence_len_offset: Tensor = None,
):
"""
Perform the forward pass through the transformer block.
Expand Down Expand Up @@ -436,6 +437,10 @@ def forward(
# See set_input_tensor()
hidden_states = self.input_tensor

# Update the inference parameters with the current batch size in case it is variable
if inference_params and not self.training:
inference_params.current_batch_size = hidden_states.size(1)

# Viewless tensor.
# - We only need to create a viewless tensor in the case of micro batch
# size (mbs) == 1, since in this case, 'hidden_states.transpose()'
Expand Down Expand Up @@ -512,6 +517,7 @@ def forward(
attention_bias=attention_bias,
inference_params=inference_params,
packed_seq_params=packed_seq_params,
sequence_len_offset=sequence_len_offset,
)
else:
# CUDA graph replay for layer `l_no` and microbatch
Expand Down
Loading

0 comments on commit 8fba594

Please sign in to comment.