diff --git a/src/compiler/validation.jl b/src/compiler/validation.jl index a117a019e9..3d379467de 100644 --- a/src/compiler/validation.jl +++ b/src/compiler/validation.jl @@ -78,11 +78,13 @@ function __init__() "jl_get_keyword_sorter", "ijl_get_keyword_sorter", "jl_ptr_to_array", + "ijl_ptr_to_array", "jl_box_float32", "ijl_box_float32", "jl_box_float64", "ijl_box_float64", "jl_ptr_to_array_1d", + "ijl_ptr_to_array_1d", "jl_eqtable_get", "ijl_eqtable_get", "memcmp", diff --git a/src/rules/llvmrules.jl b/src/rules/llvmrules.jl index 18b25a4a8b..9e696f0ade 100644 --- a/src/rules/llvmrules.jl +++ b/src/rules/llvmrules.jl @@ -1779,6 +1779,57 @@ end return nothing end +@register_fwd function jl_ptr_to_array_fwd(B, orig, gutils, normalR, shadowR) + if is_constant_inst(gutils, orig) + return true + end + origops = collect(operands(orig)) + width = get_width(gutils) + shadowin = invert_pointer(gutils, origops[2], B) + + valTys = API.CValueType[ + API.VT_Primal, + API.VT_Shadow, + API.VT_Primal, + API.VT_Primal, + ] + + shadowres = UndefValue(LLVM.LLVMType(API.EnzymeGetShadowType(width, value_type(orig)))) + for idx = 1:width + ev = if width == 1 + shadowin + else + extract_value!(B, shadowin, idx - 1) + end + + args = LLVM.Value[ + new_from_original(gutils, origops[1]), + ev, # data + new_from_original(gutils, origops[3]), + new_from_original(gutils, origops[4]), + ] + # TODO do runtime activity relevant errors and checks + + cal = call_samefunc_with_inverted_bundles!(B, gutils, orig, args, valTys, false) #=lookup=# + debug_from_orig!(gutils, cal, orig) + callconv!(cal, callconv(orig)) + if width == 1 + shadowres = cal + else + shadowres = insert_value!(B, shadowres, call, idx - 1) + end + end + unsafe_store!(shadowR, shadowres.ref) + + return false +end +@register_aug function jl_ptr_to_array_augfwd(B, orig, gutils, normalR, shadowR, tapeR) + jl_ptr_to_array_fwd(B, orig, gutils, normalR, shadowR) +end +@register_rev function jl_ptr_to_array_rev(B, orig, gutils, tape) + return nothing +end + @register_fwd function genericmemory_copyto_fwd(B, orig, gutils, normalR, shadowR) if is_constant_inst(gutils, orig) return true @@ -2400,6 +2451,12 @@ end @revfunc(jl_array_ptr_copy_rev), @fwdfunc(jl_array_ptr_copy_fwd), ) + register_handler!( + ("jl_ptr_to_array_1d", "ijl_ptr_to_array_1d", "jl_ptr_to_array", "ijl_ptr_to_array"), + @augfunc(jl_ptr_to_array_augfwd), + @revfunc(jl_ptr_to_array_rev), + @fwdfunc(jl_ptr_to_array_fwd), + ) register_handler!( (), @augfunc(jl_unhandled_augfwd), diff --git a/test/array.jl b/test/array.jl index c246cc5e1e..129b1c5e2f 100644 --- a/test/array.jl +++ b/test/array.jl @@ -23,3 +23,18 @@ end @test dB[1] === dA1 @test dB[2] === dA2 end + +function unsafe_wrap_test(a, i, x) + GC.@preserve a begin + ptr = pointer(a) + b = Base.unsafe_wrap(Array, ptr, length(a)) + b[i] = x + end + a[i] +end + +@testset "Unsafe wrap" begin + autodiff(Forward, unsafe_wrap_test, Duplicated(zeros(1), zeros(1)), Const(1), Duplicated(1.0, 2.0)) + + # TODO test for batch and reverse +end