Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How to compute second-order hessian for custom module? #340

Open
XuZikang opened this issue Nov 18, 2024 · 2 comments
Open

How to compute second-order hessian for custom module? #340

XuZikang opened this issue Nov 18, 2024 · 2 comments

Comments

@XuZikang
Copy link

Hi, I want to compute the second-order Hessian matrix for my custom nn.Module, which aims to replace the nn.Conv2D by CP decomposition, and I get this warning:

/data/miniforge3/envs/fairmae/lib/python3.9/site-packages/backpack/custom_module/graph_utils.py:86: UserWarning: Encountered node that may break second-order extensions: op=get_attr, target=V.1. If you encounter this problem, please open an issue at https://github.com/f-dangel/backpack/issues.

The architecture of my module is defined below:

class CPConv2d(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, rank, stride=1, padding=0, bias=True):
        super(CPConv2d, self).__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = (kernel_size, kernel_size) if isinstance(
            kernel_size, int) else kernel_size
        self.rank = rank
        self.stride = stride
        self.padding = padding
        self.bias = bias

        # U: (rank,)
        # V: [(out_channels, rank), (in_channels, rank), (kernel_size, rank), (kernel_size, rank)]

        self.U = nn.Parameter(torch.randn(rank))
        self.V = nn.ParameterList([
            nn.Parameter(torch.randn(out_channels, rank)),
            nn.Parameter(torch.randn(in_channels, rank)),
            nn.Parameter(torch.randn(kernel_size[0], rank)),
            nn.Parameter(torch.randn(kernel_size[1], rank))
        ])

        if bias:
            self.b = nn.Parameter(torch.randn(out_channels))
        else:
            self.register_parameter('b', None)

    def forward(self, x):
        W = torch.einsum('r,or,ir,kr,lr->oikl',
                         self.U, self.V[0], self.V[1], self.V[2], self.V[3])

        return F.conv2d(x, W, self.b, self.stride, self.padding)

Could you tell me how to compute the diag_h for the self.U? Thank you!

@f-dangel
Copy link
Owner

Hi,

thanks for the clear explanation. We have a tutorial in the docs which explains how to implement second-order extensions for new layers in BackPACK (see here). The tutorial explains how to support the GGN diagonal, which is slightly easier to implement than the Hessian diagonal.

Could you try following the tutorial and implement support for the GGN diagonal first?
I can then help you to generalize it to the Hessian diagonal.

Best,
Felix

@XuZikang
Copy link
Author

Thanks for your reply! I will try it :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants