Skip to content

Commit

Permalink
using nccl ops from TRT-LLM namespace
Browse files Browse the repository at this point in the history
  • Loading branch information
apbose committed Oct 24, 2024
1 parent 6d40ff1 commit 8015490
Show file tree
Hide file tree
Showing 5 changed files with 266 additions and 8 deletions.
4 changes: 4 additions & 0 deletions examples/distributed_inference/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,7 @@ See the examples started with `data_parallel` for more details.
Here we use torch.distributed as an example, but compilation with tensor parallelism is agnostic to the implementation framework as long as the module is properly sharded.

torchrun --nproc_per_node=2 tensor_parallel_llama2.py

3. Tensor parallel distributed inference using nccl ops plugin

mpirun -n 2 --allow-run-as-root python tensor_parallel_simple_example.py
5 changes: 4 additions & 1 deletion examples/distributed_inference/requirement.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
accelerate
transformers
diffusers
diffusers
site
# Install tensorrt-llm without its dependencies (use the command separately). pip install tensorrt-llm --no-deps
tensorrt-llm
191 changes: 184 additions & 7 deletions examples/distributed_inference/tensor_parallel_simple_example.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,17 @@
import ctypes
import logging
import os
import site
import sys
import time
from enum import IntEnum, IntFlag, auto
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union

import numpy as np
import tensorrt as trt
import tensorrt_llm
import torch
import torch.distributed as dist
import torch.nn as nn
import torch_tensorrt
from torch.distributed._tensor import Shard
Expand All @@ -12,6 +21,181 @@
RowwiseParallel,
parallelize_module,
)
from torch.fx import GraphModule, Node
from torch.fx.node import Argument, Target
from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext
from torch_tensorrt.dynamo.conversion._ConverterRegistry import (
dynamo_tensorrt_converter,
)
from torch_tensorrt.dynamo.lowering.passes.fuse_distributed_ops import (
custom_fused_all_gather_op,
custom_fused_reduce_scatter_op,
)
from torch_tensorrt.dynamo.types import TRTTensor
from torch_tensorrt.fx.converters.converter_utils import set_layer_name


# This is required for env initialization since we use mpirun
def initialize(rank=0, world_size=1, port=29500):
local_rank = int(
os.environ.get("OMPI_COMM_WORLD_LOCAL_RANK", rank % torch.cuda.device_count())
)
world_size = int(os.environ.get("OMPI_COMM_WORLD_SIZE", world_size))

# Set up environment variable to run with mpirun
os.environ["RANK"] = str(local_rank)
os.environ["WORLD_SIZE"] = str(world_size)
os.environ["MASTER_ADDR"] = "127.0.0.1"
os.environ["MASTER_PORT"] = str(port)

# Necessary to assign a device to each rank.
torch.cuda.set_device(local_rank)

# We use nccl backend
dist.init_process_group("nccl")

# set a manual seed for reproducibility
torch.manual_seed(1111)

return local_rank, world_size


initialize()
# create a device mesh based on the given world_size.
_world_size = int(os.environ["WORLD_SIZE"])

device_mesh = init_device_mesh(device_type="cuda", mesh_shape=(_world_size,))
_rank = device_mesh.get_rank()
device_id = _rank % torch.cuda.device_count() # Ensure each rank gets a unique device
torch.cuda.set_device(device_id)


logger = logging.getLogger()
logger.setLevel(logging.INFO)
fh = logging.FileHandler(f"./tensor_parallel_simple_example_{_rank}.log", mode="w")
fh.setLevel(logging.INFO)
logger.addHandler(fh)


# TensorRT NCCL plugins
tensorrt_llm_lib_path = tensorrt_llm.__file__
plugin_lib_path = tensorrt_llm_lib_path + "/libs/libnvinfer_plugin_tensorrt_llm.so"
try:
ctypes.CDLL(plugin_lib_path)
logger.info(f"plugin loaded successfully")
except OSError as e:
logger.info(f"unsuccessful load : {e}")
trt.init_libnvinfer_plugins(None, "")
# Iterate over all registered plugin creators
plugin_registry = trt.get_plugin_registry()
for plugin_creator in plugin_registry.plugin_creator_list:
logger.info(
f"Plugin Name: {plugin_creator.name}, Namespace: {plugin_creator.plugin_namespace}, Version: {plugin_creator.plugin_version}"
)


# class for AllReduce
class AllReduceStrategy(IntEnum):
"""Warning: actual definition is in kernels/customAllReduceKernels.h.
They must be kept in sync.
"""

NCCL = 0
ONESHOT = 1
TWOSHOT = 2
AUTO = 3


class AllReduceConfig(IntFlag):
"""Warning: actual definition is in kernels/customAllReduceKernels.h.
They must be kept in sync
"""

USE_MEMCPY = auto()
PUSH_MODE = auto()


