Skip to content

Commit

Permalink
add custom handler for ptr_to_array runtime call (#2258)
Browse files Browse the repository at this point in the history
* add custom handler for ptr_to_array runtime call

* Update array.jl

* Update llvmrules.jl

---------

Co-authored-by: William Moses <gh@wsmoses.com>
  • Loading branch information
vchuravy and wsmoses authored Jan 9, 2025
1 parent 2309abd commit 8a0bff4
Show file tree
Hide file tree
Showing 3 changed files with 74 additions and 0 deletions.
2 changes: 2 additions & 0 deletions src/compiler/validation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -78,11 +78,13 @@ function __init__()
"jl_get_keyword_sorter",
"ijl_get_keyword_sorter",
"jl_ptr_to_array",
"ijl_ptr_to_array",
"jl_box_float32",
"ijl_box_float32",
"jl_box_float64",
"ijl_box_float64",
"jl_ptr_to_array_1d",
"ijl_ptr_to_array_1d",
"jl_eqtable_get",
"ijl_eqtable_get",
"memcmp",
Expand Down
57 changes: 57 additions & 0 deletions src/rules/llvmrules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1779,6 +1779,57 @@ end
return nothing
end

@register_fwd function jl_ptr_to_array_fwd(B, orig, gutils, normalR, shadowR)
if is_constant_inst(gutils, orig)
return true
end
origops = collect(operands(orig))
width = get_width(gutils)
shadowin = invert_pointer(gutils, origops[2], B)

valTys = API.CValueType[
API.VT_Primal,
API.VT_Shadow,
API.VT_Primal,
API.VT_Primal,
]

shadowres = UndefValue(LLVM.LLVMType(API.EnzymeGetShadowType(width, value_type(orig))))
for idx = 1:width
ev = if width == 1
shadowin
else
extract_value!(B, shadowin, idx - 1)
end

args = LLVM.Value[
new_from_original(gutils, origops[1]),
ev, # data
new_from_original(gutils, origops[3]),
new_from_original(gutils, origops[4]),
]
# TODO do runtime activity relevant errors and checks

cal = call_samefunc_with_inverted_bundles!(B, gutils, orig, args, valTys, false) #=lookup=#
debug_from_orig!(gutils, cal, orig)
callconv!(cal, callconv(orig))
if width == 1
shadowres = cal
else
shadowres = insert_value!(B, shadowres, call, idx - 1)
end
end
unsafe_store!(shadowR, shadowres.ref)

return false
end
@register_aug function jl_ptr_to_array_augfwd(B, orig, gutils, normalR, shadowR, tapeR)
jl_ptr_to_array_fwd(B, orig, gutils, normalR, shadowR)
end
@register_rev function jl_ptr_to_array_rev(B, orig, gutils, tape)
return nothing
end

@register_fwd function genericmemory_copyto_fwd(B, orig, gutils, normalR, shadowR)
if is_constant_inst(gutils, orig)
return true
Expand Down Expand Up @@ -2400,6 +2451,12 @@ end
@revfunc(jl_array_ptr_copy_rev),
@fwdfunc(jl_array_ptr_copy_fwd),
)
register_handler!(
("jl_ptr_to_array_1d", "ijl_ptr_to_array_1d", "jl_ptr_to_array", "ijl_ptr_to_array"),
@augfunc(jl_ptr_to_array_augfwd),
@revfunc(jl_ptr_to_array_rev),
@fwdfunc(jl_ptr_to_array_fwd),
)
register_handler!(
(),
@augfunc(jl_unhandled_augfwd),
Expand Down
15 changes: 15 additions & 0 deletions test/array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,3 +23,18 @@ end
@test dB[1] === dA1
@test dB[2] === dA2
end

function unsafe_wrap_test(a, i, x)
GC.@preserve a begin
ptr = pointer(a)
b = Base.unsafe_wrap(Array, ptr, length(a))
b[i] = x
end
a[i]
end

@testset "Unsafe wrap" begin
autodiff(Forward, unsafe_wrap_test, Duplicated(zeros(1), zeros(1)), Const(1), Duplicated(1.0, 2.0))

# TODO test for batch and reverse
end

0 comments on commit 8a0bff4

Please sign in to comment.