Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

sparseKT-ktop算法存在的问题 #165

Open
zncj2fdx opened this issue Jan 26, 2024 · 0 comments
Open

sparseKT-ktop算法存在的问题 #165

zncj2fdx opened this issue Jan 26, 2024 · 0 comments

Comments

@zncj2fdx
Copy link

您好,感谢您提供的代码,让我有机会能够进行学习。
我在读sparseKT代码的时候发现了ktop在实现上的一些问题。
q1:在attention方法中,scores = F.softmax(scores, dim=-1) 后需要需要乘上scores = scores * mask.float().to(device)?因为在注意力权重中的第一行需要全部经过掩码操作,经过softmax后,实际的结果并不是0,而是一个接近0的极小数。
q2:ktop算法的具体实现。代码将scores分解成了scores_a和scores_b。原scores中第一行的值为全0,第二行的值应该为1+199个0。代码实现是在将scores_a和scores_b重新拼接后再进行softmax操作,这样会导致scores_a中的所有行需要重新分配注意力,原本因为掩码作用为0的位置又重新获得了注意力权重,这就导致了偷看到了未来位置。修改方法如下:是否应该先对scores_b进行softmax操作,然后再将scores_b和scores_a进行拼接。
修改后:
scores_a = scores[:, :, :k_index, :]
scores_b = scores[:, :, k_index:, :].reshape(bshead(seqlen-k_index), -1)
sorted_scores,sorted_idx = torch.sort(scores_b,descending=True)
scores_t = sorted_scores[:,k_index-1:k_index].repeat(1,seqlen)
scores_b = torch.where(scores_b - scores_t >= torch.tensor(0), scores_b, torch.tensor(-1e32)).reshape(bs,head,seqlen-k_index,-1)
scores_b = F.softmax(scores_b, dim=-1) # BS,8,seqlen,seqlen
scores = torch.cat([scores_a, scores_b], dim=2)
由于本人刚刚入门,担心作者的实现逻辑顺序有其他用意,理解不对的地方希望能够获得指正,期待您的回复!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant