-
Notifications
You must be signed in to change notification settings - Fork 6
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
SparseVariationalApproximation cleanup #86
Changes from 2 commits
c55ef11
fa516c3
4cc3dfa
96ee796
a186e7a
43ce052
8b49e75
cf0cece
90704a0
4fb5ea3
d075138
a05edf5
a4acb03
9250eb5
a6b67fa
34a3b7a
df308a3
55ece79
69fc92f
b6d9026
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,5 @@ | ||
using PDMats: chol_lower | ||
|
||
@doc raw""" | ||
Centered() | ||
|
||
|
@@ -56,7 +58,7 @@ end | |
SparseVariationalApproximation(fz::FiniteGP, q::AbstractMvNormal) | ||
|
||
Packages the prior over the pseudo-points `fz`, and the approximate posterior at the | ||
pseudo-points, which is `mean(fz) + cholesky(cov(fz)).U' * ε`, `ε ∼ q`. | ||
pseudo-points, which is `mean(fz) + cholesky(cov(fz)).L * ε`, `ε ∼ q`. | ||
|
||
Shorthand for | ||
```julia | ||
|
@@ -86,82 +88,26 @@ variational Gaussian process classification." Artificial Intelligence and | |
Statistics. PMLR, 2015. | ||
""" | ||
function AbstractGPs.posterior(sva::SparseVariationalApproximation{Centered}) | ||
# m* = K*u Kuu⁻¹ (mean(q) - mean(fz)) | ||
# = K*u α | ||
# α = Kuu⁻¹ (m - mean(fz)) | ||
# V** = K** - K*u (Kuu⁻¹ - Kuu⁻¹ cov(q) Kuu⁻¹) Ku* | ||
# = K** - K*u (Kuu⁻¹ - Kuu⁻¹ cov(q) Kuu⁻¹) Ku* | ||
# = K** - (K*u Lk⁻ᵀ) (Lk⁻¹ Ku*) + (K*u Lk⁻ᵀ) Lk⁻¹ cov(q) Lk⁻ᵀ (Lk⁻¹ Ku*) | ||
# = K** - A'A + A' Lk⁻¹ cov(q) Lk⁻ᵀ A | ||
# = K** - A'A + A' Lk⁻¹ Lq Lqᵀ Lk⁻ᵀ A | ||
# = K** - A'A + A' B B' A | ||
# A = Lk⁻¹ Ku* | ||
# B = Lk⁻¹ Lq | ||
q, fz = sva.q, sva.fz | ||
m, S = mean(q), _chol_cov(q) | ||
Kuu = _chol_cov(fz) | ||
B = Kuu.L \ S.L | ||
B = chol_lower(Kuu) \ chol_lower(S) | ||
α = Kuu \ (m - mean(fz)) | ||
data = (S=S, m=m, Kuu=Kuu, B=B, α=α) | ||
data = (Kuu=Kuu, B=B, α=α) | ||
return ApproxPosteriorGP(sva, fz.f, data) | ||
end | ||
|
||
function AbstractGPs.posterior( | ||
sva::SparseVariationalApproximation, fx::FiniteGP, ::AbstractVector{<:Real} | ||
) | ||
@assert sva.fz.f === fx.f | ||
return posterior(sva) | ||
end | ||
|
||
# | ||
# Various methods implementing the Internal AbstractGPs API. | ||
# See AbstractGPs.jl API docs for more info. | ||
# | ||
|
||
function Statistics.mean( | ||
f::ApproxPosteriorGP{<:SparseVariationalApproximation{Centered}}, x::AbstractVector | ||
) | ||
return mean(f.prior, x) + cov(f.prior, x, inducing_points(f)) * f.data.α | ||
end | ||
|
||
function Statistics.cov( | ||
f::ApproxPosteriorGP{<:SparseVariationalApproximation{Centered}}, x::AbstractVector | ||
) | ||
Cux = cov(f.prior, inducing_points(f), x) | ||
D = f.data.Kuu.L \ Cux | ||
return cov(f.prior, x) - At_A(D) + At_A(f.data.B' * D) | ||
end | ||
|
||
function Statistics.var( | ||
f::ApproxPosteriorGP{<:SparseVariationalApproximation{Centered}}, x::AbstractVector | ||
) | ||
Cux = cov(f.prior, inducing_points(f), x) | ||
D = f.data.Kuu.L \ Cux | ||
return var(f.prior, x) - diag_At_A(D) + diag_At_A(f.data.B' * D) | ||
end | ||
|
||
function Statistics.cov( | ||
f::ApproxPosteriorGP{<:SparseVariationalApproximation{Centered}}, | ||
x::AbstractVector, | ||
y::AbstractVector, | ||
) | ||
B = f.data.B | ||
Cxu = cov(f.prior, x, inducing_points(f)) | ||
Cuy = cov(f.prior, inducing_points(f), y) | ||
D = f.data.Kuu.L \ Cuy | ||
E = Cxu / f.data.Kuu.L' | ||
return cov(f.prior, x, y) - (E * D) + (E * B * B' * D) | ||
end | ||
|
||
function StatsBase.mean_and_cov( | ||
f::ApproxPosteriorGP{<:SparseVariationalApproximation{Centered}}, x::AbstractVector | ||
) | ||
Cux = cov(f.prior, inducing_points(f), x) | ||
D = f.data.Kuu.L \ Cux | ||
μ = Cux' * f.data.α | ||
Σ = cov(f.prior, x) - At_A(D) + At_A(f.data.B' * D) | ||
return μ, Σ | ||
end | ||
|
||
function StatsBase.mean_and_var( | ||
f::ApproxPosteriorGP{<:SparseVariationalApproximation{Centered}}, x::AbstractVector | ||
) | ||
Cux = cov(f.prior, inducing_points(f), x) | ||
D = f.data.Kuu.L \ Cux | ||
μ = Cux' * f.data.α | ||
Σ_diag = var(f.prior, x) - diag_At_A(D) + diag_At_A(f.data.B' * D) | ||
return μ, Σ_diag | ||
end | ||
|
||
# | ||
# NonCentered Parametrization. | ||
# | ||
|
@@ -172,7 +118,7 @@ end | |
Compute the approximate posterior [1] over the process `f = | ||
sva.fz.f`, given inducing inputs `z = sva.fz.x` and a variational | ||
distribution over inducing points `sva.q` (which represents ``q(ε)`` | ||
where `ε = cholesky(cov(fz)).U' \ (f(z) - mean(f(z)))`). The approximate posterior at test | ||
where `ε = cholesky(cov(fz)).L \ (f(z) - mean(f(z)))`). The approximate posterior at test | ||
points ``x^*`` where ``f^* = f(x^*)`` is then given by: | ||
|
||
```math | ||
|
@@ -184,70 +130,107 @@ which can be found in closed form. | |
variational Gaussian process classification." Artificial Intelligence and | ||
Statistics. PMLR, 2015. | ||
""" | ||
function AbstractGPs.posterior(approx::SparseVariationalApproximation{NonCentered}) | ||
fz = approx.fz | ||
data = (Cuu=_chol_cov(fz), C_ε=_chol_cov(approx.q)) | ||
return ApproxPosteriorGP(approx, fz.f, data) | ||
function AbstractGPs.posterior(sva::SparseVariationalApproximation{NonCentered}) | ||
# u = Lk v + mean(fz), v ~ q | ||
# m* = K*u Kuu⁻¹ Lk (mean(u) - mean(fz)) | ||
# = K*u (Lk Lkᵀ)⁻¹ Lk mean(q) | ||
# = K*u Lk⁻ᵀ Lk⁻¹ Lk mean(q) | ||
# = K*u Lk⁻ᵀ mean(q) | ||
# = K*u α | ||
# NonCentered: α = Lk⁻ᵀ m | ||
# Centered: α = Kuu⁻¹ (m - mean(fz)) | ||
# V** = K** - K*u (Kuu⁻¹ - Kuu⁻¹ Lk cov(q) Lkᵀ Kuu⁻¹) Ku* | ||
# = K** - K*u (Kuu⁻¹ - (Lk Lkᵀ)⁻¹ Lk cov(q) Lkᵀ (Lk Lkᵀ)⁻¹) Ku* | ||
# = K** - K*u (Kuu⁻¹ - Lk⁻ᵀ Lk⁻¹ Lk cov(q) Lkᵀ Lk⁻ᵀ Lk⁻¹) Ku* | ||
# = K** - K*u (Kuu⁻¹ - Lk⁻ᵀ cov(q) Lk⁻¹) Ku* | ||
# = K** - (K*u Lk⁻ᵀ) (Lk⁻¹ Ku*) - (K*u Lk⁻ᵀ) Lq Lqᵀ (Lk⁻¹ Ku*) | ||
# = K** - A'A - (K*u Lk⁻ᵀ) Lq Lqᵀ (Lk⁻¹ Ku*) | ||
# A = Lk⁻¹ Ku* | ||
# NonCentered: B = Lq | ||
# Centered: B = Lk⁻¹ Lq | ||
q, fz = sva.q, sva.fz | ||
m = mean(q) | ||
Kuu = _chol_cov(fz) | ||
α = chol_lower(Kuu) \ m | ||
Sv = _chol_cov(q) | ||
B = chol_lower(Sv) | ||
data = (Kuu=Kuu, B=B, α=α) | ||
return ApproxPosteriorGP(sva, fz.f, data) | ||
end | ||
|
||
function AbstractGPs.posterior( | ||
sva::SparseVariationalApproximation, fx::FiniteGP, ::AbstractVector{<:Real} | ||
) | ||
@assert sva.fz.f === fx.f | ||
return posterior(sva) | ||
end | ||
|
||
# | ||
# Various methods implementing the Internal AbstractGPs API. | ||
# See AbstractGPs.jl API docs for more info. | ||
# | ||
|
||
function Statistics.mean( | ||
f::ApproxPosteriorGP{<:SparseVariationalApproximation}, x::AbstractVector | ||
) | ||
return mean(f.prior, x) + cov(f.prior, x, inducing_points(f)) * f.data.α | ||
end | ||
|
||
# Produces a matrix that is consistently referred to as A in this file. A more descriptive | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe update this to reflect the new functionality? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||
# name is, unfortunately, not obvious. It's just an intermediate quantity that happens to | ||
# get used a lot. | ||
_A(f, x) = f.data.Cuu.U' \ cov(f.prior, inducing_points(f), x) | ||
|
||
function Statistics.mean( | ||
f::ApproxPosteriorGP{<:SparseVariationalApproximation{NonCentered}}, x::AbstractVector | ||
) | ||
return mean(f.prior, x) + _A(f, x)' * mean(f.approx.q) | ||
function _A_and_Cux(f, x) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. At the risk of being a massive pedant, could we please rename |
||
Cux = cov(f.prior, inducing_points(f), x) | ||
A = chol_lower(f.data.Kuu) \ Cux | ||
return A, Cux | ||
end | ||
|
||
_A(f, x) = first(_A_and_Cux(f, x)) | ||
|
||
function Statistics.cov( | ||
f::ApproxPosteriorGP{<:SparseVariationalApproximation{NonCentered}}, x::AbstractVector | ||
f::ApproxPosteriorGP{<:SparseVariationalApproximation}, x::AbstractVector | ||
) | ||
A = _A(f, x) | ||
return cov(f.prior, x) - At_A(A) + Xt_A_X(f.data.C_ε, A) | ||
return cov(f.prior, x) - At_A(A) + At_A(f.data.B' * A) | ||
end | ||
|
||
function Statistics.var( | ||
f::ApproxPosteriorGP{<:SparseVariationalApproximation{NonCentered}}, x::AbstractVector | ||
f::ApproxPosteriorGP{<:SparseVariationalApproximation}, x::AbstractVector | ||
) | ||
A = _A(f, x) | ||
return var(f.prior, x) - diag_At_A(A) + diag_Xt_A_X(f.data.C_ε, A) | ||
return var(f.prior, x) - diag_At_A(A) + diag_At_A(f.data.B' * A) | ||
end | ||
|
||
function Statistics.cov( | ||
f::ApproxPosteriorGP{<:SparseVariationalApproximation{NonCentered}}, | ||
f::ApproxPosteriorGP{<:SparseVariationalApproximation}, | ||
x::AbstractVector, | ||
y::AbstractVector, | ||
) | ||
B = f.data.B | ||
Ax = _A(f, x) | ||
Ay = _A(f, y) | ||
return cov(f.prior, x, y) - Ax'Ay + Xt_A_Y(Ax, f.data.C_ε, Ay) | ||
return cov(f.prior, x, y) - Ax'Ay + Xt_A_Y(Ax, B * B', Ay) | ||
end | ||
|
||
function StatsBase.mean_and_cov( | ||
f::ApproxPosteriorGP{<:SparseVariationalApproximation{NonCentered}}, x::AbstractVector | ||
f::ApproxPosteriorGP{<:SparseVariationalApproximation}, x::AbstractVector | ||
) | ||
A = _A(f, x) | ||
μ = mean(f.prior, x) + A' * mean(f.approx.q) | ||
Σ = cov(f.prior, x) - At_A(A) + Xt_A_X(f.data.C_ε, A) | ||
A, Cux = _A_and_Cux(f, x) | ||
μ = mean(f.prior, x) + Cux' * f.data.α | ||
Σ = cov(f.prior, x) - At_A(A) + At_A(f.data.B' * A) | ||
return μ, Σ | ||
end | ||
|
||
function StatsBase.mean_and_var( | ||
f::ApproxPosteriorGP{<:SparseVariationalApproximation{NonCentered}}, x::AbstractVector | ||
f::ApproxPosteriorGP{<:SparseVariationalApproximation}, x::AbstractVector | ||
) | ||
A = _A(f, x) | ||
μ = mean(f.prior, x) + A' * mean(f.approx.q) | ||
Σ = var(f.prior, x) - diag_At_A(A) + diag_Xt_A_X(f.data.C_ε, A) | ||
return μ, Σ | ||
A, Cux = _A_and_Cux(f, x) | ||
μ = mean(f.prior, x) + Cux' * f.data.α | ||
Σ_diag = var(f.prior, x) - diag_At_A(A) + diag_At_A(f.data.B' * A) | ||
return μ, Σ_diag | ||
end | ||
|
||
|
||
st-- marked this conversation as resolved.
Show resolved
Hide resolved
|
||
# | ||
# Misc utility. | ||
# | ||
|
@@ -344,12 +327,16 @@ function _elbo( | |
|
||
n_batch = length(y) | ||
scale = num_data / n_batch | ||
return sum(variational_exp) * scale - kl_term(sva, post) | ||
return sum(variational_exp) * scale - prior_kl(sva) | ||
end | ||
|
||
kl_term(sva::SparseVariationalApproximation{Centered}, post) = KL(sva.q, sva.fz) | ||
prior_kl(sva::SparseVariationalApproximation{Centered}) = KL(sva.q, sva.fz) | ||
|
||
function kl_term(sva::SparseVariationalApproximation{NonCentered}, post) | ||
function prior_kl(sva::SparseVariationalApproximation{NonCentered}) | ||
m_ε = mean(sva.q) | ||
return (tr(cov(sva.q)) + m_ε'm_ε - length(m_ε) - logdet(post.data.C_ε)) / 2 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This was unnecessarily making a copy (see JuliaStats/Distributions.jl#1373). |
||
C_ε = _cov(sva.q) | ||
# trace_term = tr(C_ε) # does not work due to PDMat / Zygote issues | ||
L = chol_lower(_chol_cov(sva.q)) | ||
trace_term = sum(L.^2) | ||
st-- marked this conversation as resolved.
Show resolved
Hide resolved
|
||
return (trace_term + m_ε'm_ε - length(m_ε) - logdet(C_ε)) / 2 | ||
end |
Original file line number | Diff line number | Diff line change | ||||||
---|---|---|---|---|---|---|---|---|
|
@@ -11,3 +11,5 @@ end | |||||||
|
||||||||
_chol_cov(q::AbstractMvNormal) = cholesky(Symmetric(cov(q))) | ||||||||
_chol_cov(q::MvNormal) = cholesky(q.Σ) | ||||||||
|
||||||||
_cov(q::MvNormal) = q.Σ | ||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hm, it's kinda a hack in the first place - so I think I'd rather have There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @st-- I don't get why it should error - shouldn't you just be able to use any |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This (and line 160 below in
mean_and_var
) were missing the prior mean.