Skip to content

Commit

Permalink
nonzero beta + flipkernel bugfix (#519)
Browse files Browse the repository at this point in the history
* nonzero beta + flipkernel bugfix
* conv! alpha/beta tests added, conv_filter_direct flipkernel with view
  • Loading branch information
nikopj authored Jul 8, 2023
1 parent ace7d53 commit 629475a
Show file tree
Hide file tree
Showing 2 changed files with 100 additions and 3 deletions.
8 changes: 5 additions & 3 deletions src/impl/conv_direct.jl
Original file line number Diff line number Diff line change
Expand Up @@ -191,9 +191,11 @@ function ∇conv_filter_direct!(dw::AbstractArray{wT,5}, x::AbstractArray{xT,5},
dy = transpose_swapbatch(predilate(dy, stride(cdims)))
ctdims = DenseConvDims(dy, x; padding=transpose_pad(cdims),
stride=dilation(cdims))
conv_direct!(dw, dy, x, ctdims; alpha=alpha, beta=beta)
if flipkernel(cdims)
dw .= dw[end:-1:1, end:-1:1, end:-1:1, :, :]
dw_ = if flipkernel(cdims)
view(dw, reverse(axes(dw, 1)), reverse(axes(dw, 2)), reverse(axes(dw, 3)), :, :)
else
dw
end
conv_direct!(dw_, dy, x, ctdims; alpha=alpha, beta=beta)
return dw
end
95 changes: 95 additions & 0 deletions test/conv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -310,6 +310,49 @@ ddims(x) = dropdims(x, dims=(ndims(x)-1, ndims(x)))
end
end

# Test all in-place implementations/interfaces
convs = [NNlib.conv!, NNlib.conv_im2col!, NNlib.conv_direct!,]
NNlib.is_nnpack_available() && push!(convs, NNlib.conv_nnpack!)
for conv! in convs
if NNlib.is_nnpack_available()
if conv! == NNlib.conv_nnpack! && !NNlib.nnpack_supported_operation(DenseConvDims(x, w))
continue
end
end
α, β = 2e0, -1e0

@testset "$(conv!)" begin
# First, your basic convolution with no parameters
cdims = DenseConvDims(x, w)
y0 = rand(rng, -9e0:9e0, size(y_plain)..., 1, 1)
@test isapprox(ddims(conv!(copy(y0), x, w, cdims; alpha=α, beta=β)), α*y_plain + β*y0, rtol = 1.0e-7)

# Next, test convolution on views and alternate datatypes:
@test isapprox(ddims(conv!(copy(y0), view(x, repeat([:], ndims(x))...), w, cdims; alpha=α, beta=β)), α*y_plain + β*y0, rtol = 1.0e-7)
@test isapprox(ddims(conv!(Float32.(copy(y0)), Float32.(x), Float32.(w), cdims; alpha=Float32(α), beta=Float32(β))), Float32.(α*y_plain + β*y0), rtol = 1.0e-7)

# Next, introduce stride:
cdims = DenseConvDims(x, w; stride=2)
y0 = rand(rng, -9e0:9e0, size(y_stride)..., 1, 1)
@test isapprox(ddims(conv!(copy(y0), x, w, cdims; alpha=α, beta=β)), α*y_stride + β*y0, rtol = 1.0e-7)

# Next, introduce dilation:
cdims = DenseConvDims(x, w; dilation=2)
y0 = rand(rng, -9e0:9e0, size(y_dil)..., 1, 1)
@test isapprox(ddims(conv!(copy(y0), x, w, cdims; alpha=α, beta=β)), α*y_dil + β*y0, rtol = 1.0e-7)

# Next, introduce padding:
cdims = DenseConvDims(x, w; padding=1)
y0 = rand(rng, -9e0:9e0, size(y_pad)..., 1, 1)
@test isapprox(ddims(conv!(copy(y0), x, w, cdims; alpha=α, beta=β)), α*y_pad + β*y0, rtol = 1.0e-7)

# Next, test crosscor/conv with a flipped kernel
cdims = DenseConvDims(x, w; flipkernel=true)
y0 = rand(rng, -9e0:9e0, size(y_flip)..., 1, 1)
@test isapprox(ddims(conv!(copy(y0), x, w, cdims; alpha=α, beta=β)), α*y_flip + β*y0, rtol = 1.0e-7)
end
end

# Test all implementations/interfaces
for (∇conv_filter, ∇conv_data) in (
(NNlib.∇conv_filter, NNlib.∇conv_data),
Expand Down Expand Up @@ -355,6 +398,58 @@ ddims(x) = dropdims(x, dims=(ndims(x)-1, ndims(x)))
@test isapprox(ddims(∇conv_data(dy, w, cdims)), dx_flip, rtol = 1.0e-7)
end
end

# Test all in-place implementations/interfaces
for (∇conv_filter!, ∇conv_data!) in (
(NNlib.∇conv_filter!, NNlib.∇conv_data!),
(NNlib.∇conv_filter_im2col!, NNlib.∇conv_data_im2col!),
(NNlib.∇conv_filter_direct!, NNlib.∇conv_data_direct!),
)
#α, β = 2*rand(rng) - 1, 2*rand(rng) - 1
α, β = 2e0, -1e0
flag = ∇conv_data! in (NNlib.∇conv_data!, NNlib.∇conv_data_im2col!)

@testset "$(∇conv_filter!)/$(∇conv_data!)" begin
# First, your basic convolution with no parameters
cdims = DenseConvDims(x, w)
dy = NNlib.conv(x, w, cdims)
@test isapprox(ddims(∇conv_filter!(copy(w), x, dy, cdims; alpha=α, beta=β)), α*dw + β*w, rtol = 1.0e-7)
@test isapprox(ddims(∇conv_data!(copy(x), dy, w, cdims; alpha=α, beta=β)), α*dx + β*x, rtol = 1.0e-7) broken=flag

# Next, test convolution on views and alternate datatypes:
@test isapprox(ddims(∇conv_filter!(copy(w), x, view(dy, repeat([:], ndims(dy))...), cdims; alpha=α, beta=β)), α*dw + β*w, rtol = 1.0e-7)
@test isapprox(ddims(∇conv_data!(copy(x), view(dy, repeat([:], ndims(dy))...), w, cdims; alpha=α, beta=β)), α*dx + β*x, rtol = 1.0e-7) broken=flag

@test isapprox(ddims(∇conv_filter!(Float32.(copy(w)), Float32.(x), Float32.(dy), cdims; alpha=Float32(α), beta=Float32(β))), α*dw + β*w, rtol = 1.0e-7)
@test isapprox(ddims(∇conv_data!(Float32.(copy(x)), Float32.(dy), Float32.(w), cdims; alpha=Float32(α), beta=Float32(β))), α*dx + β*x, rtol = 1.0e-7) broken=flag

# Next, introduce stride:
cdims = DenseConvDims(x, w; stride=2)
dy = NNlib.conv(x, w, cdims)
flag_ = ∇conv_filter! == NNlib.∇conv_filter_direct! && rank in (1,3)
@test isapprox(ddims(∇conv_filter!(copy(w), x, dy, cdims; alpha=α, beta=β)), α*dw_stride + β*w, rtol = 1.0e-7) broken=flag_
@test isapprox(ddims(∇conv_data!(copy(x), dy, w, cdims; alpha=α, beta=β)), α*dx_stride + β*x, rtol = 1.0e-7) broken=flag

# Next, introduce dilation:
cdims = DenseConvDims(x, w; dilation=2)
dy = NNlib.conv(x, w, cdims)
flag_ = ∇conv_data! == NNlib.∇conv_data_direct! && rank == 3
@test isapprox(ddims(∇conv_filter!(copy(w), x, dy, cdims; alpha=α, beta=β)), α*dw_dil + β*w, rtol = 1.0e-7)
@test isapprox(ddims(∇conv_data!(copy(x), dy, w, cdims; alpha=α, beta=β)), α*dx_dil + β*x, rtol = 1.0e-7) broken=flag || flag_

# Next, introduce padding:
cdims = DenseConvDims(x, w; padding=1)
dy = NNlib.conv(x, w, cdims)
@test isapprox(ddims(∇conv_filter!(copy(w), x, dy, cdims; alpha=α, beta=β)), α*dw_pad + β*w, rtol = 1.0e-7)
@test isapprox(ddims(∇conv_data!(copy(x), dy, w, cdims; alpha=α, beta=β)), α*dx_pad + β*x, rtol = 1.0e-7) broken=flag

# Next, test crosscor/conv with a flipped kernel
cdims = DenseConvDims(x, w; flipkernel=true)
dy = NNlib.conv(x, w, cdims)
@test isapprox(ddims(∇conv_filter!(copy(w), x, dy, cdims; alpha=α, beta=β)), α*dw_flip + β*w, rtol = 1.0e-7)
@test isapprox(ddims(∇conv_data!(copy(x), dy, w, cdims; alpha=α, beta=β)), α*dx_flip + β*x, rtol = 1.0e-7) broken=flag
end
end
end
end
end
Expand Down

0 comments on commit 629475a

Please sign in to comment.