-
Notifications
You must be signed in to change notification settings - Fork 198
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[llama] Use horizontal fusion trick from Attention for FeedForward #606
Comments
This comment was marked as resolved.
This comment was marked as resolved.
How would you account for
Something like this could work: class FusedOperation(nn.Module):
def __init__(self, input_dim, hidden_dim, output_dim):
super().__init__()
self.fused_weight = nn.Parameter(torch.randn(output_dim, input_dim * 2))
self.bias = nn.Parameter(torch.zeros(output_dim))
def forward(self, x):
# Split the fused weight into two parts
w1_w3, w2 = self.fused_weight.chunk(2, dim=1)
# Compute the fused operation
hidden = F.silu(x @ w1_w3[:, :x.size(-1)].t()) * (x @ w1_w3[:, x.size(-1):].t())
return hidden @ w2.t() + self.bias Or do you have a simpler alternative in mind? |
@sayakpaul Oh, I mean
As in, just fuse w1 and w3 not all of w1, w2 and w3. Similar to how wqkv fuses wq, wk and wv, but leaves the output projections (wo) alone. So more specifically
Right now
will cause 3 calls to an nn.Linear, but with the above change it's 2 calls and also Essentially you stack w1 and w3 horizontally like
instead of
but you can split the result of the former (and do so without causing a copy, because striding). |
…ytorch#606) * Automatically identify cuda from nvidia-smi in install-requirements * Update README.md --------- Co-authored-by: Michael Gschwind <61328285+mikekgfb@users.noreply.github.com>
* executable README * fix title of CI workflow * markup commands in markdown * extend the markup-markdown language * Automatically identify cuda from nvidia-smi in install-requirements (pytorch#606) * Automatically identify cuda from nvidia-smi in install-requirements * Update README.md --------- Co-authored-by: Michael Gschwind <61328285+mikekgfb@users.noreply.github.com> * Unbreak zero-temperature sampling (pytorch#599) Fixes pytorch#581. * Improve process README * [retake] Add sentencepiece tokenizer (pytorch#626) * Add sentencepiece tokenizer Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: * Add white space Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: * Handle white space: Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: * Handle control ids Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: * More cleanup Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: * Lint Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: * Use unique_ptr Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: * Use a larger runner Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: * Debug Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: * Debug Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: * Cleanup * Update install_utils.sh to use python3 instead of python (pytorch#636) As titled. On some devices `python` and `python3` are pointing to different environments so good to unify them. * Fix quantization doc to specify dytpe limitation on a8w4dq (pytorch#629) Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: Co-authored-by: Kimish Patel <kimishpatel@fb.com> * add desktop.json (pytorch#622) * add desktop.json * add fast * remove embedding * improvements * update readme from doc branch * tab/spc * fix errors in updown language * fix errors in updown language, and [skip]: begin/end * fix errors in updown language, and [skip]: begin/end * a storied run * stories run on readme instructions does not need HF token * increase timeout * check for hang un hf_login * executable README improvements * typo * typo --------- Co-authored-by: Ian Barber <ian.barber@gmail.com> Co-authored-by: Scott Wolchok <swolchok@meta.com> Co-authored-by: Mengwei Liu <larryliu0820@users.noreply.github.com> Co-authored-by: Kimish Patel <kimishpatel@fb.com> Co-authored-by: Scott Roy <161522778+metascroy@users.noreply.github.com>
For the Attention module we can concatenate the weights and do one instead of three GEMMs for the input to gain a speedup, because each GEMM will be applied to the same input.
ao/torchao/_models/llama/model.py
Lines 220 to 225 in 22d6f97
and
ao/torchao/_models/llama/model.py
Lines 230 to 231 in 22d6f97
I suspect we can do the exact same thing for FeedFoward
ao/torchao/_models/llama/model.py
Lines 262 to 263 in 22d6f97
Task:
Implement the above trick and rerun the benchmarks to show gains. If you don't have access to an A100, another (ideally similar) GPU is fine too as a proxy. Also, if you can, try to confirm via a trace that indeed two GEMMs have been turned into one.
The text was updated successfully, but these errors were encountered: