Skip to content

Commit

Permalink
Static array return for forward gradient
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses committed Oct 15, 2024
1 parent c0c5e51 commit 422e1df
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 7 deletions.
2 changes: 2 additions & 0 deletions ext/EnzymeStaticArraysExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ end
reshape(reduce(hcat, map(vec, rows)), Size(outshape..., inshape...))
end

@inline Enzyme.specialize_output(output, input::StaticArray) = convert(SArray, output)

@inline function Enzyme.onehot(x::StaticArrays.SArray{S, T, N, L}) where {S, T, N, L}
ntuple(Val(L)) do i
Base.@_inline_meta
Expand Down
6 changes: 4 additions & 2 deletions src/Enzyme.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1884,6 +1884,8 @@ end
end
end

@inline specialize_output(output, input) = output

"""
gradient(::ForwardMode, f, x; shadows=onehot(x), chunk=nothing)
Expand Down Expand Up @@ -2135,11 +2137,11 @@ gradient(Forward, mul, [2.0, 3.0], Const([2.7, 3.1]))
# st : outshape x total inputs
tupstack($tmp, outshape, inshape)
else
TupleArray($tmp, size($arg))
specialize_output(TupleArray($tmp, size($arg)), $(vals[1]))
end
end
else
:(TupleArray($tmp, size($arg)))
:(specialize_output(TupleArray($tmp, size($arg)), $(vals[1])))
end
else
tmp
Expand Down
14 changes: 9 additions & 5 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2860,19 +2860,22 @@ end

x = @SVector Float64[1, 2]


@inferred gradient(Forward, f0, x)
dx = gradient(Forward, f0, x)[1]
@test dx isa Enzyme.TupleArray
@test convert(SArray, dx) == [2.0, 2.0] # test to make sure conversion works
@test dx isa SVector
@test dx == [2.0, 2.0] # test to make sure conversion works
@test gradient(Forward, f1, x)[1] isa SMatrix
@test gradient(Forward, f1, x)[1] == [0 1.0; 0 2.0]
@test Enzyme.jacobian(Forward, f2, x)[1] isa SArray
@test Enzyme.jacobian(Forward, f2, x)[1] == reshape(Float64[0,0,1,2,1,2,0,0], (2,2,2))

x = @SMatrix Float64[1 2; 3 4]

@inferred gradient(Forward, f0, x)
dx = gradient(Forward, f0, x)[1]
@test dx isa Enzyme.TupleArray
@test convert(SArray, dx) == fill(2.0, (2,2))
@test dx isa SVector
@test dx == fill(2.0, (2,2))
@test gradient(Forward, f1, x)[1] isa SArray
@test gradient(Forward, f1, x)[1] == reshape(Float64[0,0,1,2,0,0,0,0], (2,2,2))
@test Enzyme.jacobian(Forward, f2, x)[1] isa SArray
Expand All @@ -2882,9 +2885,10 @@ end

x = @SVector Float64[1, 2]

@inferred gradient(Reverse, f0, x)
dx = gradient(Reverse, f0, x)[1]
@test dx isa SVector
@test convert(SArray, dx) == [2.0, 2.0] # test to make sure conversion works
@test dx == [2.0, 2.0] # test to make sure conversion works
@test_broken gradient(Reverse, f1, x)[1] isa SMatrix
@test_broken gradient(Reverse, f1, x)[1] == [0 1.0; 0 2.0]
@test_broken Enzyme.jacobian(Reverse, f2, x)[1] isa SArray
Expand Down

0 comments on commit 422e1df

Please sign in to comment.