Skip to content

Commit

Permalink
Include (most) CI hints: long lines, docstrings, unused imports
Browse files Browse the repository at this point in the history
Signed-off-by: Jan Lasek <janek.lasek@gmail.com>
  • Loading branch information
janekl committed Jan 9, 2025
1 parent 3e5fb00 commit 633e973
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 11 deletions.
40 changes: 32 additions & 8 deletions nemo/deploy/nlp/megatronllm_deployable.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def GetNumpyDtype(pyvalue):


class ServerSync(IntEnum):
"""Enum for synchronization messages using torch.distributed"""
"""Enum for synchronization messages using torch.distributed."""

WAIT = auto()
SIGNAL = auto()
Expand All @@ -105,6 +105,11 @@ def to_long_tensor(self):


class MegatronLLMDeploy:
"""
A factory class for creating deployable instances of Megatron LLM models.
This class provides a method to get the appropriate deployable instance
based on the version of the NeMo checkpoint model used.
"""

@staticmethod
def get_deployable(
Expand All @@ -115,7 +120,20 @@ def get_deployable(
pipeline_model_parallel_size: int = 1,
context_parallel_size: int = 1,
):
"""
Returns the appropriate deployable instance for the given NeMo checkpoint.
Args:
nemo_checkpoint_filepath (str): Path to the .nemo checkpoint file.
num_devices (int): Number of devices to use for deployment.
num_nodes (int): Number of nodes to use for deployment.
tensor_model_parallel_size (int): Size of the tensor model parallelism.
pipeline_model_parallel_size (int): Size of the pipeline model parallelism.
context_parallel_size (int): Size of the context parallelism.
Returns:
ITritonDeployable: An instance of a deployable class compatible with Triton inference server.
"""
if nemo_checkpoint_version(nemo_checkpoint_filepath) == NEMO2:
return MegatronLLMDeployableNemo2(
nemo_checkpoint_filepath=nemo_checkpoint_filepath,
Expand Down Expand Up @@ -290,11 +308,14 @@ def __init__(
raise IMPORT_ERROR
if nemo_checkpoint_filepath is None and existing_model is None:
raise ValueError(
"MegatronLLMDeployable requires either a .nemo checkpoint filepath or an existing MegatronGPTModel, but both provided were None"
"MegatronLLMDeployable requires either a .nemo checkpoint filepath "
"or an existing MegatronGPTModel, but both provided were None."
)
if num_devices > 1:
LOGGER.warning(
"Creating a MegatronLLMDeployable with num_devices>1 will assume running with a PyTorch Lightning DDP-variant strategy, which will run the main script once per device. Make sure any user code is compatible with multiple executions!"
"Creating a MegatronLLMDeployable with num_devices > 1 will assume running with "
"a PyTorch Lightning DDP-variant strategy, which will run the main script once per device. "
"Make sure any user code is compatible with multiple executions!"
)

# if both existing_model and nemo_checkpoint_filepath are provided, existing_model will take precedence
Expand All @@ -319,14 +340,16 @@ def _load_from_nemo_checkpoint(self, nemo_checkpoint_filepath: str, num_devices:
# transformer_engine should always be true according to EricH, but GPT-2B model will fail if it is enabled
if not custom_config.transformer_engine:
LOGGER.warning(
"MegatronLLMDeployable expects model config transformer_engine=True, but this model has it =False. "
"Overriding it to =True, but this may break certain checkpoints converted on older Nemo versions. "
"MegatronLLMDeployable expects model config transformer_engine=True, but this model uses False. "
"Overriding it to True, but this may break certain checkpoints converted on older Nemo versions. "
"If your model breaks, please try re-converting the checkpoint on the current Nemo version."
)
custom_config.transformer_engine = True
# using multi-gpu for tensor parallelism directly for now, could do pipeline parallel instead or a combination
# using multi-gpu for tensor parallelism directly for now,
# could do pipeline parallel instead or a combination
custom_config.tensor_model_parallel_size = num_devices
# had to override these to make Nemotron3-22B work, see sample_sequence_batch() in text_generation_utils.py
# had to override these to make Nemotron3-22B work,
# see sample_sequence_batch() in text_generation_utils.py
custom_config.activations_checkpoint_granularity = None
custom_config.activations_checkpoint_method = None
# Models trained with TE < 1.10 and loaded with TE >= 1.10 require
Expand Down Expand Up @@ -425,7 +448,8 @@ def generate(self, inputs: List[str], length_params: LengthParam, sampling_param
distributed_rank = torch.distributed.get_rank()
if distributed_rank != 0:
raise ValueError(
f"Triton inference function should not be called on a thread with torch.distributed rank != 0, but this thread is rank {distributed_rank}"
"Triton inference function should not be called on a thread with "
f"torch.distributed rank != 0, but this thread is rank {distributed_rank}."
)
signal_value = ServerSync.SIGNAL.to_long_tensor()
torch.distributed.broadcast(signal_value, 0)
Expand Down
2 changes: 1 addition & 1 deletion nemo/deploy/nlp/query_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.

import time
from abc import ABC, abstractmethod
from abc import ABC

import numpy as np

Expand Down
4 changes: 2 additions & 2 deletions tests/export/nemo_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,8 @@
)
except Exception as e:
LOGGER.warning(
"Cannot import MegatronLLMDeployable or NemoQueryLLMPyTorch,"
f" in-framework inference will not be available. {type(e).__name__}: {e}"
"Cannot import MegatronLLMDeploy* classes, or NemoQueryLLMPyTorch, or CommonInferenceParams, "
f"in-framework inference will not be available. Reason: {type(e).__name__}: {e}"
)
in_framework_supported = False

Expand Down

0 comments on commit 633e973

Please sign in to comment.