Skip to content

Commit

Permalink
Update iterate unwrap (#2238)
Browse files Browse the repository at this point in the history
* Update iterate unwrap

* fix

* Update jitrules.jl

* Update jitrules.jl

* Update jitrules.jl
  • Loading branch information
wsmoses authored Dec 31, 2024
1 parent 8a1dd04 commit cacf326
Showing 1 changed file with 23 additions and 19 deletions.
42 changes: 23 additions & 19 deletions src/rules/jitrules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -817,31 +817,35 @@ function push_if_not_ref(
return darg
end

struct PushInnerStruct{reverse, Vals}
vals::Vals
end

@inline function (v::PushInnerStruct{reverse})(@nospecialize(arg), @nospecialize(darg)) where reverse
ty = Core.Typeof(arg)
actreg = active_reg_nothrow(ty, Val(nothing))
if actreg == AnyState
Const(arg)
elseif actreg == ActiveState
Active(arg)
elseif actreg == MixedState
darg = Base.inferencebarrier(darg)
MixedDuplicated(
arg,
push_if_not_ref(Val(reverse), v.vals, darg, ty)::Base.RefValue{ty},
)
else
Duplicated(arg, darg)
end
end

@inline function iterate_unwrap_augfwd_dup(
::Val{reverse},
vals,
args,
dargs,
) where {reverse}
ntuple(Val(length(args))) do i
Base.@_inline_meta
arg = args[i]
ty = Core.Typeof(arg)
actreg = active_reg_nothrow(ty, Val(nothing))
if actreg == AnyState
Const(arg)
elseif actreg == ActiveState
Active(arg)
elseif actreg == MixedState
darg = Base.inferencebarrier(dargs[i])
MixedDuplicated(
arg,
push_if_not_ref(Val(reverse), vals, darg, ty)::Base.RefValue{ty},
)
else
Duplicated(arg, dargs[i])
end
end
map(PushInnerStruct{reverse, typeof(vals)}(vals), args, dargs)
end

@inline function iterate_unwrap_augfwd_batchdup(
Expand Down

0 comments on commit cacf326

Please sign in to comment.