Skip to content

Commit

Permalink
Fix make_zero(!) corner case bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
danielwe committed Oct 7, 2024
1 parent 3c0871d commit 7a6ca9f
Showing 1 changed file with 47 additions and 18 deletions.
65 changes: 47 additions & 18 deletions src/make_zero.jl
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ end
prev::Complex{RT},
::Val{copy_if_inactive} = Val(false),
)::Complex{RT} where {copy_if_inactive,RT<:AbstractFloat}
return RT(0)
return Complex{RT}(0)
end

@inline function EnzymeCore.make_zero(
Expand Down Expand Up @@ -117,11 +117,10 @@ end
return seen[prev]
end
prev2 = prev.contents
res = Core.Box()
seen[prev] = res
res.contents = Base.Ref(
res = Core.Box(
EnzymeCore.make_zero(Core.Typeof(prev2), seen, prev2, Val(copy_if_inactive)),
)
seen[prev] = res
return res
end

Expand Down Expand Up @@ -160,7 +159,12 @@ end
end

if nf == 0
return prev
# I can't think of a type that would wind up here and not get caught by specialized
# methods or guaranteed_const_nongen, but I guess if it has a zero method...
if applicable(Base.zero, prev)
return Base.zero(prev)::RT
end
error("zero of type $RT not defined")
end

flds = Vector{Any}(undef, nf)
Expand All @@ -187,26 +191,43 @@ function make_zero_immutable!(
prev::Complex{T},
seen::S,
)::Complex{T} where {T<:AbstractFloat,S}
zero(T)
zero(Complex{T})
end

function make_zero_immutable!(prev::T, seen::S)::T where {T<:Tuple,S}
ntuple(Val(length(T.parameters))) do i
Base.@_inline_meta
make_zero_immutable!(prev[i], seen)
p = prev[i]
SBT = Core.Typeof(p)
if active_reg_inner(SBT, (), nothing, Val(true)) == ActiveState #=justActive=#
make_zero_immutable!(p, seen)
else
EnzymeCore.make_zero!(p, seen)
p
end
end
end

function make_zero_immutable!(prev::NamedTuple{a,b}, seen::S)::NamedTuple{a,b} where {a,b,S}
NamedTuple{a,b}(ntuple(Val(length(T.parameters))) do i
NamedTuple{a,b}(ntuple(Val(length(b.parameters))) do i
Base.@_inline_meta
make_zero_immutable!(prev[a[i]], seen)
p = prev[a[i]]
SBT = Core.Typeof(p)
if active_reg_inner(SBT, (), nothing, Val(true)) == ActiveState #=justActive=#
make_zero_immutable!(p, seen)
else
EnzymeCore.make_zero!(p, seen)
p
end
end)
end


function make_zero_immutable!(prev::T, seen::S)::T where {T,S}
if guaranteed_const_nongen(T, nothing)
# This branch will never be reached from make_zero!, as `make_zero_immutable!` is
# only called when active_reg_inner(T, (), nothing, Val(true)) == ActiveStyle, which
# implies guaranteed_const_nongen(T, nothing) == false
return prev
end
@assert !ismutable(prev)
Expand Down Expand Up @@ -239,15 +260,15 @@ end
prev::Base.RefValue{T},
seen::ST,
)::Nothing where {T<:AbstractFloat,ST}
T[] = zero(T)
prev[] = zero(T)
nothing
end

@inline function EnzymeCore.make_zero!(
prev::Base.RefValue{Complex{T}},
seen::ST,
)::Nothing where {T<:AbstractFloat,ST}
T[] = zero(Complex{T})
prev[] = zero(Complex{T})
nothing
end

Expand Down Expand Up @@ -297,7 +318,7 @@ end
if guaranteed_const_nongen(T, nothing)
return
end
if in(seen, prev)
if prev in seen
return
end
push!(seen, prev)
Expand Down Expand Up @@ -325,7 +346,7 @@ end
if guaranteed_const_nongen(T, nothing)
return
end
if in(seen, prev)
if prev in seen
return
end
push!(seen, prev)
Expand All @@ -348,13 +369,13 @@ end
if guaranteed_const_nongen(T, nothing)
return
end
if in(seen, prev)
if prev in seen
return
end
push!(seen, prev)
SBT = Core.Typeof(pv)
if active_reg_inner(SBT, (), nothing, Val(true)) == ActiveState #=justActive=#
prev.contents = EnzymeCore.make_zero_immutable!(pv, seen)
prev.contents = make_zero_immutable!(pv, seen)
nothing
else
EnzymeCore.make_zero!(pv, seen)
Expand All @@ -370,7 +391,7 @@ end
if guaranteed_const_nongen(T, nothing)
return
end
if in(prev, seen)
if prev in seen
return
end
@assert !Base.isabstracttype(T)
Expand All @@ -379,7 +400,10 @@ end


if nf == 0
return
# I can't think of a type that would wind up here and not get caught by specialized
# methods or guaranteed_const_nongen under valid use of make_zero!, however,
# nonsensical things like make_zero!(::Float64) gets you here
error("cannot zero $T in-place: it is apparently differentiable but has no fields")
end

push!(seen, prev)
Expand All @@ -392,7 +416,12 @@ end
continue
end
if active_reg_inner(SBT, (), nothing, Val(true)) == ActiveState #=justActive=#
setfield!(prev, i, make_zero_immutable!(xi, seen))
yi = make_zero_immutable!(xi, seen)
if Base.isconst(T, i)
ccall(:jl_set_nth_field, Cvoid, (Any, Csize_t, Any), prev, i-1, yi)
else
setfield!(prev, i, yi)
end
nothing
else
EnzymeCore.make_zero!(xi, seen)
Expand Down

0 comments on commit 7a6ca9f

Please sign in to comment.