Skip to content

Commit

Permalink
Merge pull request #119 from decargroup/add_gaussian_mixtures
Browse files Browse the repository at this point in the history
Add Gaussian Mixture residuals to batch
  • Loading branch information
vkorotkine authored Apr 18, 2024
2 parents 890af51 + 30bbfef commit 1dce296
Show file tree
Hide file tree
Showing 5 changed files with 801 additions and 2 deletions.
91 changes: 91 additions & 0 deletions examples/ex_gaussian_mixture.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
from navlie.batch.gaussian_mixtures import (
GaussianMixtureResidual,
MaxMixtureResidual,
SumMixtureResidual,
MaxSumMixtureResidual,
HessianSumMixtureResidual,
)
from navlie.batch.residuals import PriorResidual

from navlie.lib.states import VectorState
import os
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
from pathlib import Path
from navlie.batch.problem import Problem


def main():
key = "x"
component_residuals = []
stamp = 0.0
means = [np.array([0.0]), np.array([0.5]), np.array([1])]
covariances = [
np.atleast_2d(np.array([2])),
np.atleast_2d(np.array([2])),
np.atleast_2d(np.array([3])),
]
weights = [0.5, 0.5]
for lv1 in range(len(means)):
prior_state = VectorState(means[lv1], stamp)

component_residuals.append(PriorResidual([key], prior_state, covariances[lv1]))
res_dict = {
"Max-Mixture": MaxMixtureResidual(component_residuals, weights),
"Sum-Mixture": SumMixtureResidual(component_residuals, weights),
"Max-Sum-Mixture": MaxSumMixtureResidual(component_residuals, weights, 10),
"Hessian-Sum-Mixture": HessianSumMixtureResidual(
component_residuals, weights, True, 0.1
),
}

x0 = VectorState(2, 0.0, "x")
plt.figure()
for key, res in res_dict.items():
x = x0.copy()
print(f"Running {key} optimization...")
problem = Problem(
solver="LM",
max_iters=100,
step_tol=1e-8,
tau=1e-11,
verbose=False,
)
problem.add_residual(res)
problem.add_variable("x", x)
opt_nv_res = problem.solve()
x = np.linspace(-3, 3, 1000)
linestyles = ["-", "--", "-.", ":", "-"] * 6

res: MaxMixtureResidual = res
plt.plot(
x,
np.array([evaluate_log_likelihood(res, val) for val in x]),
label=key,
linestyle=linestyles[lv1],
)
plt.scatter(
opt_nv_res["variables"]["x"].value,
np.sum(res.evaluate([opt_nv_res["variables"]["x"]]) ** 2),
marker="o",
)
plt.xlabel("x")
plt.ylabel("Cost Function Value (offset by norm. constants)")
plt.legend()
plt.show()


def evaluate_error(res: GaussianMixtureResidual, val: np.ndarray):
error = res.evaluate([VectorState(val.squeeze())])
error = np.atleast_1d(error)
return error


def evaluate_log_likelihood(res: GaussianMixtureResidual, val: np.ndarray):
error = evaluate_error(res, np.atleast_1d(val))
return np.linalg.norm(error, 2) ** 2


if __name__ == "__main__":
main()
Loading

0 comments on commit 1dce296

Please sign in to comment.