Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PP comm optimization: replace send with partial send + allgather #6695

Merged
merged 12 commits into from
Aug 1, 2024

Conversation

aurickq
Copy link
Contributor

@aurickq aurickq commented Jul 23, 2024

Submitting on behalf of @zhisbug as he is on break. Communication optimization for pipeline parallelism, we observed 5% improvement in throughput for llama 3.1 405b.

BEFORE SUBMITTING, PLEASE READ THE CHECKLIST BELOW AND FILL IN THE DESCRIPTION ABOVE


PR Checklist (Click to Expand)

Thank you for your contribution to vLLM! Before submitting the pull request, please ensure the PR meets the following criteria. This helps vLLM maintain the code quality and improve the efficiency of the review process.

PR Title and Classification

Only specific types of PRs will be reviewed. The PR title is prefixed appropriately to indicate the type of change. Please use one of the following:

  • [Bugfix] for bug fixes.
  • [CI/Build] for build or continuous integration improvements.
  • [Doc] for documentation fixes and improvements.
  • [Model] for adding a new model or improving an existing model. Model name should appear in the title.
  • [Frontend] For changes on the vLLM frontend (e.g., OpenAI API server, LLM class, etc.)
  • [Kernel] for changes affecting CUDA kernels or other compute kernels.
  • [Core] for changes in the core vLLM logic (e.g., LLMEngine, AsyncLLMEngine, Scheduler, etc.)
  • [Hardware][Vendor] for hardware-specific changes. Vendor name should appear in the prefix (e.g., [Hardware][AMD]).
  • [Misc] for PRs that do not fit the above categories. Please use this sparingly.

Note: If the PR spans more than one category, please include all relevant prefixes.

Code Quality

The PR need to meet the following code quality standards:

  • We adhere to Google Python style guide and Google C++ style guide.
  • Pass all linter checks. Please use format.sh to format your code.
  • The code need to be well-documented to ensure future contributors can easily understand the code.
  • Include sufficient tests to ensure the project to stay correct and robust. This includes both unit tests and integration tests.
  • Please add documentation to docs/source/ if the PR modifies the user-facing behaviors of vLLM. It helps vLLM user understand and utilize the new features or changes.

Notes for Large Changes

Please keep the changes as concise as possible. For major architectural changes (>500 LOC excluding kernel/data/config/test), we would expect a GitHub issue (RFC) discussing the technical design and justification. Otherwise, we will tag it with rfc-required and might not go through the PR.

What to Expect for the Reviews

The goal of the vLLM team is to be a transparent reviewing machine. We would like to make the review process transparent and efficient and make sure no contributor feel confused or frustrated. However, the vLLM team is small, so we need to prioritize some PRs over others. Here is what you can expect from the review process:

  • After the PR is submitted, the PR will be assigned to a reviewer. Every reviewer will pick up the PRs based on their expertise and availability.
  • After the PR is assigned, the reviewer will provide status update every 2-3 days. If the PR is not reviewed within 7 days, please feel free to ping the reviewer or the vLLM team.
  • After the review, the reviewer will put an action-required label on the PR if there are changes required. The contributor should address the comments and ping the reviewer to re-review the PR.
  • Please respond to all comments within a reasonable time frame. If a comment isn't clear or you disagree with a suggestion, feel free to ask for clarification or discuss the suggestion.

Thank You

Finally, thank you for taking the time to read these guidelines and for your interest in contributing to vLLM. Your contributions make vLLM a great tool for everyone!

Copy link

👋 Hi! Thank you for contributing to the vLLM project.
Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which consists a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of default ones by unblocking the steps in your fast-check build on Buildkite UI.

Once the PR is approved and ready to go, please make sure to run full CI as it is required to merge (or just use auto-merge).

To run full CI, you can do one of these:

  • Comment /ready on the PR
  • Add ready label to the PR
  • Enable auto-merge.

🚀

Copy link
Collaborator

@comaniac comaniac left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR! I have few questions:

  1. Do you have benchmark results?
  2. At the first glance it seems always better to enable this mechanism. Have you observed performance regression with this PR in certain circumstances? Like when the tensor size is small the allgather overhead is not negligible (just guess)?

vllm/distributed/parallel_state.py Outdated Show resolved Hide resolved
vllm/distributed/parallel_state.py Outdated Show resolved Hide resolved
vllm/distributed/parallel_state.py Outdated Show resolved Hide resolved
Copy link
Collaborator

@tlrmchlsmth tlrmchlsmth left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is a good approach. This is really similar to the optimization from Section 4.1 of https://arxiv.org/abs/2104.04473, except that the sender pipeline stage will still do an AllReduce instead of a ReduceScatter. Is that right?

I think the ReduceScatter -> Send to the next pipeline stage -> AllGather is an optimization that we should explore along these lines in the future but it's going to require a lot more software engineering than this :)

vllm/model_executor/layers/sampler.py Outdated Show resolved Hide resolved
@andoorve
Copy link
Collaborator

LGTM from a first pass. Thanks for the contribution!

