diff --git a/ext/EnzymeStaticArraysExt.jl b/ext/EnzymeStaticArraysExt.jl index af31d405d7..bf3bfbd177 100644 --- a/ext/EnzymeStaticArraysExt.jl +++ b/ext/EnzymeStaticArraysExt.jl @@ -12,6 +12,8 @@ end reshape(reduce(hcat, map(vec, rows)), Size(inshape..., outshape...)) end +@inline Enzyme.specialize_output(output, input::StaticArray) = convert(SArray, output) + @inline function Enzyme.onehot(x::StaticArrays.SArray{S, T, N, L}) where {S, T, N, L} ntuple(Val(L)) do i Base.@_inline_meta diff --git a/src/Enzyme.jl b/src/Enzyme.jl index 2e7789d660..8ade350d3c 100644 --- a/src/Enzyme.jl +++ b/src/Enzyme.jl @@ -1847,6 +1847,8 @@ end end end +@inline specialize_output(output, input) = output + """ gradient(::ForwardMode, f, x; shadows=onehot(x), chunk=nothing) @@ -2010,7 +2012,7 @@ grad = gradient(Forward, f, [2.0, 3.0, 4.0]) # st : outshape x total inputs tupstack(cols, outshape, inshape) elseif x isa AbstractArray - TupleArray(cols, size(x)) + specialize_output(TupleArray(cols, size(x)), x) else cols end diff --git a/test/runtests.jl b/test/runtests.jl index 902b9e4f65..f6aeef835a 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -2869,9 +2869,11 @@ end x = @SVector Float64[1, 2] + + @inferred gradient(Forward, f0, x) dx = gradient(Forward, f0, x)[1] - @test dx isa Enzyme.TupleArray - @test convert(SArray, dx) == [2.0, 2.0] # test to make sure conversion works + @test dx isa SVector + @test dx == [2.0, 2.0] # test to make sure conversion works @test gradient(Forward, f1, x)[1] isa SMatrix @test gradient(Forward, f1, x)[1] == [0 1.0; 0 2.0] @test Enzyme.jacobian(Forward, f2, x)[1] isa SArray @@ -2879,9 +2881,10 @@ end x = @SMatrix Float64[1 2; 3 4] + @inferred gradient(Forward, f0, x) dx = gradient(Forward, f0, x)[1] - @test dx isa Enzyme.TupleArray - @test convert(SArray, dx) == fill(2.0, (2,2)) + @test dx isa SVector + @test dx == fill(2.0, (2,2)) @test gradient(Forward, f1, x)[1] isa SArray @test gradient(Forward, f1, x)[1] == reshape(Float64[0,0,1,2,0,0,0,0], (2,2,2)) @test Enzyme.jacobian(Forward, f2, x)[1] isa SArray @@ -2891,9 +2894,10 @@ end x = @SVector Float64[1, 2] + @inferred gradient(Reverse, f0, x) dx = gradient(Reverse, f0, x)[1] @test dx isa SVector - @test convert(SArray, dx) == [2.0, 2.0] # test to make sure conversion works + @test dx == [2.0, 2.0] # test to make sure conversion works @test_broken gradient(Reverse, f1, x)[1] isa SMatrix @test_broken gradient(Reverse, f1, x)[1] == [0 1.0; 0 2.0] @test_broken Enzyme.jacobian(Reverse, f2, x)[1] isa SArray