Skip to content

Commit

Permalink
add recursive add to accumulate (#1213)
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses authored Dec 19, 2023
1 parent d35a0d0 commit 09df042
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 2 deletions.
25 changes: 24 additions & 1 deletion src/compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4964,12 +4964,35 @@ function jl_set_typeof(v::Ptr{Cvoid}, T)
return nothing
end

@generated function splatnew(::Type{T}, args::NTuple{N,AT}) where {T,N,AT}
return quote
Base.@_inline_meta
$(Expr(:splatnew, :T, :args))
end
end

@inline function recursive_add(x::T, y::T) where T
if guaranteed_const(T)
return x
end
splatnew(T, Val(fieldcount(T)) do i
Base.@_inline_meta
prev = getfield(x, i)
next = getfield(y, i)
recursive_add(prev, next)
end)
end

@inline function recursive_add(x::T, y::T) where {T<:AbstractFloat}
return x + y
end

function add_one_in_place(x)
ty = typeof(x)
# ptr = Base.pointer_from_objref(x)
ptr = unsafe_to_pointer(x)
if ty <: Base.RefValue || ty == Base.RefValue{Float64}
x[] += one(eltype(ty))
x[] = recursive_add(x[], one(eltype(ty)))
else
error("Enzyme Mutability Error: Cannot add one in place to immutable value "*string(x))
end
Expand Down
2 changes: 1 addition & 1 deletion src/rules/jitrules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,7 @@ function body_runtime_generic_rev(N, Width, wrapped, primttypes, shadowargs)
shad = shadowargs[i][w]
out = :(if $expr === nothing
elseif $shad isa Base.RefValue
$shad[] += $expr
$shad[] = recursive_add($shad[], $expr)
else
error("Enzyme Mutability Error: Cannot add one in place to immutable value "*string($shad))
end
Expand Down

0 comments on commit 09df042

Please sign in to comment.