diff --git a/nemo/export/tensorrt_llm.py b/nemo/export/tensorrt_llm.py index 1e06d0fdb8b9..4aae83efd6f6 100644 --- a/nemo/export/tensorrt_llm.py +++ b/nemo/export/tensorrt_llm.py @@ -45,7 +45,14 @@ from nemo.export.trt_llm.qnemo.tokenizer_utils import get_nmt_tokenizer from nemo.export.trt_llm.qnemo.utils import is_qnemo_checkpoint from nemo.export.trt_llm.tensorrt_llm_build import build_and_save_engine -from nemo.export.trt_llm.tensorrt_llm_run import generate, generate_streaming, load, load_distributed, refit +from nemo.export.trt_llm.tensorrt_llm_run import ( + generate, + generate_streaming, + load, + load_distributed, + refit, + unload_engine, +) use_deploy = True try: @@ -505,12 +512,12 @@ def build( engine = build_and_save_engine( max_input_len=max_input_len, max_output_len=max_output_len, + max_seq_len=max_input_len + max_output_len, max_batch_size=max_batch_size, model_config=model_config[0], model_weights=weights[0], model_dir=self.model_dir, model_type=model_type, - custom_all_reduce=False, use_refit=use_refit, ) torch.distributed.barrier() @@ -983,3 +990,6 @@ def _load(self): "model needs to be exported again. " "Error message: " + repr(error) ) from error + + def unload_engine(self): + unload_engine() diff --git a/nemo/export/trt_llm/converter/utils.py b/nemo/export/trt_llm/converter/utils.py index 3f9f2a31a307..006d31053a66 100755 --- a/nemo/export/trt_llm/converter/utils.py +++ b/nemo/export/trt_llm/converter/utils.py @@ -17,7 +17,7 @@ import numpy as np import tensorrt_llm import torch -from tensorrt_llm._utils import torch_to_numpy +from tensorrt_llm._utils import mpi_comm, torch_to_numpy # A global dicts to store exported weights. # This is set to be a global variable to avoid extra code modification from tensorrt_llm. @@ -586,6 +586,13 @@ def init_model_parallel_from_nemo(reshard_model): pp_size = 1 mp_rank = tp_size * pp_rank + tp_rank + # Need to split cpp MPI World Comm because TensorRT-LLM NCCL plugins refer to the locally split comm. + # High level call structure is: MpiComm::split -> MpiComm::setSession -> LOCAL_COMM_SESSION (used in allReducePlugin.cpp) tensorrt_llm.bindings.MpiComm.split(dp_rank, mp_rank) + # Also split the python mpi communicator and set the global world one to the local split one + new_comm = mpi_comm().Split(color=dp_rank, key=mp_rank) + from mpi4py import MPI + + MPI.COMM_WORLD = new_comm return mp_rank, dp_rank, tp_size, pp_size, dp_size diff --git a/nemo/export/trt_llm/tensorrt_llm_build.py b/nemo/export/trt_llm/tensorrt_llm_build.py index 4720efc51e53..cdf8eaac6b1c 100755 --- a/nemo/export/trt_llm/tensorrt_llm_build.py +++ b/nemo/export/trt_llm/tensorrt_llm_build.py @@ -53,6 +53,7 @@ def build_and_save_engine( multiple_profiles: bool = False, gpt_attention_plugin: str = "auto", gemm_plugin: str = "auto", + reduce_fusion: bool = False, ): architecture = "LLaMAForCausalLM" if model_config.architecture == "LlamaForCausalLM" else model_config.architecture try: @@ -71,6 +72,7 @@ def build_and_save_engine( plugin_config.remove_input_padding = remove_input_padding plugin_config.use_paged_context_fmha = paged_context_fmha plugin_config.multiple_profiles = multiple_profiles + plugin_config.reduce_fusion = reduce_fusion max_num_tokens, opt_num_tokens = check_max_num_tokens( max_num_tokens=max_num_tokens, diff --git a/nemo/export/trt_llm/tensorrt_llm_run.py b/nemo/export/trt_llm/tensorrt_llm_run.py index 852eddc6a468..1772c071a745 100644 --- a/nemo/export/trt_llm/tensorrt_llm_run.py +++ b/nemo/export/trt_llm/tensorrt_llm_run.py @@ -23,29 +23,26 @@ from typing import List, Optional import numpy as np +import tensorrt as trt import tensorrt_llm import torch from mpi4py.futures import MPIPoolExecutor +from tensorrt_llm._utils import mpi_comm +from tensorrt_llm.builder import Engine from tensorrt_llm.lora_manager import LoraManager +from tensorrt_llm.mapping import Mapping from tensorrt_llm.quantization import QuantMode -from tensorrt_llm.runtime import ModelConfig, ModelRunner, ModelRunnerCpp, SamplingConfig - +from tensorrt_llm.runtime import GenerationSession, ModelConfig, ModelRunner, ModelRunnerCpp, SamplingConfig from transformers import PreTrainedTokenizer LOGGER = logging.getLogger("NeMo") use_trtllm_bindings = True try: - from tensorrt_llm.bindings import GptJsonConfig, GptSession, GptSessionConfig, KvCacheConfig, WorldConfig + from tensorrt_llm.bindings import GptJsonConfig, KvCacheConfig, WorldConfig except Exception as e: use_trtllm_bindings = False -use_cpp_gpt_session = True -try: - from tensorrt_llm.runtime.model_runner_cpp import ModelRunnerCppGptSession -except Exception as e: - use_cpp_gpt_session = False - @dataclass class TensorrtLLMHostContext: @@ -63,7 +60,7 @@ class TensorrtLLMHostContext: class TensorrtLLMWorkerContext: """The MPI worker side context for TRT LLM inference.""" - decoder: ModelRunner = None + decoder: ModelRunner | ModelRunnerCpp = None sampling_config: SamplingConfig = None lora_manager: LoraManager = None max_batch_size: int = 0 @@ -123,7 +120,6 @@ def _read_config(config_path: Path): lora_plugin=config["plugin_config"]["lora_plugin"], lora_target_modules=config["builder_config"]["lora_target_modules"], quant_mode=quant_mode, - use_custom_all_reduce=config["plugin_config"]["use_custom_all_reduce"], use_context_fmha_for_generation=config["plugin_config"]["use_context_fmha_for_generation"], gather_context_logits=config["builder_config"]["gather_context_logits"], gather_generation_logits=config["builder_config"]["gather_generation_logits"], @@ -456,7 +452,7 @@ def load_distributed(engine_dir, model_parallel_rank, gpus_per_node): this function creates a custom mapping of device_id to WorldConfig """ global tensorrt_llm_worker_context - if isinstance(tensorrt_llm_worker_context.decoder, ModelRunnerCppGptSession): + if isinstance(tensorrt_llm_worker_context.decoder, ModelRunner): return config_path = Path(engine_dir) / f"config_{torch.distributed.get_rank()}.json" @@ -480,46 +476,102 @@ def load_distributed(engine_dir, model_parallel_rank, gpus_per_node): device_ids = [i for i in range(gpus_per_node)] for _ in range(offset): device_ids.append(device_ids.pop(0)) - world_config = WorldConfig.mpi( - gpus_per_node=gpus_per_node, tensor_parallelism=tp_size, pipeline_parallelism=pp_size, device_ids=device_ids - ) - engine_filename = json_config.engine_filename(world_config) + engine_index = model_parallel_rank + mpi_rank = mpi_comm().Get_rank() + # Copied from worldConfig.h (getDevice()) + mpi_device = mpi_rank % gpus_per_node + # TODO: Consider re-enabling + # assert torch.cuda.current_device() == mpi_device + + # TODO: check if API exists (copied from gptJsonConfig.cpp) + # https://github.com/terrykong/TensorRT-LLM/blob/05316d3313360012536ace46c781518f5afae75e/cpp/tensorrt_llm/runtime/gptJsonConfig.cpp#L478 + engine_filename = f"rank{engine_index}.engine" serialize_path = Path(engine_dir) / engine_filename - assert torch.cuda.current_device() == world_config.device - - session_config = GptSessionConfig( - max_batch_size=max_batch_size, max_beam_width=max_beam_width, max_sequence_length=max_seq_len - ) - session_config.gen_micro_batch_size = max_batch_size - session_config.ctx_micro_batch_size = max_batch_size - session_config.kv_cache_config = KvCacheConfig( - max_tokens=max_seq_len * max_batch_size, max_attention_window=max_seq_len - ) - with open(serialize_path, "rb") as f: engine_data = bytearray(f.read()) - session = GptSession(session_config, model_config, world_config, engine_data) - decoder = ModelRunnerCppGptSession( - session, - lora_manager=None, - max_batch_size=max_batch_size, - max_input_len=max_input_len, - max_seq_len=max_seq_len, - max_beam_width=max_beam_width, + with open(config_path) as f: + json_config_str = f.read() + + engine = Engine.from_buffer(engine_buffer=engine_data, json_config_str=json_config_str, rank=model_parallel_rank) + decoder = ModelRunner.from_engine( + engine=engine, + # We want the engine to have the mp_rank, but the python runtime to not resassign the device of the current process + # So we will set it to the current device + rank=torch.cuda.current_device(), + _disable_torch_cuda_device_set=True, ) tensorrt_llm_worker_context.decoder = decoder tensorrt_llm_worker_context.max_batch_size = max_batch_size tensorrt_llm_worker_context.max_input_len = max_input_len - # Save the model config in case for refit - tensorrt_llm_worker_context.model_config = model_config -def refit(weights_dict): +def maybe_cast_to_trt_dtype(dtype): + if isinstance(dtype, trt.DataType): + return dtype + elif isinstance(dtype, torch.dtype): + return tensorrt_llm._utils.torch_dtype_to_trt(dtype) + else: + raise NotImplementedError(f"Expects the type to be a tensorrt.DataType or torch.dtype, but got {type(dtype)=}") + + +def refit(weights_dict: dict): global tensorrt_llm_worker_context - dtype = tensorrt_llm_worker_context.model_config.data_type - tensorrt_llm_worker_context.decoder.session.refit_engine(weights_dict, dtype) + decoder = tensorrt_llm_worker_context.decoder + if not isinstance(decoder, ModelRunner): + raise ValueError( + f"Refit is only supported with ModelRunner, but export has been configured with {type(decoder)=}" + ) + + engine = decoder.session.runtime.engine + # The session dtype plumbs the model_config's dtype + model_dtype = maybe_cast_to_trt_dtype(decoder.session.dtype) + assert engine.refittable, "Tried refitting engine without refit enabled" + + refitter = trt.Refitter(engine=engine, logger=trt.Logger(trt.Logger.ERROR)) + remaining_refit_weights = set(refitter.get_all_weights()) + skipped_weights = [] + for trt_name, weight in weights_dict.items(): + if trt_name not in remaining_refit_weights: + skipped_weights.append(trt_name) + continue + trt_weight = trt.Weights(model_dtype, weight.data_ptr(), torch.numel(weight)) + trt_wt_location = trt.TensorLocation.DEVICE if weight.is_cuda else trt.TensorLocation.HOST + assert ( + model_dtype == refitter.get_weights_prototype(trt_name).dtype == maybe_cast_to_trt_dtype(weight.dtype) + ), f"Expected all three of these dtypes to be the same {model_dtype=} {refitter.get_weights_prototype(trt_name).dtype=} weight.dtype={maybe_cast_to_trt_dtype(weight.dtype)}" + + refitter.set_named_weights( + trt_name, trt_weight, trt_wt_location + ), f"Unable to set {trt_name=} {trt_weight=} {trt_wt_location=}" + remaining_refit_weights.remove(trt_name) + if skipped_weights: + logging.warning( + f"These weights were ignored during refit since they are not present in engine: {skipped_weights}" + ) + if remaining_refit_weights: + logging.warning(f"Weights dict did not contain weights for these named TRT weights: {remaining_refit_weights}") + + if not refitter.refit_cuda_engine(): + raise ValueError(f"Refit failed!") + + +def unload_engine(): + """ + Deletes the ModelRunner which should free up device memory + """ + global tensorrt_llm_worker_context + decoder = tensorrt_llm_worker_context.decoder + if not isinstance(decoder, ModelRunner): + raise ValueError( + f"unload_engine is only supported with ModelRunner, but export has been configured with {type(decoder)=}" + ) + + logging.info("Unloading engine...") + del tensorrt_llm_worker_context.decoder + tensorrt_llm_worker_context.decoder = None + logging.info("Engine unloaded!") def prepare_input_tensors(