Skip to content

Commit

Permalink
Merge pull request #102 from decargroup/add_tutorial
Browse files Browse the repository at this point in the history
Added more docs, fun demo figure, minor renaming, some tests
  • Loading branch information
CharlesCossette authored Oct 4, 2023
2 parents 104c0d1 + 2d5b1e0 commit a19dbe3
Show file tree
Hide file tree
Showing 15 changed files with 253 additions and 106 deletions.
6 changes: 6 additions & 0 deletions README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,12 @@ navlie
:target: https://www.python.org/downloads/
:alt: Python Version

.. image:: ./docs/source/fun_figs.png
:alt: Demo Figures
:align: center
:width: 100%


An on-manifold state estimation library for robotics.

The core idea behind this project is to abstract-away the state definition such that a single estimator implementation can operate on a variety of state manifolds, such as the usual vector space, and any Lie group. At the moment, algorithms and features of this package include:
Expand Down
Binary file added docs/source/fun_figs.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
165 changes: 165 additions & 0 deletions docs/source/fun_figs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,165 @@
# %%
import navlie as nav
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

sns.set_style("whitegrid")

# %% Banana distribution plot
def banana_plot(ax = None):
N = 500
x0 = nav.lib.SE2State([0.3, 3, 4], direction="right")
covariance = np.diag([0.2**2, 0.05**2, 0.05**2])
process_model = nav.lib.BodyFrameVelocity(np.zeros(3))

dx_samples = nav.randvec(covariance, N).T
x0_samples = [x0.plus(dx) for dx in dx_samples]

# Monte-carlo the trajectory forward in time
dt = 0.1
T = 10
stamps = np.arange(0, T, dt)

if ax is None:
fig, ax = plt.subplots(figsize=(8, 8))

final_states = []
for sample in x0_samples:
x_traj = [sample.copy()]
u = nav.lib.VectorInput([0.1, 0.3, 0])
x = sample
for _ in stamps:
x = process_model.evaluate(x, u, dt)
x_traj.append(x.copy())

# plot the trajectory
traj_pos = np.array([x.position for x in x_traj])

# random greyscale color
color = np.random.uniform(0.3, 0.9)
ax.plot(traj_pos[:, 0], traj_pos[:, 1], color=(color, color, color),zorder=1)

# save the final state
final_states.append(x_traj[-1])

final_positions = np.array([x.position for x in final_states])
ax.scatter(final_positions[:, 0], final_positions[:, 1], color="C0", zorder=2)

# Propagate the mean with EKF
kf = nav.ExtendedKalmanFilter(process_model)
x0_hat = nav.StateWithCovariance(x0, covariance)

x_hat_traj = [x0_hat]
for t in stamps:
u.stamp = t
x_hat_traj.append(kf.predict(x_hat_traj[-1], u, dt))

mean_traj = np.array([x.state.position for x in x_hat_traj])
ax.plot(mean_traj[:, 0], mean_traj[:, 1], color="r", zorder=3, linewidth=3)
ax.set_aspect("equal")

# banana_plot()

# %%
def pose3d_plot(ax = None):
N = 500
x0 = nav.lib.SE3State([0.3, 3, 4, 0, 0, 0], direction="right")
process_model = nav.lib.BodyFrameVelocity(np.zeros(6))

dt = 0.1
T = 20
stamps = np.arange(0, T, dt)

x_traj = [x0.copy()]
u = nav.lib.VectorInput([0.1, 0.3, 0, 1, 0, 0])
x = x0.copy()
for _ in stamps:
x = process_model.evaluate(x, u, dt)
x_traj.append(x.copy())

fig, ax = nav.plot_poses(x_traj, ax = ax)


# pose3d_plot()

# %%
def three_sigma_plot(axs = None):
dataset = nav.lib.datasets.SimulatedPoseRangingDataset()

estimates = nav.run_filter(
nav.ExtendedKalmanFilter(dataset.process_model),
dataset.get_ground_truth()[0],
np.diag([0.1**2, 0.1**2, 0.1**2, 0.1**2, 0.1**2, 0.1**2]),
dataset.get_input_data(),
dataset.get_measurement_data()
)

results = nav.GaussianResultList.from_estimates(estimates, dataset.get_ground_truth())

fig, axs = nav.plot_error(results[:, :3], axs=axs)
axs[2].set_xlabel("Time (s)")


# three_sigma_plot()


if __name__ == "__main__":

# Make one large figure which has all the plots. This will be a 1x3 grid, with the
# last plot itself being a three vertically stacked plots.

# The following values where chosen by trial and error
# top=0.975,
# bottom=0.097,
# left=0.025,
# right=0.992,
# hspace=0.2,
# wspace=0.117

# which will be used here:



fig = plt.figure(figsize=(20, 6))
gs = fig.add_gridspec(1, 3, width_ratios=[1, 1, 1])
ax1 = fig.add_subplot(gs[0])
ax2 = fig.add_subplot(gs[1], projection='3d')

