Skip to content

Commit

Permalink
Pass primal to getfield rev for debugging
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses committed Sep 6, 2024
1 parent 754937b commit e59bc8d
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 2 deletions.
2 changes: 1 addition & 1 deletion src/compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3419,7 +3419,7 @@ for (k, v) in (
("enz_runtime_jl_getfield_rev", Enzyme.Compiler.rt_jl_getfield_rev),

("enz_runtime_idx_jl_getfield_aug", Enzyme.Compiler.idx_jl_getfield_aug),
("enz_runtime_idx_jl_getfield_rev", Enzyme.Compiler.idx_jl_getfield_aug),
("enz_runtime_idx_jl_getfield_rev", Enzyme.Compiler.idx_jl_getfield_rev),

("enz_runtime_jl_setfield_aug", Enzyme.Compiler.rt_jl_setfield_aug),
("enz_runtime_jl_setfield_rev", Enzyme.Compiler.rt_jl_setfield_rev),
Expand Down
4 changes: 3 additions & 1 deletion src/rules/typeunstablerules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -796,7 +796,7 @@ end
end
end

function idx_jl_getfield_rev(dptr::T, dret, ::Type{Val{symname}}, ::Val{isconst}, dptrs::Vararg{T2, Nargs}) where {T, T2, Nargs, symname, isconst}
function idx_jl_getfield_rev(primal, dptr::T, dret, ::Type{Val{symname}}, ::Val{isconst}, dptrs::Vararg{T2, Nargs}) where {T, T2, Nargs, symname, isconst}
cur = if dptr isa Base.RefValue
Base.getfield(dptr[], symname+1)
else
Expand Down Expand Up @@ -1135,6 +1135,8 @@ end
end

vals = LLVM.Value[]

push!(vals, lookup_value(gutils, new_from_original(gutils, ops[1]), B))
push!(vals, inps[1])

push!(vals, tape)
Expand Down

0 comments on commit e59bc8d

Please sign in to comment.