-
Notifications
You must be signed in to change notification settings - Fork 2.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Question] About MultiHeadAttention's inputs shape. #3831
Comments
Are u means bertQKVToContextPlugin's fused_multihead_attention_v2 ?
ref https://github.com/NVIDIA/TensorRT/blob/release/8.5/plugin/bertQKVToContextPlugin/README.md S : seq_len ## x,y,z is embeddings, x(/y/z).shape (S, B, E)
Q = self.Wq(x)
K = self.Wk(y)
V = self.Wv(z)
qkv = torch.cat([Q, K, V], dim=2)
qkv = qkv.view(x.size(0), x.size(1), 3, self.num_heads, self.size_per_head)
## the last qkv as plugin's one input
last_qkv = qkv.transpose(2, 3).contiguous().view(x.size(0), x.size(1), 3*self.hidden_size, 1, 1) weights of self.Wq, self.Wk, self.Wv can be different. If in torch.nn.MultiheadAttention view, forward(query,
key,
value,
key_padding_mask=None,
need_weights=True,
attn_mask=None,
average_attn_weights=True,
is_causal=False)
# https://pytorch.org/docs/stable/generated/torch.nn.MultiheadAttention.html if Q, K, V means Query embeddings, Key embeddings, Value embeddings, respond to x,y,z. They can be the same or not. |
Yes, the mha requires |
yes |
@matrix97317 could you try direct import ONNX ? let TRT to do the mha fusion. The native mha fusion in TRT support different sequence length. see https://docs.nvidia.com/deeplearning/tensorrt/developer-guide/index.html#mha-fusion |
closing since no activity for more than 3 weeks, pls reopen if you still have question, thanks all! |
Hi, I have a question about MHAv2. MHAv2 uses [S,B,3*E,1,1] as Inputs shape, Is S must be the same for Q,K,V?I think K and V must be the same, but Q is not.
The text was updated successfully, but these errors were encountered: