From 65084040cd7575dd66bdf14d3cdafa6386e8c84e Mon Sep 17 00:00:00 2001 From: Keshav Santhanam Date: Sun, 2 Feb 2025 10:51:23 -0800 Subject: [PATCH] ADLR/megatron-lm!2614 - Fix pipeline parallelism bugs in MCore inference --- megatron/core/inference/communication_utils.py | 4 ++++ .../abstract_model_inference_wrapper.py | 9 ++++++++- .../t5/t5_inference_wrapper.py | 6 ++---- .../text_generation_controller.py | 18 ++++++++++-------- .../test_simple_text_generation_controller.py | 16 ++++++++++++---- 5 files changed, 36 insertions(+), 17 deletions(-) diff --git a/megatron/core/inference/communication_utils.py b/megatron/core/inference/communication_utils.py index 0c23a583de..8b2f5188f0 100644 --- a/megatron/core/inference/communication_utils.py +++ b/megatron/core/inference/communication_utils.py @@ -14,6 +14,10 @@ def broadcast_from_last_pipeline_stage(size, dtype, tensor=None): """Broadcast a tensor from last pipeline stage to all ranks.""" if parallel_state.is_pipeline_last_stage(): + assert size == list( + tensor.shape + ), f"Expected tensor of shape {size} but got {list(tensor.shape)}" + assert dtype == tensor.dtype, f"Expected tensor of type {dtype} but got {tensor.dtype}" _is_cuda(tensor) assert tensor.is_contiguous() else: diff --git a/megatron/core/inference/model_inference_wrappers/abstract_model_inference_wrapper.py b/megatron/core/inference/model_inference_wrappers/abstract_model_inference_wrapper.py index ea319b08fc..06f7248f32 100644 --- a/megatron/core/inference/model_inference_wrappers/abstract_model_inference_wrapper.py +++ b/megatron/core/inference/model_inference_wrappers/abstract_model_inference_wrapper.py @@ -157,6 +157,9 @@ def forward_pass_with_pipeline_parallel_small_input_batch( logits = output_tensor logits = tensor_parallel.gather_from_tensor_model_parallel_region(logits) + # Explicitly cast logits to expected dtype + logits = logits.to(self.inference_wrapper_config.params_dtype) + return logits def forward_pass_with_pipeline_parallel_large_input_batch( @@ -188,7 +191,7 @@ def forward_pass_with_pipeline_parallel_large_input_batch( if parallel_state.is_pipeline_last_stage(): logits = torch.empty( (batch_size, seq_len, self.inference_wrapper_config.padded_vocab_size), - dtype=torch.float32, + dtype=self.pipeline_communication_dtype, device=torch.cuda.current_device(), ) @@ -223,8 +226,12 @@ def forward_pass_with_pipeline_parallel_large_input_batch( output_tensor = tensor_parallel.gather_from_tensor_model_parallel_region( output_tensor ) + assert logits is not None logits[start:end, ...] = output_tensor + # Explicitly cast logits to expected dtype + logits = logits.to(self.inference_wrapper_config.params_dtype) + # Once done with all micro batches, we reset batch size offset and seq len offset self.inference_params.sequence_len_offset += seq_len self.inference_params.batch_size_offset = 0 diff --git a/megatron/core/inference/model_inference_wrappers/t5/t5_inference_wrapper.py b/megatron/core/inference/model_inference_wrappers/t5/t5_inference_wrapper.py index dce3a6ae17..9dddb9ab8a 100644 --- a/megatron/core/inference/model_inference_wrappers/t5/t5_inference_wrapper.py +++ b/megatron/core/inference/model_inference_wrappers/t5/t5_inference_wrapper.py @@ -14,6 +14,7 @@ InferenceWrapperConfig, ) from megatron.core.models.T5 import T5Model +from megatron.core.utils import get_attr_wrapped_model # pylint: disable=line-too-long @@ -56,10 +57,7 @@ def prep_inference_input( A dict with all the inference input needed for the batch. """ # get max_sequence_length - if hasattr(self.model, "module"): # if self.model is Float16Module - max_sequence_length = self.model.module.max_sequence_length - else: - max_sequence_length = self.model.max_sequence_length + max_sequence_length = get_attr_wrapped_model(self.model, "max_sequence_length") encoder_prompts_tokens_list = [ self.tokenize_encoder_prompt(encoder_prompt, tokenizer) diff --git a/megatron/core/inference/text_generation_controllers/text_generation_controller.py b/megatron/core/inference/text_generation_controllers/text_generation_controller.py index ab2ac294be..bab66a63bf 100644 --- a/megatron/core/inference/text_generation_controllers/text_generation_controller.py +++ b/megatron/core/inference/text_generation_controllers/text_generation_controller.py @@ -16,6 +16,7 @@ ) from megatron.core.inference.sampling_params import SamplingParams from megatron.core.transformer.cuda_graphs import create_cudagraphs +from megatron.core.utils import get_attr_wrapped_model, get_model_config class TextGenerationController: @@ -335,6 +336,13 @@ def generate_all_output_tokens_static_batch( batch_size, device=torch.cuda.current_device() ).cuda() + # Use model vocab size since tokenizer vocab size might not include padding + # to nearest power of 2 + vocab_size = get_attr_wrapped_model(self.inference_wrapped_model.model, "vocab_size") + + # Check whether CUDA graphs are enabled + enable_cuda_graph = get_model_config(self.inference_wrapped_model.model).enable_cuda_graph + streaming_enabled = active_streams is not None and len(active_streams) > 0 if streaming_enabled: # Start a separate thread for streaming tokens to avoid blocking the @@ -352,12 +360,6 @@ def generate_all_output_tokens_static_batch( streaming_executor = concurrent.futures.ThreadPoolExecutor(max_workers=1) stream_tokens = functools.partial(self.stream_tokens, sampling_params) - # Check whether CUDA graphs are enabled - if hasattr(self.inference_wrapped_model.model, "module"): # if model is Float16Module - enable_cuda_graph = self.inference_wrapped_model.model.module.config.enable_cuda_graph - else: - enable_cuda_graph = self.inference_wrapped_model.model.config.enable_cuda_graph - use_attention_mask = True with torch.no_grad(): @@ -398,7 +400,7 @@ def generate_all_output_tokens_static_batch( if self.model_is_pipeline_parallel: context_length = context_end_position - context_start_position logits = broadcast_from_last_pipeline_stage( - [batch_size, context_length, self.tokenizer.vocab_size], + [batch_size, context_length, vocab_size], dtype=self.inference_wrapped_model.inference_wrapper_config.params_dtype, tensor=logits, ) @@ -409,7 +411,7 @@ def generate_all_output_tokens_static_batch( generation_started = prompt_lengths_in_batch <= context_end_position last_token_logits = logits[:, -1, :] sampled_logits = self.sample_from_logits( - last_token_logits, sampling_params, self.tokenizer.vocab_size + last_token_logits, sampling_params, vocab_size ) # Substitute the sampled logits only for only the prompts that diff --git a/tests/unit_tests/inference/text_generation_controllers/test_simple_text_generation_controller.py b/tests/unit_tests/inference/text_generation_controllers/test_simple_text_generation_controller.py index c6548170ad..e39148522c 100644 --- a/tests/unit_tests/inference/text_generation_controllers/test_simple_text_generation_controller.py +++ b/tests/unit_tests/inference/text_generation_controllers/test_simple_text_generation_controller.py @@ -31,7 +31,7 @@ class TestTextGenerationController: - def setup_method(self, method): + def setup_model(self, dtype): Utils.initialize_model_parallel( tensor_model_parallel_size=2, pipeline_model_parallel_size=2 ) @@ -60,7 +60,7 @@ def setup_method(self, method): hidden_size=self.hidden_size, inference_batch_times_seqlen_threshold=-1, fp32_residual_connection=False, - params_dtype=torch.float, + params_dtype=dtype, padded_vocab_size=self.vocab_size, ) @@ -76,6 +76,8 @@ def teardown_method(self, method): Utils.destroy_model_parallel() def test_sample_from_logits(self): + self.setup_model(torch.float32) + with pytest.raises(AssertionError) as aerror: self.text_generation_controller.sample_from_logits( last_token_logits=None, @@ -139,7 +141,10 @@ def test_sample_from_logits(self): sampled_logits >= expected_min_value ), f"The sampled logits should all be greater than {expected_min_value} but its {sampled_logits}" - def test_generate_all_output_tokens_static_batch(self): + @pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16]) + def test_generate_all_output_tokens_static_batch(self, dtype): + self.setup_model(dtype) + self.mock_tokenizer.vocab_size = self.vocab_size self.mock_tokenizer.eod = self.vocab_size - 1 self.mock_tokenizer.detokenize.return_value = ''.join( @@ -183,7 +188,10 @@ def test_generate_all_output_tokens_static_batch(self): all_prompt_tokens[request_id] == request.prompt_tokens ), "Prompt tokens should not have changed during generation" - def test_output_log_probs(self): + @pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16]) + def test_output_log_probs(self, dtype): + self.setup_model(dtype) + self.mock_tokenizer.vocab_size = self.vocab_size self.mock_tokenizer.bos = 0 self.mock_tokenizer.eod = self.vocab_size - 1