From 195b1c457cb8f0598d6c81b96d452ce281d9f56e Mon Sep 17 00:00:00 2001 From: apbose Date: Mon, 23 Sep 2024 11:09:23 -0700 Subject: [PATCH] using nccl ops from TRT-LLM namespace --- .../tensor_parallel_simple_example.py | 57 ++++++++++++++- .../lowering/passes/_aten_lowering_pass.py | 1 + .../lowering/passes/fuse_distributed_ops.py | 69 +++++++++++++++++++ 3 files changed, 126 insertions(+), 1 deletion(-) create mode 100644 py/torch_tensorrt/dynamo/lowering/passes/fuse_distributed_ops.py diff --git a/examples/distributed_inference/tensor_parallel_simple_example.py b/examples/distributed_inference/tensor_parallel_simple_example.py index 470487a751..eaa75111d6 100755 --- a/examples/distributed_inference/tensor_parallel_simple_example.py +++ b/examples/distributed_inference/tensor_parallel_simple_example.py @@ -5,6 +5,7 @@ import torch import torch.nn as nn import torch_tensorrt +import torch.distributed as dist from torch.distributed._tensor import Shard from torch.distributed._tensor.device_mesh import init_device_mesh from torch.distributed.tensor.parallel import ( @@ -12,11 +13,65 @@ RowwiseParallel, parallelize_module, ) - +from torch_tensorrt.dynamo.conversion._ConverterRegistry import ( + dynamo_tensorrt_converter, +) +from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext +from torch.fx.node import Target, Argument +from typing import Any, Callable, Dict, Optional, Sequence, Tuple, Union +from torch_tensorrt.dynamo.types import TRTTensor +import numpy as np +from torch_tensorrt.fx.converters.converter_utils import ( + set_layer_name, +) +import tensorrt as trt +import tensorrt_llm +import ctypes +import logging """ This example copies some code from https://github.com/pytorch/examples/blob/main/distributed/tensor_parallelism/tensor_parallel_example.py """ +plugin_lib_path = "/root/.pyenv/versions/3.10.14/lib/python3.10/site-packages/tensorrt_llm/libs/libnvinfer_plugin_tensorrt_llm.so" +try: + ctypes.CDLL("/root/.pyenv/versions/3.10.14/lib/python3.10/site-packages/tensorrt_llm/libs/libnvinfer_plugin_tensorrt_llm.so") + print("plugin loaded sucessfully") +except OSError as e: + print(f"unsuccessful load : {e}") +logger = trt.Logger(trt.Logger.VERBOSE) +trt.init_libnvinfer_plugins(None, '') +#-[p;Iterate over all registered plugin creators +plugin_registry = trt.get_plugin_registry() +for plugin_creator in plugin_registry.plugin_creator_list: + print(f"Plugin Name: {plugin_creator.name}, Namespace: {plugin_creator.plugin_namespace}, Version: {plugin_creator.plugin_version}") + + +@dynamo_tensorrt_converter(torch.ops._c10d_functional.all_gather_into_tensor.default) +def insert_gather_op( + ctx: ConversionContext, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + plug_inputs = [args[0]] + allgather_plg_creator = trt.get_plugin_registry().get_plugin_creator( + "AllGather", "1", "tensorrt_llm" + ) + assert allgather_plg_creator is not None + world_size = dist.get_world_size() + group = list(range(world_size)) + group = trt.PluginField("group", np.array(group, dtype=np.int32), trt.PluginFieldType.INT32) + p_dtype = trt.float16 + pf_type = trt.PluginField( + "type_id", np.array([int(p_dtype)], np.int32), trt.PluginFieldType.INT32 + ) + pfc = trt.PluginFieldCollection([group, pf_type]) + allgather = allgather_plg_creator.create_plugin("allgather", pfc) + layer = ctx.net.add_plugin_v2(plug_inputs, allgather) + set_layer_name(layer, target, name) + return layer.get_output(0) + class ToyModel(nn.Module): """MLP based model""" diff --git a/py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py b/py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py index b6435c0d8c..f659968a8f 100644 --- a/py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py +++ b/py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py @@ -6,6 +6,7 @@ from .accumulate_fp32_matmul import accumulate_fp32_matmul from .constant_folding import constant_fold from .fuse_prims_broadcast import fuse_prims_broadcast +from .fuse_distributed_ops import fuse_distributed_ops from .lower_linear import lower_linear from .lower_scaled_dot_product_attention import lower_scaled_dot_product_attention from .pass_manager import DynamoPassManager diff --git a/py/torch_tensorrt/dynamo/lowering/passes/fuse_distributed_ops.py b/py/torch_tensorrt/dynamo/lowering/passes/fuse_distributed_ops.py new file mode 100644 index 0000000000..d49de3190c --- /dev/null +++ b/py/torch_tensorrt/dynamo/lowering/passes/fuse_distributed_ops.py @@ -0,0 +1,69 @@ +import logging +from typing import Sequence + +import torch + +# dead-code elimination, linting, and recompilation for graph, in-place +from torch_tensorrt.dynamo.lowering.passes.pass_utils import ( + clean_up_graph_after_modifications, +) + +logger = logging.getLogger(__name__) + + +def custom_fused_all_gather_op(args0, args1, args2): + return torch.ops._c10d_functional.wait_tensor.default( + torch.ops._c10d_functional.all_gather_into_tensor.default(args0, args1, args2) + ) + + +def custom_fused_reduce_scatter_op(args0, args1, args2, args3): + return torch.ops._c10d_functional.wait_tensor.default( + torch.ops._c10d_functional.reduce_scatter_tensor.default( + args0, args1, args2, args3 + ) + ) + + +def fuse_distributed_ops(gm: torch.fx.GraphModule) -> torch.fx.GraphModule: + modified_graph = False + for node in gm.graph.nodes: + if ( + node.target + in ( + torch.ops._c10d_functional.all_gather_into_tensor.default, + torch.ops._c10d_functional.reduce_scatter_tensor.default, + ) + and len(node.users) == 1 + and list(node.users)[0].target + == torch.ops._c10d_functional.wait_tensor.default + ): + wait_tensor_node = list(node.users)[0] + fused_op = None + if node.target == torch.ops._c10d_functional.all_gather_into_tensor.default: + fused_op = custom_fused_all_gather_op + fused_op_args = (node.args[0], node.args[1], node.args[2]) + else: + fused_op = custom_fused_reduce_scatter_op + fused_op_args = (node.args[0], node.args[1], node.args[2], node.args[3]) + with gm.graph.inserting_after(wait_tensor_node): + fused_node = gm.graph.create_node( + op="call_function", + target=fused_op, # Define your custom fused function + args=fused_op_args, + ) + + wait_tensor_node.replace_all_uses_with(fused_node) + fused_node.meta.update(node.meta) + modified_graph = True + gm.graph.erase_node(wait_tensor_node) + gm.graph.erase_node(node) + + # If graph was modified, clean it up + if modified_graph: + gm = clean_up_graph_after_modifications(gm) + logger.debug( + f"Graph after fusing wait_tensor and distributed op tensor:\n{gm.graph}" + ) + + return gm