Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

test arbitrary functions with f/rrule-like API #166

Merged
merged 23 commits into from
Jun 11, 2021
Merged

test arbitrary functions with f/rrule-like API #166

merged 23 commits into from
Jun 11, 2021

Conversation

mzgubic
Copy link
Member

@mzgubic mzgubic commented Jun 2, 2021

Closes #114
Closes #173

To do:

  • tests for passing in xrule_f
  • tests for config
  • documentation
  • implement an actual rrule_f example in Zygote to test real examples
  • think about matching requirements for f/rrule_via_ad
  • decide whether to define test_approx(NoTangent, x) = test_approx(zero(x), x)
  • wait for configs to be merged and tagged, and update compat

@oxinabox
Copy link
Member

oxinabox commented Jun 3, 2021

We should think a bout the relationship between this and JuliaDiff/ChainRulesCore.jl#363
In particular that frule_f and rrule_f
have basically (literally?) the same requirements as the frule_via_ad and rrule_via_ad from that PR.
and so when using this for your AD you are probably having both defined.

test_approx(::ZeroTangent, x, msg=""; kwargs...) = test_approx(zero(x), x, msg; kwargs...)
test_approx(x, ::ZeroTangent, msg=""; kwargs...) = test_approx(x, zero(x), msg; kwargs...)
test_approx(::AbstractZero, x, msg=""; kwargs...) = test_approx(zero(x), x, msg; kwargs...)
test_approx(x, ::AbstractZero, msg=""; kwargs...) = test_approx(x, zero(x), msg; kwargs...)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The reason I've changed this is somewhat convoluted and the dependency goes the wrong way. But hear me out:

