Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

test arbitrary functions with f/rrule-like API #166

Merged
merged 23 commits into from
Jun 11, 2021
Merged
Show file tree
Hide file tree
Changes from 18 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "ChainRulesTestUtils"
uuid = "cdddcdb0-9152-4a09-a978-84456f9df70a"
version = "0.7.7"
version = "0.7.8"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Expand Down
4 changes: 4 additions & 0 deletions docs/make.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,10 @@ makedocs(;
format=Documenter.HTML(; prettyurls=false, assets=["assets/chainrules.css"]),
sitename="ChainRulesTestUtils",
authors="JuliaDiff contributors",
pages=[
"ChainRulesTestUtils" => "index.md",
"API" => "api.md",
],
strict=true,
checkdocs=:exports,
)
Expand Down
6 changes: 6 additions & 0 deletions docs/src/api.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
# API Documentation

```@autodocs
Modules = [ChainRulesTestUtils]
Private = false
```
68 changes: 61 additions & 7 deletions docs/src/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,67 @@ In particular, when specifying the input tangents to [`test_frule`](@ref) and th
As these tangents are used to seed the derivative computation.
Inserting inappropriate zeros can thus hide errors.

## Testing higher order functions

Higher order functions, such as `map`, take a function (or a functor) `f` as an argument.
`f/rrule`s for these functions call back into AD to compute the `f/rrule` of `f`.
To test these functions, we use a dummy AD system, which simply calls the appropriate rule for `f` directly.
For this reason, when testing `map(f, collection)`, the rules for `f` need to be defined.
The `RuleConfig` for this dummy AD system is the default one, and does not need to be provided.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

might be worth hard-coding a link to the ChainRulesCore docs

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
The `RuleConfig` for this dummy AD system is the default one, and does not need to be provided.
The [`RuleConfig`](https://juliadiff.org/ChainRulesCore.jl/stable/config.html) for this dummy AD system is the default one, and does not need to be provided.

```julia
test_rrule(map, x->2x [1, 2, 3.]) # fails, because there is no rrule for x->2x

mydouble(x) = 2x
function ChainRulesCore.rrule(::typeof(mydouble), x)
mydouble_pullback(ȳ) = (NoTangent(), ȳ)
return mydouble(x), mydouble_pullback
end
test_rrule(map, mydouble, [1, 2, 3.]) # works
```

## Testing AD systems

The gradients computed by AD systems can be also be tested using `test_rrule`.
To do that, one needs to provide an `rrule_f`/`frule_f` keyword argument, as well as the `RuleConfig` used by the AD system.
`rrule_f` is a function that wraps the gradient computation by an AD system in the same API as the `rrule`.
`RuleConfig` is an object that determines which sets of rules are defined for an AD system.
For example, let's say we have a complicated function

```julia
function complicated(x, y)
return do(x + y) + some(x) * hard(y) + maths(x * y)
end
```

that we do not know an `rrule` for, and we want to check whether the gradients provided by the AD system are correct.

Firstly, we need to define an `rrule`-like function which wraps the gradients computed by AD.

Let's say the AD package uses some custom differential types and does not provide a gradient w.r.t. the function itself.
In order to make the pullback compatible with the `rrule` API we need to add a `NoTangent()` to represent the differential w.r.t. the function itself.
We also need to transform the `ChainRules` differential types to the custom types (`cr2custom`) before feeding the `Δ` to the AD-generated pullback, and back to `ChainRules` differential types when returning from the `rrule` (`custom2cr`).

```julia
function ad_rrule(f::Function, args...)
y, ad_pullback = ADSystem.pullback(f, args...)
function rrulelike_pullback(Δ)
diffs = custom2cr(ad_pullback(cr2custom(Δ)))
return NoTangent(), diffs...
end

return y, rrulelike_pullback
end

custom2cr(differential) = ...
cr2custom(differential) = ...
```
Secondly, we use the `test_rrule` function to test the gradients using the config used by the AD system
```julia
config = MyAD.CustomRuleConfig()
test_rrule(config, complicated, 2.3, 6.1; rrule_f=ad_rrule)
```
by specifying the `ad_rrule` as the `rrule_f` keyword argument.

## Custom finite differencing

If a package is using a custom finite differencing method of testing the `frule`s and `rrule`s, `test_approx` function provides a convenient way of comparing [various types](https://www.juliadiff.org/ChainRulesCore.jl/dev/design/many_differentials.html#Design-Notes:-The-many-to-many-relationship-between-differential-types-and-primal-types.) of differentials.
Expand Down Expand Up @@ -199,10 +260,3 @@ Test.DefaultTestSet("test_rrule: abs on Float64", Any[], 5, false, false)

This behavior can also be overridden globally by setting the environment variable `CHAINRULES_TEST_INFERRED` before ChainRulesTestUtils is loaded or by changing `ChainRulesTestUtils.TEST_INFERRED[]` from inside Julia.
ChainRulesTestUtils can detect whether a test is run as part of [PkgEval](https://github.com/JuliaCI/PkgEval.jl)and in this case disables inference tests automatically. Packages can use [`@maybe_inferred`](@ref) to get the same behavior for other inference tests.

# API Documentation

```@autodocs
Modules = [ChainRulesTestUtils]
Private = false
```
1 change: 1 addition & 0 deletions src/ChainRulesTestUtils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ include("iterator.jl")
include("output_control.jl")
include("check_result.jl")

include("rule_config.jl")
include("finite_difference_calls.jl")
include("testers.jl")

Expand Down
4 changes: 2 additions & 2 deletions src/check_result.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@ for (T1, T2) in ((AbstractThunk, Any), (AbstractThunk, AbstractThunk), (Any, Abs
end
end

test_approx(::ZeroTangent, x, msg=""; kwargs...) = test_approx(zero(x), x, msg; kwargs...)
test_approx(x, ::ZeroTangent, msg=""; kwargs...) = test_approx(x, zero(x), msg; kwargs...)
test_approx(::AbstractZero, x, msg=""; kwargs...) = test_approx(zero(x), x, msg; kwargs...)
test_approx(x, ::AbstractZero, msg=""; kwargs...) = test_approx(x, zero(x), msg; kwargs...)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The reason I've changed this is somewhat convoluted and the dependency goes the wrong way. But hear me out:

When writing a helper zygote_rrule function in Zygote (see FluxML/Zygote.jl#987) one needs to transform nothing (representing a zero gradient) to either ZeroTangent or NoTangent.
If we decide to transform to ZeroTangent, test_rrule complains that NoTangent should be used for non-perturbable arguments (say indices).
If we decide to transform to NoTangent, the test_approx function throws a fit, since test_approx(::NoTangent, x) is not defined and it falls back to the default implementation.

Alternatives:

  • Do not provide zygote_rrule function in Zygote. It would still be possible to write a conversion for each separate function (where we know which AbstractZero the nothing represents). This seems to be the cleanest option but on the other hand it would greatly reduce the usability of rrule_f kwarg for testing Zygote gradients, since one would need to write a custom rrule wrapper for each function that is being tested. Eww.
  • Pirate test_approx(::NoTangent, x) in Zygote. (probably very bad)

All of the above aside, I am leaning towards the current state where test_approx(::NoTangent, x) is not defined. However in this case I suspect it would be better to be practical. After all, we are happy to define +(NoTangent, x) to be equal to x, so for all intents and purposes we are implicitly saying that test_approx(::NoTangent, x) = test_approx(zero(x), x).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rather than allowing it in general between AbstractZero and anything, which would allow between NoTangent() and 0.0 which seems bad.
would it be pragmatic to allow it only between AbstractZero and NoTangent ?

Pirate test_approx(::NoTangent, x) in Zygote. (probably very bad)

Doing this in Zygote's tests though doesn't seem that awful.
(With appropriate comments)

Copy link
Member Author

@mzgubic mzgubic Jun 3, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is exactly the NoTangent() and 0.0 that we need, sorry that wasn't clear. Which is essentially what we say when we define +(NoTangent, x) = x.

Doing this in Zygote's tests though doesn't seem that awful.

I thought rrule_via_ad would be in src/, to allow usage outside the testing environment, so that one can just do

using Zygote

test_rrule(myfunc, args...; rrule_f=rrule_via_ad)

If we put the test_approx(::NoTangent, x) in /test the test_rrule will behave differently in the test suite than in the REPL when using Zygote, which might lead to confusion.

Copy link
Member

@oxinabox oxinabox Jun 10, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm Ok. lets do it.

aside: the new way to do this is going to be

test_rrule(ZygoteRuleConfig, myfunc, args...; rrule_f=rrule_via_ad)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, thanks, it's already changed in the code

test_approx(x::ZeroTangent, y::ZeroTangent, msg=""; kwargs...) = @test true

# remove once https://github.com/JuliaDiff/ChainRulesTestUtils.jl/issues/113
Expand Down
16 changes: 16 additions & 0 deletions src/rule_config.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# For testing this config re-dispatches Xrule_via_ad to Xrule without config argument
struct ADviaRuleConfig <: RuleConfig{Union{HasReverseMode, HasForwardsMode}} end

function ChainRulesCore.frule_via_ad(config::ADviaRuleConfig, ȧrgs, f, args...; kws...)
ret = frule(config, ȧrgs, f, args...; kws...)
# we don't support actually doing AD: the rule has to exist. lets give helpfulish error
ret === nothing && throw(MethodError(frule, (ȧrgs, f, args...)))
return ret
end

function ChainRulesCore.rrule_via_ad(config::ADviaRuleConfig, f, args...; kws...)
ret = rrule(config, f, args...; kws...)
# we don't support actually doing AD: the rule has to exist. lets give helpfulish error
ret === nothing && throw(MethodError(rrule, (f, args...)))
return ret
end
42 changes: 31 additions & 11 deletions src/testers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -67,29 +67,39 @@ function test_scalar(f, z; rtol=1e-9, atol=1e-9, fdm=_fdm, fkwargs=NamedTuple(),
end

"""
test_frule(f, args..; kwargs...)
test_frule([config::RuleConfig,] f, args..; kwargs...)

# Arguments
- `config`: defaults to `ChainRulesTestUtils.ADviaRuleConfig`.
- `f`: Function for which the `frule` should be tested. Can also provide `f ⊢ ḟ`.
- `args` either the primal args `x`, or primals and their tangents: `x ⊢ ẋ`
- `x`: input at which to evaluate `f` (should generally be set to an arbitary point in the domain).
- ``: differential w.r.t. `x`, will be generated automatically if not provided
Non-differentiable arguments, such as indices, should have `` set as `NoTangent()`.
- ``: differential w.r.t. `x`, will be generated automatically if not provided
Non-differentiable arguments, such as indices, should have `` set as `NoTangent()`.

# Keyword Arguments
- `output_tangent` tangent to test accumulation of derivatives against
should be a differential for the output of `f`. Is set automatically if not provided.
- `fdm::FiniteDifferenceMethod`: the finite differencing method to use.
- `frule_f=frule`: Function with an `frule`-like API that is tested (defaults to
`frule`). Used for testing gradients from AD systems.
- If `check_inferred=true`, then the inferrability of the `frule` is checked,
as long as `f` is itself inferrable.
- `fkwargs` are passed to `f` as keyword arguments.
- All remaining keyword arguments are passed to `isapprox`.
"""
function test_frule(args...; kwargs...)
config = ChainRulesTestUtils.ADviaRuleConfig()
test_frule(config, args...; kwargs...)
end

function test_frule(
config::RuleConfig,
f,
args...;
output_tangent=Auto(),
fdm=_fdm,
frule_f=ChainRulesCore.frule,
check_inferred::Bool=true,
fkwargs::NamedTuple=NamedTuple(),
rtol::Real=1e-9,
Expand All @@ -109,11 +119,11 @@ function test_frule(
tangents = tangent.(primals_and_tangents)

if check_inferred && _is_inferrable(deepcopy(primals)...; deepcopy(fkwargs)...)
_test_inferred(frule, deepcopy(tangents), deepcopy(primals)...; deepcopy(fkwargs)...)
_test_inferred(frule_f, deepcopy(config), deepcopy(tangents), deepcopy(primals)...; deepcopy(fkwargs)...)
end

res = frule(deepcopy(tangents), deepcopy(primals)...; deepcopy(fkwargs)...)
res === nothing && throw(MethodError(frule, typeof(primals)))
res = frule_f(deepcopy(config), deepcopy(tangents), deepcopy(primals)...; deepcopy(fkwargs)...)
res === nothing && throw(MethodError(frule_f, typeof(primals)))
@test_msg "The frule should return (y, ∂y), not $res." res isa Tuple{Any,Any}
Ω_ad, dΩ_ad = res
Ω = call_on_copy(primals...)
Expand All @@ -139,9 +149,10 @@ function test_frule(
end

"""
test_rrule(f, args...; kwargs...)
test_rrule([config::RuleConfig,] f, args...; kwargs...)

# Arguments
- `config`: defaults to `ChainRulesTestUtils.ADviaRuleConfig`.
- `f`: Function to which rule should be applied. Can also provide `f ⊢ f̄`.
- `args` either the primal args `x`, or primals and their tangents: `x ⊢ x̄`
- `x`: input at which to evaluate `f` (should generally be set to an arbitary point in the domain).
Expand All @@ -152,16 +163,25 @@ end
- `output_tangent` the seed to propagate backward for testing (technically a cotangent).
should be a differential for the output of `f`. Is set automatically if not provided.
- `fdm::FiniteDifferenceMethod`: the finite differencing method to use.
- `rrule_f=rrule`: Function with an `rrule`-like API that is tested (defaults to `rrule`).
Used for testing gradients from AD systems.
- If `check_inferred=true`, then the inferrability of the `rrule` is checked
— if `f` is itself inferrable — along with the inferrability of the pullback it returns.
- `fkwargs` are passed to `f` as keyword arguments.
- All remaining keyword arguments are passed to `isapprox`.
"""
function test_rrule(args...; kwargs...)
config = ChainRulesTestUtils.ADviaRuleConfig()
test_rrule(config, args...; kwargs...)
end

function test_rrule(
config::RuleConfig,
f,
args...;
output_tangent=Auto(),
fdm=_fdm,
rrule_f=ChainRulesCore.rrule,
check_inferred::Bool=true,
fkwargs::NamedTuple=NamedTuple(),
rtol::Real=1e-9,
Expand All @@ -182,15 +202,15 @@ function test_rrule(
accum_cotangents = tangent.(primals_and_tangents)

if check_inferred && _is_inferrable(primals...; fkwargs...)
_test_inferred(rrule, primals...; fkwargs...)
_test_inferred(rrule_f, config, primals...; fkwargs...)
end
res = rrule(primals...; fkwargs...)
res === nothing && throw(MethodError(rrule, typeof((primals...))))
res = rrule_f(config, primals...; fkwargs...)
res === nothing && throw(MethodError(rrule_f, typeof(primals)))
y_ad, pullback = res
y = call(primals...)
test_approx(y_ad, y; isapprox_kwargs...) # make sure primal is correct

= output_tangent isa Auto ? rand_tangent(y) : output_tangent
ȳ = output_tangent isa Auto ? rand_tangent(y) : output_tangent

check_inferred && _test_inferred(pullback, ȳ)
ad_cotangents = pullback(ȳ)
Expand Down
50 changes: 50 additions & 0 deletions test/testers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -599,6 +599,56 @@ end
test_rrule(f_notimplemented, randn(), randn())
end

@testset "custom rrule_f" begin
only2x(x, y) = 2x
custom(::RuleConfig, ::typeof(only2x), x, y) = only2x(x, y), Δ -> (NoTangent(), 2Δ, ZeroTangent())
wrong1(::RuleConfig, ::typeof(only2x), x, y) = only2x(x, y), Δ -> (ZeroTangent(), 2Δ, ZeroTangent())
wrong2(::RuleConfig, ::typeof(only2x), x, y) = only2x(x, y), Δ -> (NoTangent(), 2.1Δ, ZeroTangent())
wrong3(::RuleConfig, ::typeof(only2x), x, y) = only2x(x, y), Δ -> (NoTangent(), 2Δ)

test_rrule(only2x, 2.0, 3.0; rrule_f=custom, check_inferred=false)
@test errors(() -> test_rrule(only2x, 2.0, 3.0; rrule_f=wrong1, check_inferred=false))
@test fails(() -> test_rrule(only2x, 2.0, 3.0; rrule_f=wrong2, check_inferred=false))
@test fails(() -> test_rrule(only2x, 2.0, 3.0; rrule_f=wrong3, check_inferred=false))
end

@testset "custom frule_f" begin
mytuple(x, y) = return 2x, 1.0
T = Tuple{Float64, Float64}
custom(::RuleConfig, (Δf, Δx, Δy), ::typeof(mytuple), x, y) = mytuple(x, y), Tangent{T}(2Δx, ZeroTangent())
wrong1(::RuleConfig, (Δf, Δx, Δy), ::typeof(mytuple), x, y) = mytuple(x, y), Tangent{T}(2.1Δx, ZeroTangent())
wrong2(::RuleConfig, (Δf, Δx, Δy), ::typeof(mytuple), x, y) = mytuple(x, y), Tangent{T}(2Δx, 1.0)

test_frule(mytuple, 2.0, 3.0; frule_f=custom, check_inferred=false)
@test fails(() -> test_frule(mytuple, 2.0, 3.0; frule_f=wrong1, check_inferred=false))
@test fails(() -> test_frule(mytuple, 2.0, 3.0; frule_f=wrong2, check_inferred=false))
end

@testset "custom_config" begin
abstract type MySpecialTrait end
struct MySpecialConfig <: RuleConfig{Union{MySpecialTrait}} end

has_config(x) = 2x
function ChainRulesCore.rrule(::MySpecialConfig, ::typeof(has_config), x)
has_config_pullback(ȳ) = return (NoTangent(), 2ȳ)
return has_config(x), has_config_pullback
end

has_trait(x) = 2x
function ChainRulesCore.rrule(::RuleConfig{<:MySpecialTrait}, ::typeof(has_trait), x)
has_trait_pullback(ȳ) = return (NoTangent(), 2ȳ)
return has_trait(x), has_trait_pullback
end

# it works if the special config is provided
test_rrule(MySpecialConfig(), has_config, rand())
test_rrule(MySpecialConfig(), has_trait, rand())

# but it doesn't work for the default config
errors(() -> test_rrule(has_config, rand()), "no method matching rrule")
errors(() -> test_rrule(has_trait, rand()), "no method matching rrule")
end

@testset "@maybe_inferred" begin
f_noninferrable(x) = Ref{Real}(x)[]

Expand Down