Skip to content

Commit

Permalink
Init
Browse files Browse the repository at this point in the history
  • Loading branch information
mattBrzezinski committed Jan 31, 2020
1 parent 1c39ca4 commit 963922d
Show file tree
Hide file tree
Showing 6 changed files with 251 additions and 1 deletion.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,6 @@ docs/site/
# committed for packages, but should be committed for applications that require a static
# environment.
Manifest.toml

# JetBrains meta files
.idea/*
19 changes: 19 additions & 0 deletions .travis.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
language: julia
os:
- linux
- osx
julia:
- 1.0
- 1.3
- nightly

notifications:
email:
recipients:
- nightly-rse@invenia.ca
on_success: never
on_failure: always
if: type = cron
matrix:
allow_failures:
- julia: nightly
19 changes: 19 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
name = "ChainRulesTestUtils"
uuid = "cdddcdb0-9152-4a09-a978-84456f9df70a"
version = "0.1.0"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"

[compat]
ChainRulesCore = "0.5, 0.6"
FiniteDifferences = "0.7, 0.8, 0.9"
julia = "1"

[extras]
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["Random", "Test"]
6 changes: 5 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1 +1,5 @@
# ChainRulesTestUtils.jl
# ChainRulesTestUtils.jl

[![Travis](https://travis-ci.org/JuliaDiff/ChainRulesTestUtils.jl.svg?branch=master)](https://travis-ci.org/JuliaDiff/ChainRulesTestUtils.jl)

`ChainRulesTestUtils.jl` provides a variety of common utilities for testing forward- and reverse- primitives.
182 changes: 182 additions & 0 deletions src/ChainRulesTestUtils.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,182 @@
module ChainRulesTestUtils

using ChainRulesCore
using ChainRulesCore: frule, rrule
using ChainRulesCore: AbstractDifferential
using FiniteDifferences
using Test

const _fdm = central_fdm(5, 1)

export test_scalar, frule_test, rrule_test, isapprox, generate_well_conditioned_matrix

Base.isapprox(d_ad::DoesNotExist, d_fd; kwargs...) = error("Tried to differentiate w.r.t. a `DoesNotExist`")
Base.isapprox(d_ad::AbstractDifferential, d_fd; kwargs...) = isapprox(extern(d_ad), d_fd; kwargs...)

function _make_fdm_call(fdm, f, ȳ, xs, ignores)
sig = Expr(:tuple)
call = Expr(:call, f)
newxs = Any[]
arginds = Int[]
i = 1
for (x, ignore) in zip(xs, ignores)
if ignore
push!(call.args, x)
else
push!(call.args, Symbol(:x, i))
push!(sig.args, Symbol(:x, i))
push!(newxs, x)
push!(arginds, i)
end
i += 1
end
fdexpr = :(j′vp($fdm, $sig -> $call, $ȳ, $(newxs...)))
fd = eval(fdexpr)
fd isa Tuple || (fd = (fd,))
args = Any[nothing for _ in 1:length(xs)]
for (dx, ind) in zip(fd, arginds)
args[ind] = dx
end
return (args...,)
end

# Useful for LinearAlgebra tests
function generate_well_conditioned_matrix(rng, N)
A = randn(rng, N, N)
return A * A' + I
end

"""
test_scalar(f, x; rtol=1e-9, atol=1e-9, fdm=central_fdm(5, 1), kwargs...)
Given a function `f` with scalar input and scalar output, perform finite differencing checks,
at input point `x` to confirm that there are correct `frule` and `rrule`s provided.
# Arguments
- `f`: Function for which the `frule` and `rrule` should be tested.
- `x`: input at which to evaluate `f` (should generally be set to an arbitary point in the domain).
All keyword arguments except for `fdm` is passed to `isapprox`.
"""
function test_scalar(f, x; rtol=1e-9, atol=1e-9, fdm=_fdm, kwargs...)
ensure_not_running_on_functor(f, "test_scalar")

r_res = rrule(f, x)
f_res = frule(f, x, Zero(), 1)
@test r_res !== nothing # Check the rule was defined
@test f_res !== nothing
r_fx, prop_rule = r_res
f_fx, f_∂x = f_res
@testset "$f at $x, $(nameof(rule))" for (rule, fx, ∂x) in (
(rrule, r_fx, prop_rule(1)),
(frule, f_fx, f_∂x)
)
@test fx == f(x) # Check we still get the normal value, right

if rule == rrule
∂self, ∂x = ∂x
@test ∂self === NO_FIELDS
end
@test isapprox(∂x, fdm(f, x); rtol=rtol, atol=atol, kwargs...)
end
end

function ensure_not_running_on_functor(f, name)
# if x itself is a Type, then it is a constructor, thus not a functor.
# This also catchs UnionAll constructors which have a `:var` and `:body` fields
f isa Type && return

if fieldcount(typeof(f)) > 0
throw(ArgumentError(
"$name cannot be used on closures/functors (such as $f)"
))
end
end

"""
frule_test(f, (x, ẋ)...; rtol=1e-9, atol=1e-9, fdm=central_fdm(5, 1), kwargs...)
# Arguments
- `f`: Function for which the `frule` should be tested.
- `x`: input at which to evaluate `f` (should generally be set to an arbitary point in the domain).
- `ẋ`: differential w.r.t. `x` (should generally be set randomly).
All keyword arguments except for `fdm` are passed to `isapprox`.
"""
function frule_test(f, (x, ẋ); rtol=1e-9, atol=1e-9, fdm=_fdm, kwargs...)
return frule_test(f, ((x, ẋ),); rtol=rtol, atol=atol, fdm=fdm, kwargs...)
end

function frule_test(f, xẋs::Tuple{Any, Any}...; rtol=1e-9, atol=1e-9, fdm=_fdm, kwargs...)
ensure_not_running_on_functor(f, "frule_test")
xs, ẋs = collect(zip(xẋs...))
Ω, dΩ_ad = frule(f, xs..., NO_FIELDS, ẋs...)
@test f(xs...) == Ω

# Correctness testing via finite differencing.
dΩ_fd = jvp(fdm, xs->f(xs...), (xs, ẋs))
@test isapprox(
collect(extern.(dΩ_ad)), # Use collect so can use vector equality
collect(dΩ_fd);
rtol=rtol,
atol=atol,
kwargs...
)
end


"""
rrule_test(f, ȳ, (x, x̄)...; rtol=1e-9, atol=1e-9, fdm=central_fdm(5, 1), kwargs...)
# Arguments
- `f`: Function to which rule should be applied.
- `ȳ`: adjoint w.r.t. output of `f` (should generally be set randomly).
Should be same structure as `f(x)` (so if multiple returns should be a tuple)
- `x`: input at which to evaluate `f` (should generally be set to an arbitary point in the domain).
- `x̄`: currently accumulated adjoint (should generally be set randomly).
All keyword arguments except for `fdm` are passed to `isapprox`.
"""
function rrule_test(f, ȳ, (x, x̄)::Tuple{Any, Any}; rtol=1e-9, atol=1e-9, fdm=_fdm, kwargs...)
ensure_not_running_on_functor(f, "rrule_test")

# Check correctness of evaluation.
fx, pullback = rrule(f, x)
@test collect(fx) collect(f(x)) # use collect so can do vector equality
(∂self, x̄_ad) = if fx isa Tuple
# If the function returned multiple values,
# then it must have multiple seeds for propagating backwards
pullback(ȳ...)
else
pullback(ȳ)
end

@test ∂self === NO_FIELDS # No internal fields
# Correctness testing via finite differencing.
x̄_fd = j′vp(fdm, f, ȳ, x)
@test isapprox(x̄_ad, x̄_fd; rtol=rtol, atol=atol, kwargs...)
end

# case where `f` takes multiple arguments
function rrule_test(f, ȳ, xx̄s::Tuple{Any, Any}...; rtol=1e-9, atol=1e-9, fdm=_fdm, kwargs...)
ensure_not_running_on_functor(f, "rrule_test")

# Check correctness of evaluation.
xs, x̄s = collect(zip(xx̄s...))
y, pullback = rrule(f, xs...)
@test f(xs...) == y

@assert !(isa(ȳ, Thunk))
∂s = pullback(ȳ)
∂self = ∂s[1]
x̄s_ad = ∂s[2:end]
@test ∂self === NO_FIELDS

# Correctness testing via finite differencing.
x̄s_fd = j′vp(fdm, f, ȳ, xs...)
map(x̄s_ad, x̄s_fd) do x̄_ad, x̄_fd
@test isapprox(x̄_ad, x̄_fd; rtol=rtol, atol=atol, kwargs...)
end
end

end # module
23 changes: 23 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
using ChainRulesCore
using ChainRulesTestUtils
using Random
using Test

@testset "ChainRulesTestUtils.jl" begin
double(x) = 2x
@scalar_rule(double(x), 2)
test_scalar(double, 2)

fst(x, y) = x
ChainRulesCore.frule(::typeof(fst), x, y, _, dx, dy) = (x, dx)

function ChainRulesCore.rrule(::typeof(fst), x, y)
function fst_pullback(Δx)
return (NO_FIELDS, Δx, Zero())
end
return x, fst_pullback
end

frule_test(fst, (2, 4.0), (3, 5.0))
rrule_test(fst, rand(), (2.0, 4.0), (3.0, 5.0))
end

0 comments on commit 963922d

Please sign in to comment.