@aurickq
Copy link
Contributor Author

aurickq commented Jul 24, 2024

Thanks for the PR! I have few questions:

  1. Do you have benchmark results?
  2. At the first glance it seems always better to enable this mechanism. Have you observed performance regression with this PR in certain circumstances? Like when the tensor size is small the allgather overhead is not negligible (just guess)?

Re: benchmark results, in our tests using llama 3.1 405b (yellow=before, gray=after):
Screenshot 2024-07-24 at 11 21 47 AM

We haven't observed any performance regressions yet but also haven't tested very many scenarios.

@andoorve
Copy link
Collaborator

Hey! @youkaichao any blockers for merge here?

@zhisbug
Copy link
Collaborator

zhisbug commented Jul 29, 2024

This one should have better improvement when your inter-node bandwidth is not high.

Copy link
Collaborator

@tlrmchlsmth tlrmchlsmth left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does it make sense to add an integration test? vLLM does have some 4 GPU CI runs in buildkite, so it should be possible to add one. Otherwise this code won't be exercised and may break.

Other than that, LGTM

@andoorve
Copy link
Collaborator

Does it make sense to add an integration test? vLLM does have some 4 GPU CI runs in buildkite, so it should be possible to add one. Otherwise this code won't be exercised and may break.

+1. Should be just a matter of adding a parameter to existing pipeline tests

@youkaichao
Copy link
Member

Sorry for the delay, just notice this PR and the message.

The idea looks good to me, we use send --> allgather to replace send full tensor on all ranks.

Implementation wise, can we use this by default? As long as tensor.numel() % tp_size == 0, you can use it. send_tensor_dict is only called in pipeline parallel. And you don't need to create a full suit of _ENABLE_ALLGATHER_PIPELINE_COMM stuff. All you need, should be adding a all_gather_group argument to send_tensor_dict:

def send_tensor_dict(

    def send_tensor_dict(
        self,
        tensor_dict: Dict[str, Union[torch.Tensor, Any]],
        dst: Optional[int] = None,
+     all_gather_group=Optional[GroupCoordinator],
    )
...

            if tensor.is_cpu:
                # use metadata_group for CPU tensors
-                torch.distributed.send(tensor,
+                torch.distributed.send(tensor[slice],
                                       dst=self.ranks[dst],
                                       group=metadata_group)
+                # add allgather with `all_gather_group.cpu_group`
            else:
                # use group for GPU tensors
-                torch.distributed.send(tensor,
+                torch.distributed.send(tensor[slice],
                                       dst=self.ranks[dst],
                                       group=group)
+                # add allgather with `all_gather_group.device_group`

# change `recv_tensor_dict` accordingly

And update this line

get_pp_group().send_tensor_dict(output.tensors)

to get_pp_group().send_tensor_dict(output.tensors, all_gather_group=get_tp_group())

@aurickq
Copy link
Contributor Author

aurickq commented Aug 1, 2024

Sorry for the delay, just notice this PR and the message.

The idea looks good to me, we use send --> allgather to replace send full tensor on all ranks.

Implementation wise, can we use this by default? As long as tensor.numel() % tp_size == 0, you can use it. send_tensor_dict is only called in pipeline parallel. And you don't need to create a full suit of _ENABLE_ALLGATHER_PIPELINE_COMM stuff. All you need, should be adding a all_gather_group argument to send_tensor_dict:

def send_tensor_dict(

    def send_tensor_dict(
        self,
        tensor_dict: Dict[str, Union[torch.Tensor, Any]],
        dst: Optional[int] = None,
+     all_gather_group=Optional[GroupCoordinator],
    )
...

            if tensor.is_cpu:
                # use metadata_group for CPU tensors
-                torch.distributed.send(tensor,
+                torch.distributed.send(tensor[slice],
                                       dst=self.ranks[dst],
                                       group=metadata_group)
+                # add allgather with `all_gather_group.cpu_group`
            else:
                # use group for GPU tensors
-                torch.distributed.send(tensor,
+                torch.distributed.send(tensor[slice],
                                       dst=self.ranks[dst],
                                       group=group)
+                # add allgather with `all_gather_group.device_group`

# change `recv_tensor_dict` accordingly

And update this line

get_pp_group().send_tensor_dict(output.tensors)

to get_pp_group().send_tensor_dict(output.tensors, all_gather_group=get_tp_group())

Done and tested

@youkaichao youkaichao changed the title [Core] Add allgather comm for PP PP comm optimization: replace send with partial send + allgather Aug 1, 2024
Copy link
Member

@youkaichao youkaichao left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the great work!

@youkaichao youkaichao added the ready ONLY add when PR is ready to merge/full CI is needed label Aug 1, 2024
@youkaichao
Copy link
Member

merge as failing tests are unrelated.

@youkaichao youkaichao merged commit 0437492 into vllm-project:main Aug 1, 2024
71 of 76 checks passed
kylesayrs pushed a commit to neuralmagic/vllm that referenced this pull request Aug 17, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ready ONLY add when PR is ready to merge/full CI is needed
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants