Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for newstructt #1209

Merged
merged 1 commit into from
Dec 19, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions src/rules/llvmrules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1300,6 +1300,12 @@ end
@revfunc(new_structv_rev),
@fwdfunc(new_structv_fwd),
)
register_handler!(
("jl_new_structt","ijl_new_structt"),
@augfunc(new_structt_augfwd),
@revfunc(new_structt_rev),
@fwdfunc(new_structt_fwd),
)
register_handler!(
("jl_get_binding_or_error", "ijl_get_binding_or_error"),
@augfunc(get_binding_or_error_augfwd),
Expand Down
59 changes: 59 additions & 0 deletions src/rules/typeunstablerules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,15 @@ function common_newstructv_rev(offset, B, orig, gutils, tape)
if is_constant_value(gutils, orig)
return true
end
needsShadowP = Ref{UInt8}(0)
needsPrimalP = Ref{UInt8}(0)
activep = API.EnzymeGradientUtilsGetReturnDiffeType(gutils, orig, needsPrimalP, needsShadowP, API.DEM_ReverseModePrimal)
needsPrimal = needsPrimalP[] != 0
needsShadow = needsShadowP[] != 0

if !needsShadow
return
end
emit_error(B, orig, "Enzyme: Not yet implemented reverse for jl_new_struct "*string(orig)*" "*string(operands(orig)[offset])*"\n"*string(LLVM.parent(orig)))
return nothing
end
Expand Down Expand Up @@ -100,6 +109,56 @@ function new_structv_rev(B, orig, gutils, tape)
return nothing
end

function new_structt_fwd(B, orig, gutils, normalR, shadowR)
if is_constant_value(gutils, orig) || unsafe_load(shadowR) == C_NULL
return true
end
origops = collect(operands(orig))
width = get_width(gutils)

@assert is_constant_value(gutils, origops[1])
if is_constant_value(gutils, origops[2])
emit_error(B, orig, "Enzyme: Not yet implemented, mixed activity for jl_new_struct_t"*string(orig))
end

shadowsin = invert_pointer(gutils, origops[2], B)
if width == 1
vals = [new_from_original(gutils, origops[1]), shadowsin]
shadowres = LLVM.call!(B, called_type(orig), LLVM.called_operand(orig), vals)
callconv!(shadowres, callconv(orig))
else
shadowres = UndefValue(LLVM.LLVMType(API.EnzymeGetShadowType(width, value_type(orig))))
for idx in 1:width
vals = [new_from_original(gutils, origops[1]), extract_value!(B, shadowsin, idx-1)]
tmp = LLVM.call!(B, called_type(orig), LLVM.called_operand(orig), args)
callconv!(tmp, callconv(orig))
shadowres = insert_value!(B, shadowres, tmp, idx-1)
end
end
unsafe_store!(shadowR, shadowres.ref)
return false
end
function new_structt_augfwd(B, orig, gutils, normalR, shadowR, tapeR)
new_structt_fwd(B, orig, gutils, normalR, shadowR)
end

function new_structt_rev(B, orig, gutils, tape)
if is_constant_value(gutils, orig)
return true
end
needsShadowP = Ref{UInt8}(0)
needsPrimalP = Ref{UInt8}(0)
activep = API.EnzymeGradientUtilsGetReturnDiffeType(gutils, orig, needsPrimalP, needsShadowP, API.DEM_ReverseModePrimal)
needsPrimal = needsPrimalP[] != 0
needsShadow = needsShadowP[] != 0

if !needsShadow
return
end
emit_error(B, orig, "Enzyme: Not yet implemented reverse for jl_new_structt "*string(orig))
return nothing
end

function common_jl_getfield_fwd(offset, B, orig, gutils, normalR, shadowR)
if is_constant_value(gutils, orig) || unsafe_load(shadowR) == C_NULL
return true
Expand Down
Loading