Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

using nccl ops from TRT-LLM namespace #3250

Merged
merged 6 commits into from
Jan 3, 2025
Merged

using nccl ops from TRT-LLM namespace #3250

merged 6 commits into from
Jan 3, 2025

Conversation

apbose
Copy link
Collaborator

@apbose apbose commented Oct 19, 2024

This PR illustrates the use of nccl ops from TRT-LLM for the example examples/distributed_inference/tensor_parallel_simple_example.py

@github-actions github-actions bot added component: lowering Issues re: The lowering / preprocessing passes component: api [Python] Issues re: Python API component: dynamo Issues relating to the `torch.compile` or `torch._dynamo.export` paths labels Oct 19, 2024
@github-actions github-actions bot requested a review from gs-olive October 19, 2024 00:55
Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are some changes that do not conform to Python style guidelines:

--- /home/runner/work/TensorRT/TensorRT/examples/distributed_inference/tensor_parallel_simple_example.py	2024-10-19 00:55:11.232553+00:00
+++ /home/runner/work/TensorRT/TensorRT/examples/distributed_inference/tensor_parallel_simple_example.py	2024-10-19 00:55:32.513756+00:00
@@ -84,11 +84,11 @@
    ctypes.CDLL(plugin_lib_path)
    logger.info(f"plugin loaded successfully")
except OSError as e:
    logger.info(f"unsuccessful load : {e}")
trt.init_libnvinfer_plugins(None, "")
-#Iterate over all registered plugin creators
+# Iterate over all registered plugin creators
plugin_registry = trt.get_plugin_registry()
for plugin_creator in plugin_registry.plugin_creator_list:
    logger.info(
        f"Plugin Name: {plugin_creator.name}, Namespace: {plugin_creator.plugin_namespace}, Version: {plugin_creator.plugin_version}"
    )

@apbose apbose marked this pull request as draft October 19, 2024 00:56
@apbose apbose removed the request for review from gs-olive October 19, 2024 00:56
@apbose apbose force-pushed the nccl_ops_multi_gpu branch 3 times, most recently from c916bf6 to 195b1c4 Compare October 21, 2024 20:25
Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are some changes that do not conform to Python style guidelines:

--- /home/runner/work/TensorRT/TensorRT/examples/distributed_inference/tensor_parallel_simple_example.py	2024-10-21 20:25:45.697459+00:00
+++ /home/runner/work/TensorRT/TensorRT/examples/distributed_inference/tensor_parallel_simple_example.py	2024-10-21 20:26:10.941910+00:00
@@ -26,44 +26,51 @@
)
import tensorrt as trt
import tensorrt_llm
import ctypes
import logging
+
"""
This example copies some code from https://github.com/pytorch/examples/blob/main/distributed/tensor_parallelism/tensor_parallel_example.py
"""

plugin_lib_path = "/root/.pyenv/versions/3.10.14/lib/python3.10/site-packages/tensorrt_llm/libs/libnvinfer_plugin_tensorrt_llm.so"
try:
-    ctypes.CDLL("/root/.pyenv/versions/3.10.14/lib/python3.10/site-packages/tensorrt_llm/libs/libnvinfer_plugin_tensorrt_llm.so")
+    ctypes.CDLL(
+        "/root/.pyenv/versions/3.10.14/lib/python3.10/site-packages/tensorrt_llm/libs/libnvinfer_plugin_tensorrt_llm.so"
+    )
    print("plugin loaded sucessfully")
except OSError as e:
    print(f"unsuccessful load : {e}")
logger = trt.Logger(trt.Logger.VERBOSE)
-trt.init_libnvinfer_plugins(None, '')
-#-[p;Iterate over all registered plugin creators
+trt.init_libnvinfer_plugins(None, "")
+# -[p;Iterate over all registered plugin creators
plugin_registry = trt.get_plugin_registry()
for plugin_creator in plugin_registry.plugin_creator_list:
-    print(f"Plugin Name: {plugin_creator.name}, Namespace: {plugin_creator.plugin_namespace}, Version: {plugin_creator.plugin_version}")
+    print(
+        f"Plugin Name: {plugin_creator.name}, Namespace: {plugin_creator.plugin_namespace}, Version: {plugin_creator.plugin_version}"
+    )