When writing a helper zygote_rrule function in Zygote (see FluxML/Zygote.jl#987) one needs to transform nothing (representing a zero gradient) to either ZeroTangent or NoTangent.
If we decide to transform to ZeroTangent, test_rrule complains that NoTangent should be used for non-perturbable arguments (say indices).
If we decide to transform to NoTangent, the test_approx function throws a fit, since test_approx(::NoTangent, x) is not defined and it falls back to the default implementation.

Alternatives:

  • Do not provide zygote_rrule function in Zygote. It would still be possible to write a conversion for each separate function (where we know which AbstractZero the nothing represents). This seems to be the cleanest option but on the other hand it would greatly reduce the usability of rrule_f kwarg for testing Zygote gradients, since one would need to write a custom rrule wrapper for each function that is being tested. Eww.
  • Pirate test_approx(::NoTangent, x) in Zygote. (probably very bad)

All of the above aside, I am leaning towards the current state where test_approx(::NoTangent, x) is not defined. However in this case I suspect it would be better to be practical. After all, we are happy to define +(NoTangent, x) to be equal to x, so for all intents and purposes we are implicitly saying that test_approx(::NoTangent, x) = test_approx(zero(x), x).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rather than allowing it in general between AbstractZero and anything, which would allow between NoTangent() and 0.0 which seems bad.
would it be pragmatic to allow it only between AbstractZero and NoTangent ?

Pirate test_approx(::NoTangent, x) in Zygote. (probably very bad)

Doing this in Zygote's tests though doesn't seem that awful.
(With appropriate comments)

Copy link
Member Author

@mzgubic mzgubic Jun 3, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is exactly the NoTangent() and 0.0 that we need, sorry that wasn't clear. Which is essentially what we say when we define +(NoTangent, x) = x.

Doing this in Zygote's tests though doesn't seem that awful.

I thought rrule_via_ad would be in src/, to allow usage outside the testing environment, so that one can just do

using Zygote

test_rrule(myfunc, args...; rrule_f=rrule_via_ad)

If we put the test_approx(::NoTangent, x) in /test the test_rrule will behave differently in the test suite than in the REPL when using Zygote, which might lead to confusion.

Copy link
Member

@oxinabox oxinabox Jun 10, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm Ok. lets do it.

aside: the new way to do this is going to be

test_rrule(ZygoteRuleConfig, myfunc, args...; rrule_f=rrule_via_ad)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, thanks, it's already changed in the code

test/testers.jl Outdated Show resolved Hide resolved
@mzgubic mzgubic changed the title WIP: test arbitrary functions with f/rrule-like API test arbitrary functions with f/rrule-like API Jun 3, 2021
Project.toml Outdated Show resolved Hide resolved
@mzgubic
Copy link
Member Author

mzgubic commented Jun 4, 2021

We should think a bout the relationship between this and JuliaDiff/ChainRulesCore.jl#363
In particular that frule_f and rrule_f
have basically (literally?) the same requirements as the frule_via_ad and rrule_via_ad from that PR.
and so when using this for your AD you are probably having both defined.

Yeah that is a good point, I can't think of any differences on the top of my head. One consideration perhaps is how to deal with kwargs, since Zygote pretty much ignores them.

@oxinabox
Copy link
Member

oxinabox commented Jun 4, 2021

One consideration perhaps is how to deal with kwargs, since Zygote pretty much ignores them.

So does rrule and frule.
I have an idea how to make them work, that works for Zygote and rrule/frule: defining it at the positional kwsorter level.
But we don't need to worry about it til later.
It's so far never come up as a problem.
In the years of people doing AD in julia

@mzgubic
Copy link
Member Author

mzgubic commented Jun 4, 2021

One consideration perhaps is how to deal with kwargs, since Zygote pretty much ignores them.

So does rrule and frule.
I have an idea how to make them work, that works for Zygote and rrule/frule: defining it at the positional kwsorter level.
But we don't need to worry about it til later.
It's so far never come up as a problem.
In the years of people doing AD in julia

The differences is that in f/rrules one can pass the kwargs to the primal pass, but in Zygote

pullback((x, y) -> cat(x, y; dims=1), rand(3), rand(4))
# pullback(cat, rand(3), rand(4); dims=1) # does not work

@oxinabox
Copy link
Member

oxinabox commented Jun 4, 2021

The differences is that in f/rrules one can pass the kwargs to the primal pass, but in Zygote

pullback((x, y) -> cat(x, y; dims=1), rand(3), rand(4))
# pullback(cat, rand(3), rand(4); dims=1) # does not work

True. So yes Zygote's rrule_f and it's rrule_via_ad would need to accept kwargs, then put them into a closure which it calls it's pullback on (maan I hate that name).

Lets hold off merging this PR til we are about to merge the calling back into AD one.

Just on the principle of delaying decisions til the last minute so that we have maximum information.
We might spot some additonal need or synergy

`f/rrule`s for these functions call back into AD to compute the `f/rrule` of `f`.
To test these functions, we use a dummy AD system, which simply calls the appropriate rule for `f` directly.
For this reason, when testing `map(f, collection)`, the rules for `f` need to be defined.
The `RuleConfig` for this dummy AD system is the default one, and does not need to be provided.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

might be worth hard-coding a link to the ChainRulesCore docs

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
The `RuleConfig` for this dummy AD system is the default one, and does not need to be provided.
The [`RuleConfig`](https://juliadiff.org/ChainRulesCore.jl/stable/config.html) for this dummy AD system is the default one, and does not need to be provided.

@oxinabox
Copy link
Member

oxinabox commented Jun 11, 2021

Julia 1.0 error is real
LoadError: error compiling top-level scope: type definition not allowed inside a local scope
types in tests need to be declared at top-level scope .

@mzgubic
Copy link
Member Author

mzgubic commented Jun 11, 2021

yeah on it

@codecov-commenter
Copy link

Codecov Report

Merging #166 (71a08c9) into master (83f6ad3) will decrease coverage by 2.25%.
The diff coverage is 66.66%.

Impacted file tree graph

@@            Coverage Diff             @@
##           master     #166      +/-   ##
==========================================
- Coverage   89.17%   86.92%   -2.26%     
==========================================
  Files           9       10       +1     
  Lines         268      283      +15     
==========================================
+ Hits          239      246       +7     
- Misses         29       37       +8     
Impacted Files Coverage Δ
src/ChainRulesTestUtils.jl 80.00% <ø> (ø)
src/rule_config.jl 11.11% <11.11%> (ø)
src/check_result.jl 88.88% <100.00%> (ø)
src/testers.jl 93.06% <100.00%> (+0.43%) ⬆️

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update 83f6ad3...71a08c9. Read the comment docs.

@mzgubic mzgubic merged commit 8c5da46 into master Jun 11, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
3 participants