Skip to content

Commit

Permalink
Implement single source of truth for structure
Browse files Browse the repository at this point in the history
  • Loading branch information
danielwe committed Oct 1, 2024
1 parent a91b179 commit 5c770e8
Showing 1 changed file with 23 additions and 15 deletions.
38 changes: 23 additions & 15 deletions src/make_zero.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,10 @@ element of `xs` such that `yi == x1i`. If `copy_if_inactive == false`, this is d
sharing, `yi = x1i`; if `copy_if_inactive == true`, it is done by copying,
`yi = deepcopy(x1i)`.
The first element of `xs` is also used to keep track of values that reference the same
memory. This structure is reproduced in the return value `y`.
Each element in `xs` is assumed to have the same structure as `x1 = first(xs)`, including
which fields, if any, reference the same memory or are undefined. This structure will be
mirrored in the return value `y`. If this assumption does not hold, errors or incorrect
results may occur.
A function `isleaftype` can be provided to customize which types are considered leafs:
values of type `T` such that `isleaftype(T) == true` are not recursed into, but instead
Expand Down Expand Up @@ -67,10 +69,11 @@ end
end

nf = fieldcount(RT)
x1 = first(xs)
if ismutabletype(RT)
y = ccall(:jl_new_struct_uninit, Any, (Any,), RT)
for i in 1:nf
if all(x -> isdefined(x, i), xs)
if isdefined(x1, i)
yi = newyi(i)
if Base.isconst(RT, i)
ccall(:jl_set_nth_field, Cvoid, (Any, Csize_t, Any), y, i - 1, yi)
Expand All @@ -82,15 +85,15 @@ end
return y
elseif nf == 0
return f(xs...)::RT
elseif all(x -> isdefined(x, nf), xs)
elseif isdefined(x1, nf)
# fast path when all fields are set
return splatnew(RT, ntuple(newyi, Val(nf)))
else
flds = Vector{Any}(undef, nf)
nset = nf
for i in 1:nf
if all(x -> isdefined(x, i), xs)
flds[i] = newyi(i)
if isdefined(x1, i)
@inbounds flds[i] = newyi(i)
else
nset = i - 1 # rest of tail must be undefined values
break
Expand All @@ -104,11 +107,12 @@ end
::Type{RT}, f::F, seen::IdDict, xs::NTuple{N,RT}, args...
) where {RT<:Array,F,N}
y = RT(undef, size(first(xs)))
for I in eachindex(xs...)
if all(x -> isassigned(x, I), xs)
x1 = first(xs)
for I in eachindex(y, xs...)
@inbounds if isassigned(x1, I)
xIs = ntuple(j -> xs[j][I], N)
ST = Core.Typeof(first(xIs))
@inbounds y[I] = recursive_map(ST, f, seen, xIs, args...)
y[I] = recursive_map(ST, f, seen, xIs, args...)
end
end
return y
Expand All @@ -125,6 +129,10 @@ equal `f(x1i, x2i, ..., xNi)`, where `x1i, x2i, ..., xNi` are the corresponding
in the `xs`. Each subtree in `y` that can be proven by type to only contain
non-differentiable values is left unchanged.
Each element in `xs` is assumed to have the same structure as `y`, including which fields,
if any, reference the same memory or are undefined. If this assumption does not hold, errors
or incorrect results may occur.
If every differentiable value in `y` is contained in a mutable object (i.e., `y` has
inferred activity state Duplicated), this function performs a fully in-place update and
returns `y`. If every differentiable value is held in immutable storage (i.e., `y`
Expand Down Expand Up @@ -187,7 +195,7 @@ end
return nothing
end
for i = 1:nf
if isdefined(y, i) && all(x -> isdefined(x, i), xs)
if isdefined(y, i)
yi = getfield(y, i)
xis = ntuple(j -> getfield(xs[j], i), N)
newyi = recursive_map!!(f, yi, seen, xis, isleaftype)
Expand All @@ -207,12 +215,12 @@ end
f::F, y::Array{T,M}, seen, xs::NTuple{N,Array{T,M}}, isleaftype
) where {F,T,M,N}
for I in eachindex(y, xs...)
if isassigned(y, I) && all(x -> isassigned(x, I), xs)
@inbounds if isassigned(y, I)
yvalue = y[I]
xvalues = ntuple(j -> xs[j][I], N)
newyvalue = recursive_map!!(f, yvalue, seen, xvalues, isleaftype)
if newyvalue !== yvalue
@inbounds y[I] = newyvalue
y[I] = newyvalue
end
end
end
Expand All @@ -235,15 +243,15 @@ end
nf = fieldcount(T)
if nf == 0
return f(xs...)::T
elseif isdefined(y, nf) && all(x -> isdefined(x, nf), xs)
elseif isdefined(y, nf)
# fast path when all fields are set
return splatnew(T, ntuple(newyi, Val(nf)))
else
flds = Vector{Any}(undef, nf)
nset = nf
for i = 1:nf
if isdefined(y, i) && all(x -> isdefined(x, i), xs)
flds[i] = newyi(i)
if isdefined(y, i)
@inbounds flds[i] = newyi(i)
else
nset = i - 1 # rest of tail must be undefined values
break
Expand Down

0 comments on commit 5c770e8

Please sign in to comment.