Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[llama] Use horizontal fusion trick from Attention for FeedForward #606

Open
cpuhrsch opened this issue Aug 6, 2024 · 3 comments
Open
Labels
good first issue Good for newcomers

Comments

@cpuhrsch
Copy link
Contributor

cpuhrsch commented Aug 6, 2024

For the Attention module we can concatenate the weights and do one instead of three GEMMs for the input to gain a speedup, because each GEMM will be applied to the same input.

def load_hook(self, state_dict, prefix, *args):
if prefix + "wq.weight" in state_dict:
wq = state_dict.pop(prefix + "wq.weight")
wk = state_dict.pop(prefix + "wk.weight")
wv = state_dict.pop(prefix + "wv.weight")
state_dict[prefix + "wqkv.weight"] = torch.cat([wq, wk, wv])

and
kv_size = self.n_local_heads * self.head_dim
q, k, v = self.wqkv(x).split([self.dim, kv_size, kv_size], dim=-1)

I suspect we can do the exact same thing for FeedFoward

def forward(self, x: Tensor) -> Tensor:
return self.w2(F.silu(self.w1(x)) * self.w3(x))

Task:
Implement the above trick and rerun the benchmarks to show gains. If you don't have access to an A100, another (ideally similar) GPU is fine too as a proxy. Also, if you can, try to confirm via a trace that indeed two GEMMs have been turned into one.

@msaroufim msaroufim added the good first issue Good for newcomers label Aug 6, 2024
@sanchitintel

This comment was marked as resolved.

@sayakpaul
Copy link
Contributor

suspect we can do the exact same thing for FeedFoward

How would you account for silu here?

self.w2(F.silu(self.w1(x)) * self.w3(x)) 

Something like this could work:

class FusedOperation(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super().__init__()
        self.fused_weight = nn.Parameter(torch.randn(output_dim, input_dim * 2))
        self.bias = nn.Parameter(torch.zeros(output_dim))

    def forward(self, x):
        # Split the fused weight into two parts
        w1_w3, w2 = self.fused_weight.chunk(2, dim=1)
        
        # Compute the fused operation
        hidden = F.silu(x @ w1_w3[:, :x.size(-1)].t()) * (x @ w1_w3[:, x.size(-1):].t())
        return hidden @ w2.t() + self.bias

Or do you have a simpler alternative in mind?

@cpuhrsch
Copy link
Contributor Author

cpuhrsch commented Aug 31, 2024

@sayakpaul Oh, I mean

x1, x3 = self.w13(x).split([...])

As in, just fuse w1 and w3 not all of w1, w2 and w3. Similar to how wqkv fuses wq, wk and wv, but leaves the output projections (wo) alone.

So more specifically

x1, x3 = self.w13(x).split([...])
return self.w2(F.silu(x1) * x3)

Right now

self.w2(F.silu(self.w1(x)) * self.w3(x))

will cause 3 calls to an nn.Linear, but with the above change it's 2 calls and also F.silu(x1) * x3 can become an epilogue of w13(x) if using torch.compile.

Essentially you stack w1 and w3 horizontally like

[w1,
 w3] @ x

instead of

[w1 @ x,
 w3 @ x]

but you can split the result of the former (and do so without causing a copy, because striding).

yanbing-j pushed a commit to yanbing-j/ao that referenced this issue Dec 9, 2024
…ytorch#606)

* Automatically identify cuda from nvidia-smi in install-requirements

* Update README.md

---------

Co-authored-by: Michael Gschwind <61328285+mikekgfb@users.noreply.github.com>
yanbing-j pushed a commit to yanbing-j/ao that referenced this issue Dec 9, 2024
* executable README

* fix title of CI workflow

* markup commands in markdown

* extend the markup-markdown language

* Automatically identify cuda from nvidia-smi in install-requirements (pytorch#606)

* Automatically identify cuda from nvidia-smi in install-requirements

* Update README.md

---------

Co-authored-by: Michael Gschwind <61328285+mikekgfb@users.noreply.github.com>

* Unbreak zero-temperature sampling (pytorch#599)

Fixes pytorch#581.

* Improve process README

* [retake] Add sentencepiece tokenizer (pytorch#626)

* Add sentencepiece tokenizer

Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

* Add white space

Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

* Handle white space:

Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

* Handle control ids

Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

* More cleanup

Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

* Lint

Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

* Use unique_ptr

Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

* Use a larger runner

Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

* Debug

Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

* Debug

Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

* Cleanup

* Update install_utils.sh to use python3 instead of python (pytorch#636)

As titled. On some devices `python` and `python3` are pointing to different environments so good to unify them.

* Fix quantization doc to specify dytpe limitation on a8w4dq (pytorch#629)

Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

Co-authored-by: Kimish Patel <kimishpatel@fb.com>

* add desktop.json (pytorch#622)

* add desktop.json

* add fast

* remove embedding

* improvements

* update readme from doc branch

* tab/spc

* fix errors in updown language

* fix errors in updown language, and [skip]: begin/end

* fix errors in updown language, and [skip]: begin/end

* a storied run

* stories run on readme instructions does not need HF token

* increase timeout

* check for hang un hf_login

* executable README improvements

* typo

* typo

---------

Co-authored-by: Ian Barber <ian.barber@gmail.com>
Co-authored-by: Scott Wolchok <swolchok@meta.com>
Co-authored-by: Mengwei Liu <larryliu0820@users.noreply.github.com>
Co-authored-by: Kimish Patel <kimishpatel@fb.com>
Co-authored-by: Scott Roy <161522778+metascroy@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
good first issue Good for newcomers
Projects
None yet
Development

No branches or pull requests

4 participants