-
Notifications
You must be signed in to change notification settings - Fork 6
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #6 from rossviljoen/ross/tests
Equivalence tests
- Loading branch information
Showing
5 changed files
with
162 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
[deps] | ||
AbstractGPs = "99985d1d-32ba-4be9-9821-2ec096f28918" | ||
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" | ||
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" | ||
IterTools = "c8e1da08-722c-5040-9ed9-7db0dc04731e" | ||
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" | ||
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" | ||
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,127 @@ | ||
@testset "equivalences" begin | ||
rng, N = MersenneTwister(654321), 20 | ||
x = rand(rng, N) * 10 | ||
y = sin.(x) + 0.9 * cos.(x * 1.6) + 0.4 * rand(rng, N) | ||
|
||
z = copy(x) # Set inducing inputs == training inputs | ||
|
||
make_kernel(k) = softplus(k[1]) * (SqExponentialKernel() ∘ ScaleTransform(softplus(k[2]))) | ||
|
||
k_init = [0.2, 0.6] # initial kernel parameters | ||
lik_noise = 0.1 # The (fixed) Gaussian likelihood noise | ||
|
||
@testset "exact posterior" begin | ||
# There is a closed form optimal solution for the variational posterior | ||
# q(u) (e.g. # https://krasserm.github.io/2020/12/12/gaussian-processes-sparse/ | ||
# equations (11) & (12)). The SVGP posterior with this optimal q(u) | ||
# should therefore be equivalent to the sparse GP (Titsias) posterior | ||
# and exact GP regression (when z == x). | ||
|
||
function exact_q(fu, fx, y) | ||
σ² = fx.Σy[1] | ||
Kuf = cov(fu, fx) | ||
Kuu = Symmetric(cov(fu)) | ||
Σ = (Symmetric(cov(fu) + (1/σ²) * Kuf * Kuf')) | ||
m = ((1/σ²)*Kuu* (Σ\Kuf)) * y | ||
S = Symmetric(Kuu * (Σ \ Kuu)) | ||
return MvNormal(m, S) | ||
end | ||
|
||
kernel = make_kernel(k_init) | ||
f = GP(kernel) | ||
fx = f(x, lik_noise) | ||
fu = f(z) | ||
q_ex = exact_q(fu, fx, y) | ||
|
||
gpr_post = AbstractGPs.posterior(fx, y) # Exact GP regression | ||
vfe_post = AbstractGPs.approx_posterior(VFE(), fx, y, fu) # Titsias posterior | ||
svgp_post = SparseGPs.approx_posterior(SVGP(), fu, q_ex) # Hensman (2013) exact posterior | ||
|
||
@test mean(gpr_post, x) ≈ mean(svgp_post, x) atol=1e-10 | ||
@test cov(gpr_post, x) ≈ cov(svgp_post, x) atol=1e-10 | ||
|
||
@test mean(vfe_post, x) ≈ mean(svgp_post, x) atol=1e-10 | ||
@test cov(vfe_post, x) ≈ cov(svgp_post, x) atol=1e-10 | ||
end | ||
|
||
@testset "optimised posterior" begin | ||
jitter = 1e-5 | ||
|
||
## FIRST - define the models | ||
# GPR - Exact GP regression | ||
struct GPRModel | ||
k # kernel parameters | ||
end | ||
@Flux.functor GPRModel | ||
|
||
function (m::GPRModel)(x) | ||
f = GP(make_kernel(m.k)) | ||
fx = f(x, lik_noise) | ||
return fx | ||
end | ||
|
||
# SVGP - Sparse variational GP regression (Hensman 2014) | ||
struct SVGPModel | ||
k # kernel parameters | ||
z # inducing points | ||
m # variational mean | ||
A # variational covariance sqrt (Σ = A'A) | ||
end | ||
@Flux.functor SVGPModel (k, m, A,) # Don't train the inducing inputs | ||
|
||
function (m::SVGPModel)(x) | ||
f = GP(make_kernel(m.k)) | ||
q = MvNormal(m.m, m.A'm.A) | ||
fx = f(x, lik_noise) | ||
fz = f(m.z, jitter) | ||
return fx, fz, q | ||
end | ||
|
||
## SECOND - create the models and associated training losses | ||
gpr = GPRModel(copy(k_init)) | ||
function GPR_loss(x, y) | ||
fx = gpr(x) | ||
return -logpdf(fx, y) | ||
end | ||
|
||
m, A = zeros(N), Matrix{Float64}(I, N, N) # initialise the variational parameters | ||
svgp = SVGPModel(copy(k_init), copy(z), m, A) | ||
function SVGP_loss(x, y) | ||
fx, fz, q = svgp(x) | ||
return -SparseGPs.elbo(fx, y, fz, q) | ||
end | ||
|
||
## THIRD - train the models | ||
data = [(x, y)] | ||
opt = ADAM(0.001) | ||
|
||
svgp_ps = Flux.params(svgp) | ||
delete!(svgp_ps, svgp.k) # Don't train the kernel parameters | ||
|
||
# Optimise q(u) | ||
Flux.train!((x, y) -> SVGP_loss(x, y), svgp_ps, ncycle(data, 20000), opt) | ||
|
||
## FOURTH - construct the posteriors | ||
function posterior(m::GPRModel, x, y) | ||
f = GP(make_kernel(m.k)) | ||
fx = f(x, lik_noise) | ||
return AbstractGPs.posterior(fx, y) | ||
end | ||
|
||
function posterior(m::SVGPModel) | ||
f = GP(make_kernel(m.k)) | ||
fz = f(m.z, jitter) | ||
q = MvNormal(m.m, m.A'm.A) | ||
return SparseGPs.approx_posterior(SVGP(), fz, q) | ||
end | ||
|
||
gpr_post = posterior(gpr, x, y) | ||
svgp_post = posterior(svgp) | ||
|
||
## FIFTH - test equivalences | ||
@test all(isapprox.(mean(gpr_post, x), mean(svgp_post, x), atol=1e-4)) | ||
@test all(isapprox.(cov(gpr_post, x), cov(svgp_post, x), atol=1e-4)) | ||
end | ||
|
||
end | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
using Random | ||
using Test | ||
using SparseGPs | ||
using Flux | ||
using IterTools | ||
using AbstractGPs | ||
using Distributions | ||
using LinearAlgebra | ||
|
||
const GROUP = get(ENV, "GROUP", "All") | ||
const PKGDIR = dirname(dirname(pathof(SparseGPs))) | ||
|
||
include("test_utils.jl") | ||
|
||
@testset "SparseGPs" begin | ||
include("svgp.jl") | ||
println(" ") | ||
@info "Ran svgp tests" | ||
|
||
include("equivalences.jl") | ||
println(" ") | ||
@info "Ran equivalences tests" | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
@testset "svgp" begin | ||
x = 4 | ||
@test x == 4 | ||
end |
Empty file.