Skip to content

Commit

Permalink
Bound tests
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses committed Sep 26, 2023
1 parent cc202a8 commit f430a9b
Show file tree
Hide file tree
Showing 6 changed files with 20 additions and 1 deletion.
5 changes: 4 additions & 1 deletion test/conv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -878,6 +878,8 @@ end
gradtest((y, w) -> sum(∇depthwiseconv_data(y, w, dcdims)), y, w)
end

@static if Test_Enzyme

@testset "EnzymeRules: conv! spatial_rank=$spatial_rank" for spatial_rank in (1, 2, 3)
x = rand(rng, repeat([5], spatial_rank)..., 3, 2)
w = rand(rng, repeat([3], spatial_rank)..., 3, 3)
Expand All @@ -899,7 +901,6 @@ end
end
end


@testset "EnzymeRules: depthwiseconv! spatial_rank=$spatial_rank" for spatial_rank in (1, 2, 3)
x = rand(rng, repeat([5], spatial_rank)..., 3, 2)
w = rand(rng, repeat([3], spatial_rank)..., 3, 3)
Expand All @@ -919,4 +920,6 @@ end

EnzymeTestUtils.test_reverse(curconv!, Tret, (dst, Tdst), (x, Tx), (w, Tw), (cdims, EnzymeCore.Const))
end
end

end
4 changes: 4 additions & 0 deletions test/dropout.jl
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,8 @@ using Zygote, StableRNGs, ChainRulesCore, Enzyme
@test_throws ArgumentError dropout!(y1, x1, 3)
end

@static if Test_Enzyme

@testset "EnzymeRules: dropout " begin
rng = Random.default_rng()

Expand All @@ -99,4 +101,6 @@ end
val = convert(Float32, 1/(1-p))

@test dx1[tape[1]] (val * dout)[tape[1]]
end

end
3 changes: 3 additions & 0 deletions test/gather.jl
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,7 @@ function gather_testsuite(Backend)
gradtest_fn((s, i) -> gather(s, i), src, idx)
end

@static if Test_Enzyme

@testset "EnzymeRules: gather! gradient for scalar index" begin
src = device(Float64[3, 4, 5, 6, 7])
Expand All @@ -172,6 +173,8 @@ function gather_testsuite(Backend)
end
end

end

@testset "gather gradient for tuple index" begin
src = device(Float64[
3 5 7
Expand Down
3 changes: 3 additions & 0 deletions test/pooling.jl
Original file line number Diff line number Diff line change
Expand Up @@ -968,6 +968,7 @@ end
gradtest(x -> sum(meanpool(x, k)), x)
end

@static if Test_Enzyme

@testset "EnzymeRules: pooling! $pool spatial_rank=$spatial_rank " for spatial_rank in (1, 2),
(pool, pool!) in ((maxpool, maxpool!), (meanpool, meanpool!))
Expand All @@ -985,4 +986,6 @@ end
EnzymeTestUtils.test_reverse(pool!, Tret, (y, Tdst), (x, Tsrc), (pdims, EnzymeCore.Const))
end

end

end
2 changes: 2 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ using Adapt
using KernelAbstractions
import ReverseDiff as RD # used in `pooling.jl`

const Test_Enzyme = VERSION <= v"1.10" && !Sys.iswindows()

DocMeta.setdocmeta!(NNlib, :DocTestSetup, :(using NNlib, UnicodePlots); recursive=true)

# ENV["NNLIB_TEST_CUDA"] = "true" # uncomment to run CUDA tests
Expand Down
4 changes: 4 additions & 0 deletions test/scatter.jl
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,8 @@ function scatter_testsuite(Backend)
end
end

@static if Test_Enzyme

@testset "EnzymeRules" begin
idx = device([2, 2, 3, 4, 4])
src = device(ones(T, 3, 5))
Expand All @@ -226,5 +228,7 @@ function scatter_testsuite(Backend)
end
end
end

end
end
end

0 comments on commit f430a9b

Please sign in to comment.