# The last plot is a 3x1 grid
gs2 = gs[2].subgridspec(3, 1, hspace=0.1)
ax3 = fig.add_subplot(gs2[0])
ax4 = fig.add_subplot(gs2[1])
ax5 = fig.add_subplot(gs2[2])

# Remove tick labels for ax3 and ax4
ax3.set_xticklabels([])
ax4.set_xticklabels([])

# Remove all tick labels for ax2
ax2.set_xticklabels([])
ax2.set_yticklabels([])
ax2.set_zticklabels([])


banana_plot(ax1)
pose3d_plot(ax2)
three_sigma_plot(np.array([ax3, ax4, ax5]))

# Set spacing to the above values
fig.subplots_adjust(
top=0.975,
bottom=0.097,
left=0.025,
right=0.992,
hspace=0.2,
wspace=0.117
)


# Save the figure with transparent background, next to this file
import os
fig.savefig(os.path.join(os.path.dirname(__file__), "fun_figs.png"), transparent=True)

plt.show()
# %%
5 changes: 5 additions & 0 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,11 @@
Welcome to navlie!
------------------

.. make a row of three figures stacked side by side
.. image:: ./fun_figs.png
:width: 100%
:align: center

navlie is a state estimation package specifically designed for both traditional and Lie-group-based estimation problems!

The core idea behind this project is to use abstraction in such a way that both traditional and Lie-group-based problems fall under the exact same interface. Using this, a single estimator implementation can operate on a variety of state definitions, such as the usual vector space, and any Lie group. We allow the user to define their custom state, process model, and measurement models, after which they will have a variety of algorithms available to them, including:
Expand Down
4 changes: 3 additions & 1 deletion examples/ex_ekf_vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def main():
import matplotlib.pyplot as plt
import seaborn as sns

sns.set_theme()
sns.set_style("whitegrid")
fig, ax = plt.subplots(1, 1)
ax.plot(results.value[:, 0], results.value[:, 1], label="Estimate")
ax.plot(
Expand All @@ -113,4 +113,6 @@ def main():
axs[i].plot(results.stamp, results.error[:, i])
axs[0].set_title("Estimation error")
axs[1].set_xlabel("Time (s)")
axs[0].set_ylabel("x error (m)")
axs[1].set_ylabel("y error (m)")
plt.show()
2 changes: 1 addition & 1 deletion examples/ex_inertial_gps.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ def main():
data = SimulatedInertialGPSDataset(t_start=0, t_end=20)
gt_states = data.get_ground_truth()
input_data = data.get_input_data()
meas_data = data.get_meas_data()
meas_data = data.get_measurement_data()

# Filter initialization
P0 = np.eye(15)
Expand Down
2 changes: 1 addition & 1 deletion examples/ex_iterated_ekf_se3.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ def main():
data = nav.lib.SimulatedPoseRangingDataset()
gt_states = data.get_ground_truth()
input_data = data.get_input_data()
meas_data = data.get_meas_data()
meas_data = data.get_measurement_data()

# %% ###########################################################################
# Perturb initial groundtruth state to initialize filter
Expand Down
2 changes: 1 addition & 1 deletion examples/ex_ukf_se2.py → examples/ex_ukf_se3.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def main():
data = SimulatedPoseRangingDataset(x0=x0, Q=Q, noise_active=noise_active)
state_true = data.get_ground_truth()
input_data = data.get_input_data()
meas_data = data.get_meas_data()
meas_data = data.get_measurement_data()
if noise_active:
x0 = x0.plus(randvec(P0))
# %% #######################################################################
Expand Down
92 changes: 0 additions & 92 deletions examples/tutorial.py

This file was deleted.

4 changes: 4 additions & 0 deletions navlie/filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,10 @@ def predict(
# If state has no time stamp, load from measurement.
# usually only happens on estimator start-up
if x.state.stamp is None:
if u.stamp is None:
raise ValueError(
"Either state or input must have a time stamp"
)
t_km1 = u.stamp
else:
t_km1 = x.state.stamp
Expand Down
4 changes: 2 additions & 2 deletions navlie/lib/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def get_ground_truth(self) -> List[SE3State]:
def get_input_data(self) -> List[nav.Input]:
return self.input_data

def get_meas_data(self) -> List[nav.Measurement]:
def get_measurement_data(self) -> List[nav.Measurement]:
return self.meas_data


Expand Down Expand Up @@ -213,5 +213,5 @@ def get_ground_truth(self) -> List[IMUState]:
def get_input_data(self) -> List[IMU]:
return self.input_data

def get_meas_data(self) -> List[nav.Measurement]:
def get_measurement_data(self) -> List[nav.Measurement]:
return self.meas_data
2 changes: 1 addition & 1 deletion navlie/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -588,6 +588,6 @@ def get_input_data(self) -> List[Input]:
pass

@abstractmethod
def get_meas_data(self) -> List[Measurement]:
def get_measurement_data(self) -> List[Measurement]:
"""Returns a list of measurements."""
pass
Loading

0 comments on commit a19dbe3

Please sign in to comment.