Skip to content

Commit

Permalink
Merge pull request #6 from JuliaDiff/MB/v1
Browse files Browse the repository at this point in the history
Initial Port from ChainRules.jl
  • Loading branch information
oxinabox authored Jan 31, 2020
2 parents 1c39ca4 + 963922d commit 90ddcae
Show file tree
Hide file tree
Showing 6 changed files with 251 additions and 1 deletion.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,6 @@ docs/site/
# committed for packages, but should be committed for applications that require a static
# environment.
Manifest.toml

# JetBrains meta files
.idea/*
19 changes: 19 additions & 0 deletions .travis.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
language: julia
os:
- linux
- osx
julia:
- 1.0
- 1.3
- nightly

notifications:
email:
recipients:
- nightly-rse@invenia.ca
on_success: never
on_failure: always
if: type = cron
matrix:
allow_failures:
- julia: nightly
19 changes: 19 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
name = "ChainRulesTestUtils"
uuid = "cdddcdb0-9152-4a09-a978-84456f9df70a"
version = "0.1.0"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"

[compat]
ChainRulesCore = "0.5, 0.6"
FiniteDifferences = "0.7, 0.8, 0.9"
julia = "1"

[extras]
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["Random", "Test"]
6 changes: 5 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1 +1,5 @@
# ChainRulesTestUtils.jl
# ChainRulesTestUtils.jl

[![Travis](https://travis-ci.org/JuliaDiff/ChainRulesTestUtils.jl.svg?branch=master)](https://travis-ci.org/JuliaDiff/ChainRulesTestUtils.jl)

`ChainRulesTestUtils.jl` provides a variety of common utilities for testing forward- and reverse- primitives.
182 changes: 182 additions & 0 deletions src/ChainRulesTestUtils.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,182 @@
module ChainRulesTestUtils

using ChainRulesCore
using ChainRulesCore: frule, rrule
using ChainRulesCore: AbstractDifferential
using FiniteDifferences
using Test

const _fdm = central_fdm(5, 1)

export test_scalar, frule_test, rrule_test, isapprox, generate_well_conditioned_matrix

Base.isapprox(d_ad::DoesNotExist, d_fd; kwargs...) = error("Tried to differentiate w.r.t. a `DoesNotExist`")
Base.isapprox(d_ad::AbstractDifferential, d_fd; kwargs...) = isapprox(extern(d_ad), d_fd; kwargs...)

function _make_fdm_call(fdm, f, ȳ, xs, ignores)
sig = Expr(:tuple)
call = Expr(:call, f)
newxs = Any[]
arginds = Int[]
i = 1
for (x, ignore) in zip(xs, ignores)
if ignore
push!(call.args, x)
else
push!(call.args, Symbol(:x, i))
push!(sig.args, Symbol(:x, i))
push!(newxs, x)
push!(arginds, i)
end
i += 1
end
fdexpr = :(j′vp($fdm, $sig -> $call, $ȳ, $(newxs...)))
fd = eval(fdexpr)
fd isa Tuple || (fd = (fd,))
args = Any[nothing for _ in 1:length(xs)]
for (dx, ind) in zip(fd, arginds)
args[ind] = dx
end
return (args...,)
end

# Useful for LinearAlgebra tests
function generate_well_conditioned_matrix(rng, N)
A = randn(rng, N, N)
return A * A' + I
end

"""
test_scalar(f, x; rtol=1e-9, atol=1e-9, fdm=central_fdm(5, 1), kwargs...)
Given a function `f` with scalar input and scalar output, perform finite differencing checks,
at input point `x` to confirm that there are correct `frule` and `rrule`s provided.
# Arguments
- `f`: Function for which the `frule` and `rrule` should be tested.
- `x`: input at which to evaluate `f` (should generally be set to an arbitary point in the domain).
All keyword arguments except for `fdm` is passed to `isapprox`.
"""
function test_scalar(f, x; rtol=1e-9, atol=1e-9, fdm=_fdm, kwargs...)
ensure_not_running_on_functor(f, "test_scalar")

r_res = rrule(f, x)
f_res = frule(f, x, Zero(), 1)
@test r_res !== nothing # Check the rule was defined
@test f_res !== nothing
r_fx, prop_rule = r_res
f_fx, f_∂x = f_res
@testset "$f at $x, $(nameof(rule))" for (rule, fx, ∂x) in (
(rrule, r_fx, prop_rule(1)),
(frule, f_fx, f_∂x)
)
@test fx == f(x) # Check we still get the normal value, right

if rule == rrule
∂self, ∂x = ∂x
@test ∂self === NO_FIELDS
end
@test isapprox(∂x, fdm(f, x); rtol=rtol, atol=atol, kwargs...)
end
end

function ensure_not_running_on_functor(f, name)
# if x itself is a Type, then it is a constructor, thus not a functor.
# This also catchs UnionAll constructors which have a `:var` and `:body` fields
f isa Type && return

if fieldcount(typeof(f)) > 0
throw(ArgumentError(
"$name cannot be used on closures/functors (such as $f)"
))
end
end

"""
frule_test(f, (x, ẋ)...; rtol=1e-9, atol=1e-9, fdm=central_fdm(5, 1), kwargs...)
# Arguments
- `f`: Function for which the `frule` should be tested.
- `x`: input at which to evaluate `f` (should generally be set to an arbitary point in the domain).
- `ẋ`: differential w.r.t. `x` (should generally be set randomly).
All keyword arguments except for `fdm` are passed to `isapprox`.
"""
function frule_test(f, (x, ẋ); rtol=1e-9, atol=1e-9, fdm=_fdm, kwargs...)
return frule_test(f, ((x, ẋ),); rtol=rtol, atol=atol, fdm=fdm, kwargs...)
end

function frule_test(f, xẋs::Tuple{Any, Any}...; rtol=1e-9, atol=1e-9, fdm=_fdm, kwargs...)
ensure_not_running_on_functor(f, "frule_test")
xs, ẋs = collect(zip(xẋs...))
Ω, dΩ_ad = frule(f, xs..., NO_FIELDS, ẋs...)
@test f(xs...) == Ω

# Correctness testing via finite differencing.
dΩ_fd = jvp(fdm, xs->f(xs...), (xs, ẋs))
@test isapprox(
collect(extern.(dΩ_ad)), # Use collect so can use vector equality
collect(dΩ_fd);
rtol=rtol,
atol=atol,
kwargs...
)
end


"""
rrule_test(f, ȳ, (x, x̄)...; rtol=1e-9, atol=1e-9, fdm=central_fdm(5, 1), kwargs...)
# Arguments
- `f`: Function to which rule should be applied.
- `ȳ`: adjoint w.r.t. output of `f` (should generally be set randomly).
Should be same structure as `f(x)` (so if multiple returns should be a tuple)
- `x`: input at which to evaluate `f` (should generally be set to an arbitary point in the domain).
- `x̄`: currently accumulated adjoint (should generally be set randomly).
All keyword arguments except for `fdm` are passed to `isapprox`.
"""
function rrule_test(f, ȳ, (x, x̄)::Tuple{Any, Any}; rtol=1e-9, atol=1e-9, fdm=_fdm, kwargs...)
ensure_not_running_on_functor(f, "rrule_test")

# Check correctness of evaluation.
fx, pullback = rrule(f, x)
@test collect(fx) collect(f(x)) # use collect so can do vector equality
(∂self, x̄_ad) = if fx isa Tuple
# If the function returned multiple values,
# then it must have multiple seeds for propagating backwards
pullback(ȳ...)
else
pullback(ȳ)
end

@test ∂self === NO_FIELDS # No internal fields
# Correctness testing via finite differencing.
x̄_fd = j′vp(fdm, f, ȳ, x)
@test isapprox(x̄_ad, x̄_fd; rtol=rtol, atol=atol, kwargs...)
end

# case where `f` takes multiple arguments
function rrule_test(f, ȳ, xx̄s::Tuple{Any, Any}...; rtol=1e-9, atol=1e-9, fdm=_fdm, kwargs...)
ensure_not_running_on_functor(f, "rrule_test")

# Check correctness of evaluation.
xs, x̄s = collect(zip(xx̄s...))
y, pullback = rrule(f, xs...)
@test f(xs...) == y

@assert !(isa(ȳ, Thunk))
∂s = pullback(ȳ)
∂self = ∂s[1]
x̄s_ad = ∂s[2:end]
@test ∂self === NO_FIELDS

# Correctness testing via finite differencing.
x̄s_fd = j′vp(fdm, f, ȳ, xs...)
map(x̄s_ad, x̄s_fd) do x̄_ad, x̄_fd
@test isapprox(x̄_ad, x̄_fd; rtol=rtol, atol=atol, kwargs...)
end
end

end # module
23 changes: 23 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
using ChainRulesCore
using ChainRulesTestUtils
using Random
using Test

@testset "ChainRulesTestUtils.jl" begin
double(x) = 2x
@scalar_rule(double(x), 2)
test_scalar(double, 2)

fst(x, y) = x
ChainRulesCore.frule(::typeof(fst), x, y, _, dx, dy) = (x, dx)

function ChainRulesCore.rrule(::typeof(fst), x, y)
function fst_pullback(Δx)
return (NO_FIELDS, Δx, Zero())
end
return x, fst_pullback
end

frule_test(fst, (2, 4.0), (3, 5.0))
rrule_test(fst, rand(), (2.0, 4.0), (3.0, 5.0))
end

2 comments on commit 90ddcae

@oxinabox
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator register()

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/8713

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if Julia TagBot is installed, or can be done manually through the github interface, or via:

git tag -a v0.1.0 -m "<description of version>" 90ddcae91ac6af04effaf8631fc0fce5655c4a5b
git push origin v0.1.0

Please sign in to comment.