Skip to content

Commit

Permalink
Improve deferred error message (#1827)
Browse files Browse the repository at this point in the history
* Improve deferred error message

* fix
  • Loading branch information
wsmoses authored Sep 15, 2024
1 parent 24c58ef commit afedaac
Showing 1 changed file with 46 additions and 0 deletions.
46 changes: 46 additions & 0 deletions src/rules/llvmrules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1329,6 +1329,46 @@ end
end


@register_fwd function deferred_fwd(B, orig, gutils, normalR, shadowR)
if is_constant_value(gutils, orig) && is_constant_inst(gutils, orig)
return true
end
err = emit_error(B, orig, "There is a known issue in GPUCompiler.jl which is preventing higher-order AD of this code.\nPlease see https://github.com/JuliaGPU/GPUCompiler.jl/issues/629 for more information and to alert the GPUCompiler authors of your use case and need.")
newo = new_from_original(gutils, orig)
API.moveBefore(newo, err, B)
normal = (unsafe_load(normalR) != C_NULL) ? LLVM.Instruction(unsafe_load(normalR)) : nothing
if shadowR != C_NULL && normal !== nothing
unsafe_store!(shadowR, normal.ref)
end
return false
end

@register_aug function deferred_augfwd(B, orig, gutils, normalR, shadowR, tapeR)
if is_constant_value(gutils, orig) && is_constant_inst(gutils, orig)
return true
end
err = emit_error(B, orig, "There is a known issue in GPUCompiler.jl which is preventing higher-order AD of this code.\nPlease see https://github.com/JuliaGPU/GPUCompiler.jl/issues/629 for more information and to alert the GPUCompiler authors of your use case and need.")
newo = new_from_original(gutils, orig)
API.moveBefore(newo, err, B)
normal = (unsafe_load(normalR) != C_NULL) ? LLVM.Instruction(unsafe_load(normalR)) : nothing
if shadowR != C_NULL && normal !== nothing
unsafe_store!(shadowR, normal.ref)
end
# Delete the primal code
if normal !== nothing
unsafe_store!(normalR, C_NULL)
else
ni = new_from_original(gutils, orig)
API.EnzymeGradientUtilsErase(gutils, ni)
end
return false
end

@register_rev function deferred_rev(B, orig, gutils, tape)
return nothing
end


function register_handler!(variants, augfwd_handler, rev_handler, fwd_handler=nothing)
for variant in variants
if augfwd_handler !== nothing && rev_handler !== nothing
Expand Down Expand Up @@ -1522,6 +1562,12 @@ end
@revfunc(finalizer_rev),
@fwdfunc(finalizer_fwd),
)
register_handler!(
("deferred_codegen",),
@augfunc(deferred_augfwd),
@revfunc(deferred_rev),
@fwdfunc(deferred_fwd),
)
register_handler!(
("jl_array_grow_end","ijl_array_grow_end"),
@augfunc(jl_array_grow_end_augfwd),
Expand Down

0 comments on commit afedaac

Please sign in to comment.