Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SparseVariationalApproximation cleanup #86

Merged
merged 20 commits into from
Jan 7, 2022
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
GPLikelihoods = "6031954c-0455-49d7-b3b9-3e1c99afaf40"
KLDivergences = "3c9cd921-3d3f-41e2-830c-e020174918cc"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
PDMats = "90014a1f-27ba-587c-ab20-58faa44d9150"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
Expand Down
187 changes: 87 additions & 100 deletions src/sparse_variational.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
using PDMats: chol_lower

@doc raw"""
Centered()

Expand Down Expand Up @@ -56,7 +58,7 @@ end
SparseVariationalApproximation(fz::FiniteGP, q::AbstractMvNormal)

Packages the prior over the pseudo-points `fz`, and the approximate posterior at the
pseudo-points, which is `mean(fz) + cholesky(cov(fz)).U' * ε`, `ε ∼ q`.
pseudo-points, which is `mean(fz) + cholesky(cov(fz)).L * ε`, `ε ∼ q`.

Shorthand for
```julia
Expand Down Expand Up @@ -86,82 +88,26 @@ variational Gaussian process classification." Artificial Intelligence and
Statistics. PMLR, 2015.
"""
function AbstractGPs.posterior(sva::SparseVariationalApproximation{Centered})
# m* = K*u Kuu⁻¹ (mean(q) - mean(fz))
# = K*u α
# α = Kuu⁻¹ (m - mean(fz))
# V** = K** - K*u (Kuu⁻¹ - Kuu⁻¹ cov(q) Kuu⁻¹) Ku*
# = K** - K*u (Kuu⁻¹ - Kuu⁻¹ cov(q) Kuu⁻¹) Ku*
# = K** - (K*u Lk⁻ᵀ) (Lk⁻¹ Ku*) + (K*u Lk⁻ᵀ) Lk⁻¹ cov(q) Lk⁻ᵀ (Lk⁻¹ Ku*)
# = K** - A'A + A' Lk⁻¹ cov(q) Lk⁻ᵀ A
# = K** - A'A + A' Lk⁻¹ Lq Lqᵀ Lk⁻ᵀ A
# = K** - A'A + A' B B' A
# A = Lk⁻¹ Ku*
# B = Lk⁻¹ Lq
q, fz = sva.q, sva.fz
m, S = mean(q), _chol_cov(q)
Kuu = _chol_cov(fz)
B = Kuu.L \ S.L
B = chol_lower(Kuu) \ chol_lower(S)
α = Kuu \ (m - mean(fz))
data = (S=S, m=m, Kuu=Kuu, B=B, α=α)
data = (Kuu=Kuu, B=B, α=α)
return ApproxPosteriorGP(sva, fz.f, data)
end

function AbstractGPs.posterior(
sva::SparseVariationalApproximation, fx::FiniteGP, ::AbstractVector{<:Real}
)
@assert sva.fz.f === fx.f
return posterior(sva)
end

#
# Various methods implementing the Internal AbstractGPs API.
# See AbstractGPs.jl API docs for more info.
#

function Statistics.mean(
f::ApproxPosteriorGP{<:SparseVariationalApproximation{Centered}}, x::AbstractVector
)
return mean(f.prior, x) + cov(f.prior, x, inducing_points(f)) * f.data.α
end

