diff --git a/src/compiler.jl b/src/compiler.jl index a85bc90b52..6f21fb399d 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -9294,7 +9294,7 @@ function GPUCompiler.codegen(output::Symbol, job::CompilerJob{<:EnzymeTarget}; if params.run_enzyme # Generate the adjoint - jlrules = String[] + jlrules = String["enzyme_custom"] for (fname, (ftyp, mi)) in foundTys haskey(functions(mod), fname) || continue push!(jlrules, fname) diff --git a/src/internal_rules.jl b/src/internal_rules.jl index 1be9c157b1..2e61ce9cc6 100644 --- a/src/internal_rules.jl +++ b/src/internal_rules.jl @@ -428,3 +428,61 @@ function EnzymeRules.reverse(config, func::Const{typeof(\)}, ::Type{RT}, cache, return (nothing,nothing) end + +# Force a rule around hvcat_fill as it is type unstable if the tuple is not of the same type (e.g., int, float, int, float) +function EnzymeRules.augmented_primal(config, func::Const{typeof(Base.hvcat_fill!)}, ::Type{RT}, out::Annotation{AT}, inp::Annotation{BT}) where {RT, AT <: Array, BT <: Tuple} + primal = if EnzymeRules.needs_primal(config) + out.val + else + nothing + end + shadow = if EnzymeRules.needs_shadow(config) + out.dval + else + nothing + end + func.val(out.val, inp.val) + + if EnzymeRules.width(config) == 1 + out.dval .= 0 + else + for i in 1:EnzymeRules.width(config) + out.dval[i] .= 0 + end + end + + return EnzymeRules.AugmentedReturn(primal, shadow, nothing) +end + +function EnzymeRules.reverse(config, func::Const{typeof(Base.hvcat_fill!)}, ::Type{RT}, _, out::Annotation{AT}, inp::Annotation{BT}) where {RT, AT <: Array, BT <: Tuple} + nr, nc = size(out.val,1), size(out.val,2) + for b in 1:EnzymeRules.width(config) + da = if EnzymeRules.width(config) == 1 + out.dval + else + out.dval[b] + end + i = 1 + j = 1 + if (typeof(inp) <: Active) + dinp = ntuple(Val(length(inp.val))) do k + Base.@_inline_meta + res = da[i, j] + da[i, j] = 0 + j += 1 + if j == nc+1 + i += 1 + j = 1 + end + T = BT.parameters[k] + if T <: AbstractFloat + T(res) + else + T(0) + end + end + return (nothing, dinp)::Tuple{Nothing, BT} + end + end + return (nothing, nothing) +end diff --git a/test/runtests.jl b/test/runtests.jl index 2855206cac..7dee77817a 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -2543,3 +2543,18 @@ end y = A \ b @test dA ≈ (-z * transpose(y)) end + +@testset "hvcat_fill" begin + ar = Matrix{Float64}(undef, 2, 3) + dar = [1.0 2.0 3.0; 4.0 5.0 6.0] + + res = Enzyme.autodiff(Reverse, Base.hvcat_fill!, Const, Duplicated(ar, dar), Active((1, 2.2, 3, 4.4, 5, 6.6))) + + @test res[2][1] == 0 + @test res[2][2] ≈ 2.0 + @test res[2][3] ≈ 0 + @test res[2][4] ≈ 4.0 + @test res[2][5] ≈ 0 + @test res[2][6] ≈ 6.0 +end +