From afedaac9dac1fc039aa585307398247cbbb54c68 Mon Sep 17 00:00:00 2001 From: William Moses Date: Sat, 14 Sep 2024 23:24:54 -0500 Subject: [PATCH] Improve deferred error message (#1827) * Improve deferred error message * fix --- src/rules/llvmrules.jl | 46 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 46 insertions(+) diff --git a/src/rules/llvmrules.jl b/src/rules/llvmrules.jl index fb93016063..962a4f46af 100644 --- a/src/rules/llvmrules.jl +++ b/src/rules/llvmrules.jl @@ -1329,6 +1329,46 @@ end end +@register_fwd function deferred_fwd(B, orig, gutils, normalR, shadowR) + if is_constant_value(gutils, orig) && is_constant_inst(gutils, orig) + return true + end + err = emit_error(B, orig, "There is a known issue in GPUCompiler.jl which is preventing higher-order AD of this code.\nPlease see https://github.com/JuliaGPU/GPUCompiler.jl/issues/629 for more information and to alert the GPUCompiler authors of your use case and need.") + newo = new_from_original(gutils, orig) + API.moveBefore(newo, err, B) + normal = (unsafe_load(normalR) != C_NULL) ? LLVM.Instruction(unsafe_load(normalR)) : nothing + if shadowR != C_NULL && normal !== nothing + unsafe_store!(shadowR, normal.ref) + end + return false +end + +@register_aug function deferred_augfwd(B, orig, gutils, normalR, shadowR, tapeR) + if is_constant_value(gutils, orig) && is_constant_inst(gutils, orig) + return true + end + err = emit_error(B, orig, "There is a known issue in GPUCompiler.jl which is preventing higher-order AD of this code.\nPlease see https://github.com/JuliaGPU/GPUCompiler.jl/issues/629 for more information and to alert the GPUCompiler authors of your use case and need.") + newo = new_from_original(gutils, orig) + API.moveBefore(newo, err, B) + normal = (unsafe_load(normalR) != C_NULL) ? LLVM.Instruction(unsafe_load(normalR)) : nothing + if shadowR != C_NULL && normal !== nothing + unsafe_store!(shadowR, normal.ref) + end + # Delete the primal code + if normal !== nothing + unsafe_store!(normalR, C_NULL) + else + ni = new_from_original(gutils, orig) + API.EnzymeGradientUtilsErase(gutils, ni) + end + return false +end + +@register_rev function deferred_rev(B, orig, gutils, tape) + return nothing +end + + function register_handler!(variants, augfwd_handler, rev_handler, fwd_handler=nothing) for variant in variants if augfwd_handler !== nothing && rev_handler !== nothing @@ -1522,6 +1562,12 @@ end @revfunc(finalizer_rev), @fwdfunc(finalizer_fwd), ) + register_handler!( + ("deferred_codegen",), + @augfunc(deferred_augfwd), + @revfunc(deferred_rev), + @fwdfunc(deferred_fwd), + ) register_handler!( ("jl_array_grow_end","ijl_array_grow_end"), @augfunc(jl_array_grow_end_augfwd),