Skip to content

Commit

Permalink
fix: MIOpen only supports till dimension 5 (#601)
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal authored Jul 18, 2024
1 parent ddfe49b commit 86473fc
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 15 deletions.
16 changes: 8 additions & 8 deletions ext/NNlibAMDGPUExt/activations.jl
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
for (f, op) in [
NNlib.relu => MIOpen.relu,
NNlib.relu6 => x -> MIOpen.clippedrelu(x, 6),
NNlib.softplus => MIOpen.softrelu,
NNlib.σ => MIOpen.sigmoid,
Base.tanh => MIOpen.tanh,
# TODO define for leakyrelu, elu, etc.?
]
NNlib.relu => MIOpen.relu,
NNlib.relu6 => x -> MIOpen.clippedrelu(x, 6),
NNlib.softplus => MIOpen.softrelu,
NNlib.σ => MIOpen.sigmoid,
Base.tanh => MIOpen.tanh,
# TODO define for leakyrelu, elu, etc.?
], N in 1:5
@eval function Base.materialize(
bc::Broadcast.Broadcasted{<:Any,<:Any,typeof($f),<:Tuple{ROCArray{<:MIOPENFloat}}}
bc::Broadcast.Broadcasted{<:Any,<:Any,typeof($f),<:Tuple{ROCArray{<:MIOPENFloat,$N}}}
)
return $op(bc.args[1])
end
Expand Down
15 changes: 8 additions & 7 deletions test/ext_amdgpu/activations.jl
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
@testset "Compare CPU & GPU" begin
for (T, atol) in ((Float16, 1f-2), (Float32, 1f-5))
x = randn(T, 16)
gputest(x -> NNlib.relu.(x), x; atol)
gputest(x -> NNlib.relu6.(x), x; atol)
gputest(x -> NNlib.softplus.(x), x; atol)
gputest(x -> tanh.(x), x; atol)
gputest(x -> identity.(x), x; atol)
for (T, atol) in ((Float16, 1.0f-2), (Float32, 1.0f-5))
@testset "ndims: $(ndims(x))" for x in (randn(T, 16), randn(T, ntuple(_ -> 2, 5)...), randn(T, ntuple(_ -> 2, 6)...))
gputest(x -> NNlib.relu.(x), x; atol)
gputest(x -> NNlib.relu6.(x), x; atol)
gputest(x -> NNlib.softplus.(x), x; atol)
gputest(x -> tanh.(x), x; atol)
gputest(x -> identity.(x), x; atol)
end
end
end

0 comments on commit 86473fc

Please sign in to comment.