Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding a new transposed convolution function (similar to torch.nn.ConvTranspose2d()) #1872

Open
andsteing opened this issue Feb 8, 2022 · 8 comments
Labels
Priority: P2 - no schedule Best effort response and resolution. We have no plan to work on this at the moment. Status: blocked The issue/PR is blocked by another issue/PR.

Comments

@andsteing
Copy link
Collaborator

Adding a transposed convolution as proposed in
jax-ml/jax#5772
would also be very useful when porting models from PyTorch to Flax (as in #1848).

@jheek jheek added Priority: P2 - no schedule Best effort response and resolution. We have no plan to work on this at the moment. Status: blocked The issue/PR is blocked by another issue/PR. labels Mar 14, 2022
@codeboy5
Copy link

Hey is this an active issue that is being worked on ?

@marcvanzee
Copy link
Collaborator

We actually have an implementation already: https://github.com/google/flax/blob/main/flax/linen/linear.py#L447

I don't really understand why we have this issue.

@andsteing can you please clarify?

@codeboy5
Copy link

According to the docs :-

torch.nn.ConvTranspose2d and nn.ConvTranspose are not compatible. nn.ConvTranspose is a wrapper around jax.lax.conv_transpose which computes a fractionally strided convolution, while torch.nn.ConvTranspose2d computes a gradient based transposed convolution.

@marcvanzee this might be the reason and also why i created this issue.

@marcvanzee
Copy link
Collaborator

Ahh sorry, didn't see that, thanks for noting! No it isn't being worked on, do you want to work on it?

@codeboy5
Copy link

Yes, I will be interested to work on it.

@marcvanzee marcvanzee assigned codeboy5 and unassigned codeboy5 Apr 13, 2022
@marcvanzee
Copy link
Collaborator

I just noticed this issue has the "blocked" label. @jheek could you please explain this? I suppose it is blocked on the JAX issue jax-ml/jax#5772?

@codeboy5 In that case I guess we have to wait with working on this issue until that one is merged.

@codeboy5
Copy link

Oh okay. I looked at that issue too, hasn't been any new updates for a year.
Thanks

@marcvanzee
Copy link
Collaborator

Hmm I see, maybe you can reply to that issue and ask whether they are planning to merge it soon? Otherwise you could ask them if you can pick up that issue if you are really interested!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Priority: P2 - no schedule Best effort response and resolution. We have no plan to work on this at the moment. Status: blocked The issue/PR is blocked by another issue/PR.
Projects
None yet
Development

No branches or pull requests

4 participants