function Statistics.cov(
f::ApproxPosteriorGP{<:SparseVariationalApproximation{Centered}}, x::AbstractVector
)
Cux = cov(f.prior, inducing_points(f), x)
D = f.data.Kuu.L \ Cux
return cov(f.prior, x) - At_A(D) + At_A(f.data.B' * D)
end

function Statistics.var(
f::ApproxPosteriorGP{<:SparseVariationalApproximation{Centered}}, x::AbstractVector
)
Cux = cov(f.prior, inducing_points(f), x)
D = f.data.Kuu.L \ Cux
return var(f.prior, x) - diag_At_A(D) + diag_At_A(f.data.B' * D)
end

function Statistics.cov(
f::ApproxPosteriorGP{<:SparseVariationalApproximation{Centered}},
x::AbstractVector,
y::AbstractVector,
)
B = f.data.B
Cxu = cov(f.prior, x, inducing_points(f))
Cuy = cov(f.prior, inducing_points(f), y)
D = f.data.Kuu.L \ Cuy
E = Cxu / f.data.Kuu.L'
return cov(f.prior, x, y) - (E * D) + (E * B * B' * D)
end

function StatsBase.mean_and_cov(
f::ApproxPosteriorGP{<:SparseVariationalApproximation{Centered}}, x::AbstractVector
)
Cux = cov(f.prior, inducing_points(f), x)
D = f.data.Kuu.L \ Cux
μ = Cux' * f.data.α
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This (and line 160 below in mean_and_var) were missing the prior mean.

Σ = cov(f.prior, x) - At_A(D) + At_A(f.data.B' * D)
return μ, Σ
end

function StatsBase.mean_and_var(
f::ApproxPosteriorGP{<:SparseVariationalApproximation{Centered}}, x::AbstractVector
)
Cux = cov(f.prior, inducing_points(f), x)
D = f.data.Kuu.L \ Cux
μ = Cux' * f.data.α
Σ_diag = var(f.prior, x) - diag_At_A(D) + diag_At_A(f.data.B' * D)
return μ, Σ_diag
end

#
# NonCentered Parametrization.
#
Expand All @@ -172,7 +118,7 @@ end
Compute the approximate posterior [1] over the process `f =
sva.fz.f`, given inducing inputs `z = sva.fz.x` and a variational
distribution over inducing points `sva.q` (which represents ``q(ε)``
where `ε = cholesky(cov(fz)).U' \ (f(z) - mean(f(z)))`). The approximate posterior at test
where `ε = cholesky(cov(fz)).L \ (f(z) - mean(f(z)))`). The approximate posterior at test
points ``x^*`` where ``f^* = f(x^*)`` is then given by:

```math
Expand All @@ -184,70 +130,107 @@ which can be found in closed form.
variational Gaussian process classification." Artificial Intelligence and
Statistics. PMLR, 2015.
"""
function AbstractGPs.posterior(approx::SparseVariationalApproximation{NonCentered})
fz = approx.fz
data = (Cuu=_chol_cov(fz), C_ε=_chol_cov(approx.q))
return ApproxPosteriorGP(approx, fz.f, data)
function AbstractGPs.posterior(sva::SparseVariationalApproximation{NonCentered})
# u = Lk v + mean(fz), v ~ q
# m* = K*u Kuu⁻¹ Lk (mean(u) - mean(fz))
# = K*u (Lk Lkᵀ)⁻¹ Lk mean(q)
# = K*u Lk⁻ᵀ Lk⁻¹ Lk mean(q)
# = K*u Lk⁻ᵀ mean(q)
# = K*u α
# NonCentered: α = Lk⁻ᵀ m
# Centered: α = Kuu⁻¹ (m - mean(fz))
# V** = K** - K*u (Kuu⁻¹ - Kuu⁻¹ Lk cov(q) Lkᵀ Kuu⁻¹) Ku*
# = K** - K*u (Kuu⁻¹ - (Lk Lkᵀ)⁻¹ Lk cov(q) Lkᵀ (Lk Lkᵀ)⁻¹) Ku*
# = K** - K*u (Kuu⁻¹ - Lk⁻ᵀ Lk⁻¹ Lk cov(q) Lkᵀ Lk⁻ᵀ Lk⁻¹) Ku*
# = K** - K*u (Kuu⁻¹ - Lk⁻ᵀ cov(q) Lk⁻¹) Ku*
# = K** - (K*u Lk⁻ᵀ) (Lk⁻¹ Ku*) - (K*u Lk⁻ᵀ) Lq Lqᵀ (Lk⁻¹ Ku*)
# = K** - A'A - (K*u Lk⁻ᵀ) Lq Lqᵀ (Lk⁻¹ Ku*)
# A = Lk⁻¹ Ku*
# NonCentered: B = Lq
# Centered: B = Lk⁻¹ Lq
q, fz = sva.q, sva.fz
m = mean(q)
Kuu = _chol_cov(fz)
α = chol_lower(Kuu) \ m
Sv = _chol_cov(q)
B = chol_lower(Sv)
data = (Kuu=Kuu, B=B, α=α)
return ApproxPosteriorGP(sva, fz.f, data)
end

function AbstractGPs.posterior(
sva::SparseVariationalApproximation, fx::FiniteGP, ::AbstractVector{<:Real}
)
@assert sva.fz.f === fx.f
return posterior(sva)
end

#
# Various methods implementing the Internal AbstractGPs API.
# See AbstractGPs.jl API docs for more info.
#

function Statistics.mean(
f::ApproxPosteriorGP{<:SparseVariationalApproximation}, x::AbstractVector
)
return mean(f.prior, x) + cov(f.prior, x, inducing_points(f)) * f.data.α
end

# Produces a matrix that is consistently referred to as A in this file. A more descriptive
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe update this to reflect the new functionality?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

# name is, unfortunately, not obvious. It's just an intermediate quantity that happens to
# get used a lot.
_A(f, x) = f.data.Cuu.U' \ cov(f.prior, inducing_points(f), x)

function Statistics.mean(
f::ApproxPosteriorGP{<:SparseVariationalApproximation{NonCentered}}, x::AbstractVector
)
return mean(f.prior, x) + _A(f, x)' * mean(f.approx.q)
function _A_and_Cux(f, x)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

At the risk of being a massive pedant, could we please rename Cux to Czx, i.e. consistently referring to inputs rather than random variables?

Cux = cov(f.prior, inducing_points(f), x)
A = chol_lower(f.data.Kuu) \ Cux
return A, Cux
end

_A(f, x) = first(_A_and_Cux(f, x))

function Statistics.cov(
f::ApproxPosteriorGP{<:SparseVariationalApproximation{NonCentered}}, x::AbstractVector
f::ApproxPosteriorGP{<:SparseVariationalApproximation}, x::AbstractVector
)
A = _A(f, x)
return cov(f.prior, x) - At_A(A) + Xt_A_X(f.data.C_ε, A)
return cov(f.prior, x) - At_A(A) + At_A(f.data.B' * A)
end

function Statistics.var(
f::ApproxPosteriorGP{<:SparseVariationalApproximation{NonCentered}}, x::AbstractVector
f::ApproxPosteriorGP{<:SparseVariationalApproximation}, x::AbstractVector
)
A = _A(f, x)
return var(f.prior, x) - diag_At_A(A) + diag_Xt_A_X(f.data.C_ε, A)
return var(f.prior, x) - diag_At_A(A) + diag_At_A(f.data.B' * A)
end

function Statistics.cov(
f::ApproxPosteriorGP{<:SparseVariationalApproximation{NonCentered}},
f::ApproxPosteriorGP{<:SparseVariationalApproximation},
x::AbstractVector,
y::AbstractVector,
)
B = f.data.B
Ax = _A(f, x)
Ay = _A(f, y)
return cov(f.prior, x, y) - Ax'Ay + Xt_A_Y(Ax, f.data.C_ε, Ay)
return cov(f.prior, x, y) - Ax'Ay + Xt_A_Y(Ax, B * B', Ay)
end

function StatsBase.mean_and_cov(
f::ApproxPosteriorGP{<:SparseVariationalApproximation{NonCentered}}, x::AbstractVector
f::ApproxPosteriorGP{<:SparseVariationalApproximation}, x::AbstractVector
)
A = _A(f, x)
μ = mean(f.prior, x) + A' * mean(f.approx.q)
Σ = cov(f.prior, x) - At_A(A) + Xt_A_X(f.data.C_ε, A)
A, Cux = _A_and_Cux(f, x)
μ = mean(f.prior, x) + Cux' * f.data.α
Σ = cov(f.prior, x) - At_A(A) + At_A(f.data.B' * A)
return μ, Σ
end

function StatsBase.mean_and_var(
f::ApproxPosteriorGP{<:SparseVariationalApproximation{NonCentered}}, x::AbstractVector
f::ApproxPosteriorGP{<:SparseVariationalApproximation}, x::AbstractVector
)
A = _A(f, x)
μ = mean(f.prior, x) + A' * mean(f.approx.q)
Σ = var(f.prior, x) - diag_At_A(A) + diag_Xt_A_X(f.data.C_ε, A)
return μ, Σ
A, Cux = _A_and_Cux(f, x)
μ = mean(f.prior, x) + Cux' * f.data.α
Σ_diag = var(f.prior, x) - diag_At_A(A) + diag_At_A(f.data.B' * A)
return μ, Σ_diag
end


st-- marked this conversation as resolved.
Show resolved Hide resolved
#
# Misc utility.
#
Expand Down Expand Up @@ -344,12 +327,16 @@ function _elbo(

n_batch = length(y)
scale = num_data / n_batch
return sum(variational_exp) * scale - kl_term(sva, post)
return sum(variational_exp) * scale - prior_kl(sva)
end

kl_term(sva::SparseVariationalApproximation{Centered}, post) = KL(sva.q, sva.fz)
prior_kl(sva::SparseVariationalApproximation{Centered}) = KL(sva.q, sva.fz)

function kl_term(sva::SparseVariationalApproximation{NonCentered}, post)
function prior_kl(sva::SparseVariationalApproximation{NonCentered})
m_ε = mean(sva.q)
return (tr(cov(sva.q)) + m_ε'm_ε - length(m_ε) - logdet(post.data.C_ε)) / 2
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was unnecessarily making a copy (see JuliaStats/Distributions.jl#1373).

C_ε = _cov(sva.q)
# trace_term = tr(C_ε) # does not work due to PDMat / Zygote issues
L = chol_lower(_chol_cov(sva.q))
trace_term = sum(L.^2)
st-- marked this conversation as resolved.
Show resolved Hide resolved
return (trace_term + m_ε'm_ε - length(m_ε) - logdet(C_ε)) / 2
end
2 changes: 2 additions & 0 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,5 @@ end

_chol_cov(q::AbstractMvNormal) = cholesky(Symmetric(cov(q)))
_chol_cov(q::MvNormal) = cholesky(q.Σ)

_cov(q::MvNormal) = q.Σ
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
_cov(q::MvNormal) = q.Σ
_cov(q::AbstractMvNormal) = cov(q)
_cov(q::MvNormal) = q.Σ

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hm, it's kinda a hack in the first place - so I think I'd rather have _cov(q) error when called with anything else, and see why... no?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@st-- I don't get why it should error - shouldn't you just be able to use any AbstractMvNormal for sva.q?