diff --git a/src/compiler.jl b/src/compiler.jl index 79d8e1f9d7..ba822f454e 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -1248,37 +1248,14 @@ end return NamedTuple{A,RT}(EnzymeCore.make_zero(RT, seen, RT(prev), Val(copy_if_inactive))) end - - -@generated function fakecopy(T, x) - if T <: AbstractFloat || T <: Complex - ty = convert(LLVMType, x) - T_jlvalue = LLVM.StructType(LLVM.LLVMType[]) - T_prjlvalue = LLVM.PointerType(T_jlvalue, Tracked) - llvm_f, _ = LLVM.Interop.create_function(T_prjlvalue, [ty]) - LLVM.IRBuilder() do builder - entry = BasicBlock(llvm_f, "entry") - position!(builder, entry) - inp = parameters(entry)[1] - obj = emit_allocobj!(builder, Base.RefValue{x}) - obj2 = bitcast!(builder, LLVM.PointerType(ty, Tracked), obj) - store!(builder, obj, obj2) - ret!(builder, obj) - end - push!(function_attributes(llvm_f), EnumAttribute("alwaysinline", 0)) - ir = string(mod) - fn = LLVM.name(llvm_f) - return quote - Base.@_inline_meta - Base.llvmcall(($ir, $fn), Any, - Tuple{x}, x) - end - else - quote - Base.@_inline_meta - x - end +@inline function EnzymeCore.make_zero(::Type{Core.Box}, seen::IdDict, prev::Core.Box, ::Val{copy_if_inactive}=Val(false)) where {copy_if_inactive, RT} + if haskey(seen, prev) + return seen[prev] end + prev2 = prev.contents + res = Core.Box(Base.Ref(EnzymeCore.make_zero(Core.Typeof(prev2), seen, prev2, Val(copy_if_inactive)))) + seen[prev] = res + return res end @inline function EnzymeCore.make_zero(::Type{RT}, seen::IdDict, prev::RT, ::Val{copy_if_inactive}=Val(false))::RT where {copy_if_inactive, RT} @@ -1298,10 +1275,8 @@ end for i in 1:nf if isdefined(prev, i) xi = getfield(prev, i) - xi = EnzymeCore.make_zero(Core.Typeof(xi), seen, xi, Val(copy_if_inactive)) - # TODO, try this once sarah's test case is available - # ideally we don't need this still - # xi = fakecopy(Core.Typeof(xi), xi) + T = Core.Typeof(xi) + xi = EnzymeCore.make_zero(T, seen, xi, Val(copy_if_inactive)) ccall(:jl_set_nth_field, Cvoid, (Any, Csize_t, Any), y, i-1, xi) end end