@dynamo_tensorrt_converter(torch.ops._c10d_functional.all_gather_into_tensor.default)
def insert_gather_op(
    ctx: ConversionContext,
    target: Target,
    args: Tuple[Argument, ...],
    kwargs: Dict[str, Argument],
-    name: str,    
+    name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
    plug_inputs = [args[0]]
    allgather_plg_creator = trt.get_plugin_registry().get_plugin_creator(
        "AllGather", "1", "tensorrt_llm"
    )
    assert allgather_plg_creator is not None
    world_size = dist.get_world_size()
    group = list(range(world_size))
-    group = trt.PluginField("group", np.array(group, dtype=np.int32), trt.PluginFieldType.INT32)
+    group = trt.PluginField(
+        "group", np.array(group, dtype=np.int32), trt.PluginFieldType.INT32
+    )
    p_dtype = trt.float16
    pf_type = trt.PluginField(
        "type_id", np.array([int(p_dtype)], np.int32), trt.PluginFieldType.INT32
    )
    pfc = trt.PluginFieldCollection([group, pf_type])

@apbose apbose force-pushed the nccl_ops_multi_gpu branch 5 times, most recently from 8015490 to a27b719 Compare October 25, 2024 00:25
@apbose apbose marked this pull request as ready for review October 25, 2024 00:26
@apbose apbose requested review from narendasan and peri044 October 25, 2024 00:49
logger.info(f"plugin loaded successfully")
except OSError as e:
logger.info(f"unsuccessful load : {e}")
trt.init_libnvinfer_plugins(None, "")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you need these lines as well?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think these lines are required actually. Just tested the code without these lines and having "import tensorrt_llm" should be fine to have the plugins with namespace as tensorrt_llm to be loaded.

logger.info(f"unsuccessful load : {e}")
trt.init_libnvinfer_plugins(None, "")
# Iterate over all registered plugin creators
plugin_registry = trt.get_plugin_registry()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this just for debugging purposes?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes to see if the the plugins with "tensorrt_llm" namespace have been loaded properly or not

"AllGather", "1", "tensorrt_llm"
)
assert allgather_plg_creator is not None
world_size = dist.get_world_size()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How might the converter get this info if it was in library?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not clear what is meant by library here? You mean the aten_ops_converters.py? Generally the converter should get this info when the distributed environment is initialized. It is implicitly done when using torhrun but we explicitly initialize this in the initialize_distributed_env()

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok so we dont need a dist object? can we use that version here?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes we could if we maintain a global variable for it and use that in the file. But the dist object would be required for initialization

Copy link
Collaborator

@narendasan narendasan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did you verify that numerical results are correct here?

@apbose
Copy link
Collaborator Author

apbose commented Oct 25, 2024

Yes @narendasan , the numerical results come out to be correct for this example and the llama3 within 0.01 error threshold

@apbose apbose force-pushed the nccl_ops_multi_gpu branch from a27b719 to b6f5980 Compare November 8, 2024 01:19
group = trt.PluginField(
"group", np.array(group, dtype=np.int32), trt.PluginFieldType.INT32
)
p_dtype = trt.float16
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do these kernels only support FP16?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No they can support FP32 too

logger = logging.getLogger(__name__)


def custom_fused_all_gather_op(args0, args1, args2):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Lets call this something like tensorrt_fused_nccl_all_gather or something

f"Plugin Name: {plugin_creator.name}, Namespace: {plugin_creator.plugin_namespace}, Version: {plugin_creator.plugin_version}"
)

@dynamo_tensorrt_converter(custom_fused_all_gather_op)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We want to start thinking about how these might get added as actual converters like how we support quantization. I think the global variable dependency is a issue. How might we work around that?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes pulling in the global variable assuming that the environment variable is set and initialized in the initialization part can be done, instead of using the dist package

# Initialization
initialize_distributed_env()
# create a device mesh based on the given world_size.
_world_size = int(os.environ["WORLD_SIZE"])
Copy link
Collaborator

@narendasan narendasan Nov 8, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Things like this I am ok pulling in "globally", since we can assume the env variable is set and presumably this is what people are doing aready

@apbose apbose force-pushed the nccl_ops_multi_gpu branch from 38335b9 to 06fb7a8 Compare December 21, 2024 00:24
@apbose apbose marked this pull request as ready for review December 21, 2024 00:24
@apbose apbose force-pushed the nccl_ops_multi_gpu branch from 06fb7a8 to 509d917 Compare December 21, 2024 00:37
@apbose apbose force-pushed the nccl_ops_multi_gpu branch from 6ffc284 to e96ce78 Compare December 21, 2024 00:46

logger.debug("Post-AOT Autograd graph:\n" + str(gm.graph))

gm = post_lowering(gm, settings)

logger.debug("Lowered Input graph:\n " + str(gm.graph))

