diff --git a/Project.toml b/Project.toml index 0a063ae9..ffe0e999 100644 --- a/Project.toml +++ b/Project.toml @@ -5,6 +5,7 @@ version = "1.5.1" [deps] BayesBase = "b4ee3484-f114-42fe-b91c-797d54a0c67e" +BlockArrays = "8e7c35d0-a365-5155-bbbb-fb81a777f24e" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" DomainSets = "5b8099bc-c8ec-5219-889f-1d9e522a28bf" FastCholesky = "2d5283b6-8564-42b6-bb00-83ed8e915756" @@ -29,6 +30,7 @@ TinyHugeNumbers = "783c9a47-75a3-44ac-a16b-f1ab7b3acf04" [compat] Aqua = "0.8.7" BayesBase = "1.2" +BlockArrays = "1.1.1" Distributions = "0.25" DomainSets = "0.5.2, 0.6, 0.7" FastCholesky = "1.0" @@ -57,8 +59,8 @@ Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf" CpuId = "adafc99b-e345-5852-983c-f28acb93d879" JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b" -RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd" ReTestItems = "817f1d60-ba6b-4fd5-9520-3cf149f6a823" +RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd" StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" diff --git a/docs/src/library.md b/docs/src/library.md index f1786834..a271e018 100644 --- a/docs/src/library.md +++ b/docs/src/library.md @@ -14,6 +14,7 @@ ExponentialFamily.NormalWeightedMeanPrecision ExponentialFamily.MvNormalMeanPrecision ExponentialFamily.MvNormalMeanCovariance ExponentialFamily.MvNormalWeightedMeanPrecision +ExponentialFamily.MvNormalMeanScalePrecision ExponentialFamily.JointNormal ExponentialFamily.JointGaussian ExponentialFamily.WishartFast diff --git a/src/ExponentialFamily.jl b/src/ExponentialFamily.jl index 0366326e..f2ec4c1c 100644 --- a/src/ExponentialFamily.jl +++ b/src/ExponentialFamily.jl @@ -46,6 +46,7 @@ include("distributions/normal_family/mv_normal_mean_covariance.jl") include("distributions/normal_family/mv_normal_mean_precision.jl") include("distributions/normal_family/mv_normal_weighted_mean_precision.jl") include("distributions/normal_family/normal_family.jl") +include("distributions/normal_family/mv_normal_mean_scale_precision.jl") include("distributions/gamma_inverse.jl") include("distributions/geometric.jl") include("distributions/matrix_dirichlet.jl") diff --git a/src/distributions/normal_family/mv_normal_mean_scale_precision.jl b/src/distributions/normal_family/mv_normal_mean_scale_precision.jl new file mode 100644 index 00000000..44f74558 --- /dev/null +++ b/src/distributions/normal_family/mv_normal_mean_scale_precision.jl @@ -0,0 +1,288 @@ +export MvNormalMeanScalePrecision, MvGaussianMeanScalePrecision + +import Distributions: logdetcov, distrname, sqmahal, sqmahal!, AbstractMvNormal +import LinearAlgebra: diag, Diagonal, dot +import Base: ndims, precision, length, size, prod +import BlockArrays: Block, BlockArray, undef_blocks + +""" + MvNormalMeanScalePrecision{T <: Real, M <: AbstractVector{T}} <: AbstractMvNormal + +A multivariate normal distribution with mean `μ` and scale parameter `γ` that scales the identity precision matrix. + +# Type Parameters +- `T`: The element type of the mean vector and scale parameter +- `M`: The type of the mean vector, which must be a subtype of `AbstractVector{T}` + +# Fields +- `μ::M`: The mean vector of the multivariate normal distribution +- `γ::T`: The scale parameter that scales the identity precision matrix + +# Notes +The precision matrix of this distribution is `γ * I`, where `I` is the identity matrix. +The covariance matrix is the inverse of the precision matrix, i.e., `(1/γ) * I`. +""" +struct MvNormalMeanScalePrecision{T <: Real, M <: AbstractVector{T}} <: AbstractMvNormal + μ::M + γ::T +end + +const MvGaussianMeanScalePrecision = MvNormalMeanScalePrecision + +function MvNormalMeanScalePrecision(μ::AbstractVector{<:Real}, γ::Real) + T = promote_type(eltype(μ), eltype(γ)) + return MvNormalMeanScalePrecision(convert(AbstractArray{T}, μ), convert(T, γ)) +end + +function MvNormalMeanScalePrecision(μ::AbstractVector{<:Integer}, γ::Real) + return MvNormalMeanScalePrecision(float.(μ), float(γ)) +end + +function MvNormalMeanScalePrecision(μ::AbstractVector{T}) where {T} + return MvNormalMeanScalePrecision(μ, convert(T, 1)) +end + +function MvNormalMeanScalePrecision(μ::AbstractVector{T1}, γ::T2) where {T1, T2} + T = promote_type(T1, T2) + μ_new = convert(AbstractArray{T}, μ) + γ_new = convert(T, γ)(length(μ)) + return MvNormalMeanScalePrecision(μ_new, γ_new) +end + +function unpack_parameters(::Type{MvNormalMeanScalePrecision}, packed) + p₁ = view(packed, 1:length(packed)-1) + p₂ = packed[end] + + return (p₁, p₂) +end + +function isproper(::NaturalParametersSpace, ::Type{MvNormalMeanScalePrecision}, η, conditioner) + k = length(η) - 1 + if length(η) < 2 || (length(η) !== k + 1) + return false + end + (η₁, η₂) = unpack_parameters(MvNormalMeanScalePrecision, η) + return isnothing(conditioner) && isone(size(η₂, 1)) && isposdef(-η₂) +end + +function (::MeanToNatural{MvNormalMeanScalePrecision})(tuple_of_θ::Tuple{Any, Any}) + (μ, γ) = tuple_of_θ + return (γ * μ, - γ / 2) +end + +function (::NaturalToMean{MvNormalMeanScalePrecision})(tuple_of_η::Tuple{Any, Any}) + (η₁, η₂) = tuple_of_η + γ = -2 * η₂ + return (η₁ / γ, γ) +end + +function nabs2(x) + return sum(map(abs2, x)) +end + +getsufficientstatistics(::Type{MvNormalMeanScalePrecision}) = (identity, nabs2) + +# Conversions +function Base.convert( + ::Type{MvNormal{T, C, M}}, + dist::MvNormalMeanScalePrecision +) where {T <: Real, C <: Distributions.PDMats.PDMat{T, Matrix{T}}, M <: AbstractVector{T}} + m, σ = mean(dist), std(dist) + return MvNormal(convert(M, m), convert(T, σ)) +end + +function Base.convert( + ::Type{MvNormalMeanScalePrecision{T, M}}, + dist::MvNormalMeanScalePrecision +) where {T <: Real, M <: AbstractArray{T}} + m, γ = mean(dist), dist.γ + return MvNormalMeanScalePrecision{T, M}(convert(M, m), convert(T, γ)) +end + +function Base.convert( + ::Type{MvNormalMeanScalePrecision{T}}, + dist::MvNormalMeanScalePrecision +) where {T <: Real} + return convert(MvNormalMeanScalePrecision{T, AbstractArray{T, 1}}, dist) +end + +function Base.convert(::Type{MvNormalMeanCovariance}, dist::MvNormalMeanScalePrecision) + m, σ = mean(dist), cov(dist) + return MvNormalMeanCovariance(m, σ * diagm(ones(length(m)))) +end + +function Base.convert(::Type{MvNormalMeanPrecision}, dist::MvNormalMeanScalePrecision) + m, γ = mean(dist), precision(dist) + return MvNormalMeanPrecision(m, γ * diagm(ones(length(m)))) +end + +function Base.convert(::Type{MvNormalWeightedMeanPrecision}, dist::MvNormalMeanScalePrecision) + m, γ = mean(dist), precision(dist) + return MvNormalWeightedMeanPrecision(γ * m, γ * diagm(ones(length(m)))) +end + +Distributions.distrname(::MvNormalMeanScalePrecision) = "MvNormalMeanScalePrecision" + +BayesBase.weightedmean(dist::MvNormalMeanScalePrecision) = precision(dist) * mean(dist) + +BayesBase.mean(dist::MvNormalMeanScalePrecision) = dist.μ +BayesBase.mode(dist::MvNormalMeanScalePrecision) = mean(dist) +BayesBase.var(dist::MvNormalMeanScalePrecision) = diag(cov(dist)) +BayesBase.cov(dist::MvNormalMeanScalePrecision) = cholinv(invcov(dist)) +BayesBase.invcov(dist::MvNormalMeanScalePrecision) = scale(dist) * I(length(mean(dist))) +BayesBase.std(dist::MvNormalMeanScalePrecision) = cholsqrt(cov(dist)) +BayesBase.logdetcov(dist::MvNormalMeanScalePrecision) = -chollogdet(invcov(dist)) +BayesBase.scale(dist::MvNormalMeanScalePrecision) = dist.γ +BayesBase.params(dist::MvNormalMeanScalePrecision) = (mean(dist), scale(dist)) + +function Distributions.sqmahal(dist::MvNormalMeanScalePrecision, x::AbstractVector) + T = promote_type(eltype(x), paramfloattype(dist)) + return sqmahal!(similar(x, T), dist, x) +end + +function Distributions.sqmahal!(r, dist::MvNormalMeanScalePrecision, x::AbstractVector) + μ, γ = params(dist) + @inbounds @simd for i in 1:length(r) + r[i] = μ[i] - x[i] + end + return dot3arg(r, γ, r) # x' * A * x +end + +Base.eltype(::MvNormalMeanScalePrecision{T}) where {T} = T +Base.precision(dist::MvNormalMeanScalePrecision) = invcov(dist) +Base.length(dist::MvNormalMeanScalePrecision) = length(mean(dist)) +Base.ndims(dist::MvNormalMeanScalePrecision) = length(dist) +Base.size(dist::MvNormalMeanScalePrecision) = (length(dist),) + +Base.convert(::Type{<:MvNormalMeanScalePrecision}, μ::AbstractVector, γ::Real) = MvNormalMeanScalePrecision(μ, γ) + +function Base.convert(::Type{<:MvNormalMeanScalePrecision{T}}, μ::AbstractVector, γ::Real) where {T <: Real} + MvNormalMeanScalePrecision(convert(AbstractArray{T}, μ), convert(T, γ)) +end + +BayesBase.vague(::Type{<:MvNormalMeanScalePrecision}, dims::Int) = + MvNormalMeanScalePrecision(zeros(Float64, dims), convert(Float64, tiny)) + +BayesBase.default_prod_rule(::Type{<:MvNormalMeanScalePrecision}, ::Type{<:MvNormalMeanScalePrecision}) = PreserveTypeProd(Distribution) + +function BayesBase.prod(::PreserveTypeProd{Distribution}, left::MvNormalMeanScalePrecision, right::MvNormalMeanScalePrecision) + w = scale(left) + scale(right) + m = (scale(left) * mean(left) + scale(right) * mean(right)) / w + return MvNormalMeanScalePrecision(m, w) +end + +BayesBase.default_prod_rule(::Type{<:MultivariateNormalDistributionsFamily}, ::Type{<:MvNormalMeanScalePrecision}) = PreserveTypeProd(Distribution) + +function BayesBase.prod( + ::PreserveTypeProd{Distribution}, + left::L, + right::R +) where {L <: MultivariateNormalDistributionsFamily, R <: MvNormalMeanScalePrecision} + wleft = convert(MvNormalWeightedMeanPrecision, left) + wright = convert(MvNormalWeightedMeanPrecision, right) + return prod(BayesBase.default_prod_rule(wleft, wright), wleft, wright) +end + +function BayesBase.rand(rng::AbstractRNG, dist::MvGaussianMeanScalePrecision{T}) where {T} + μ, γ = mean(dist), scale(dist) + return μ + 1 / γ .* randn(rng, T, length(μ)) +end + +function BayesBase.rand(rng::AbstractRNG, dist::MvGaussianMeanScalePrecision{T}, size::Int64) where {T} + container = Matrix{T}(undef, length(dist), size) + return rand!(rng, dist, container) +end + +# FIXME: This is not the most efficient way to generate random samples within container +# it needs to work with scale method, not with std +function BayesBase.rand!( + rng::AbstractRNG, + dist::MvGaussianMeanScalePrecision, + container::AbstractArray{T} +) where {T <: Real} + preallocated = similar(container) + randn!(rng, reshape(preallocated, length(preallocated))) + μ, L = mean_std(dist) + @views for i in axes(preallocated, 2) + copyto!(container[:, i], μ) + mul!(container[:, i], L, preallocated[:, i], 1, 1) + end + container +end + +function getsupport(ef::ExponentialFamilyDistribution{MvNormalMeanScalePrecision}) + dim = length(getnaturalparameters(ef)) - 1 + return Domain(IndicatorFunction{AbstractVector}(MvNormalDomainIndicator(dim))) +end + +isbasemeasureconstant(::Type{MvNormalMeanScalePrecision}) = ConstantBaseMeasure() + +getbasemeasure(::Type{MvNormalMeanScalePrecision}) = (x) -> (2π)^(-length(x) / 2) + +getlogbasemeasure(::Type{MvNormalMeanScalePrecision}) = (x) -> -length(x) / 2 * log2π + +getlogpartition(::NaturalParametersSpace, ::Type{MvNormalMeanScalePrecision}) = + (η) -> begin + η1 = @view η[1:end-1] + η2 = η[end] + k = length(η1) + Cinv = inv(η2) + return -dot(η1, 1/4*Cinv, η1) - (k / 2)*log(-2*η2) + end + +getgradlogpartition(::NaturalParametersSpace, ::Type{MvNormalMeanScalePrecision}) = + (η) -> begin + η1 = @view η[1:end-1] + η2 = η[end] + inv2 = inv(η2) + k = length(η1) + return pack_parameters(MvNormalMeanCovariance, (-1/(2*η2) * η1, dot(η1,η1) / 4*inv2^2 - k/2 * inv2)) + end + +getfisherinformation(::NaturalParametersSpace, ::Type{MvNormalMeanScalePrecision}) = + (η) -> begin + η1 = @view η[1:end-1] + η2 = η[end] + k = length(η1) + + η1_part = -inv(2*η2)* I(length(η1)) + η1η2 = zeros(k, 1) + η1η2 .= η1*inv(2*η2^2) + + η2_part = zeros(1, 1) + η2_part .= k*inv(2abs2(η2)) - dot(η1,η1) / (2*η2^3) + # inv(2abs2(η₂))-abs2(η₁)/(2(η₂^3)) + + fisher = BlockArray{eltype(η)}(undef_blocks, [k, 1], [k, 1]) + + fisher[Block(1), Block(1)] = η1_part + fisher[Block(1), Block(2)] = η1η2 + fisher[Block(2), Block(1)] = η1η2' + fisher[Block(2), Block(2)] = η2_part + return fisher + end + + +getfisherinformation(::MeanParametersSpace, ::Type{MvNormalMeanScalePrecision}) = + (θ) -> begin + μ = @view θ[1:end-1] + γ = θ[end] + k = length(μ) + + μ_part = γ * I(k) + + μγ_part = zeros(k, 1) + μγ_part .= 0 + + γ_part = zeros(1, 1) + γ_part .= k*inv(2abs2(γ)) + + fisher = BlockArray{eltype(θ)}(undef_blocks, [k, 1], [k, 1]) + + fisher[Block(1), Block(1)] = μ_part + fisher[Block(1), Block(2)] = μγ_part + fisher[Block(2), Block(1)] = μγ_part' + fisher[Block(2), Block(2)] = γ_part + + return fisher + end diff --git a/test/distributions/distributions_setuptests.jl b/test/distributions/distributions_setuptests.jl index a6186e20..7292f8f7 100644 --- a/test/distributions/distributions_setuptests.jl +++ b/test/distributions/distributions_setuptests.jl @@ -557,4 +557,4 @@ function test_generic_simple_exponentialfamily_product( end return true -end +end \ No newline at end of file diff --git a/test/distributions/normal_family/mv_normal_mean_scale_precision_tests.jl b/test/distributions/normal_family/mv_normal_mean_scale_precision_tests.jl new file mode 100644 index 00000000..d8f43f9c --- /dev/null +++ b/test/distributions/normal_family/mv_normal_mean_scale_precision_tests.jl @@ -0,0 +1,225 @@ + +@testitem "MvNormalMeanScalePrecision: Constructor" begin + include("./normal_family_setuptests.jl") + + @test MvNormalMeanScalePrecision <: AbstractMvNormal + + @test MvNormalMeanScalePrecision([1.0, 1.0]) == MvNormalMeanScalePrecision([1.0, 1.0], 1.0) + @test MvNormalMeanScalePrecision([1.0, 2.0]) == MvNormalMeanScalePrecision([1.0, 2.0], 1.0) + @test MvNormalMeanScalePrecision([1, 2]) == MvNormalMeanScalePrecision([1.0, 2.0], 1.0) + @test MvNormalMeanScalePrecision([1.0f0, 2.0f0]) == MvNormalMeanScalePrecision([1.0f0, 2.0f0], 1.0f0) + + @test eltype(MvNormalMeanScalePrecision([1.0, 1.0])) === Float64 + @test eltype(MvNormalMeanScalePrecision([1.0, 1.0], 1.0)) === Float64 + @test eltype(MvNormalMeanScalePrecision([1, 1])) === Float64 + @test eltype(MvNormalMeanScalePrecision([1, 1], 1)) === Float64 + @test eltype(MvNormalMeanScalePrecision([1.0f0, 1.0f0])) === Float32 + @test eltype(MvNormalMeanScalePrecision([1.0f0, 1.0f0], 1.0f0)) === Float32 + + @test MvNormalMeanScalePrecision(ones(3), 5) == MvNormalMeanScalePrecision(ones(3), 5) + @test MvNormalMeanScalePrecision([1, 2, 3, 4], 7.0) == MvNormalMeanScalePrecision([1.0, 2.0, 3.0, 4.0], 7.0) +end + +@testitem "MvNormalMeanScalePrecision: distrname" begin + include("./normal_family_setuptests.jl") + + @test ExponentialFamily.distrname(MvNormalMeanScalePrecision(zeros(2))) === "MvNormalMeanScalePrecision" +end + +@testitem "MvNormalMeanScalePrecision: ExponentialFamilyDistribution" begin + include("../distributions_setuptests.jl") + + rng = StableRNG(42) + + for s in 1:6 + μ = randn(rng, s) + γ = rand(rng) + + @testset let d = MvNormalMeanScalePrecision(μ, γ) + ef = test_exponentialfamily_interface(d;) + end + end + + μ = randn(rng, 1) + γ = rand(rng) + + d = MvNormalMeanScalePrecision(μ, γ) + ef = convert(ExponentialFamilyDistribution, d) + + d1d = NormalMeanPrecision(μ[1], γ) + ef1d = convert(ExponentialFamilyDistribution, d1d) + + @test logpartition(ef) ≈ logpartition(ef1d) + @test gradlogpartition(ef) ≈ gradlogpartition(ef1d) + @test fisherinformation(ef) ≈ fisherinformation(ef1d) +end + +@testitem "MvNormalMeanScalePrecision: Stats methods" begin + include("./normal_family_setuptests.jl") + + μ = [0.2, 3.0, 4.0] + γ = 2.0 + dist = MvNormalMeanScalePrecision(μ, γ) + rdist = MvNormalMeanPrecision(μ, γ * ones(length(μ))) + + @test mean(dist) == μ + @test mode(dist) == μ + @test scale(dist) == γ + @test weightedmean(dist) == weightedmean(rdist) + @test invcov(dist) == invcov(rdist) + @test precision(dist) == precision(rdist) + @test cov(dist) ≈ cov(rdist) + @test std(dist) * std(dist)' ≈ std(rdist) * std(rdist)' + @test all(mean_cov(dist) .≈ mean_cov(rdist)) + @test all(mean_invcov(dist) .≈ mean_invcov(rdist)) + @test all(mean_precision(dist) .≈ mean_precision(rdist)) + @test all(weightedmean_cov(dist) .≈ weightedmean_cov(rdist)) + @test all(weightedmean_invcov(dist) .≈ weightedmean_invcov(rdist)) + @test all(weightedmean_precision(dist) .≈ weightedmean_precision(rdist)) + + @test length(dist) == 3 + @test entropy(dist) ≈ entropy(rdist) + @test pdf(dist, [0.2, 3.0, 4.0]) ≈ pdf(rdist, [0.2, 3.0, 4.0]) + @test pdf(dist, [0.202, 3.002, 4.002]) ≈ pdf(rdist, [0.202, 3.002, 4.002]) atol = 1e-4 + @test logpdf(dist, [0.2, 3.0, 4.0]) ≈ logpdf(rdist, [0.2, 3.0, 4.0]) + @test logpdf(dist, [0.202, 3.002, 4.002]) ≈ logpdf(rdist, [0.202, 3.002, 4.002]) atol = 1e-4 + @test rand(StableRNG(42), dist, 1000) ≈ rand(StableRNG(42), rdist, 1000) +end + +@testitem "MvNormalMeanScalePrecision: Base methods" begin + include("./normal_family_setuptests.jl") + + @test convert(MvNormalMeanScalePrecision{Float32}, MvNormalMeanScalePrecision([0.0, 0.0])) == + MvNormalMeanScalePrecision([0.0f0, 0.0f0], 1.0f0) + @test convert(MvNormalMeanScalePrecision{Float64}, [0.0, 0.0], 2.0) == + MvNormalMeanScalePrecision([0.0, 0.0], 2.0) + + @test length(MvNormalMeanScalePrecision([0.0, 0.0])) === 2 + @test length(MvNormalMeanScalePrecision([0.0, 0.0, 0.0])) === 3 + @test ndims(MvNormalMeanScalePrecision([0.0, 0.0])) === 2 + @test ndims(MvNormalMeanScalePrecision([0.0, 0.0, 0.0])) === 3 + @test size(MvNormalMeanScalePrecision([0.0, 0.0])) === (2,) + @test size(MvNormalMeanScalePrecision([0.0, 0.0, 0.0])) === (3,) + + μ, γ = zeros(2), 2.0 + distribution = MvNormalMeanScalePrecision(μ, γ) + + @test distribution ≈ distribution + @test convert(MvNormalMeanCovariance, distribution) == MvNormalMeanCovariance(μ, inv(γ) * I(length(μ))) + @test convert(MvNormalMeanPrecision, distribution) == MvNormalMeanPrecision(μ, γ * I(length(μ))) + @test convert(MvNormalWeightedMeanPrecision, distribution) == MvNormalWeightedMeanPrecision(γ * μ, γ * I(length(μ))) +end + +@testitem "MvNormalMeanScalePrecision: vague" begin + include("./normal_family_setuptests.jl") + + @test_throws MethodError vague(MvNormalMeanScalePrecision) + + d1 = vague(MvNormalMeanScalePrecision, 2) + + @test typeof(d1) <: MvNormalMeanScalePrecision + @test mean(d1) == zeros(2) + @test invcov(d1) == Matrix(Diagonal(1e-12 * ones(2))) + @test ndims(d1) == 2 + + d2 = vague(MvNormalMeanScalePrecision, 3) + + @test typeof(d2) <: MvNormalMeanScalePrecision + @test mean(d2) == zeros(3) + @test invcov(d2) == Matrix(Diagonal(1e-12 * ones(3))) + @test ndims(d2) == 3 +end + +@testitem "MvNormalMeanScalePrecision: prod" begin + include("./normal_family_setuptests.jl") + + for strategy in (ClosedProd(), PreserveTypeProd(Distribution), GenericProd()) + @test prod(strategy, MvNormalMeanScalePrecision([-1, -1], 2), MvNormalMeanPrecision([1, 1], [2, 4])) ≈ + MvNormalWeightedMeanPrecision([0, 2], [4, 6]) + + μ = [1.0, 2.0, 3.0] + γ = 2.0 + dist = MvNormalMeanScalePrecision(μ, γ) + + @test prod(strategy, dist, dist) ≈ + MvNormalMeanScalePrecision([1.0, 2.0, 3.0], 2γ) + end +end + +@testitem "MvNormalMeanScalePrecision: convert" begin + include("./normal_family_setuptests.jl") + + @test convert(MvNormalMeanScalePrecision, zeros(2), 1.0) == + MvNormalMeanScalePrecision(zeros(2), 1.0) + @test begin + m = rand(5) + c = rand() + convert(MvNormalMeanScalePrecision, m, c) == MvNormalMeanScalePrecision(m, c) + end +end + +@testitem "MvNormalMeanScalePrecision: rand" begin + include("./normal_family_setuptests.jl") + + rng = MersenneTwister(42) + + for T in (Float32, Float64) + @testset "Basic functionality" begin + μ = [1.0, 2.0, 3.0] + γ = 2.0 + dist = convert(MvNormalMeanScalePrecision{T}, μ, γ) + + @test typeof(rand(dist)) <: Vector{T} + + samples = rand(rng, dist, 5_000) + + @test isapprox(mean(samples), mean(μ), atol = 0.5) + end + end +end + +@testitem "MvNormalMeanScalePrecision: Fisher is faster then for full parametrization" begin + include("./normal_family_setuptests.jl") + using BenchmarkTools + using LinearAlgebra + using JET + + rng = StableRNG(42) + for k in 20:40 + μ = randn(rng, k) + γ = rand(rng) + cov = γ * I(k) + + ef_small = convert(ExponentialFamilyDistribution, MvNormalMeanScalePrecision(μ, γ)) + ef_full = convert(ExponentialFamilyDistribution, MvNormalMeanCovariance(μ, cov)) + + fi_small = fisherinformation(ef_small) + fi_full = fisherinformation(ef_full) + + @test_opt fisherinformation(ef_small) + @test_opt fisherinformation(ef_full) + + fi_mvsp_time = @elapsed fisherinformation(ef_small) + fi_mvsp_alloc = @allocated fisherinformation(ef_small) + + fi_full_time = @elapsed fisherinformation(ef_full) + fi_full_alloc = @allocated fisherinformation(ef_full) + + @test_opt cholinv(fi_small) + @test_opt cholinv(fi_full) + + cholinv_time_small = @elapsed cholinv(fi_small) + cholinv_alloc_small = @allocated cholinv(fi_small) + + cholinv_time_full = @elapsed cholinv(fi_full) + cholinv_alloc_full = @allocated cholinv(fi_full) + + # small time is supposed to be O(k) and full time is supposed to O(k^2) + # the constant C is selected to account to fluctuations in test runs + C = 0.9 + @test fi_mvsp_time < fi_full_time/(C*k) + @test fi_mvsp_alloc < fi_full_alloc/(C*k) + @test cholinv_time_small < cholinv_time_full/(C*k) + @test cholinv_alloc_small < cholinv_alloc_full/(C*k) + end +end \ No newline at end of file