From 3336842952d4647fe2b8926d13201f138c29c38f Mon Sep 17 00:00:00 2001 From: adrhill Date: Mon, 17 Jun 2024 22:48:50 +0200 Subject: [PATCH 1/4] Add test for unusual input datatypes --- test/conv.jl | 37 +++++++++++++++++++++++++++++++++++++ 1 file changed, 37 insertions(+) diff --git a/test/conv.jl b/test/conv.jl index dce01771..2bb05473 100644 --- a/test/conv.jl +++ b/test/conv.jl @@ -2,6 +2,7 @@ using NNlib, Test using NNlib: input_size, kernel_size, channels_in, channels_out, channel_multiplier, stride, padding, dilation, flipkernel, output_size, groupcount +using Random: AbstractRNG, SamplerType @testset "ConvDims" begin for T in (DenseConvDims, DepthwiseConvDims) @@ -865,6 +866,42 @@ end @test size(NNlib.∇conv_filter_direct!(w, x, y, cdims)) == w_size end +# https://github.com/FluxML/NNlib.jl/issues/490 +# https://github.com/FluxML/NNlib.jl/issues/405 +@testset "conv_direct! - Unusual input types" begin + # Create test type that can't be indexed when undefined. + # This simulates the worst-case scenario for custom types. + struct MyFloat <: Real + set::Set{Float32} + end + + # Test that direct indexing fails when undefined. + v = Array{MyFloat}(undef, 3) + @test_throws UndefRefError v[1] + + # Define minimal set of functions required for conv_direct! + MyFloat(x::MyFloat) = x + MyFloat(x::Real) = MyFloat(Set(Float32(x))) + + Base.:+(x::MyFloat, y::MyFloat) = MyFloat(only(x.set) + only(y.set)) + Base.:*(x::MyFloat, y::MyFloat) = MyFloat(only(x.set) * only(y.set)) + Base.promote_rule(::Type{MyFloat}, ::Type{Float32}) = MyFloat + Base.rand(::AbstractRNG, ::SamplerType{MyFloat}) = MyFloat(rand(Float32)) + Base.zero(::MyFloat) = MyFloat(zero(Float32)) + Base.zero(::Type{MyFloat}) = MyFloat(zero(Float32)) + + # Test conv_direct! + x_size = (6, 7, 8, 5, 3) + y_size = (5, 6, 7, 4, 3) + w_size = (2, 2, 2, 5, 4) + x = rand(MyFloat, x_size); + w = randn(Float32, w_size); + y = Array{MyFloat}(undef, y_size...); + cdims = DenseConvDims(x_size, w_size) + y_out = NNlib.conv_direct!(y, x, w, cdims) + @test size(y_out) == y_size +end + @testset "AutoDiff: spatial_rank=$spatial_rank" for spatial_rank in (1, 2, 3) x = rand(rng, repeat([5], spatial_rank)..., 3, 2) w = rand(rng, repeat([3], spatial_rank)..., 3, 3) From 0e318fd4fa305c36655615c362667c9a044f31a3 Mon Sep 17 00:00:00 2001 From: adrhill Date: Mon, 17 Jun 2024 22:49:05 +0200 Subject: [PATCH 2/4] Add fix to `conv_direct!` --- src/impl/conv_direct.jl | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/impl/conv_direct.jl b/src/impl/conv_direct.jl index 9f12f1dc..29ee5af5 100644 --- a/src/impl/conv_direct.jl +++ b/src/impl/conv_direct.jl @@ -81,6 +81,9 @@ function conv_direct!( # Use `calc_padding_regions` to determine where we do or don't need to worry about padding padded_regions, central_region = calc_padding_regions(cdims) + # Set outputs to zero (https://github.com/FluxML/NNlib.jl/issues/490) + y = fill!(y, zero(yT)) + # Start with the central region w_region, h_region, d_region = central_region @inbounds for batch in 1:size(x, 5), From 68e15a9a68e532a9cf45f413c5c8384ca4b17994 Mon Sep 17 00:00:00 2001 From: adrhill Date: Mon, 17 Jun 2024 23:11:00 +0200 Subject: [PATCH 3/4] Only set y to zero if beta is false or zero --- src/impl/conv_direct.jl | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/impl/conv_direct.jl b/src/impl/conv_direct.jl index 29ee5af5..497f2e92 100644 --- a/src/impl/conv_direct.jl +++ b/src/impl/conv_direct.jl @@ -81,8 +81,10 @@ function conv_direct!( # Use `calc_padding_regions` to determine where we do or don't need to worry about padding padded_regions, central_region = calc_padding_regions(cdims) - # Set outputs to zero (https://github.com/FluxML/NNlib.jl/issues/490) - y = fill!(y, zero(yT)) + # Set outputs to zero to support custom datatypes (https://github.com/FluxML/NNlib.jl/issues/490) + if iszero(beta) + y = fill!(y, zero(yT)) + end # Start with the central region w_region, h_region, d_region = central_region From 8f8759668c13f51ac1794be70f025aa269899dfc Mon Sep 17 00:00:00 2001 From: adrhill Date: Mon, 17 Jun 2024 23:14:00 +0200 Subject: [PATCH 4/4] Test output eltype --- test/conv.jl | 2 ++ 1 file changed, 2 insertions(+) diff --git a/test/conv.jl b/test/conv.jl index 2bb05473..492de2cc 100644 --- a/test/conv.jl +++ b/test/conv.jl @@ -899,6 +899,8 @@ end y = Array{MyFloat}(undef, y_size...); cdims = DenseConvDims(x_size, w_size) y_out = NNlib.conv_direct!(y, x, w, cdims) + + @test eltype(y_out) == MyFloat @test size(y_out) == y_size end