complex_nodes = find_complex_nodes(gm)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why isnt this part of lowering?

Copy link
Collaborator Author

@apbose apbose Dec 23, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Need to get the complex_nodes before I do the lowering pass of replace_complex_placeholder_to_tuple(). Can put this in replace_complex_placeholder_to_tuple lowering pass, but since its a util function and pertaining more to modify_complex_nodes I put this in utils

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I mean seems to me these 3-4 lines are a pass that can be added to lowering

@@ -3590,3 +3592,76 @@ def aten_ops_full(
fill_value=args[1],
dtype=kwargs.get("dtype", None),
)


try:
Copy link
Collaborator

@narendasan narendasan Dec 23, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Turn this into a utility function to load trtllm / plugins lib, return bool on success and use that to condition the converter

counter = 0
strategy = AllReduceStrategy.NCCL
config = AllReduceConfig(0)
_world_size = os.environ.get("WORLD_SIZE")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does WORLD_SIZE get baked into the engine? if i load from serialized do i need the env variable?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NoI dont think so. This relies on the initialize_distributed_env() function in tensor_parallel_dist_env.py to do that. In torchrun command it would implicitly do it, but since we do mpirun for the nccl commands TRT-LLM support we need to initialize the variables

@apbose apbose force-pushed the nccl_ops_multi_gpu branch from e96ce78 to d161946 Compare December 23, 2024 18:40
Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are some changes that do not conform to Python style guidelines:

--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/backend/backends.py	2024-12-23 18:40:27.812736+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/backend/backends.py	2024-12-23 18:40:48.084051+00:00
@@ -133,11 +133,11 @@
            gm = post_lowering(gm, settings)

            logger.debug("Lowered Input graph:\n " + str(gm.graph))

            complex_nodes = find_complex_nodes(gm)
-            if (complex_nodes):
+            if complex_nodes:
                replace_complex_placeholder_to_tuple(gm, complexInputIndices)
                modify_complex_nodes(gm, complex_nodes)

            torchtrt_inputs = prepare_inputs(
                torch_inputs, disable_memory_format_check=True

@narendasan
Copy link
Collaborator

@apbose run the linter

os.environ["MASTER_PORT"] = str(port)
# Note this will not work in the initialization here
# You would need to set it externally as a user
os.environ["trtllm_env"] = (
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be all caps and be something like TRTLLM_PLUGINS_PATH


logger.debug("Post-AOT Autograd graph:\n" + str(gm.graph))

gm = post_lowering(gm, settings)

logger.debug("Lowered Input graph:\n " + str(gm.graph))

complex_nodes = find_complex_nodes(gm)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I mean seems to me these 3-4 lines are a pass that can be added to lowering

@apbose apbose force-pushed the nccl_ops_multi_gpu branch 4 times, most recently from a18ba8b to 3148697 Compare December 23, 2024 21:53
@apbose apbose force-pushed the nccl_ops_multi_gpu branch from 3148697 to b77a971 Compare December 23, 2024 23:30
visited_nodes.add(node)
update_node_meta(node, fake_mode)
for user in node.users:
if (
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we have this special case?

Copy link
Collaborator Author

@apbose apbose Dec 25, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is a terminating case for the Llama model complex placeholder nodes. The model is like complex placeholder->reshape->slice->complex mul , we need the meta data for reshape and slice to be amended, stopping at mul node (we are removing the complex mul to a custom torchTRT mul later in modify_reshape_complex_nodes)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems like it is too broad, either we should denote that this function should only be used as a helper for the complex lowering or we should make sure that it only will run against complex placeholder->reshape->slice->complex mul. Right now propogate_shape_change makes it seem like a generic util. We could call it something like _propogate_complex_num_shape_change

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Modified it

@github-actions github-actions bot added the component: build system Issues re: Build system label Dec 30, 2024
@apbose apbose force-pushed the nccl_ops_multi_gpu branch from 7fbb857 to 633f0f2 Compare December 31, 2024 21:28
Signed-off-by: Naren Dasan <naren@narendasan.com>
@apbose apbose force-pushed the nccl_ops_multi_gpu branch from 633f0f2 to 45b28b7 Compare December 31, 2024 21:30
@narendasan narendasan merged commit c636b39 into main Jan 3, 2025
6 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
cla signed component: api [Python] Issues re: Python API component: build system Issues re: Build system component: conversion Issues re: Conversion stage component: converters Issues re: Specific op converters component: dynamo Issues relating to the `torch.compile` or `torch._dynamo.export` paths component: lowering Issues re: The lowering / preprocessing passes component: torch_compile
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants