This package provides a simple implementation of the Expectation Maximization algorithm used to fit mixture models. Due to Julia amazing multiple dispatch systems and the Distributions package, the code is very generic i.e., a lot of mixture should work:
- Univariate continuous distributions
- Univariate discrete distributions
- Multivariate distributions (continuous or discrete).
- Mixture of mixture (univariate or multivariate and continuous or discrete). Note that Distributions currently does not allow
MixtureModel
to have discrete and continuous components (but who does that? Rain).
Have a look at the tests sections to see examples.
To work, the only requirements are that the dist<:Distribution
considered has implanted
logpdf(dist, y)
(used in the E-step)fit_mle(dist, y, weigths)
(used in the M-step)
In general 1. is easy, while 2. is only known explicitly for a few common distributions.
In case 2. is not explicit known, you can always implement a numerical scheme, if it exists, for fit_mle(dist, y)
see Gamma
distribution example.
Or, when possible, represent your “difficult” distribution as a mixture of simple terms.
(I had this in mind, but it is not directly a mixture model.)
- Add different variants for E-step and M-steps like stochastic EM and others.
- Add examples e.g., MNIST dataset and Bernoulli mixtures.
- Add Docs
- Benchmark against popular EM implementations + Speed up code
using Distributions
using ExpectationMaximization
N = 50_000
θ₁ = 10
θ₂ = 5
α = 0.2
β = 0.3
# Mixture Model here one can put any classical distributions
mix_true = MixtureModel([Exponential(θ₁), Gamma(α, θ₂)], [β, 1 - β])
# Generate N samples from the mixture
y = rand(mix_true, N)
# Initial guess
mix_guess = MixtureModel([Exponential(1), Gamma(0.5, 1)], [0.5, 1 - 0.5])
# Fit the MLE with the EM algorithm
mix_mle = fit_mle(mix_guess, y; display = :iter, tol = 1e-3, robust = false, infos = false)
rtol = 5e-2
p = params(mix_mle)[1] # (θ₁, (α, θ₂))
isapprox(β, probs(mix_mle)[1]; rtol = rtol)
isapprox(θ₁, p[1]...; rtol = rtol)
isapprox(α, p[2][1]; rtol = rtol)
isapprox(θ₂, p[2][2]; rtol = rtol)