Skip to content

Commit

Permalink
ADLR/megatron-lm!2614 - Fix pipeline parallelism bugs in MCore inference
Browse files Browse the repository at this point in the history
  • Loading branch information
santhnm2 authored and deepakn94 committed Feb 2, 2025
1 parent 731fbfd commit 6508404
Show file tree
Hide file tree
Showing 5 changed files with 36 additions and 17 deletions.
4 changes: 4 additions & 0 deletions megatron/core/inference/communication_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,10 @@ def broadcast_from_last_pipeline_stage(size, dtype, tensor=None):
"""Broadcast a tensor from last pipeline stage to all ranks."""

if parallel_state.is_pipeline_last_stage():
assert size == list(
tensor.shape
), f"Expected tensor of shape {size} but got {list(tensor.shape)}"
assert dtype == tensor.dtype, f"Expected tensor of type {dtype} but got {tensor.dtype}"
_is_cuda(tensor)
assert tensor.is_contiguous()
else:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,9 @@ def forward_pass_with_pipeline_parallel_small_input_batch(
logits = output_tensor
logits = tensor_parallel.gather_from_tensor_model_parallel_region(logits)

# Explicitly cast logits to expected dtype
logits = logits.to(self.inference_wrapper_config.params_dtype)

return logits

def forward_pass_with_pipeline_parallel_large_input_batch(
Expand Down Expand Up @@ -188,7 +191,7 @@ def forward_pass_with_pipeline_parallel_large_input_batch(
if parallel_state.is_pipeline_last_stage():
logits = torch.empty(
(batch_size, seq_len, self.inference_wrapper_config.padded_vocab_size),
dtype=torch.float32,
dtype=self.pipeline_communication_dtype,
device=torch.cuda.current_device(),
)

Expand Down Expand Up @@ -223,8 +226,12 @@ def forward_pass_with_pipeline_parallel_large_input_batch(
output_tensor = tensor_parallel.gather_from_tensor_model_parallel_region(
output_tensor
)
assert logits is not None
logits[start:end, ...] = output_tensor

# Explicitly cast logits to expected dtype
logits = logits.to(self.inference_wrapper_config.params_dtype)

# Once done with all micro batches, we reset batch size offset and seq len offset
self.inference_params.sequence_len_offset += seq_len
self.inference_params.batch_size_offset = 0
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
InferenceWrapperConfig,
)
from megatron.core.models.T5 import T5Model
from megatron.core.utils import get_attr_wrapped_model


# pylint: disable=line-too-long
Expand Down Expand Up @@ -56,10 +57,7 @@ def prep_inference_input(
A dict with all the inference input needed for the batch.
"""
# get max_sequence_length
if hasattr(self.model, "module"): # if self.model is Float16Module
max_sequence_length = self.model.module.max_sequence_length
else:
max_sequence_length = self.model.max_sequence_length
max_sequence_length = get_attr_wrapped_model(self.model, "max_sequence_length")

encoder_prompts_tokens_list = [
self.tokenize_encoder_prompt(encoder_prompt, tokenizer)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
)
from megatron.core.inference.sampling_params import SamplingParams
from megatron.core.transformer.cuda_graphs import create_cudagraphs
from megatron.core.utils import get_attr_wrapped_model, get_model_config


class TextGenerationController:
Expand Down Expand Up @@ -335,6 +336,13 @@ def generate_all_output_tokens_static_batch(
batch_size, device=torch.cuda.current_device()
).cuda()

# Use model vocab size since tokenizer vocab size might not include padding
# to nearest power of 2
vocab_size = get_attr_wrapped_model(self.inference_wrapped_model.model, "vocab_size")

# Check whether CUDA graphs are enabled
enable_cuda_graph = get_model_config(self.inference_wrapped_model.model).enable_cuda_graph

streaming_enabled = active_streams is not None and len(active_streams) > 0
if streaming_enabled:
# Start a separate thread for streaming tokens to avoid blocking the
Expand All @@ -352,12 +360,6 @@ def generate_all_output_tokens_static_batch(
streaming_executor = concurrent.futures.ThreadPoolExecutor(max_workers=1)
stream_tokens = functools.partial(self.stream_tokens, sampling_params)

# Check whether CUDA graphs are enabled
if hasattr(self.inference_wrapped_model.model, "module"): # if model is Float16Module
enable_cuda_graph = self.inference_wrapped_model.model.module.config.enable_cuda_graph
else:
enable_cuda_graph = self.inference_wrapped_model.model.config.enable_cuda_graph

use_attention_mask = True

with torch.no_grad():
Expand Down Expand Up @@ -398,7 +400,7 @@ def generate_all_output_tokens_static_batch(
if self.model_is_pipeline_parallel:
context_length = context_end_position - context_start_position
logits = broadcast_from_last_pipeline_stage(
[batch_size, context_length, self.tokenizer.vocab_size],
[batch_size, context_length, vocab_size],
dtype=self.inference_wrapped_model.inference_wrapper_config.params_dtype,
tensor=logits,
)
Expand All @@ -409,7 +411,7 @@ def generate_all_output_tokens_static_batch(
generation_started = prompt_lengths_in_batch <= context_end_position
last_token_logits = logits[:, -1, :]
sampled_logits = self.sample_from_logits(
last_token_logits, sampling_params, self.tokenizer.vocab_size
last_token_logits, sampling_params, vocab_size
)

# Substitute the sampled logits only for only the prompts that
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@

class TestTextGenerationController:

def setup_method(self, method):
def setup_model(self, dtype):
Utils.initialize_model_parallel(
tensor_model_parallel_size=2, pipeline_model_parallel_size=2
)
Expand Down Expand Up @@ -60,7 +60,7 @@ def setup_method(self, method):
hidden_size=self.hidden_size,
inference_batch_times_seqlen_threshold=-1,
fp32_residual_connection=False,
params_dtype=torch.float,
params_dtype=dtype,
padded_vocab_size=self.vocab_size,
)

Expand All @@ -76,6 +76,8 @@ def teardown_method(self, method):
Utils.destroy_model_parallel()

def test_sample_from_logits(self):
self.setup_model(torch.float32)

with pytest.raises(AssertionError) as aerror:
self.text_generation_controller.sample_from_logits(
last_token_logits=None,
Expand Down Expand Up @@ -139,7 +141,10 @@ def test_sample_from_logits(self):
sampled_logits >= expected_min_value
), f"The sampled logits should all be greater than {expected_min_value} but its {sampled_logits}"

def test_generate_all_output_tokens_static_batch(self):
@pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16])
def test_generate_all_output_tokens_static_batch(self, dtype):
self.setup_model(dtype)

self.mock_tokenizer.vocab_size = self.vocab_size
self.mock_tokenizer.eod = self.vocab_size - 1
self.mock_tokenizer.detokenize.return_value = ''.join(
Expand Down Expand Up @@ -183,7 +188,10 @@ def test_generate_all_output_tokens_static_batch(self):
all_prompt_tokens[request_id] == request.prompt_tokens
), "Prompt tokens should not have changed during generation"

def test_output_log_probs(self):
@pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16])
def test_output_log_probs(self, dtype):
self.setup_model(dtype)

self.mock_tokenizer.vocab_size = self.vocab_size
self.mock_tokenizer.bos = 0
self.mock_tokenizer.eod = self.vocab_size - 1
Expand Down

0 comments on commit 6508404

Please sign in to comment.