From 5c770e888ef7afc7cb2d9cab6b77077c85e64b08 Mon Sep 17 00:00:00 2001 From: Daniel Wennberg Date: Tue, 1 Oct 2024 00:01:44 -0700 Subject: [PATCH] Implement single source of truth for structure --- src/make_zero.jl | 38 +++++++++++++++++++++++--------------- 1 file changed, 23 insertions(+), 15 deletions(-) diff --git a/src/make_zero.jl b/src/make_zero.jl index 25777b41a0..90daf6aa08 100644 --- a/src/make_zero.jl +++ b/src/make_zero.jl @@ -21,8 +21,10 @@ element of `xs` such that `yi == x1i`. If `copy_if_inactive == false`, this is d sharing, `yi = x1i`; if `copy_if_inactive == true`, it is done by copying, `yi = deepcopy(x1i)`. -The first element of `xs` is also used to keep track of values that reference the same -memory. This structure is reproduced in the return value `y`. +Each element in `xs` is assumed to have the same structure as `x1 = first(xs)`, including +which fields, if any, reference the same memory or are undefined. This structure will be +mirrored in the return value `y`. If this assumption does not hold, errors or incorrect +results may occur. A function `isleaftype` can be provided to customize which types are considered leafs: values of type `T` such that `isleaftype(T) == true` are not recursed into, but instead @@ -67,10 +69,11 @@ end end nf = fieldcount(RT) + x1 = first(xs) if ismutabletype(RT) y = ccall(:jl_new_struct_uninit, Any, (Any,), RT) for i in 1:nf - if all(x -> isdefined(x, i), xs) + if isdefined(x1, i) yi = newyi(i) if Base.isconst(RT, i) ccall(:jl_set_nth_field, Cvoid, (Any, Csize_t, Any), y, i - 1, yi) @@ -82,15 +85,15 @@ end return y elseif nf == 0 return f(xs...)::RT - elseif all(x -> isdefined(x, nf), xs) + elseif isdefined(x1, nf) # fast path when all fields are set return splatnew(RT, ntuple(newyi, Val(nf))) else flds = Vector{Any}(undef, nf) nset = nf for i in 1:nf - if all(x -> isdefined(x, i), xs) - flds[i] = newyi(i) + if isdefined(x1, i) + @inbounds flds[i] = newyi(i) else nset = i - 1 # rest of tail must be undefined values break @@ -104,11 +107,12 @@ end ::Type{RT}, f::F, seen::IdDict, xs::NTuple{N,RT}, args... ) where {RT<:Array,F,N} y = RT(undef, size(first(xs))) - for I in eachindex(xs...) - if all(x -> isassigned(x, I), xs) + x1 = first(xs) + for I in eachindex(y, xs...) + @inbounds if isassigned(x1, I) xIs = ntuple(j -> xs[j][I], N) ST = Core.Typeof(first(xIs)) - @inbounds y[I] = recursive_map(ST, f, seen, xIs, args...) + y[I] = recursive_map(ST, f, seen, xIs, args...) end end return y @@ -125,6 +129,10 @@ equal `f(x1i, x2i, ..., xNi)`, where `x1i, x2i, ..., xNi` are the corresponding in the `xs`. Each subtree in `y` that can be proven by type to only contain non-differentiable values is left unchanged. +Each element in `xs` is assumed to have the same structure as `y`, including which fields, +if any, reference the same memory or are undefined. If this assumption does not hold, errors +or incorrect results may occur. + If every differentiable value in `y` is contained in a mutable object (i.e., `y` has inferred activity state Duplicated), this function performs a fully in-place update and returns `y`. If every differentiable value is held in immutable storage (i.e., `y` @@ -187,7 +195,7 @@ end return nothing end for i = 1:nf - if isdefined(y, i) && all(x -> isdefined(x, i), xs) + if isdefined(y, i) yi = getfield(y, i) xis = ntuple(j -> getfield(xs[j], i), N) newyi = recursive_map!!(f, yi, seen, xis, isleaftype) @@ -207,12 +215,12 @@ end f::F, y::Array{T,M}, seen, xs::NTuple{N,Array{T,M}}, isleaftype ) where {F,T,M,N} for I in eachindex(y, xs...) - if isassigned(y, I) && all(x -> isassigned(x, I), xs) + @inbounds if isassigned(y, I) yvalue = y[I] xvalues = ntuple(j -> xs[j][I], N) newyvalue = recursive_map!!(f, yvalue, seen, xvalues, isleaftype) if newyvalue !== yvalue - @inbounds y[I] = newyvalue + y[I] = newyvalue end end end @@ -235,15 +243,15 @@ end nf = fieldcount(T) if nf == 0 return f(xs...)::T - elseif isdefined(y, nf) && all(x -> isdefined(x, nf), xs) + elseif isdefined(y, nf) # fast path when all fields are set return splatnew(T, ntuple(newyi, Val(nf))) else flds = Vector{Any}(undef, nf) nset = nf for i = 1:nf - if isdefined(y, i) && all(x -> isdefined(x, i), xs) - flds[i] = newyi(i) + if isdefined(y, i) + @inbounds flds[i] = newyi(i) else nset = i - 1 # rest of tail must be undefined values break