@dynamo_tensorrt_converter(custom_fused_all_gather_op)
def insert_nccl_gather_op(
ctx: ConversionContext,
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
plug_inputs = [args[0]]
allgather_plg_creator = trt.get_plugin_registry().get_plugin_creator(
"AllGather", "1", "tensorrt_llm"
)
assert allgather_plg_creator is not None
world_size = dist.get_world_size()
group = list(range(world_size))
group = trt.PluginField(
"group", np.array(group, dtype=np.int32), trt.PluginFieldType.INT32
)
p_dtype = trt.float16
pf_type = trt.PluginField(
"type_id", np.array([int(p_dtype)], np.int32), trt.PluginFieldType.INT32
)
pfc = trt.PluginFieldCollection([group, pf_type])
allgather = allgather_plg_creator.create_plugin("allgather", pfc)
layer = ctx.net.add_plugin_v2(plug_inputs, allgather)
set_layer_name(layer, target, name)
return layer.get_output(0)


@dynamo_tensorrt_converter(custom_fused_reduce_scatter_op)
def insert_nccl_reduce_scatter_plugin(
ctx: ConversionContext,
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
plug_inputs = [args[0]]
allreduce_plg_creator = trt.get_plugin_registry().get_plugin_creator(
"ReduceScatter", "1", "tensorrt_llm"
)

assert allreduce_plg_creator is not None

counter = 0
strategy = AllReduceStrategy.NCCL
config = AllReduceConfig(0)

world_size = dist.get_world_size()
group = list(range(world_size))
group = trt.PluginField(
"group", np.array(group, dtype=np.int32), trt.PluginFieldType.INT32
)

p_dtype = trt.float16
pf_dtype = trt.PluginField(
"type_id", np.array([int(p_dtype)], np.int32), trt.PluginFieldType.INT32
)
pfc = [group, pf_dtype]
p_strategy = trt.PluginField(
"strategy", np.array([int(strategy)], np.int8), trt.PluginFieldType.INT8
)
pfc.append(p_strategy)
p_config = trt.PluginField(
"config", np.array([int(config)], np.int8), trt.PluginFieldType.INT8
)
pfc.append(p_config)
p_counter = trt.PluginField(
"counter", np.array([counter], np.int32), trt.PluginFieldType.INT32
)
pfc.append(p_counter)

pfc = trt.PluginFieldCollection(pfc)
ar_plug = allreduce_plg_creator.create_plugin("allreduce", pfc)

layer = ctx.net.add_plugin_v2(plug_inputs, ar_plug)
set_layer_name(layer, target, name)
return layer.get_output(0)


"""
This example copies some code from https://github.com/pytorch/examples/blob/main/distributed/tensor_parallelism/tensor_parallel_example.py
Expand All @@ -36,13 +220,6 @@ def forward(self, x):
return x


# create a device mesh based on the given world_size.
_world_size = int(os.environ["WORLD_SIZE"])

device_mesh = init_device_mesh(device_type="cuda", mesh_shape=(_world_size,))
_rank = device_mesh.get_rank()


print(f"Starting PyTorch TP example on rank {_rank}.")
assert (
_world_size % 2 == 0
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from .accumulate_fp32_matmul import accumulate_fp32_matmul
from .constant_folding import constant_fold
from .fuse_distributed_ops import fuse_distributed_ops
from .fuse_prims_broadcast import fuse_prims_broadcast
from .lower_linear import lower_linear
from .lower_scaled_dot_product_attention import lower_scaled_dot_product_attention
Expand All @@ -26,6 +27,7 @@
lower_scaled_dot_product_attention,
lower_linear,
fuse_prims_broadcast,
fuse_distributed_ops,
replace_max_pool_with_indices,
replace_full_like_with_full,
view_to_reshape,
Expand Down
72 changes: 72 additions & 0 deletions py/torch_tensorrt/dynamo/lowering/passes/fuse_distributed_ops.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
import logging
from typing import Sequence

import torch
from torch_tensorrt.dynamo._settings import CompilationSettings

# dead-code elimination, linting, and recompilation for graph, in-place
from torch_tensorrt.dynamo.lowering.passes.pass_utils import (
clean_up_graph_after_modifications,
)

logger = logging.getLogger(__name__)


def custom_fused_all_gather_op(args0, args1, args2):
return torch.ops._c10d_functional.wait_tensor.default(
torch.ops._c10d_functional.all_gather_into_tensor.default(args0, args1, args2)
)


def custom_fused_reduce_scatter_op(args0, args1, args2, args3):
return torch.ops._c10d_functional.wait_tensor.default(
torch.ops._c10d_functional.reduce_scatter_tensor.default(
args0, args1, args2, args3
)
)


def fuse_distributed_ops(
gm: torch.fx.GraphModule, settings: CompilationSettings
) -> torch.fx.GraphModule:
modified_graph = False
for node in gm.graph.nodes:
if (
node.target
in (
torch.ops._c10d_functional.all_gather_into_tensor.default,
torch.ops._c10d_functional.reduce_scatter_tensor.default,
)
and len(node.users) == 1
and list(node.users)[0].target
== torch.ops._c10d_functional.wait_tensor.default
):
wait_tensor_node = list(node.users)[0]
fused_op = None
if node.target == torch.ops._c10d_functional.all_gather_into_tensor.default:
fused_op = custom_fused_all_gather_op
fused_op_args = (node.args[0], node.args[1], node.args[2])
else:
fused_op = custom_fused_reduce_scatter_op
fused_op_args = (node.args[0], node.args[1], node.args[2], node.args[3])
with gm.graph.inserting_after(wait_tensor_node):
fused_node = gm.graph.create_node(
op="call_function",
target=fused_op, # Define your custom fused function
args=fused_op_args,
)

wait_tensor_node.replace_all_uses_with(fused_node)
fused_node.meta.update(node.meta)
modified_graph = True
gm.graph.erase_node(wait_tensor_node)
gm.graph.erase_node(node)

# If graph was modified, clean it up
if modified_graph:
gm = clean_up_graph_after_modifications(gm)
logger.debug(
f"Graph after fusing wait_tensor and distributed op tensor:\n{gm.graph}"
)

return gm

0 comments on commit 8015490

Please sign in to comment.