Skip to content

Commit

Permalink
small fix for static array onehot (#1732)
Browse files Browse the repository at this point in the history
* small fix for static array onehot

* weaken tests for old julia versions
  • Loading branch information
ExpandingMan authored Aug 21, 2024
1 parent ffc1035 commit c75bbd4
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 1 deletion.
2 changes: 1 addition & 1 deletion ext/EnzymeStaticArraysExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ end
ntuple(Val(endl-start+1)) do i
Base.@_inline_meta
StaticArrays.SArray{S, T, N, L}(
ntuple(Val(N)) do idx
ntuple(Val(L)) do idx
Base.@_inline_meta
return (i + start - 1 == idx) ? 1.0 : 0.0
end)
Expand Down
16 changes: 16 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2774,6 +2774,22 @@ end
@test dx isa SArray
@test dx [0 30 0]

x = @SVector [1.0, 2.0, 3.0]
y = onehot(x)
# this should be a very specific type of SArray, but there
# is a bizarre issue with older julia versions where it can be MArray
@test eltype(y) <: StaticVector
@test length(y) == 3
@test y[1] == [1.0, 0.0, 0.0]
@test y[2] == [0.0, 1.0, 0.0]
@test y[3] == [0.0, 0.0, 1.0]

y = onehot(x, 2, 3)
@test eltype(y) <: StaticVector
@test length(y) == 2
@test y[1] == [0.0, 1.0, 0.0]
@test y[2] == [0.0, 0.0, 1.0]

@static if VERSION v"1.9-"
x = @SArray [5.0 0.0 6.0]
dx = Enzyme.gradient(Forward, prod, x)
Expand Down

0 comments on commit c75bbd4

Please sign in to comment.