-
Notifications
You must be signed in to change notification settings - Fork 12
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Attentionpooling nonlinear_ops #3
Comments
Howdy You can pass in custom operations using the
This should be fairly straightforward for activation functions in whatever implementation you're using. I'm not actually sure how to handle attention pooling layers though -- pooling is sort of its own beast because you need a function to unpool the data as well. If you could figure out what the appropriate to do is I could see how to incorporate it. A challenge is that most of the non-standard ops are implemented in the same package as Enformer is and, since the dictionary requires a reference to the object type (like |
Hi Jacob, Still on this topic, I wanted to ask if it's that straightforward to expose ReLU functions that are not defined as layers. For context, I'm playing with tangermeme to generate attributions for the Pangolin model, where ReLU activations are applied within the residual blocks. It seems that Thanks for the package, it will be super useful. |
No, it's not straightforward to register operations that are not defined as layers. I'm not sure it's even possible. You need the layer object to assign the hook that overrides the backward pass. You'd have to make your own model definition that uses ReLU layers and load up the weights into it. That should work because you're not changing the weight values or shapes. |
Just to add to this thread, we are using the attentionpool layer in enformer-pytorch here: https://github.com/lucidrains/enformer-pytorch/blob/9ffeb8b62927d752b4983ef308a28bf70b34b160/enformer_pytorch/modeling_enformer.py#L159. Any suggestions for how to create a dictionary for this layer to use with deepshap? |
Attention pooling looks tricky because it seems like it might be linear except for the softmax. Annoyingly, rather than using a |
So, if I:
That should just work? |
Hm. Looking closer at the implementation, I think that the solution might be a little more complicated. You need to account for the dot product which is the final line, as well. Sorry for not mentioning that before. We encountered a similar issue with the profile attributions: https://github.com/jmschrei/bpnet-lite/blob/master/bpnetlite/bpnet.py#L54 Basically, registered operations: (1) must have the same input and output size I think you'll need to do something conceptually similar to what I did except that you'll also need to run the convolution. I haven't tested this code, just to show you what needs to go where: class AttentionPool(nn.Module):
def __init__(self, dim, pool_size = 2):
super().__init__()
self.pool_size = pool_size
self.pool_fn = Rearrange('b d (n p) -> b d n p', p = pool_size)
self.attention = _AttentionPoolOp
def forward(self, x):
b, _, n = x.shape
remainder = n % self.pool_size
needs_padding = remainder > 0
if needs_padding:
x = F.pad(x, (0, remainder), value = 0)
mask = torch.zeros((b, 1, n), dtype = torch.bool, device = x.device)
mask = F.pad(mask, (0, remainder), value = True)
x = self.pool_fn(x)
return self._attention(x)
class _AttentionPoolOp(torch.nn.Module):
def __init__(self):
self.softmax = torch.nn.Softmax()
self.conv = nn.Conv2d(dim, dim, 1, bias = False)
nn.init.dirac_(self.to_attn_logits.weight)
with torch.no_grad():
self.to_attn_logits.weight.mul_(2)
def forward(self, X):
logits = self.conv(X)
if needs_padding:
mask_value = -torch.finfo(logits.dtype).max
logits = logits.masked_fill(self.pool_fn(mask), mask_value)
attn = self.softmax(logits)
return (X * attn).sum(dim = -1) Then, all you'll need to do is register |
Hi Jacob,
Thanks a lot for this neat package. I don't know if you take on requests, but it would be great to see support for Attention pooling layers, and others present in Enformer. Do you plan to write functions for these in the future?
Thanks
The text was updated successfully, but these errors were encountered: