-
Notifications
You must be signed in to change notification settings - Fork 15
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
test arbitrary functions with f/rrule
-like API
#166
Conversation
We should think a bout the relationship between this and JuliaDiff/ChainRulesCore.jl#363 |
test_approx(::ZeroTangent, x, msg=""; kwargs...) = test_approx(zero(x), x, msg; kwargs...) | ||
test_approx(x, ::ZeroTangent, msg=""; kwargs...) = test_approx(x, zero(x), msg; kwargs...) | ||
test_approx(::AbstractZero, x, msg=""; kwargs...) = test_approx(zero(x), x, msg; kwargs...) | ||
test_approx(x, ::AbstractZero, msg=""; kwargs...) = test_approx(x, zero(x), msg; kwargs...) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The reason I've changed this is somewhat convoluted and the dependency goes the wrong way. But hear me out:
When writing a helper zygote_rrule
function in Zygote (see FluxML/Zygote.jl#987) one needs to transform nothing
(representing a zero gradient) to either ZeroTangent
or NoTangent
.
If we decide to transform to ZeroTangent
, test_rrule
complains that NoTangent
should be used for non-perturbable arguments (say indices).
If we decide to transform to NoTangent
, the test_approx
function throws a fit, since test_approx(::NoTangent, x)
is not defined and it falls back to the default implementation.
Alternatives:
- Do not provide
zygote_rrule
function in Zygote. It would still be possible to write a conversion for each separate function (where we know whichAbstractZero
thenothing
represents). This seems to be the cleanest option but on the other hand it would greatly reduce the usability ofrrule_f
kwarg for testing Zygote gradients, since one would need to write a custom rrule wrapper for each function that is being tested. Eww. - Pirate
test_approx(::NoTangent, x)
in Zygote. (probably very bad)
All of the above aside, I am leaning towards the current state where test_approx(::NoTangent, x)
is not defined. However in this case I suspect it would be better to be practical. After all, we are happy to define +(NoTangent, x)
to be equal to x
, so for all intents and purposes we are implicitly saying that test_approx(::NoTangent, x) = test_approx(zero(x), x)
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Rather than allowing it in general between AbstractZero
and anything, which would allow between NoTangent()
and 0.0
which seems bad.
would it be pragmatic to allow it only between AbstractZero
and NoTangent
?
Pirate test_approx(::NoTangent, x) in Zygote. (probably very bad)
Doing this in Zygote's tests though doesn't seem that awful.
(With appropriate comments)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is exactly the NoTangent()
and 0.0
that we need, sorry that wasn't clear. Which is essentially what we say when we define +(NoTangent, x) = x
.
Doing this in Zygote's tests though doesn't seem that awful.
I thought rrule_via_ad
would be in src/
, to allow usage outside the testing environment, so that one can just do
using Zygote
test_rrule(myfunc, args...; rrule_f=rrule_via_ad)
If we put the test_approx(::NoTangent, x)
in /test
the test_rrule
will behave differently in the test suite than in the REPL when using Zygote
, which might lead to confusion.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hmm Ok. lets do it.
aside: the new way to do this is going to be
test_rrule(ZygoteRuleConfig, myfunc, args...; rrule_f=rrule_via_ad)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, thanks, it's already changed in the code
f/rrule
-like APIf/rrule
-like API
Yeah that is a good point, I can't think of any differences on the top of my head. One consideration perhaps is how to deal with kwargs, since Zygote pretty much ignores them. |
So does rrule and frule. |
The differences is that in pullback((x, y) -> cat(x, y; dims=1), rand(3), rand(4))
# pullback(cat, rand(3), rand(4); dims=1) # does not work |
True. So yes Zygote's Lets hold off merging this PR til we are about to merge the calling back into AD one. Just on the principle of delaying decisions til the last minute so that we have maximum information. |
`f/rrule`s for these functions call back into AD to compute the `f/rrule` of `f`. | ||
To test these functions, we use a dummy AD system, which simply calls the appropriate rule for `f` directly. | ||
For this reason, when testing `map(f, collection)`, the rules for `f` need to be defined. | ||
The `RuleConfig` for this dummy AD system is the default one, and does not need to be provided. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
might be worth hard-coding a link to the ChainRulesCore docs
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The `RuleConfig` for this dummy AD system is the default one, and does not need to be provided. | |
The [`RuleConfig`](https://juliadiff.org/ChainRulesCore.jl/stable/config.html) for this dummy AD system is the default one, and does not need to be provided. |
Julia 1.0 error is real |
yeah on it |
Codecov Report
@@ Coverage Diff @@
## master #166 +/- ##
==========================================
- Coverage 89.17% 86.92% -2.26%
==========================================
Files 9 10 +1
Lines 268 283 +15
==========================================
+ Hits 239 246 +7
- Misses 29 37 +8
Continue to review full report at Codecov.
|
Closes #114
Closes #173
To do:
xrule_f
rrule_f
example in Zygote to test real examplesf/rrule_via_ad
test_approx(NoTangent, x) = test_approx(zero(x), x)