Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix UNet implementation with arbitrary channel sizes (#243) #276

Open
wants to merge 10 commits into
base: master
Choose a base branch
from

Conversation

vinayakjeet
Copy link

#243

Bug Description:
The current UNet implementation in the Metalhead package has a limitation where it only works with input tensors of channel size 3. This restriction causes compatibility issues when users try to use UNet with input tensors of different channel sizes.

Patch Description:
To address this limitation, I've modified the UNet implementation to support input tensors with arbitrary channel sizes. The UNet model can now handle input with varying dimensions

Test Case:
using Metalhead
UNet((128,128),1,3,Metalhead.backbone(DenseNet(121)))

This UNet model can process without any errors

Fix UNet implementation to support input  with channel sizes other than 3
@theabhirath
Copy link
Member

Hi Vinayakjeet, thanks for the PR! Unfortunately, I don't think this does what we want yet. The problem is that inchannels isn't being passed to the model backbone. What you've done is try and change the input being passed in to the Flux.outputsize function, which actually causes an error when I try to initialise the model:

julia> using Metalhead

julia> model = UNet((128,128),1,3,Metalhead.backbone(DenseNet(121)))
ERROR: DimensionMismatch: layer Conv((7, 7), 3 => 64, pad=3, stride=2, bias=false) expects size(input, 3) == 3, but got 128×128×1×1 Array{Flux.NilNumber.Nil, 4}
Stacktrace:
  [1] _size_check(layer::Flux.Conv{2, 2, typeof(identity), Array{…}, Bool}, x::Array{Flux.NilNumber.Nil, 4}, ::Pair{Int64, Int64})
    @ Flux ~/.julia/packages/Flux/jgpVj/src/layers/basic.jl:195
  [2] (::Flux.Conv{2, 2, typeof(identity), Array{Float32, 4}, Bool})(x::Array{Flux.NilNumber.Nil, 4})
    @ Flux ~/.julia/packages/Flux/jgpVj/src/layers/conv.jl:198
  [3] #outputsize#340
    @ ~/.julia/packages/Flux/jgpVj/src/outputsize.jl:93 [inlined]
  [4] outputsize(m::Flux.Conv{2, 2, typeof(identity), Array{Float32, 4}, Bool}, inputsizes::NTuple{4, Int64})
    @ Flux ~/.julia/packages/Flux/jgpVj/src/outputsize.jl:91
  [5] unetlayers(layers::Vector{…}, sz::NTuple{…}; outplanes::Nothing, skip_upscale::Int64, m_middle::typeof(Metalhead.unet_middle_block))
    @ Metalhead ~/Code/Metalhead.jl/src/convnets/unet.jl:34
  [6] unet(encoder_backbone::Flux.Chain{…}, imgdims::Tuple{…}, inchannels::Int64, outplanes::Int64, final::typeof(Metalhead.unet_final_block), fdownscale::Int64)
    @ Metalhead ~/Code/Metalhead.jl/src/convnets/unet.jl:81
  [7] unet
    @ ~/Code/Metalhead.jl/src/convnets/unet.jl:76 [inlined]
  [8] #UNet#175
    @ ~/Code/Metalhead.jl/src/convnets/unet.jl:120 [inlined]
  [9] UNet(imsize::Tuple{Int64, Int64}, inchannels::Int64, outplanes::Int64, encoder_backbone::Flux.Chain{Tuple{…}})
    @ Metalhead ~/Code/Metalhead.jl/src/convnets/unet.jl:118
 [10] top-level scope
    @ REPL[3]:1
Some type information was truncated. Use `show(err)` to see complete types.

I would suggest that you try and rewrite the function in such a way that inchannels is passed along to the encoder backbone.

Comment on lines 119 to 120
encoder_backbone = Metalhead.backbone(DenseNet(121)); pretrain::Bool = false)
layers = unet(encoder_backbone, (imsize..., inchannels, 1), outplanes)
layers = unet(encoder_backbone, imsize, inchannels, outplanes)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

inchannels should somehow be passed in to the encoder backbone here. Of course, we will have to decide how to deal with this in case the user passes in a model with this initialised and also separately inchannels

Modified the first convolutional layer of the encoder backbone to ensure compatibility with the input's channel size and dimension mismatch error is thus prevented #1
skip_upscale = fdownscale)
function unet(encoder_backbone, imgdims, inchannels::Integer, outplanes::Integer,
final::Any = unet_final_block, fdownscale::Integer = 0)
backbonelayers = collect(flatten_chains(encoder_backbone))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please pay attention to the formatting, you lost the indentation here

@vinayakjeet
Copy link
Author

A beginner contributor to the codebase, can you review the logic I have implemented, additionally I have encountered an error MethodError indicating a mismatch in method signatures for the unet function. It appears that there might be an issue with how the encoder_backbone is instantiated or utilized within the unet function. Could you please review the instantiation and usage of the encoder_backbone

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants