Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add MvNormalMeanScalePrecision distribution #206

Open
wants to merge 24 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
97488a3
Add tests for MvNormalMeanScalePrecision
albertpod Aug 8, 2024
a590f7f
Add MvNormalMeanScalePrecision
albertpod Aug 8, 2024
d78da98
Fix distribution
albertpod Aug 8, 2024
45d0f59
Fix tests
albertpod Aug 9, 2024
9818bd3
Update structure and tests
albertpod Aug 9, 2024
22e611a
Add natural parameters related functions
albertpod Aug 12, 2024
3f467b8
Merge branch 'main' into dev_mvscalenormal
albertpod Aug 12, 2024
c9ad326
WIP: Parameters transforamtion
albertpod Aug 14, 2024
a0ca848
Add fisher information
albertpod Aug 15, 2024
8a37b2c
Add fisher tests
albertpod Aug 21, 2024
1260dd3
Add rand
albertpod Aug 21, 2024
d8b2370
Add MvNormalMeanScalePrecision to library.md
albertpod Aug 21, 2024
44e2ce6
test: add test exponentialfamily interface for MvNormalMeanScalePreci…
Nimrais Sep 20, 2024
49670f8
feat: add basic functions for MvNormalMeanScalePrecision
Nimrais Sep 20, 2024
1877a70
feat: draft MvNormalMeanScalePrecision
Nimrais Sep 20, 2024
118ccfd
fix: dimension match
Nimrais Sep 20, 2024
9d6159a
test: add check that samples are correct
Nimrais Sep 23, 2024
2fb5717
feat: implement getfisherinformation(::NaturalParametersSpace, ::Type…
Nimrais Sep 23, 2024
89a4932
feat: implement getfisherinformation(::NaturalParametersSpace, ::Type…
Nimrais Sep 23, 2024
e319963
fix: correct getfisherinformation(::MeanParametersSpace, ::Type{MvNor…
Nimrais Sep 24, 2024
77f4a0d
test: use test_exponentialfamily_interface and add MvNormalMeanScaleP…
Nimrais Sep 24, 2024
0b87569
Delete test/repopack-output.txt
Nimrais Sep 25, 2024
575c4af
Update test/distributions/normal_family/mv_normal_mean_scale_precisio…
Nimrais Sep 26, 2024
2178b10
test(fix): typo in @allocated cholinv(fi_small)
Nimrais Sep 26, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ version = "1.5.1"

[deps]
BayesBase = "b4ee3484-f114-42fe-b91c-797d54a0c67e"
BlockArrays = "8e7c35d0-a365-5155-bbbb-fb81a777f24e"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
DomainSets = "5b8099bc-c8ec-5219-889f-1d9e522a28bf"
FastCholesky = "2d5283b6-8564-42b6-bb00-83ed8e915756"
Expand All @@ -29,6 +30,7 @@ TinyHugeNumbers = "783c9a47-75a3-44ac-a16b-f1ab7b3acf04"
[compat]
Aqua = "0.8.7"
BayesBase = "1.2"
BlockArrays = "1.1.1"
Distributions = "0.25"
DomainSets = "0.5.2, 0.6, 0.7"
FastCholesky = "1.0"
Expand Down Expand Up @@ -57,8 +59,8 @@ Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
CpuId = "adafc99b-e345-5852-983c-f28acb93d879"
JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd"
ReTestItems = "817f1d60-ba6b-4fd5-9520-3cf149f6a823"
RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd"
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

Expand Down
1 change: 1 addition & 0 deletions docs/src/library.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ ExponentialFamily.NormalWeightedMeanPrecision
ExponentialFamily.MvNormalMeanPrecision
ExponentialFamily.MvNormalMeanCovariance
ExponentialFamily.MvNormalWeightedMeanPrecision
ExponentialFamily.MvNormalMeanScalePrecision
ExponentialFamily.JointNormal
ExponentialFamily.JointGaussian
ExponentialFamily.WishartFast
Expand Down
1 change: 1 addition & 0 deletions src/ExponentialFamily.jl
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ include("distributions/normal_family/mv_normal_mean_covariance.jl")
include("distributions/normal_family/mv_normal_mean_precision.jl")
include("distributions/normal_family/mv_normal_weighted_mean_precision.jl")
include("distributions/normal_family/normal_family.jl")
include("distributions/normal_family/mv_normal_mean_scale_precision.jl")
include("distributions/gamma_inverse.jl")
include("distributions/geometric.jl")
include("distributions/matrix_dirichlet.jl")
Expand Down
288 changes: 288 additions & 0 deletions src/distributions/normal_family/mv_normal_mean_scale_precision.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,288 @@
export MvNormalMeanScalePrecision, MvGaussianMeanScalePrecision

import Distributions: logdetcov, distrname, sqmahal, sqmahal!, AbstractMvNormal
import LinearAlgebra: diag, Diagonal, dot
import Base: ndims, precision, length, size, prod
import BlockArrays: Block, BlockArray, undef_blocks

"""
MvNormalMeanScalePrecision{T <: Real, M <: AbstractVector{T}} <: AbstractMvNormal

A multivariate normal distribution with mean `μ` and scale parameter `γ` that scales the identity precision matrix.

# Type Parameters
- `T`: The element type of the mean vector and scale parameter
- `M`: The type of the mean vector, which must be a subtype of `AbstractVector{T}`

# Fields
- `μ::M`: The mean vector of the multivariate normal distribution
- `γ::T`: The scale parameter that scales the identity precision matrix

# Notes
The precision matrix of this distribution is `γ * I`, where `I` is the identity matrix.
The covariance matrix is the inverse of the precision matrix, i.e., `(1/γ) * I`.
"""
struct MvNormalMeanScalePrecision{T <: Real, M <: AbstractVector{T}} <: AbstractMvNormal
μ::M
γ::T
end

const MvGaussianMeanScalePrecision = MvNormalMeanScalePrecision

function MvNormalMeanScalePrecision(μ::AbstractVector{<:Real}, γ::Real)
T = promote_type(eltype(μ), eltype(γ))
return MvNormalMeanScalePrecision(convert(AbstractArray{T}, μ), convert(T, γ))
end

function MvNormalMeanScalePrecision(μ::AbstractVector{<:Integer}, γ::Real)
return MvNormalMeanScalePrecision(float.(μ), float(γ))
end

function MvNormalMeanScalePrecision(μ::AbstractVector{T}) where {T}
return MvNormalMeanScalePrecision(μ, convert(T, 1))
end

function MvNormalMeanScalePrecision(μ::AbstractVector{T1}, γ::T2) where {T1, T2}
T = promote_type(T1, T2)
μ_new = convert(AbstractArray{T}, μ)
γ_new = convert(T, γ)(length(μ))
return MvNormalMeanScalePrecision(μ_new, γ_new)
end

function unpack_parameters(::Type{MvNormalMeanScalePrecision}, packed)
p₁ = view(packed, 1:length(packed)-1)
p₂ = packed[end]

return (p₁, p₂)
end

function isproper(::NaturalParametersSpace, ::Type{MvNormalMeanScalePrecision}, η, conditioner)
k = length(η) - 1
if length(η) < 2 || (length(η) !== k + 1)
return false
end
(η₁, η₂) = unpack_parameters(MvNormalMeanScalePrecision, η)
return isnothing(conditioner) && isone(size(η₂, 1)) && isposdef(-η₂)
end

function (::MeanToNatural{MvNormalMeanScalePrecision})(tuple_of_θ::Tuple{Any, Any})
(μ, γ) = tuple_of_θ
return (γ * μ, - γ / 2)
end

function (::NaturalToMean{MvNormalMeanScalePrecision})(tuple_of_η::Tuple{Any, Any})
(η₁, η₂) = tuple_of_η
γ = -2 * η₂
return (η₁ / γ, γ)
end

function nabs2(x)
return sum(map(abs2, x))
end

getsufficientstatistics(::Type{MvNormalMeanScalePrecision}) = (identity, nabs2)

# Conversions
function Base.convert(
::Type{MvNormal{T, C, M}},
dist::MvNormalMeanScalePrecision
) where {T <: Real, C <: Distributions.PDMats.PDMat{T, Matrix{T}}, M <: AbstractVector{T}}
m, σ = mean(dist), std(dist)
return MvNormal(convert(M, m), convert(T, σ))
end

function Base.convert(
::Type{MvNormalMeanScalePrecision{T, M}},
dist::MvNormalMeanScalePrecision
) where {T <: Real, M <: AbstractArray{T}}
m, γ = mean(dist), dist.γ
return MvNormalMeanScalePrecision{T, M}(convert(M, m), convert(T, γ))
end

function Base.convert(
::Type{MvNormalMeanScalePrecision{T}},
dist::MvNormalMeanScalePrecision
) where {T <: Real}
return convert(MvNormalMeanScalePrecision{T, AbstractArray{T, 1}}, dist)
end

function Base.convert(::Type{MvNormalMeanCovariance}, dist::MvNormalMeanScalePrecision)
m, σ = mean(dist), cov(dist)
return MvNormalMeanCovariance(m, σ * diagm(ones(length(m))))
end

function Base.convert(::Type{MvNormalMeanPrecision}, dist::MvNormalMeanScalePrecision)
m, γ = mean(dist), precision(dist)
return MvNormalMeanPrecision(m, γ * diagm(ones(length(m))))
end

function Base.convert(::Type{MvNormalWeightedMeanPrecision}, dist::MvNormalMeanScalePrecision)
m, γ = mean(dist), precision(dist)
return MvNormalWeightedMeanPrecision(γ * m, γ * diagm(ones(length(m))))
end

Distributions.distrname(::MvNormalMeanScalePrecision) = "MvNormalMeanScalePrecision"

BayesBase.weightedmean(dist::MvNormalMeanScalePrecision) = precision(dist) * mean(dist)

BayesBase.mean(dist::MvNormalMeanScalePrecision) = dist.μ
BayesBase.mode(dist::MvNormalMeanScalePrecision) = mean(dist)
BayesBase.var(dist::MvNormalMeanScalePrecision) = diag(cov(dist))
BayesBase.cov(dist::MvNormalMeanScalePrecision) = cholinv(invcov(dist))
BayesBase.invcov(dist::MvNormalMeanScalePrecision) = scale(dist) * I(length(mean(dist)))
BayesBase.std(dist::MvNormalMeanScalePrecision) = cholsqrt(cov(dist))
BayesBase.logdetcov(dist::MvNormalMeanScalePrecision) = -chollogdet(invcov(dist))
BayesBase.scale(dist::MvNormalMeanScalePrecision) = dist.γ
BayesBase.params(dist::MvNormalMeanScalePrecision) = (mean(dist), scale(dist))

function Distributions.sqmahal(dist::MvNormalMeanScalePrecision, x::AbstractVector)
T = promote_type(eltype(x), paramfloattype(dist))
return sqmahal!(similar(x, T), dist, x)
end

function Distributions.sqmahal!(r, dist::MvNormalMeanScalePrecision, x::AbstractVector)
Nimrais marked this conversation as resolved.
Show resolved Hide resolved
μ, γ = params(dist)
@inbounds @simd for i in 1:length(r)
r[i] = μ[i] - x[i]
end
return dot3arg(r, γ, r) # x' * A * x
end

Base.eltype(::MvNormalMeanScalePrecision{T}) where {T} = T
Base.precision(dist::MvNormalMeanScalePrecision) = invcov(dist)
Base.length(dist::MvNormalMeanScalePrecision) = length(mean(dist))
Base.ndims(dist::MvNormalMeanScalePrecision) = length(dist)
Base.size(dist::MvNormalMeanScalePrecision) = (length(dist),)

Base.convert(::Type{<:MvNormalMeanScalePrecision}, μ::AbstractVector, γ::Real) = MvNormalMeanScalePrecision(μ, γ)

function Base.convert(::Type{<:MvNormalMeanScalePrecision{T}}, μ::AbstractVector, γ::Real) where {T <: Real}
MvNormalMeanScalePrecision(convert(AbstractArray{T}, μ), convert(T, γ))
end

BayesBase.vague(::Type{<:MvNormalMeanScalePrecision}, dims::Int) =
MvNormalMeanScalePrecision(zeros(Float64, dims), convert(Float64, tiny))

BayesBase.default_prod_rule(::Type{<:MvNormalMeanScalePrecision}, ::Type{<:MvNormalMeanScalePrecision}) = PreserveTypeProd(Distribution)

function BayesBase.prod(::PreserveTypeProd{Distribution}, left::MvNormalMeanScalePrecision, right::MvNormalMeanScalePrecision)
w = scale(left) + scale(right)
m = (scale(left) * mean(left) + scale(right) * mean(right)) / w
return MvNormalMeanScalePrecision(m, w)
end

BayesBase.default_prod_rule(::Type{<:MultivariateNormalDistributionsFamily}, ::Type{<:MvNormalMeanScalePrecision}) = PreserveTypeProd(Distribution)

function BayesBase.prod(
::PreserveTypeProd{Distribution},
left::L,
right::R
) where {L <: MultivariateNormalDistributionsFamily, R <: MvNormalMeanScalePrecision}
wleft = convert(MvNormalWeightedMeanPrecision, left)
wright = convert(MvNormalWeightedMeanPrecision, right)
return prod(BayesBase.default_prod_rule(wleft, wright), wleft, wright)
end

function BayesBase.rand(rng::AbstractRNG, dist::MvGaussianMeanScalePrecision{T}) where {T}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

function BayesBase.rand(rng::AbstractRNG, dist::MvGaussianMeanScalePrecision{T}) where {T}
μ, γ = mean(dist), scale(dist)
return μ .+ (1 / γ) .* randn(rng, T, length(μ))
end

Avoid constructing the identity matrix I(length(μ)) and directly scale the random vector.
Use broadcasting with ., which is more efficient and avoids unnecessary allocations.

μ, γ = mean(dist), scale(dist)
return μ + 1 / γ .* randn(rng, T, length(μ))
end

function BayesBase.rand(rng::AbstractRNG, dist::MvGaussianMeanScalePrecision{T}, size::Int64) where {T}
container = Matrix{T}(undef, length(dist), size)
return rand!(rng, dist, container)
end

# FIXME: This is not the most efficient way to generate random samples within container
# it needs to work with scale method, not with std
function BayesBase.rand!(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similiarly to rand

function BayesBase.rand!(rng::AbstractRNG, dist::MvGaussianMeanScalePrecision, container::AbstractArray{T}) where {T <: Real}
    μ, γ = mean(dist), scale(dist)
    randn!(rng, container)
    @. container = μ + (1 / γ) * container
    return container
end

Btw I think rand just need to re-use rand!

rng::AbstractRNG,
dist::MvGaussianMeanScalePrecision,
container::AbstractArray{T}
) where {T <: Real}
preallocated = similar(container)
randn!(rng, reshape(preallocated, length(preallocated)))
μ, L = mean_std(dist)
@views for i in axes(preallocated, 2)
copyto!(container[:, i], μ)
mul!(container[:, i], L, preallocated[:, i], 1, 1)
end
container
end

function getsupport(ef::ExponentialFamilyDistribution{MvNormalMeanScalePrecision})
dim = length(getnaturalparameters(ef)) - 1
return Domain(IndicatorFunction{AbstractVector}(MvNormalDomainIndicator(dim)))
end

isbasemeasureconstant(::Type{MvNormalMeanScalePrecision}) = ConstantBaseMeasure()

getbasemeasure(::Type{MvNormalMeanScalePrecision}) = (x) -> (2π)^(-length(x) / 2)

getlogbasemeasure(::Type{MvNormalMeanScalePrecision}) = (x) -> -length(x) / 2 * log2π

getlogpartition(::NaturalParametersSpace, ::Type{MvNormalMeanScalePrecision}) =
(η) -> begin
η1 = @view η[1:end-1]
η2 = η[end]
k = length(η1)
Cinv = inv(η2)
return -dot(η1, 1/4*Cinv, η1) - (k / 2)*log(-2*η2)
end

getgradlogpartition(::NaturalParametersSpace, ::Type{MvNormalMeanScalePrecision}) =
(η) -> begin
η1 = @view η[1:end-1]
η2 = η[end]
inv2 = inv(η2)
k = length(η1)
return pack_parameters(MvNormalMeanCovariance, (-1/(2*η2) * η1, dot(η1,η1) / 4*inv2^2 - k/2 * inv2))
end

getfisherinformation(::NaturalParametersSpace, ::Type{MvNormalMeanScalePrecision}) =
(η) -> begin
η1 = @view η[1:end-1]
η2 = η[end]
k = length(η1)

η1_part = -inv(2*η2)* I(length(η1))
η1η2 = zeros(k, 1)
η1η2 .= η1*inv(2*η2^2)

η2_part = zeros(1, 1)
η2_part .= k*inv(2abs2(η2)) - dot(η1,η1) / (2*η2^3)
# inv(2abs2(η₂))-abs2(η₁)/(2(η₂^3))

fisher = BlockArray{eltype(η)}(undef_blocks, [k, 1], [k, 1])

fisher[Block(1), Block(1)] = η1_part
fisher[Block(1), Block(2)] = η1η2
fisher[Block(2), Block(1)] = η1η2'
fisher[Block(2), Block(2)] = η2_part
return fisher
end


getfisherinformation(::MeanParametersSpace, ::Type{MvNormalMeanScalePrecision}) =
(θ) -> begin
μ = @view θ[1:end-1]
γ = θ[end]
k = length(μ)

μ_part = γ * I(k)

μγ_part = zeros(k, 1)
μγ_part .= 0

γ_part = zeros(1, 1)
γ_part .= k*inv(2abs2(γ))

fisher = BlockArray{eltype(θ)}(undef_blocks, [k, 1], [k, 1])

fisher[Block(1), Block(1)] = μ_part
fisher[Block(1), Block(2)] = μγ_part
fisher[Block(2), Block(1)] = μγ_part'
fisher[Block(2), Block(2)] = γ_part

return fisher
end
2 changes: 1 addition & 1 deletion test/distributions/distributions_setuptests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -557,4 +557,4 @@ function test_generic_simple_exponentialfamily_product(
end

return true
end
end
Loading
Loading