Skip to content

Commit

Permalink
Static array return for forward gradient
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses committed Oct 8, 2024
1 parent 3c0871d commit 9f0e9f8
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 6 deletions.
2 changes: 2 additions & 0 deletions ext/EnzymeStaticArraysExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ end
reshape(reduce(hcat, map(vec, rows)), Size(inshape..., outshape...))
end

@inline Enzyme.specialize_output(output, input::StaticArray) = convert(SArray, output)

@inline function Enzyme.onehot(x::StaticArrays.SArray{S, T, N, L}) where {S, T, N, L}
ntuple(Val(L)) do i
Base.@_inline_meta
Expand Down
4 changes: 3 additions & 1 deletion src/Enzyme.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1847,6 +1847,8 @@ end
end
end

@inline specialize_output(output, input) = output

"""
gradient(::ForwardMode, f, x; shadows=onehot(x), chunk=nothing)
Expand Down Expand Up @@ -2010,7 +2012,7 @@ grad = gradient(Forward, f, [2.0, 3.0, 4.0])
# st : outshape x total inputs
tupstack(cols, outshape, inshape)
elseif x isa AbstractArray
TupleArray(cols, size(x))
specialize_output(TupleArray(cols, size(x)), x)
else
cols
end
Expand Down
14 changes: 9 additions & 5 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2869,19 +2869,22 @@ end

x = @SVector Float64[1, 2]


@inferred gradient(Forward, f0, x)
dx = gradient(Forward, f0, x)[1]
@test dx isa Enzyme.TupleArray
@test convert(SArray, dx) == [2.0, 2.0] # test to make sure conversion works
@test dx isa SVector
@test dx == [2.0, 2.0] # test to make sure conversion works
@test gradient(Forward, f1, x)[1] isa SMatrix
@test gradient(Forward, f1, x)[1] == [0 1.0; 0 2.0]
@test Enzyme.jacobian(Forward, f2, x)[1] isa SArray
@test Enzyme.jacobian(Forward, f2, x)[1] == reshape(Float64[0,0,1,2,1,2,0,0], (2,2,2))

x = @SMatrix Float64[1 2; 3 4]

@inferred gradient(Forward, f0, x)
dx = gradient(Forward, f0, x)[1]
@test dx isa Enzyme.TupleArray
@test convert(SArray, dx) == fill(2.0, (2,2))
@test dx isa SVector
@test dx == fill(2.0, (2,2))
@test gradient(Forward, f1, x)[1] isa SArray
@test gradient(Forward, f1, x)[1] == reshape(Float64[0,0,1,2,0,0,0,0], (2,2,2))
@test Enzyme.jacobian(Forward, f2, x)[1] isa SArray
Expand All @@ -2891,9 +2894,10 @@ end

x = @SVector Float64[1, 2]

@inferred gradient(Reverse, f0, x)
dx = gradient(Reverse, f0, x)[1]
@test dx isa SVector
@test convert(SArray, dx) == [2.0, 2.0] # test to make sure conversion works
@test dx == [2.0, 2.0] # test to make sure conversion works
@test_broken gradient(Reverse, f1, x)[1] isa SMatrix
@test_broken gradient(Reverse, f1, x)[1] == [0 1.0; 0 2.0]
@test_broken Enzyme.jacobian(Reverse, f2, x)[1] isa SArray
Expand Down

0 comments on commit 9f0e9f8

Please sign in to comment.