diff --git a/Project.toml b/Project.toml index 0b19c3cead..2fd5da4586 100644 --- a/Project.toml +++ b/Project.toml @@ -19,7 +19,7 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" [compat] CEnum = "0.4, 0.5" ChainRulesCore = "1" -EnzymeCore = "0.7.3" +EnzymeCore = "0.7.4" Enzyme_jll = "0.0.119" GPUCompiler = "0.21, 0.22, 0.23, 0.24, 0.25, 0.26" LLVM = "6.1, 7" diff --git a/examples/custom_rule.jl b/examples/custom_rule.jl index 836d299c1e..c2098006c2 100644 --- a/examples/custom_rule.jl +++ b/examples/custom_rule.jl @@ -134,7 +134,7 @@ function forward(func::Const{typeof(f)}, RT::Type{<:Union{Const, DuplicatedNoNee if !(x isa Const) && !(y isa Const) y.dval .= 2 .* x.val .* x.dval elseif !(y isa Const) - y.dval .= 0 + make_zero!(y.dval) end dret = !(y isa Const) ? sum(y.dval) : zero(eltype(y.val)) if RT <: Const @@ -211,7 +211,7 @@ function reverse(config::ConfigWidth{1}, func::Const{typeof(f)}, dret::Active, t x.dval .+= 2 .* xval .* dret.val ## also accumulate any derivative in y's shadow into x's shadow. x.dval .+= 2 .* xval .* y.dval - y.dval .= 0 + make_zero!(y.dval) return (nothing, nothing) end @@ -251,8 +251,8 @@ end x = [3.0, 1.0] y = [0.0, 0.0] -dx .= 0 -dy .= 0 +make_zero!(dx) +make_zero!(dy) autodiff(Reverse, h, Duplicated(y, dy), Duplicated(x, dx)) @show dx # derivative of h w.r.t. x diff --git a/lib/EnzymeCore/Project.toml b/lib/EnzymeCore/Project.toml index 670e1f3014..20a89b9a05 100644 --- a/lib/EnzymeCore/Project.toml +++ b/lib/EnzymeCore/Project.toml @@ -1,7 +1,7 @@ name = "EnzymeCore" uuid = "f151be2c-9106-41f4-ab19-57ee4f262869" authors = ["William Moses ", "Valentin Churavy "] -version = "0.7.3" +version = "0.7.4" [compat] Adapt = "3, 4" diff --git a/lib/EnzymeCore/src/EnzymeCore.jl b/lib/EnzymeCore/src/EnzymeCore.jl index 482561607b..fb788fd5a6 100644 --- a/lib/EnzymeCore/src/EnzymeCore.jl +++ b/lib/EnzymeCore/src/EnzymeCore.jl @@ -226,10 +226,10 @@ function autodiff_deferred_thunk end Recursively make a zero'd copy of the value `prev` of type `T`. The argument `copy_if_inactive` specifies what to do if the type `T` is guaranteed to be inactive, use the primal (the default) or still copy the value. """ -function make_zero +function make_zero end """ - make_zero!(prev::T)::T + make_zero!(val::T, seen::IdSet{Any}=IdSet())::Nothing Recursively set a variables differentiable fields to zero. Only applicable for mutable types `T`. """ diff --git a/src/Enzyme.jl b/src/Enzyme.jl index 911d1801ad..7626304944 100644 --- a/src/Enzyme.jl +++ b/src/Enzyme.jl @@ -14,8 +14,8 @@ export BatchDuplicatedFunc import EnzymeCore: batch_size, get_func export batch_size, get_func -import EnzymeCore: autodiff, autodiff_deferred, autodiff_thunk, autodiff_deferred_thunk, tape_type, make_zero -export autodiff, autodiff_deferred, autodiff_thunk, autodiff_deferred_thunk, tape_type, make_zero +import EnzymeCore: autodiff, autodiff_deferred, autodiff_thunk, autodiff_deferred_thunk, tape_type, make_zero, make_zero! +export autodiff, autodiff_deferred, autodiff_thunk, autodiff_deferred_thunk, tape_type, make_zero, make_zero! export jacobian, gradient, gradient! export markType, batch_size, onehot, chunkedonehot @@ -1007,7 +1007,7 @@ gradient!(Reverse, dx, f, [2.0, 3.0]) ``` """ @inline function gradient!(::ReverseMode, dx::X, f::F, x::X) where {X<:Array, F} - dx .= 0 + make_zero!(dx) autodiff(Reverse, f, Active, Duplicated(x, dx)) dx end diff --git a/src/api.jl b/src/api.jl index 3c626635b0..d68d904d5a 100644 --- a/src/api.jl +++ b/src/api.jl @@ -104,7 +104,7 @@ struct CFnTypeInfo end -@static if isdefined(LLVM, :InstructionMetadataDict) +@static if !isdefined(LLVM, :ValueMetadataDict) Base.haskey(md::LLVM.InstructionMetadataDict, kind::String) = ccall((:EnzymeGetStringMD, libEnzyme), Cvoid, (LLVM.API.LLVMValueRef, Cstring), md.inst, kind) != C_NULL diff --git a/src/compiler.jl b/src/compiler.jl index 30bf6f0d9c..83ff363dbd 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -1298,7 +1298,7 @@ end xi = getfield(prev, i) T = Core.Typeof(xi) xi = EnzymeCore.make_zero(T, seen, xi, Val(copy_if_inactive)) - ccall(:jl_set_nth_field, Cvoid, (Any, Csize_t, Any), y, i-1, xi) + setfield!(y, i, xi) end end return y @@ -1324,6 +1324,201 @@ end return y end +function make_zero_immutable!(prev::T, seen::S)::T where {T <: AbstractFloat, S} + zero(T) +end + +function make_zero_immutable!(prev::Complex{T}, seen::S)::Complex{T} where {T <: AbstractFloat, S} + zero(T) +end + +function make_zero_immutable!(prev::T, seen::S)::T where {T <: Tuple, S} + ntuple(Val(length(T.parameters))) do i + Base.@_inline_meta + make_zero_immutable(prev[i], seen) + end +end + +function make_zero_immutable!(prev::NamedTuple{a, b}, seen::S)::NamedTuple{a, b} where {a,b, S} + NamedTuple{a, b}( + ntuple(Val(length(T.parameters))) do i + Base.@_inline_meta + make_zero_immutable(prev[a[i]], seen) + end + ) +end + + +function make_zero_immutable!(prev::T, seen::S)::T where {T, S} + if guaranteed_const_nongen(T, nothing) + return prev + end + @assert !mutable_register(T) + + @assert !Base.isabstracttype(RT) + @assert Base.isconcretetype(RT) + nf = fieldcount(RT) + + flds = Vector{Any}(undef, nf) + for i in 1:nf + if isdefined(prev, i) + xi = getfield(prev, i) + ST = Core.Typeof(xi) + flds[i] = if mutable_register(ST) + EnzymeCore.make_zero!(xi, seen) + xi + else + make_zero_immutable!(xi, seen) + end + else + nf = i - 1 # rest of tail must be undefined values + break + end + end + ccall(:jl_new_structv, Any, (Any, Ptr{Any}, UInt32), RT, flds, nf)::T +end + +@inline function EnzymeCore.make_zero!(prev::Base.RefValue{T}, seen::ST)::Nothing where {T <: AbstractFloat, ST} + T[] = zero(T) + nothing +end + +@inline function EnzymeCore.make_zero!(prev::Base.RefValue{Complex{T}}, seen::ST)::Nothing where {T <: AbstractFloat, ST} + T[] = zero(Complex{T}) + nothing +end + +@inline function EnzymeCore.make_zero!(prev::Array{T, N}, seen::ST)::Nothing where {T <: AbstractFloat, N, ST} + fill!(prev, zero(T)) + nothing +end + +@inline function EnzymeCore.make_zero!(prev::Array{Complex{T}, N}, seen::ST)::Nothing where {T <: AbstractFloat, N, ST} + fill!(prev, zero(Complex{T})) + nothing +end + +@inline function EnzymeCore.make_zero!(prev::Base.RefValue{T})::Nothing where {T <: AbstractFloat} + EnzymeCore.make_zero!(prev, nothing) + nothing +end + +@inline function EnzymeCore.make_zero!(prev::Base.RefValue{Complex{T}})::Nothing where {T <: AbstractFloat} + EnzymeCore.make_zero!(prev, nothing) + nothing +end + +@inline function EnzymeCore.make_zero!(prev::Array{T, N})::Nothing where {T <: AbstractFloat, N} + EnzymeCore.make_zero!(prev, nothing) + nothing +end + +@inline function EnzymeCore.make_zero!(prev::Array{Complex{T}, N})::Nothing where {T <: AbstractFloat, N} + EnzymeCore.make_zero!(prev, nothing) + nothing +end + +@inline function EnzymeCore.make_zero!(prev::Array{T, N}, seen::ST)::Nothing where {T, N, ST} + if guaranteed_const_nongen(T, nothing) + return + end + if haskey(seen, prev) + return + end + insert!(seen, prev) + + for I in eachindex(prev) + if isassigned(prev, I) + pv = prev[I] + SBT = Core.Typeof(pv) + if mutable_register(SBT) + EnzymeCore.make_zero!(pv, seen) + nothing + else + @inbounds prev[I] = EnzymeCore.make_zero_immutable!(pv, seen) + nothing + end + end + end + nothing +end + +@inline function EnzymeCore.make_zero!(prev::Base.RefValue{T}, seen::ST)::Nothing where {T, ST} + if guaranteed_const_nongen(T, nothing) + return + end + if haskey(seen, prev) + return + end + insert!(seen, prev) + + pv = prev[] + SBT = Core.Typeof(pv) + if mutable_register(SBT) + EnzymeCore.make_zero!(pv, seen) + nothing + else + prev[] = EnzymeCore.make_zero_immutable!(pv, seen) + nothing + end + nothing +end + +@inline function EnzymeCore.make_zero!(prev::Core.Box, seen::ST)::Nothing where {ST} + pv = prev.contents + T = Core.Typeof(pv) + if guaranteed_const_nongen(T, nothing) + return + end + if haskey(seen, prev) + return + end + insert!(seen, prev) + SBT = Core.Typeof(pv) + if mutable_register(SBT) + EnzymeCore.make_zero!(pv, seen) + nothing + else + prev.contents = EnzymeCore.make_zero_immutable!(pv, seen) + nothing + end + nothing +end + +@inline function EnzymeCore.make_zero!(prev::T, seen::S=IdSet{Any}())::Nothing where {T, S} + if guaranteed_const_nongen(T, nothing) + return + end + if haskey(seen, prev) + return + end + @assert !Base.isabstracttype(RT) + @assert Base.isconcretetype(RT) + nf = fieldcount(RT) + + + if nf == 0 + return + end + + insert!(seen, prev) + + for i in 1:nf + if isdefined(prev, i) + xi = getfield(prev, i) + SBT = Core.Typeof(pv) + if mutable_register(SBT) + EnzymeCore.make_zero!(xi, seen) + nothing + else + setfield!(prev, i, make_zero_immutable!(xi, seen)) + nothing + end + end + end + return +end + struct EnzymeRuntimeException <: Base.Exception msg::Cstring end @@ -5536,7 +5731,7 @@ end @assert ismutable(x) yi = getfield(y, i) nexti = recursive_add(xi, yi, f, mutable_register) - ccall(:jl_set_nth_field, Cvoid, (Any, Csize_t, Any), x, i-1, nexti) + setfield!(x, i, nexti) end end end