Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Question] About MultiHeadAttention's inputs shape. #3831

Closed
matrix97317 opened this issue Apr 28, 2024 · 7 comments
Closed

[Question] About MultiHeadAttention's inputs shape. #3831

matrix97317 opened this issue Apr 28, 2024 · 7 comments
Assignees
Labels
triaged Issue has been triaged by maintainers

Comments

@matrix97317
Copy link

Hi, I have a question about MHAv2. MHAv2 uses [S,B,3*E,1,1] as Inputs shape, Is S must be the same for Q,K,V?I think K and V must be the same, but Q is not.

@matrix97317
Copy link
Author

matrix97317 commented Apr 29, 2024

@lix19937
Copy link

lix19937 commented May 2, 2024

Are u means bertQKVToContextPlugin's fused_multihead_attention_v2 ?

The input tensor contains all 3 matrices Q, K, V - This input tensor is computed by multiplying a tensor of size [S, B, E] with the weights W_qkv of size [E, 3 * E] - The weight matrix W_qkv is NOT just the vertical concatenation of individual matrices W_tmp = [W_q', W_k', W_v']', but to start with W_tmp, reshaping it into [E, 3, N, H] (where N * H = E and N is number of heads, H is head size) transposing it into [E, N, 3, H] and reshaping it back to [E, 3 * E]. The interpretation is to layout the k-th heads of Q, K and V next to each other, instead of first all N heads of Q, then all N heads of K, then all heads of V

ref https://github.com/NVIDIA/TensorRT/blob/release/8.5/plugin/bertQKVToContextPlugin/README.md


S : seq_len
B : batch_size
E : hidden size

## x,y,z is embeddings, x(/y/z).shape (S, B, E)   
Q = self.Wq(x)
K = self.Wk(y)
V = self.Wv(z)
qkv = torch.cat([Q, K, V], dim=2)
qkv = qkv.view(x.size(0), x.size(1), 3, self.num_heads, self.size_per_head)

## the last qkv as plugin's one input  
last_qkv = qkv.transpose(2, 3).contiguous().view(x.size(0), x.size(1), 3*self.hidden_size, 1, 1)

weights of self.Wq, self.Wk, self.Wv can be different.

If in torch.nn.MultiheadAttention view,

forward(query, 
        key, 
        value, 
        key_padding_mask=None, 
        need_weights=True, 
        attn_mask=None, 
        average_attn_weights=True, 
        is_causal=False)
# https://pytorch.org/docs/stable/generated/torch.nn.MultiheadAttention.html     

if Q, K, V means Query embeddings, Key embeddings, Value embeddings, respond to x,y,z. They can be the same or not.

@ttyio
Copy link
Collaborator

ttyio commented May 2, 2024

Yes, the mha requires sequenceTo for k, v; sequenceFrom for q. So you are right. And in demobert, since the q, k and v have the same sequence, so we simplified the problem to only support same sequence in the plugin. horizontal merged into single buffer as input.

@zerollzeng zerollzeng added the triaged Issue has been triaged by maintainers label May 3, 2024
@matrix97317
Copy link
Author

@lix19937 @ttyio So, current bertQKVToContextPlugin must be Q's seqlen == K's seqlen == V's seqlen? If so, my problem has been solved. Are you considering supporting Q's seqlen= The seqlen of K?

@lix19937
Copy link

lix19937 commented May 3, 2024

current bertQKVToContextPlugin must be Q's seqlen == K's seqlen == V's seqlen?

yes

@ttyio
Copy link
Collaborator

ttyio commented May 3, 2024

@matrix97317 could you try direct import ONNX ? let TRT to do the mha fusion. The native mha fusion in TRT support different sequence length. see https://docs.nvidia.com/deeplearning/tensorrt/developer-guide/index.html#mha-fusion

@ttyio
Copy link
Collaborator

ttyio commented Jul 2, 2024

closing since no activity for more than 3 weeks, pls reopen if you still have question, thanks all!

@ttyio ttyio closed this as completed Jul 2, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
triaged Issue has been triaged by maintainers
Projects
None yet
Development

No branches or pull requests

4 participants