From c75bbd4610b73c57e4acebaa852974abca44a538 Mon Sep 17 00:00:00 2001 From: ExpandingMan Date: Tue, 20 Aug 2024 21:28:43 -0400 Subject: [PATCH] small fix for static array onehot (#1732) * small fix for static array onehot * weaken tests for old julia versions --- ext/EnzymeStaticArraysExt.jl | 2 +- test/runtests.jl | 16 ++++++++++++++++ 2 files changed, 17 insertions(+), 1 deletion(-) diff --git a/ext/EnzymeStaticArraysExt.jl b/ext/EnzymeStaticArraysExt.jl index 672d1c03bc..6dbd390cb7 100644 --- a/ext/EnzymeStaticArraysExt.jl +++ b/ext/EnzymeStaticArraysExt.jl @@ -14,7 +14,7 @@ end ntuple(Val(endl-start+1)) do i Base.@_inline_meta StaticArrays.SArray{S, T, N, L}( - ntuple(Val(N)) do idx + ntuple(Val(L)) do idx Base.@_inline_meta return (i + start - 1 == idx) ? 1.0 : 0.0 end) diff --git a/test/runtests.jl b/test/runtests.jl index 114dfb6833..113eb6f531 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -2774,6 +2774,22 @@ end @test dx isa SArray @test dx ≈ [0 30 0] + x = @SVector [1.0, 2.0, 3.0] + y = onehot(x) + # this should be a very specific type of SArray, but there + # is a bizarre issue with older julia versions where it can be MArray + @test eltype(y) <: StaticVector + @test length(y) == 3 + @test y[1] == [1.0, 0.0, 0.0] + @test y[2] == [0.0, 1.0, 0.0] + @test y[3] == [0.0, 0.0, 1.0] + + y = onehot(x, 2, 3) + @test eltype(y) <: StaticVector + @test length(y) == 2 + @test y[1] == [0.0, 1.0, 0.0] + @test y[2] == [0.0, 0.0, 1.0] + @static if VERSION ≥ v"1.9-" x = @SArray [5.0 0.0 6.0] dx = Enzyme.gradient(Forward, prod, x)