Skip to content

Commit

Permalink
add make_zero!
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses committed Jun 7, 2024
1 parent ee3bffb commit 4170e47
Show file tree
Hide file tree
Showing 7 changed files with 203 additions and 13 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
[compat]
CEnum = "0.4, 0.5"
ChainRulesCore = "1"
EnzymeCore = "0.7.3"
EnzymeCore = "0.7.4"
Enzyme_jll = "0.0.119"
GPUCompiler = "0.21, 0.22, 0.23, 0.24, 0.25, 0.26"
LLVM = "6.1, 7"
Expand Down
8 changes: 4 additions & 4 deletions examples/custom_rule.jl
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ function forward(func::Const{typeof(f)}, RT::Type{<:Union{Const, DuplicatedNoNee
if !(x isa Const) && !(y isa Const)
y.dval .= 2 .* x.val .* x.dval
elseif !(y isa Const)
y.dval .= 0
make_zero!(y.dval)
end
dret = !(y isa Const) ? sum(y.dval) : zero(eltype(y.val))
if RT <: Const
Expand Down Expand Up @@ -211,7 +211,7 @@ function reverse(config::ConfigWidth{1}, func::Const{typeof(f)}, dret::Active, t
x.dval .+= 2 .* xval .* dret.val
## also accumulate any derivative in y's shadow into x's shadow.
x.dval .+= 2 .* xval .* y.dval
y.dval .= 0
make_zero!(y.dval)
return (nothing, nothing)
end

Expand Down Expand Up @@ -251,8 +251,8 @@ end

x = [3.0, 1.0]
y = [0.0, 0.0]
dx .= 0
dy .= 0
make_zero!(dx)
make_zero!(dy)

autodiff(Reverse, h, Duplicated(y, dy), Duplicated(x, dx))
@show dx # derivative of h w.r.t. x
Expand Down
2 changes: 1 addition & 1 deletion lib/EnzymeCore/Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "EnzymeCore"
uuid = "f151be2c-9106-41f4-ab19-57ee4f262869"
authors = ["William Moses <wmoses@mit.edu>", "Valentin Churavy <vchuravy@mit.edu>"]
version = "0.7.3"
version = "0.7.4"

[compat]
Adapt = "3, 4"
Expand Down
4 changes: 2 additions & 2 deletions lib/EnzymeCore/src/EnzymeCore.jl
Original file line number Diff line number Diff line change
Expand Up @@ -226,10 +226,10 @@ function autodiff_deferred_thunk end
Recursively make a zero'd copy of the value `prev` of type `T`. The argument `copy_if_inactive` specifies
what to do if the type `T` is guaranteed to be inactive, use the primal (the default) or still copy the value.
"""
function make_zero
function make_zero end

"""
make_zero!(prev::T)::T
make_zero!(val::T, seen::IdSet{Any}=IdSet())::Nothing
Recursively set a variables differentiable fields to zero. Only applicable for mutable types `T`.
"""
Expand Down
6 changes: 3 additions & 3 deletions src/Enzyme.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@ export BatchDuplicatedFunc
import EnzymeCore: batch_size, get_func
export batch_size, get_func

import EnzymeCore: autodiff, autodiff_deferred, autodiff_thunk, autodiff_deferred_thunk, tape_type, make_zero
export autodiff, autodiff_deferred, autodiff_thunk, autodiff_deferred_thunk, tape_type, make_zero
import EnzymeCore: autodiff, autodiff_deferred, autodiff_thunk, autodiff_deferred_thunk, tape_type, make_zero, make_zero!
export autodiff, autodiff_deferred, autodiff_thunk, autodiff_deferred_thunk, tape_type, make_zero, make_zero!

export jacobian, gradient, gradient!
export markType, batch_size, onehot, chunkedonehot
Expand Down Expand Up @@ -1007,7 +1007,7 @@ gradient!(Reverse, dx, f, [2.0, 3.0])
```
"""
@inline function gradient!(::ReverseMode, dx::X, f::F, x::X) where {X<:Array, F}
dx .= 0
make_zero!(dx)
autodiff(Reverse, f, Active, Duplicated(x, dx))
dx
end
Expand Down
2 changes: 1 addition & 1 deletion src/api.jl
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ struct CFnTypeInfo
end


@static if isdefined(LLVM, :InstructionMetadataDict)
@static if !isdefined(LLVM, :ValueMetadataDict)
Base.haskey(md::LLVM.InstructionMetadataDict, kind::String) =
ccall((:EnzymeGetStringMD, libEnzyme), Cvoid, (LLVM.API.LLVMValueRef, Cstring), md.inst, kind) != C_NULL

Expand Down
192 changes: 191 additions & 1 deletion src/compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1298,7 +1298,7 @@ end
xi = getfield(prev, i)
T = Core.Typeof(xi)
xi = EnzymeCore.make_zero(T, seen, xi, Val(copy_if_inactive))
ccall(:jl_set_nth_field, Cvoid, (Any, Csize_t, Any), y, i-1, xi)
setfield!(y, i, xi)
end
end
return y
Expand All @@ -1324,6 +1324,196 @@ end
return y
end

function make_zero_immutable!(prev::T, seen::S)::T where {T <: AbstractFloat, S}
zero(T)
end

function make_zero_immutable!(prev::Complex{T}, seen::S)::Complex{T} where {T <: AbstractFloat, S}
zero(T)
end

function make_zero_immutable!(prev::T, seen::S)::T where {T <: Tuple, S}
ntuple(Val(length(T.parameters))) do i
Base.@_inline_meta
make_zero_immutable(prev[i], seen)
end
end

function make_zero_immutable!(prev::NamedTuple{a, b}, seen::S)::NamedTuple{a, b} where {a,b, S}
NamedTuple{a, b}(
ntuple(Val(length(T.parameters))) do i
Base.@_inline_meta
make_zero_immutable(prev[a[i]], seen)
end
)
end


function make_zero_immutable!(prev::T, seen::S)::T where {T, S}
if guaranteed_const_nongen(T, nothing)
return prev
end
@assert !ismutable(T)

@assert !Base.isabstracttype(RT)
@assert Base.isconcretetype(RT)
nf = fieldcount(RT)

flds = Vector{Any}(undef, nf)
for i in 1:nf
if isdefined(prev, i)
xi = getfield(prev, i)
flds[i] = if ismutable(xi)
EnzymeCore.make_zero!(xi, seen)
xi
else
make_zero_immutable!(xi, seen)
end
else
nf = i - 1 # rest of tail must be undefined values
break
end
end
ccall(:jl_new_structv, Any, (Any, Ptr{Any}, UInt32), RT, flds, nf)::T
end

@inline function EnzymeCore.make_zero!(prev::Base.RefValue{T}, seen::ST)::Nothing where {T <: AbstractFloat, ST}
T[] = zero(T)
nothing
end

@inline function EnzymeCore.make_zero!(prev::Base.RefValue{Complex{T}}, seen::ST)::Nothing where {T <: AbstractFloat, ST}
T[] = zero(Complex{T})
nothing
end

@inline function EnzymeCore.make_zero!(prev::Array{T, N}, seen::ST)::Nothing where {T <: AbstractFloat, N, ST}
fill!(prev, zero(T))
nothing
end

@inline function EnzymeCore.make_zero!(prev::Array{Complex{T}, N}, seen::ST)::Nothing where {T <: AbstractFloat, ST}
fill!(prev, zero(Complex{T}))
nothing
end

@inline function EnzymeCore.make_zero!(prev::Base.RefValue{T})::Nothing where {T <: AbstractFloat}
EnzymeCore.make_zero!(prev, nothing)
nothing
end

@inline function EnzymeCore.make_zero!(prev::Base.RefValue{Complex{T}})::Nothing where {T <: AbstractFloat}
EnzymeCore.make_zero!(prev, nothing)
nothing
end

@inline function EnzymeCore.make_zero!(prev::Array{T, N})::Nothing where {T <: AbstractFloat, N}
EnzymeCore.make_zero!(prev, nothing)
nothing
end

@inline function EnzymeCore.make_zero!(prev::Array{Complex{T}, N})::Nothing where {T <: AbstractFloat, N}
EnzymeCore.make_zero!(prev, nothing)
nothing
end

@inline function EnzymeCore.make_zero!(prev::Array{T, N}, seen::ST)::Nothing where {T, N, ST}
if guaranteed_const_nongen(T, nothing)
return
end
if haskey(seen, prev)
return
end
insert!(seen, prev)

for I in eachindex(prev)
if isassigned(prev, I)
pv = prev[I]
if ismutable(pv)
EnzymeCore.make_zero!(pv, seen)
nothing
else
@inbounds prev[I] = EnzymeCore.make_zero_immutable!(pv, seen)
nothing
end
end
end
nothing
end

@inline function EnzymeCore.make_zero!(prev::Base.RefValue{T}, seen::ST)::Nothing where {T, ST}
if guaranteed_const_nongen(T, nothing)
return
end
if haskey(seen, prev)
return
end
insert!(seen, prev)

pv = prev[]
if ismutable(pv)
EnzymeCore.make_zero!(pv, seen)
nothing
else
prev[] = EnzymeCore.make_zero_immutable!(pv, seen)
nothing
end
nothing
end

@inline function EnzymeCore.make_zero!(prev::Core.Box, seen::ST)::Nothing where {ST}
pv = prev.contents
T = Core.Typeof(pv)
if guaranteed_const_nongen(T, nothing)
return
end
if haskey(seen, prev)
return
end
insert!(seen, prev)
if ismutable(pv)
EnzymeCore.make_zero!(pv, seen)
nothing
else
prev.contents = EnzymeCore.make_zero_immutable!(pv, seen)
nothing
end
nothing
end

@inline function EnzymeCore.make_zero!(prev::T, seen::S=IdSet{Any}())::Nothing where {T, S}
if guaranteed_const_nongen(T, nothing)
return
end
if haskey(seen, prev)
return
end
@assert !Base.isabstracttype(RT)
@assert Base.isconcretetype(RT)
nf = fieldcount(RT)


if nf == 0
return
end

insert!(seen, prev)

for i in 1:nf
if isdefined(prev, i)
xi = getfield(prev, i)
if ismutable(xi)
EnzymeCore.make_zero!(xi, seen)
nothing
else
setfield!(prev, i, make_zero_immutable!(xi, seen))
nothing
end
end
end
return
end

struct EnzymeRuntimeException <: Base.Exception
msg::Cstring
end
Expand Down

0 comments on commit 4170e47

Please sign in to comment.