Skip to content

Commit

Permalink
remove custom op for all_reduce
Browse files Browse the repository at this point in the history
  • Loading branch information
SageMoore committed Aug 19, 2024
1 parent a7be101 commit 7ab9b00
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 20 deletions.
17 changes: 1 addition & 16 deletions vllm/distributed/communication_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,27 +5,12 @@

from .parallel_state import get_tp_group

torch.library.define("vllm::tensor_model_parallel_all_reduce",
("(Tensor(a!) input_ ) -> Tensor"))


@torch.library.register_kernel("vllm::tensor_model_parallel_all_reduce",
("cuda", "cpu"))
def _tensor_model_parallel_all_reduce(input_: torch.Tensor) -> torch.Tensor:
def tensor_model_parallel_all_reduce(input_: torch.Tensor) -> torch.Tensor:
"""All-reduce the input tensor across model parallel group."""
return get_tp_group().all_reduce(input_)


@torch.library.register_fake("vllm::tensor_model_parallel_all_reduce")
def _tensor_model_parallel_all_reduce_fake(
input_: torch.Tensor) -> torch.Tensor:
return input_


def tensor_model_parallel_all_reduce(input_: torch.Tensor) -> torch.Tensor:
return torch.ops.vllm.tensor_model_parallel_all_reduce(input_)


def tensor_model_parallel_all_gather(input_: torch.Tensor,
dim: int = -1) -> torch.Tensor:
"""All-gather the input tensor across model parallel group."""
Expand Down
8 changes: 4 additions & 4 deletions vllm/distributed/parallel_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,17 +272,17 @@ def all_reduce(self, input_: torch.Tensor) -> torch.Tensor:

# Bypass the function if we are using only 1 GPU.
if self.world_size == 1:
return input_.clone()
return input_

# For TPUs, use TPU communicator.
tpu_comm = self.tpu_communicator
if tpu_comm is not None and not tpu_comm.disabled:
return tpu_comm.all_reduce(input_).clone()
return tpu_comm.all_reduce(input_)

if ca_comm is not None:
out = ca_comm.custom_all_reduce(input_)
if out is not None:
return out.clone()
return out
pynccl_comm = self.pynccl_comm
if (pynccl_comm is not None and not pynccl_comm.disabled):
pynccl_comm.all_reduce(input_)
Expand All @@ -291,7 +291,7 @@ def all_reduce(self, input_: torch.Tensor) -> torch.Tensor:
ipex.distributed.all_reduce(input_, group=self.device_group)
else:
torch.distributed.all_reduce(input_, group=self.device_group)
return input_.clone()
return input_

def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor:
world_size = self.world_size
Expand Down

0 comments on commit 7ab9b00

Please sign in to comment.