-
Notifications
You must be signed in to change notification settings - Fork 355
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
changes to include the distributed operations in the aten_ops lib
- Loading branch information
Showing
6 changed files
with
257 additions
and
199 deletions.
There are no files selected for viewing
64 changes: 64 additions & 0 deletions
64
examples/distributed_inference/tensor_parallel_initialize_dist.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
import logging | ||
import os | ||
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union | ||
|
||
import numpy as np | ||
import tensorrt as trt | ||
import tensorrt_llm | ||
import torch | ||
import torch.distributed as dist | ||
from torch.distributed._tensor.device_mesh import init_device_mesh | ||
from torch.fx.node import Argument, Target | ||
from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext | ||
from torch_tensorrt.dynamo.conversion._ConverterRegistry import ( | ||
dynamo_tensorrt_converter, | ||
) | ||
from torch_tensorrt.dynamo.lowering.passes.fuse_distributed_ops import ( | ||
tensorrt_fused_nccl_all_gather_op, | ||
tensorrt_fused_nccl_reduce_scatter_op, | ||
) | ||
from torch_tensorrt.dynamo.types import TRTTensor | ||
from torch_tensorrt.fx.converters.converter_utils import set_layer_name | ||
|
||
|
||
def initialize_logger(rank, logger_file_name): | ||
logger = logging.getLogger() | ||
logger.setLevel(logging.INFO) | ||
fh = logging.FileHandler(logger_file_name + f"_{rank}.log", mode="w") | ||
fh.setLevel(logging.INFO) | ||
logger.addHandler(fh) | ||
return logger | ||
|
||
|
||
# This is required for env initialization since we use mpirun | ||
def initialize_distributed_env(logger_file_name, rank=0, world_size=1, port=29500): | ||
local_rank = int( | ||
os.environ.get("OMPI_COMM_WORLD_LOCAL_RANK", rank % torch.cuda.device_count()) | ||
) | ||
world_size = int(os.environ.get("OMPI_COMM_WORLD_SIZE", world_size)) | ||
|
||
# Set up environment variable to run with mpirun | ||
os.environ["RANK"] = str(local_rank) | ||
os.environ["WORLD_SIZE"] = str(world_size) | ||
os.environ["MASTER_ADDR"] = "127.0.0.1" | ||
os.environ["MASTER_PORT"] = str(port) | ||
|
||
# Necessary to assign a device to each rank. | ||
torch.cuda.set_device(local_rank) | ||
|
||
# We use nccl backend | ||
dist.init_process_group("nccl") | ||
|
||
# set a manual seed for reproducibility | ||
torch.manual_seed(1111) | ||
|
||
device_mesh = init_device_mesh(device_type="cuda", mesh_shape=(world_size,)) | ||
rank = device_mesh.get_rank() | ||
assert rank == local_rank | ||
logger = initialize_logger(rank, logger_file_name) | ||
device_id = ( | ||
rank % torch.cuda.device_count() | ||
) # Ensure each rank gets a unique device | ||
torch.cuda.set_device(device_id) | ||
|
||
return device_mesh, world_size, rank, logger |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
197 changes: 0 additions & 197 deletions
197
examples/distributed_inference/tensor_parallel_nccl_ops.py
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.