Skip to content

Commit

Permalink
fix up pool
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses committed Sep 25, 2023
1 parent ee36810 commit f8d4717
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 15 deletions.
4 changes: 2 additions & 2 deletions src/enzyme.jl
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,7 @@ end

for pool in [:maxpool, :meanpool, :lpnormpool]
pool! = Symbol(pool, :!)
∇pool = Symbol(:∇, pool)
∇pool = Symbol(:∇, pool, :!)

@eval begin

Expand Down Expand Up @@ -277,7 +277,7 @@ function EnzymeCore.EnzymeRules.reverse(config, func::EnzymeCore.Const{typeof($p
if !(typeof(y) <: EnzymeCore.Const) && dy !== y.val

if !(typeof(x) <: EnzymeCore.Const) && dx !== x.val
NNlib.$(∇pool)(dx, dy, cache_y, cache_x, dims; alpha=eltype(dx)(1), beta=eltype(dx)(1), kwargs...)
NNlib.$(∇pool)(dx, dy, cache_y, cache_x, dims.val; alpha=eltype(dx)(1), beta=eltype(dx)(1), kwargs...)
end

dy .= 0
Expand Down
30 changes: 17 additions & 13 deletions test/conv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -878,24 +878,28 @@ end
gradtest((y, w) -> sum(∇depthwiseconv_data(y, w, dcdims)), y, w)
end

@testset "EnzymeRules: conv! spatial_rank=$spatial_rank" for spatial_rank in (1, 2, 3)
@testset "EnzymeRules: $conv ! spatial_rank=$spatial_rank" for spatial_rank in (1, 2, 3),
name in (:conv, :depthwiseconv)
x = rand(rng, repeat([5], spatial_rank)..., 3, 2)
w = rand(rng, repeat([3], spatial_rank)..., 3, 3)
cdims = DenseConvDims(x, w)

cdims = if name == :conv
DenseConvDims(x, w)
else
DepthwiseConvDims(x, w)
end

for name in (:conv, :depthwiseconv)
curconv = @eval $(Symbol("$(name)"))
curconv! = @eval $(Symbol("$(name)!"))
dst = curconv(x, w, cdims)
curconv = @eval $(Symbol("$(name)"))
curconv! = @eval $(Symbol("$(name)!"))
dst = curconv(x, w, cdims)

for Tret in (EnzymeCore.Const, EnzymeCore.Duplicated, EnzymeCore.BatchDuplicated),
Tdst in (EnzymeCore.Duplicated, EnzymeCore.BatchDuplicated),
Tx in (EnzymeCore.Const, EnzymeCore.Duplicated, EnzymeCore.BatchDuplicated),
Tw in (EnzymeCore.Const, EnzymeCore.Duplicated, EnzymeCore.BatchDuplicated)
for Tret in (EnzymeCore.Const, EnzymeCore.Duplicated, EnzymeCore.BatchDuplicated),
Tdst in (EnzymeCore.Duplicated, EnzymeCore.BatchDuplicated),
Tx in (EnzymeCore.Const, EnzymeCore.Duplicated, EnzymeCore.BatchDuplicated),
Tw in (EnzymeCore.Const, EnzymeCore.Duplicated, EnzymeCore.BatchDuplicated)

EnzymeTestUtils.are_activities_compatible(Tret, Tdst, Tw, Tw) || continue
EnzymeTestUtils.are_activities_compatible(Tret, Tdst, Tw, Tw) || continue

EnzymeTestUtils.test_reverse(curconv!, Tret, (dst, Tdst), (x, Tx), (x, Tw), (cdims, EnzymeCore.Const))
end
EnzymeTestUtils.test_reverse(curconv!, Tret, (dst, Tdst), (x, Tx), (x, Tw), (cdims, EnzymeCore.Const))
end
end

0 comments on commit f8d4717

Please sign in to comment.