diff --git a/src/absint.jl b/src/absint.jl index 585b1625a3..03cf53cf4e 100644 --- a/src/absint.jl +++ b/src/absint.jl @@ -158,7 +158,7 @@ function absint(arg::LLVM.Value, partial::Bool = false) end function actual_size(@nospecialize(typ2)) - if typ2 <: Array || typ2 <: AbstractString + if typ2 <: Array || typ2 <: AbstractString || typ2 <: Symbol return sizeof(Int) elseif Base.isconcretetype(typ2) return sizeof(typ2) @@ -359,52 +359,78 @@ function abs_typeof( end if byref == GPUCompiler.BITS_REF || byref == GPUCompiler.MUT_REF dl = LLVM.datalayout(LLVM.parent(LLVM.parent(LLVM.parent(arg)))) - if offset === nothing - byref = GPUCompiler.BITS_VALUE - legal = true - typ2 = typ - while actual_size(typ2) != sizeof(dl, value_type(arg)) - if fieldcount(typ2) > 0 - typ2 = fieldtype(typ, 1) - if !Base.allocatedinline(typ2) - if byref != GPUCompiler.BITS_VALUE - legal = false - break + function should_recurse(typ2, arg_t) + if actual_size(typ2) != sizeof(dl, arg_t) + return true + else + if Base.isconcretetype(typ2) + if fieldcount(typ2) > 0 + if actual_size(fieldtype(typ2,1)) == actual_size(fieldtype(typ2, 1)) + return true end - byref = GPUCompiler.MUT_REF - continue end end - legal = false - break - end - if legal - return (true, typ2, byref) + return false end - else + end + + byref = GPUCompiler.BITS_VALUE + legal = true + + while offset !== nothing && legal @assert Base.isconcretetype(typ) + seen = false + lasti = 1 for i = 1:fieldcount(typ) + fo = fieldoffset(typ, i) if fieldoffset(typ, i) == offset - subT = fieldtype(typ, i) - fsize = if i == fieldcount(typ) - sizeof(typ) - else - fieldoffset(typ, i + 1) - end - offset - if fsize == sizeof(dl, value_type(arg)) - if Base.isconcretetype(subT) && - is_concrete_tuple(subT) && - length(subT.parameters) == 1 - subT = subT.parameters[1] - end - if Base.allocatedinline(subT) - return (true, subT, GPUCompiler.BITS_VALUE) - else - return (true, subT, GPUCompiler.MUT_REF) + offset = nothing + typ = fieldtype(typ, i) + if !Base.allocatedinline(typ) + if byref != GPUCompiler.BITS_VALUE + legal = false end + byref = GPUCompiler.MUT_REF + end + seen = true + break + elseif fieldoffset(typ, i) > offset + offset = offset - fieldoffset(typ, lasti) + typ = fieldtype(typ, lasti) + if !Base.allocatedinline(typ) + legal = false + end + seen = true + break + end + + if fo != 0 && fo != fieldoffset(typ, i-1) + lasti = i + end + end + if !seen + legal = false + end + end + + typ2 = typ + while should_recurse(typ2, value_type(arg)) + if fieldcount(typ2) > 0 + typ2 = fieldtype(typ2, 1) + if !Base.allocatedinline(typ2) + if byref != GPUCompiler.BITS_VALUE + legal = false + break end + byref = GPUCompiler.MUT_REF + continue end end + legal = false + break + end + if legal + return (true, typ2, byref) end end elseif legal && if typ <: Ptr && Base.isconcretetype(typ) diff --git a/test/internal_rules.jl b/test/internal_rules.jl index 32a206c62e..0d5bbdae01 100644 --- a/test/internal_rules.jl +++ b/test/internal_rules.jl @@ -591,7 +591,7 @@ end TM in (Const, Duplicated, BatchDuplicated), TB in (Const, Duplicated, BatchDuplicated) are_activities_compatible(Const, TY, TM, TB) || continue - test_reverse(f!, TY, (Y, TY), (M, TM), (B, TB), (_A, Const)) + test_reverse(f!, TY, (Y, TY), (M, TM), (B, TB), (_A, Const); atol = 1.0e-5, rtol = 1.0e-5) end end @testset "test through `Adjoint` wrapper (regression test for #1306)" begin diff --git a/test/runtests.jl b/test/runtests.jl index 92bfa47513..69e6d51cd5 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -3959,6 +3959,35 @@ function harmonic_f!(inter_list, coords, inters) return si end +function invwsumsq(w::AbstractVector, a::AbstractVector) + s = zero(zero(eltype(a)) / zero(eltype(w))) + for i in eachindex(w) + s += abs2(a[i]) / w[i] + end + return s +end + +_logpdf(d, x) = invwsumsq(d.Σ.diag, x .- d.μ) + +function demo_func(x::Any=transpose([1.5 2.0;]);) + m = [-0.30725218207431315, 0.5492115788562757] + d = (; Σ = LinearAlgebra.Diagonal([1.0, 1.0]), μ = m) + logp = _logpdf(d, reshape(x, (2,))) + return logp +end + +demof(x) = demo_func() + +@testset "Type checks" begin + x = [0.0, 0.0] + Enzyme.autodiff( + Enzyme.Reverse, + Enzyme.Const(demof), + Enzyme.Active, + Enzyme.Duplicated(x, zero(x)), + ) +end + @testset "Decay preservation" begin inters = [HarmonicAngle(1.0, 0.1), HarmonicAngle(2.0, 0.3)] inter_list = [1, 3]