Skip to content

Commit

Permalink
Add support for newstructt
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses committed Dec 18, 2023
1 parent cb3037f commit f4f33eb
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 0 deletions.
6 changes: 6 additions & 0 deletions src/rules/llvmrules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1300,6 +1300,12 @@ end
@revfunc(new_structv_rev),
@fwdfunc(new_structv_fwd),
)
register_handler!(
("jl_new_structt","ijl_new_structt"),
@augfunc(new_structt_augfwd),
@revfunc(new_structt_rev),
@fwdfunc(new_structt_fwd),
)
register_handler!(
("jl_get_binding_or_error", "ijl_get_binding_or_error"),
@augfunc(get_binding_or_error_augfwd),
Expand Down
59 changes: 59 additions & 0 deletions src/rules/typeunstablerules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,15 @@ function common_newstructv_rev(offset, B, orig, gutils, tape)
if is_constant_value(gutils, orig)
return true
end
needsShadowP = Ref{UInt8}(0)
needsPrimalP = Ref{UInt8}(0)
activep = API.EnzymeGradientUtilsGetReturnDiffeType(gutils, orig, needsPrimalP, needsShadowP, API.DEM_ReverseModePrimal)
needsPrimal = needsPrimalP[] != 0
needsShadow = needsShadowP[] != 0

if !needsShadow
return
end
emit_error(B, orig, "Enzyme: Not yet implemented reverse for jl_new_struct "*string(orig)*" "*string(operands(orig)[offset])*"\n"*string(LLVM.parent(orig)))
return nothing
end
Expand Down Expand Up @@ -100,6 +109,56 @@ function new_structv_rev(B, orig, gutils, tape)
return nothing
end

function new_structt_fwd(B, orig, gutils, normalR, shadowR)
if is_constant_value(gutils, orig) || unsafe_load(shadowR) == C_NULL
return true
end
origops = collect(operands(orig))
width = get_width(gutils)

@assert is_constant_value(gutils, origops[1])
if is_constant_value(gutils, origops[2])
emit_error(B, orig, "Enzyme: Not yet implemented, mixed activity for jl_new_struct_t"*string(orig))
end

shadowsin = invert_pointer(gutils, origops[2], B)
if width == 1
vals = [new_from_original(gutils, origops[1]), shadowsin]
shadowres = LLVM.call!(B, called_type(orig), LLVM.called_operand(orig), vals)
callconv!(shadowres, callconv(orig))
else
shadowres = UndefValue(LLVM.LLVMType(API.EnzymeGetShadowType(width, value_type(orig))))
for idx in 1:width
vals = [new_from_original(gutils, origops[1]), extract_value!(B, shadowsin, idx-1)]
tmp = LLVM.call!(B, called_type(orig), LLVM.called_operand(orig), args)
callconv!(tmp, callconv(orig))
shadowres = insert_value!(B, shadowres, tmp, idx-1)
end
end
unsafe_store!(shadowR, shadowres.ref)
return false
end
function new_structt_augfwd(offset, B, orig, gutils, normalR, shadowR, tapeR)
new_structt_fwd(offset, B, orig, gutils, normalR, shadowR)
end

function new_structt_rev(B, orig, gutils, tape)
if is_constant_value(gutils, orig)
return true
end
needsShadowP = Ref{UInt8}(0)
needsPrimalP = Ref{UInt8}(0)
activep = API.EnzymeGradientUtilsGetReturnDiffeType(gutils, orig, needsPrimalP, needsShadowP, API.DEM_ReverseModePrimal)
needsPrimal = needsPrimalP[] != 0
needsShadow = needsShadowP[] != 0

if !needsShadow
return
end
emit_error(B, orig, "Enzyme: Not yet implemented reverse for jl_new_structt "*string(orig))
return nothing
end

function common_jl_getfield_fwd(offset, B, orig, gutils, normalR, shadowR)
if is_constant_value(gutils, orig) || unsafe_load(shadowR) == C_NULL
return true
Expand Down

0 comments on commit f4f33eb

Please sign in to comment.