-
-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Performance]: Empirical Measurement of how to broadcast python object in vLLM #4440
Comments
Note: the memory alignment feature depends on the fact that s = [1] * 5
import pickle
d = pickle.dumps(s)
d = d + b"whatever"
import pickletools
pickletools.dis(d) Output:
There is a STOP code in the end. Therefore it is safe to pad/align the pickled data. |
The optimization makes sense to me (nice writeup!) |
The result of pickle.dump does not always seem to be aligned to 4 bytes. |
It does not matter though. The point is it is self-ended, so we can pad with arbitary bytes. Padding does not affect unpickle. |
Very cool! |
There are two |
the performance of broadcasting python object is largely resolved by #5399 , in single node case. |
Proposal to improve performance
When we use tensor parallel in vLLM, the driver worker need to broadcast some metadata to all workers, such as the input, the lora requests, etc. This functionality is currently implemented in:
vllm/vllm/distributed/communication_op.py
Line 143 in 9c7306a
In essence, it uses
torch.distributed.broadcast_object_list
to broadcast a Python object. This function has many overhead. The overall procedure is:There are three layers of overhead:
Current vLLM implementation packs the data in a list of size one, thus overhead 2 is eliminated:
vllm/vllm/distributed/communication_op.py
Lines 173 to 175 in 9c7306a
To remove overhead 1, we can use CPU operation to broadcast this kind of metadata.
In addition, if we can know the rough size of picked object, we can remove overhead 3 as well. Only one broadcast is required, which is the optimal case for broadcasting a Python object.
I have wrote some benchmark code in https://gist.github.com/youkaichao/b33fcd70286eb45a4a2d5a6dc32d096b and the result is in https://docs.google.com/spreadsheets/d/1c9xgR0fGvm6SROfk7vrjwOZdYnKQk9oOafWK4_KgOyo/edit?usp=sharing .
The short conclusion is:
Report of performance regression
No response
Misc discussion on performance
No response
Your current environment (if you think it is necessary)
The text was updated successfully, but these errors were encountered: