-
Notifications
You must be signed in to change notification settings - Fork 64
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add recursive map generalizing the make_zero mechanism #1852
base: main
Are you sure you want to change the base?
Conversation
2161e03
to
545bf9b
Compare
4fbdc47
to
74b212f
Compare
74b212f
to
3c6591e
Compare
Alright, I could take some feedback/discussion on this now.
TLDR: Should I rewrite @gdalle Promised to tag you when this was ready for review, but note that this PR only deals with the low-level, non-public guts of the implementation. I'll do the vector space wrapper in a separate PR as soon as this is merged (hopefully that won't be long, I really need that QuadGK rule for my research 📐) |
src/make_zero.jl
Outdated
return seen[prev] | ||
xs::NTuple{N,T}, | ||
::Val{copy_if_inactive}=Val(false), | ||
isleaftype::L=Returns(false), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Wondering whether this is necessary or if the leaf types could just be hardcoded to Union{_RealOrComplexFloat,Array{<:_RealOrComplexFloat}}
. I'll make a prototype of the vector space wrapper and the updated QuadGK rules to see if customizable leaf types comes in handy.
Update for anyone who's following: I've implemented the VectorSpace wrapper, which prompted me to adjust the recursive_map implementation a bit, all for the better. It's looking good and will make writing custom higher-order rules as well as the DI wrappers a lot nicer for arbitrary types. However, it dawned on me that you probably want |
awesome, sorry I haven't had a chance to review let [just a bunch of schenanigans atm], I'll try to take a closer look next week and ping me if not |
No worries! I restored the draft label when I realized there was a bit more to do and will remove it again once I think this is ready for review. No need to look at it until then, the current state here on github doesn't reflect what I'm working with locally anyway. |
src/make_zero.jl
Outdated
isleaftype::L=Returns(false), | ||
) where {T,F,N,L,copy_if_inactive} | ||
x1 = first(xs) | ||
if guaranteed_const_nongen(T, nothing) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just to confirm, this is only for make_zero, and not for add/etc?
Because this case here already feels specific to the context
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's going to look a bit different once I push the next update (hopefully tomorrow), but no, after some experimenting it seemed best to me to always skip guaranteed inactive subtrees and restrict recursive_map
to applying f
to the differentiable values only. I tried doing the opposite initially, leaving it as part of the isleaftype
filter and handling the possible deepcopy within the mapped function f
, but it made things a lot more complicated. I think the main issue was that the whole mechanism with seen
and keeping track of object identity then becomes the purview of the mapped function f
instead of recursive_map
itself, increasing boilerplate and complicating the contract between recursive_map and its callers. I couldn't think of a use case within Enzyme where you're interested in mapping over the guaranteed inactive parts anyway, and not recursing through inactive subtrees saves you from having to deal with deconstruction/reconstruction of a few specialized types (deepcopy
has a lot more methods than recursive_map
). So I went with this solution instead.
Of course, adding a skip_guaranteed_const
flag would be straightforward (or combining it with copy_if_inactive
into a single inactive_mode
parameter). Do you think this is warranted?
3c6591e
to
c2f05d4
Compare
c2f05d4
to
72bda99
Compare
At long last, I think this one's ready for you to take a look. Hit me with any questions and concerns, from major design issues to bikeshedding over names. I put both the implementation and tests in their own modules because they define a lot of helpers and I didn't want to pollute other modules' namespaces. |
Codecov ReportAttention: Patch coverage is
❗ Your organization needs to install the Codecov GitHub app to enable full functionality. Additional details and impacted files@@ Coverage Diff @@
## main #1852 +/- ##
==========================================
+ Coverage 67.50% 71.34% +3.83%
==========================================
Files 31 42 +11
Lines 12668 15182 +2514
==========================================
+ Hits 8552 10831 +2279
- Misses 4116 4351 +235 ☔ View full report in Codecov by Sentry. |
Eventually, `recursive_accumulate` should be rewritten on top of a new `VectorSpace` wrapper built on `recursive_map`. Until then, this will do.
This is to explore functionality for realizing JuliaMath/QuadGK.jl#120. The current draft cuts time and allocations in half for the MWE in that PR compared to the
make_zero
hack from the comments. Not sure if modifying the existingrecursive_*
functions like this is appropriate or whether it would be better to implement a separatedeep_recursive_accumulate
.This probably breaks some existing uses of
recursive_accumulate
, like the Holomorphic derivative code, becauserecursive_accumulate
now traverses most/all of the structure on its own and will double-accumulate when combined with the iteration over theseen
IdDicts. Curious to see the total impact on the test suite.This doesn't yet have any concept of
seen
and will thus double-accumulate if the structure has internal aliasing. That obviously needs to be fixed. Perhaps we can factor out and share the recursion code frommake_zero
.A bit of a tangent, but perhaps a final version of this PR should include migrating
ClosureVector
to Enzyme from the QuadGK ext as suggested in JuliaMath/QuadGK.jl#110 (comment). Looks like that's the most relevant application of fully recursive accumulation at the moment.Let me also throw out another suggestion: what if we implement a recursive generalization of broadcasting with an arbitrary number of arguments, i.e.,
recursive_broadcast!(f, a, b, c, ...)
as a recursive generalization ofa .= f.(b, c, ...)
, free of intermediate allocations whenever possible (and similarly an out-of-placerecursive_broadcast(f, a, b, c...)
generalizingf.(a, b, c...)
that only materializes/allocates once if possible). That would enable more optimized custom rules with Duplicated args, such as having the QuadGK rule call the in-place versionquadgk!(f!, result, segs...)
. Not sure if it would be hard to correctly handle aliasing without being overly defensive, or if that could mostly be taken care of by proper reuse of the existing broadcasting functionality.