Skip to content

Commit

Permalink
Enable prevously broken tests
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses committed Sep 26, 2023
1 parent cd59223 commit d648fdd
Showing 1 changed file with 7 additions and 8 deletions.
15 changes: 7 additions & 8 deletions test/conv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -407,47 +407,46 @@ ddims(x) = dropdims(x, dims=(ndims(x)-1, ndims(x)))
)
#α, β = 2*rand(rng) - 1, 2*rand(rng) - 1
α, β = 2e0, -1e0
flag = ∇conv_data! in (NNlib.∇conv_data!, NNlib.∇conv_data_im2col!)

@testset "$(∇conv_filter!)/$(∇conv_data!)" begin
# First, your basic convolution with no parameters
cdims = DenseConvDims(x, w)
dy = NNlib.conv(x, w, cdims)
@test isapprox(ddims(∇conv_filter!(copy(w), x, dy, cdims; alpha=α, beta=β)), α*dw + β*w, rtol = 1.0e-7)
@test isapprox(ddims(∇conv_data!(copy(x), dy, w, cdims; alpha=α, beta=β)), α*dx + β*x, rtol = 1.0e-7) broken=flag
@test isapprox(ddims(∇conv_data!(copy(x), dy, w, cdims; alpha=α, beta=β)), α*dx + β*x, rtol = 1.0e-7)

# Next, test convolution on views and alternate datatypes:
@test isapprox(ddims(∇conv_filter!(copy(w), x, view(dy, repeat([:], ndims(dy))...), cdims; alpha=α, beta=β)), α*dw + β*w, rtol = 1.0e-7)
@test isapprox(ddims(∇conv_data!(copy(x), view(dy, repeat([:], ndims(dy))...), w, cdims; alpha=α, beta=β)), α*dx + β*x, rtol = 1.0e-7) broken=flag
@test isapprox(ddims(∇conv_data!(copy(x), view(dy, repeat([:], ndims(dy))...), w, cdims; alpha=α, beta=β)), α*dx + β*x, rtol = 1.0e-7)

@test isapprox(ddims(∇conv_filter!(Float32.(copy(w)), Float32.(x), Float32.(dy), cdims; alpha=Float32(α), beta=Float32(β))), α*dw + β*w, rtol = 1.0e-7)
@test isapprox(ddims(∇conv_data!(Float32.(copy(x)), Float32.(dy), Float32.(w), cdims; alpha=Float32(α), beta=Float32(β))), α*dx + β*x, rtol = 1.0e-7) broken=flag
@test isapprox(ddims(∇conv_data!(Float32.(copy(x)), Float32.(dy), Float32.(w), cdims; alpha=Float32(α), beta=Float32(β))), α*dx + β*x, rtol = 1.0e-7)

# Next, introduce stride:
cdims = DenseConvDims(x, w; stride=2)
dy = NNlib.conv(x, w, cdims)
flag_ = ∇conv_filter! == NNlib.∇conv_filter_direct! && rank in (1,3)
@test isapprox(ddims(∇conv_filter!(copy(w), x, dy, cdims; alpha=α, beta=β)), α*dw_stride + β*w, rtol = 1.0e-7) broken=flag_
@test isapprox(ddims(∇conv_data!(copy(x), dy, w, cdims; alpha=α, beta=β)), α*dx_stride + β*x, rtol = 1.0e-7) broken=flag
@test isapprox(ddims(∇conv_data!(copy(x), dy, w, cdims; alpha=α, beta=β)), α*dx_stride + β*x, rtol = 1.0e-7)

# Next, introduce dilation:
cdims = DenseConvDims(x, w; dilation=2)
dy = NNlib.conv(x, w, cdims)
flag_ = ∇conv_data! == NNlib.∇conv_data_direct! && rank == 3
@test isapprox(ddims(∇conv_filter!(copy(w), x, dy, cdims; alpha=α, beta=β)), α*dw_dil + β*w, rtol = 1.0e-7)
@test isapprox(ddims(∇conv_data!(copy(x), dy, w, cdims; alpha=α, beta=β)), α*dx_dil + β*x, rtol = 1.0e-7) broken=flag || flag_
@test isapprox(ddims(∇conv_data!(copy(x), dy, w, cdims; alpha=α, beta=β)), α*dx_dil + β*x, rtol = 1.0e-7) broken=flag_

# Next, introduce padding:
cdims = DenseConvDims(x, w; padding=1)
dy = NNlib.conv(x, w, cdims)
@test isapprox(ddims(∇conv_filter!(copy(w), x, dy, cdims; alpha=α, beta=β)), α*dw_pad + β*w, rtol = 1.0e-7)
@test isapprox(ddims(∇conv_data!(copy(x), dy, w, cdims; alpha=α, beta=β)), α*dx_pad + β*x, rtol = 1.0e-7) broken=flag
@test isapprox(ddims(∇conv_data!(copy(x), dy, w, cdims; alpha=α, beta=β)), α*dx_pad + β*x, rtol = 1.0e-7)

# Next, test crosscor/conv with a flipped kernel
cdims = DenseConvDims(x, w; flipkernel=true)
dy = NNlib.conv(x, w, cdims)
@test isapprox(ddims(∇conv_filter!(copy(w), x, dy, cdims; alpha=α, beta=β)), α*dw_flip + β*w, rtol = 1.0e-7)
@test isapprox(ddims(∇conv_data!(copy(x), dy, w, cdims; alpha=α, beta=β)), α*dx_flip + β*x, rtol = 1.0e-7) broken=flag
@test isapprox(ddims(∇conv_data!(copy(x), dy, w, cdims; alpha=α, beta=β)), α*dx_flip + β*x, rtol = 1.0e-7)
end
end
end
Expand Down

0 comments on commit d648fdd

Please sign in to comment.