diff --git a/ext/EnzymeStaticArraysExt.jl b/ext/EnzymeStaticArraysExt.jl index eaae75cccb..c2639a4c99 100644 --- a/ext/EnzymeStaticArraysExt.jl +++ b/ext/EnzymeStaticArraysExt.jl @@ -12,6 +12,8 @@ end reshape(reduce(hcat, map(vec, rows)), Size(outshape..., inshape...)) end +@inline Enzyme.specialize_output(output, input::StaticArray) = convert(SArray, output) + @inline function Enzyme.onehot(x::StaticArrays.SArray{S, T, N, L}) where {S, T, N, L} ntuple(Val(L)) do i Base.@_inline_meta diff --git a/src/Enzyme.jl b/src/Enzyme.jl index 51b40cee37..7caa06c281 100644 --- a/src/Enzyme.jl +++ b/src/Enzyme.jl @@ -1884,6 +1884,8 @@ end end end +@inline specialize_output(output, input) = output + """ gradient(::ForwardMode, f, x; shadows=onehot(x), chunk=nothing) @@ -2135,11 +2137,11 @@ gradient(Forward, mul, [2.0, 3.0], Const([2.7, 3.1])) # st : outshape x total inputs tupstack($tmp, outshape, inshape) else - TupleArray($tmp, size($arg)) + specialize_output(TupleArray($tmp, size($arg)), $(vals[1])) end end else - :(TupleArray($tmp, size($arg))) + :(specialize_output(TupleArray($tmp, size($arg)), $(vals[1]))) end else tmp diff --git a/test/runtests.jl b/test/runtests.jl index c3856aabf1..ba206cb536 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -2860,9 +2860,11 @@ end x = @SVector Float64[1, 2] + + @inferred gradient(Forward, f0, x) dx = gradient(Forward, f0, x)[1] - @test dx isa Enzyme.TupleArray - @test convert(SArray, dx) == [2.0, 2.0] # test to make sure conversion works + @test dx isa SVector + @test dx == [2.0, 2.0] # test to make sure conversion works @test gradient(Forward, f1, x)[1] isa SMatrix @test gradient(Forward, f1, x)[1] == [0 1.0; 0 2.0] @test Enzyme.jacobian(Forward, f2, x)[1] isa SArray @@ -2870,9 +2872,10 @@ end x = @SMatrix Float64[1 2; 3 4] + @inferred gradient(Forward, f0, x) dx = gradient(Forward, f0, x)[1] - @test dx isa Enzyme.TupleArray - @test convert(SArray, dx) == fill(2.0, (2,2)) + @test dx isa SVector + @test dx == fill(2.0, (2,2)) @test gradient(Forward, f1, x)[1] isa SArray @test gradient(Forward, f1, x)[1] == reshape(Float64[0,0,1,2,0,0,0,0], (2,2,2)) @test Enzyme.jacobian(Forward, f2, x)[1] isa SArray @@ -2882,9 +2885,10 @@ end x = @SVector Float64[1, 2] + @inferred gradient(Reverse, f0, x) dx = gradient(Reverse, f0, x)[1] @test dx isa SVector - @test convert(SArray, dx) == [2.0, 2.0] # test to make sure conversion works + @test dx == [2.0, 2.0] # test to make sure conversion works @test_broken gradient(Reverse, f1, x)[1] isa SMatrix @test_broken gradient(Reverse, f1, x)[1] == [0 1.0; 0 2.0] @test_broken Enzyme.jacobian(Reverse, f2, x)[1] isa SArray