Skip to content

Commit

Permalink
Merge pull request #105 from decargroup/add_tutorial
Browse files Browse the repository at this point in the history
Moved `CompositeState` and friends to dedicated file. Ran black.
  • Loading branch information
CharlesCossette authored Jan 20, 2024
2 parents f3b9d38 + 262e200 commit 36bbef4
Show file tree
Hide file tree
Showing 14 changed files with 734 additions and 666 deletions.
58 changes: 35 additions & 23 deletions docs/source/fun_figs.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,9 @@

sns.set_style("whitegrid")


# %% Banana distribution plot
def banana_plot(ax = None):
def banana_plot(ax=None):
N = 500
x0 = nav.lib.SE2State([0.3, 3, 4], direction="right")
covariance = np.diag([0.2**2, 0.05**2, 0.05**2])
Expand Down Expand Up @@ -38,15 +39,22 @@ def banana_plot(ax = None):

# random greyscale color
color = np.random.uniform(0.3, 0.9)
ax.plot(traj_pos[:, 0], traj_pos[:, 1], color=(color, color, color),zorder=1)
ax.plot(
traj_pos[:, 0],
traj_pos[:, 1],
color=(color, color, color),
zorder=1,
)

# save the final state
final_states.append(x_traj[-1])

final_positions = np.array([x.position for x in final_states])
ax.scatter(final_positions[:, 0], final_positions[:, 1], color="C0", zorder=2)
ax.scatter(
final_positions[:, 0], final_positions[:, 1], color="C0", zorder=2
)

# Propagate the mean with EKF
# Propagate the mean with EKF
kf = nav.ExtendedKalmanFilter(process_model)
x0_hat = nav.StateWithCovariance(x0, covariance)

Expand All @@ -59,10 +67,12 @@ def banana_plot(ax = None):
ax.plot(mean_traj[:, 0], mean_traj[:, 1], color="r", zorder=3, linewidth=3)
ax.set_aspect("equal")


# banana_plot()


# %%
def pose3d_plot(ax = None):
def pose3d_plot(ax=None):
N = 500
x0 = nav.lib.SE3State([0.3, 3, 4, 0, 0, 0], direction="right")
process_model = nav.lib.BodyFrameVelocity(np.zeros(6))
Expand All @@ -78,24 +88,27 @@ def pose3d_plot(ax = None):
x = process_model.evaluate(x, u, dt)
x_traj.append(x.copy())

fig, ax = nav.plot_poses(x_traj, ax = ax)
fig, ax = nav.plot_poses(x_traj, ax=ax)


# pose3d_plot()


# %%
def three_sigma_plot(axs = None):
def three_sigma_plot(axs=None):
dataset = nav.lib.datasets.SimulatedPoseRangingDataset()

estimates = nav.run_filter(
nav.ExtendedKalmanFilter(dataset.process_model),
dataset.get_ground_truth()[0],
np.diag([0.1**2, 0.1**2, 0.1**2, 0.1**2, 0.1**2, 0.1**2]),
dataset.get_input_data(),
dataset.get_measurement_data()
)

results = nav.GaussianResultList.from_estimates(estimates, dataset.get_ground_truth())
dataset.get_measurement_data(),
)

results = nav.GaussianResultList.from_estimates(
estimates, dataset.get_ground_truth()
)

fig, axs = nav.plot_error(results[:, :3], axs=axs)
axs[2].set_xlabel("Time (s)")
Expand All @@ -105,7 +118,6 @@ def three_sigma_plot(axs = None):


if __name__ == "__main__":

# Make one large figure which has all the plots. This will be a 1x3 grid, with the
# last plot itself being a three vertically stacked plots.

Expand All @@ -119,13 +131,11 @@ def three_sigma_plot(axs = None):

# which will be used here:



fig = plt.figure(figsize=(20, 6))
gs = fig.add_gridspec(1, 3, width_ratios=[1, 1, 1])
ax1 = fig.add_subplot(gs[0])
ax2 = fig.add_subplot(gs[1], projection='3d')
ax2 = fig.add_subplot(gs[1], projection="3d")

# The last plot is a 3x1 grid
gs2 = gs[2].subgridspec(3, 1, hspace=0.1)
ax3 = fig.add_subplot(gs2[0])
Expand All @@ -141,25 +151,27 @@ def three_sigma_plot(axs = None):
ax2.set_yticklabels([])
ax2.set_zticklabels([])


banana_plot(ax1)
pose3d_plot(ax2)
three_sigma_plot(np.array([ax3, ax4, ax5]))

# Set spacing to the above values
# Set spacing to the above values
fig.subplots_adjust(
top=0.975,
bottom=0.097,
left=0.025,
right=0.992,
hspace=0.2,
wspace=0.117
wspace=0.117,
)


# Save the figure with transparent background, next to this file
import os
fig.savefig(os.path.join(os.path.dirname(__file__), "fun_figs.png"), transparent=True)
import os

fig.savefig(
os.path.join(os.path.dirname(__file__), "fun_figs.png"),
transparent=True,
)

plt.show()
# %%
8 changes: 7 additions & 1 deletion navlie/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,5 +40,11 @@
jacobian,
)

from .composite import (
CompositeState,
CompositeProcessModel,
CompositeMeasurementModel,
CompositeInput,
)

from .lib.states import StampedValue # for backwards compatibility
from .lib.states import StampedValue # for backwards compatibility
Loading

0 comments on commit 36bbef4

Please sign in to comment.