Skip to content

Commit

Permalink
changes to include the distributed operations in the aten_ops lib
Browse files Browse the repository at this point in the history
  • Loading branch information
apbose committed Dec 17, 2024
1 parent 6707c6f commit 38335b9
Show file tree
Hide file tree
Showing 6 changed files with 257 additions and 199 deletions.
64 changes: 64 additions & 0 deletions examples/distributed_inference/tensor_parallel_initialize_dist.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
import logging
import os
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union

import numpy as np
import tensorrt as trt
import tensorrt_llm
import torch
import torch.distributed as dist
from torch.distributed._tensor.device_mesh import init_device_mesh
from torch.fx.node import Argument, Target
from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext
from torch_tensorrt.dynamo.conversion._ConverterRegistry import (
dynamo_tensorrt_converter,
)
from torch_tensorrt.dynamo.lowering.passes.fuse_distributed_ops import (
tensorrt_fused_nccl_all_gather_op,
tensorrt_fused_nccl_reduce_scatter_op,
)
from torch_tensorrt.dynamo.types import TRTTensor
from torch_tensorrt.fx.converters.converter_utils import set_layer_name


def initialize_logger(rank, logger_file_name):
logger = logging.getLogger()
logger.setLevel(logging.INFO)
fh = logging.FileHandler(logger_file_name + f"_{rank}.log", mode="w")
fh.setLevel(logging.INFO)
logger.addHandler(fh)
return logger


# This is required for env initialization since we use mpirun
def initialize_distributed_env(logger_file_name, rank=0, world_size=1, port=29500):
local_rank = int(
os.environ.get("OMPI_COMM_WORLD_LOCAL_RANK", rank % torch.cuda.device_count())
)
world_size = int(os.environ.get("OMPI_COMM_WORLD_SIZE", world_size))

# Set up environment variable to run with mpirun
os.environ["RANK"] = str(local_rank)
os.environ["WORLD_SIZE"] = str(world_size)
os.environ["MASTER_ADDR"] = "127.0.0.1"
os.environ["MASTER_PORT"] = str(port)

# Necessary to assign a device to each rank.
torch.cuda.set_device(local_rank)

# We use nccl backend
dist.init_process_group("nccl")

# set a manual seed for reproducibility
torch.manual_seed(1111)

device_mesh = init_device_mesh(device_type="cuda", mesh_shape=(world_size,))
rank = device_mesh.get_rank()
assert rank == local_rank
logger = initialize_logger(rank, logger_file_name)
device_id = (
rank % torch.cuda.device_count()
) # Ensure each rank gets a unique device
torch.cuda.set_device(device_id)

return device_mesh, world_size, rank, logger
9 changes: 7 additions & 2 deletions examples/distributed_inference/tensor_parallel_llama3.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,20 @@
import torch
import torch_tensorrt
from llama3_model import ModelArgs, ParallelTransformer
from tensor_parallel_nccl_ops import register_nccl_ops
from torch.distributed._composable.fsdp import MixedPrecisionPolicy
from torch.distributed._composable.fsdp.fully_shard import fully_shard
from torch.distributed._tensor import Replicate, Shard
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
checkpoint_wrapper,
)

device_mesh, _world_size, _rank, logger = register_nccl_ops("./tensor_parallel_llama3")
from TensorRT.examples.distributed_inference.tensor_parallel_initialize_dist import (
initialize_distributed_env,
)

device_mesh, _world_size, _rank, logger = initialize_distributed_env(
"./tensor_parallel_llama3"
)

logger.info(f"Starting PyTorch TP example on rank {_rank}.")
assert (
Expand Down
197 changes: 0 additions & 197 deletions examples/distributed_inference/tensor_parallel_nccl_ops.py

This file was deleted.

59 changes: 59 additions & 0 deletions py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
# mypy: disallow-untyped-decorators=False

import ctypes
import logging
import operator
from typing import Any, Callable, Dict, Optional, Sequence, Tuple, Union

import numpy as np
import tensorrt as trt
import torch
from torch.fx.node import Argument, Node, Target
from torch_tensorrt.dynamo._settings import CompilationSettings
Expand All @@ -19,6 +21,11 @@
enforce_tensor_types,
get_positive_dim,
is_only_operator_on_placeholder,
plugin_lib_path,
)
from torch_tensorrt.dynamo.lowering.passes.fuse_distributed_ops import (
tensorrt_fused_nccl_all_gather_op,
tensorrt_fused_nccl_reduce_scatter_op,
)
from torch_tensorrt.dynamo.types import TRTTensor

Expand Down Expand Up @@ -3558,3 +3565,55 @@ def aten_ops_full(
fill_value=args[1],
dtype=kwargs.get("dtype", None),
)


try:
import tensorrt_llm as trt_llm
except (ImportError, AssertionError) as e:
_LOGGER.warning("tensorrt_llm is not installed. Please install tensorrt_llm", e)
# note this is for Linux only
plugin_lib_path = plugin_lib_path()
handle = ctypes.CDLL(plugin_lib_path)
try:
handle.initTrtLlmPlugins.argtypes = [ctypes.c_void_p, ctypes.c_char_p]
handle.initTrtLlmPlugins.restype = ctypes.c_bool
except AttributeError as e_1:
_LOGGER.warning("TensorRT-LLM Plugin is unavailable")
try:
TRT_LLM_PLUGIN_NAMESPACE = "tensorrt_llm"
assert handle.initTrtLlmPlugins(None, TRT_LLM_PLUGIN_NAMESPACE.encode("utf-8"))
except Exception as e_2:
_LOGGER.warning("Exception happened in initializing TensorRT-LLM plugins", e)
else:

@dynamo_tensorrt_converter(tensorrt_fused_nccl_all_gather_op)
def insert_nccl_gather_op(
ctx: ConversionContext,
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
return impl.distributed.gather_op(
ctx,
target,
SourceIR.ATEN,
name,
args[0],
)

@dynamo_tensorrt_converter(tensorrt_fused_nccl_reduce_scatter_op)
def insert_nccl_reduce_scatter_plugin(
ctx: ConversionContext,
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
return impl.distributed.gather_op(
ctx,
target,
SourceIR.ATEN,
name,
args[0],
)
7 changes: 7 additions & 0 deletions py/torch_tensorrt/dynamo/conversion/converter_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import collections
import functools
import logging
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union, overload

import numpy as np
Expand Down Expand Up @@ -913,3 +914,9 @@ def set_layer_name(
else f"{source_ir}_ops.{target.__name__}"
)
layer.name = f"[{layer.type.name}]-[{target_name}]-[{name}]"


def plugin_lib_path() -> str:
project_dir = Path(__file__).parent.parent.parent.parent.absolute()
dyn_lib = "libnvinfer_plugin_tensorrt_llm.so"
return str(project_dir.joinpath("libs", dyn_lib))
Loading

0 comments on commit 38335b9

Please sign in to comment.