Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Attentionpooling nonlinear_ops #3

Open
Vejni opened this issue May 9, 2024 · 7 comments
Open

Attentionpooling nonlinear_ops #3

Vejni opened this issue May 9, 2024 · 7 comments

Comments

@Vejni
Copy link

Vejni commented May 9, 2024

Hi Jacob,

Thanks a lot for this neat package. I don't know if you take on requests, but it would be great to see support for Attention pooling layers, and others present in Enformer. Do you plan to write functions for these in the future?
Thanks

@jmschrei
Copy link
Owner

jmschrei commented May 9, 2024

Howdy

You can pass in custom operations using the additional_nonlinear_ops parameter in deep_lift_shap. So, if you have a GeLU function implemented somewhere you would do something like...

additional_nonlinear_ops = {GeLU: _nonlinear} where _nonlinear is from tangermeme.deep_lift_shap.

This should be fairly straightforward for activation functions in whatever implementation you're using.

I'm not actually sure how to handle attention pooling layers though -- pooling is sort of its own beast because you need a function to unpool the data as well. If you could figure out what the appropriate to do is I could see how to incorporate it.

A challenge is that most of the non-standard ops are implemented in the same package as Enformer is and, since the dictionary requires a reference to the object type (like GeLU above), having these built-in means having the layers built-in.

@PedroBarbosa
Copy link

Hi Jacob,

Still on this topic, I wanted to ask if it's that straightforward to expose ReLU functions that are not defined as layers.

For context, I'm playing with tangermeme to generate attributions for the Pangolin model, where ReLU activations are applied within the residual blocks. It seems that deep_lift_shap does not register any non-linear operations in this case.

Thanks for the package, it will be super useful.

@jmschrei
Copy link
Owner

No, it's not straightforward to register operations that are not defined as layers. I'm not sure it's even possible. You need the layer object to assign the hook that overrides the backward pass.

You'd have to make your own model definition that uses ReLU layers and load up the weights into it. That should work because you're not changing the weight values or shapes.

@avantikalal
Copy link

Just to add to this thread, we are using the attentionpool layer in enformer-pytorch here: https://github.com/lucidrains/enformer-pytorch/blob/9ffeb8b62927d752b4983ef308a28bf70b34b160/enformer_pytorch/modeling_enformer.py#L159. Any suggestions for how to create a dictionary for this layer to use with deepshap?

@jmschrei
Copy link
Owner

Attention pooling looks tricky because it seems like it might be linear except for the softmax. Annoyingly, rather than using a torch.nn.Softmax layer, which would already be registered, they use the built-in .softmax method call that is not even a layer that can be registered. There is a _softmax function built in to tangermeme that can be used for other layers, but it can't handle shape changes between the input to the layer and the output because it's supposed to be used on activations. Unfortunately, the cleanest solution may be to try to push a PR to that repo changing the layer... but that may break their saved models. I will think more about it.

@avantikalal
Copy link

So, if I:

  1. Define a class identical to AttentionPool except that it uses nn.softmax
  2. Replace all the attention pools in enformer with my new class
  3. Copy over the weights from the old model

That should just work?

@jmschrei
Copy link
Owner

jmschrei commented Jun 20, 2024

Hm. Looking closer at the implementation, I think that the solution might be a little more complicated. You need to account for the dot product which is the final line, as well. Sorry for not mentioning that before. We encountered a similar issue with the profile attributions: https://github.com/jmschrei/bpnet-lite/blob/master/bpnetlite/bpnet.py#L54

Basically, registered operations:

(1) must have the same input and output size
(2) cannot have multiple inputs to the forward pass
(2) seem to need to include the dot product when you apply a non-linear set of operations to a value, and then do the dot product between the original values and the transformed ones.

I think you'll need to do something conceptually similar to what I did except that you'll also need to run the convolution. I haven't tested this code, just to show you what needs to go where:

class AttentionPool(nn.Module):
    def __init__(self, dim, pool_size = 2):
        super().__init__()
        self.pool_size = pool_size
        self.pool_fn = Rearrange('b d (n p) -> b d n p', p = pool_size)
        self.attention = _AttentionPoolOp

    def forward(self, x):
        b, _, n = x.shape
        remainder = n % self.pool_size
        needs_padding = remainder > 0

        if needs_padding:
            x = F.pad(x, (0, remainder), value = 0)
            mask = torch.zeros((b, 1, n), dtype = torch.bool, device = x.device)
            mask = F.pad(mask, (0, remainder), value = True)

        x = self.pool_fn(x)
        return self._attention(x)

class _AttentionPoolOp(torch.nn.Module):
    def __init__(self):
        self.softmax = torch.nn.Softmax()
        self.conv = nn.Conv2d(dim, dim, 1, bias = False)

        nn.init.dirac_(self.to_attn_logits.weight)
        with torch.no_grad():
            self.to_attn_logits.weight.mul_(2)

    def forward(self, X):
        logits = self.conv(X)

        if needs_padding:
            mask_value = -torch.finfo(logits.dtype).max
            logits = logits.masked_fill(self.pool_fn(mask), mask_value)

        attn = self.softmax(logits)
        return (X * attn).sum(dim = -1)

Then, all you'll need to do is register _AttentionPoolOp with _nonlinear from tangermeme.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants