Skip to content

Commit

Permalink
[python/dynamics] Take into account stride offset when trajectory tim…
Browse files Browse the repository at this point in the history
…e is wrapping.
  • Loading branch information
duburcqa committed Nov 27, 2024
1 parent 5e0c6c5 commit e7257d0
Show file tree
Hide file tree
Showing 6 changed files with 29 additions and 21 deletions.
3 changes: 1 addition & 2 deletions python/gym_jiminy/common/gym_jiminy/common/bases/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from abc import abstractmethod
from collections import OrderedDict
from typing import (
Dict, Any, List, Sequence, Optional, Tuple, Set, Union, Generic, TypeVar,
Dict, Any, List, Sequence, Optional, Tuple, Union, Generic, TypeVar,
Type, Mapping, SupportsFloat, Callable, cast, overload, TYPE_CHECKING)

import numpy as np
Expand All @@ -38,7 +38,6 @@
InfoType,
EngineObsType,
InterfaceJiminyEnv)
from .quantities import DatasetTrajectoryQuantity
from .compositions import AbstractReward, AbstractTerminationCondition
from .blocks import BaseControllerBlock, BaseObserverBlock

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from dataclasses import dataclass, replace
from functools import wraps
from typing import (
Any, Dict, List, Optional, Tuple, Generic, TypeVar, Type, Iterator, Set,
Any, Dict, List, Optional, Tuple, Generic, TypeVar, Type, Iterator,
Collection, Callable, Literal, ClassVar, TYPE_CHECKING)

import numpy as np
Expand Down
Original file line number Diff line number Diff line change
@@ -1,28 +1,22 @@
"""Implementation of Mahony filter block compatible with gym_jiminy
reinforcement learning pipeline environment design.
"""
import logging
from collections import OrderedDict
from typing import Any, List, Union, Dict, Optional, Type, TypeVar, cast
from typing import Any, Type, TypeVar, cast

import numpy as np
import numba as nb
import gymnasium as gym

from jiminy_py.core import ImuSensor # pylint: disable=no-name-in-module
from jiminy_py import tree

from ..bases import (
BaseObs, BaseAct, BaseObserverBlock, InterfaceJiminyEnv, InterfaceQuantity)
from ..utils import DataNested, fill, build_copyto
from ..utils import DataNested, build_copyto


ValueT = TypeVar('ValueT', bound=DataNested)


LOGGER = logging.getLogger(__name__)


def get_space(data: DataNested) -> gym.Space[DataNested]:
"""Infer space from a given value.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,7 @@
quantities in a large batch to leverage vectorization of math instructions.
"""
from collections.abc import MutableMapping
from typing import Any, Dict, List, Tuple, Iterator, Literal, Type, cast

from jiminy_py.dynamics import Trajectory
from typing import Any, Dict, List, Tuple, Iterator, Type, cast

from ..bases import (
QuantityCreator, InterfaceJiminyEnv, InterfaceQuantity, SharedCache,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from collections.abc import Sequence
from typing import (
Dict, Any, Optional, Union, Type, Sequence as SequenceT, Callable,
TypedDict, List, Literal, overload, cast)
TypedDict, Literal, overload, cast)

import h5py
import tomlkit
Expand Down
29 changes: 23 additions & 6 deletions python/jiminy_py/src/jiminy_py/dynamics.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def velocityXYZQuatToXYZRPY(xyzquat: np.ndarray,
# #################### State and Trajectory ###########################
# #####################################################################

@dataclass(unsafe_hash=True)
@dataclass
class State:
"""Basic data structure storing kinematics and dynamics information at a
given time.
Expand Down Expand Up @@ -143,7 +143,7 @@ class State:
"""


@dataclass(unsafe_hash=True)
@dataclass
class Trajectory:
"""Trajectory of a robot.
Expand Down Expand Up @@ -200,6 +200,13 @@ def __init__(self,
else:
self._pinocchio_model = robot.pinocchio_model

# Compute the trajectory stride.
# Ensure continuity of the freeflyer when time is wrapping.
if self.has_data:
M_start = pin.XYZQUATToSE3(self.states[0].q[:7])
M_end = pin.XYZQUATToSE3(self.states[-1].q[:7])
self._stride_offset_log6 = pin.log6(M_end * M_start.inverse())

# Keep track of last request to speed up nearest neighbors search
self._t_prev = 0.0
self._index_prev = 1
Expand Down Expand Up @@ -309,13 +316,15 @@ def get(self,
t_orig = t

# Handling of the desired mode
n_steps = 0.0
t_start, t_end = self.time_interval
if mode == "raise":
if t - t_end > TRAJ_INTERP_TOL or t_start - t > TRAJ_INTERP_TOL:
raise RuntimeError("Time is out-of-range.")
elif mode == "wrap":
if t_end > t_start:
t = ((t - t_start) % (t_end - t_start)) + t_start
n_steps, t_rel = divmod(t - t_start, t_end - t_start)
t = t_rel + t_start
else:
t = t_start
else:
Expand All @@ -334,7 +343,7 @@ def get(self,
self._times, t, self._index_prev, len(self._times) - 1)
self._t_prev = t

# Skip interpolation if not necessary.
# Skip interpolation if not necessary
index_left, index_right = self._index_prev - 1, self._index_prev
t_left, s_left = self._times[index_left], self.states[index_left]
if t - t_left < TRAJ_INTERP_TOL:
Expand All @@ -345,12 +354,20 @@ def get(self,
alpha = (t - t_left) / (t_right - t_left)

# Interpolate state
data = {"q": pin.interpolate(
self._pinocchio_model, s_left.q, s_right.q, alpha)}
position = pin.interpolate(
self._pinocchio_model, s_left.q, s_right.q, alpha)
data = {"q": position}
for field in self._fields:
value_left = getattr(s_left, field)
value_right = getattr(s_right, field)
data[field] = value_left + alpha * (value_right - value_left)

# Perform odometry if the time is wrapping
if n_steps:
stride_offset = pin.exp6(n_steps * self._stride_offset_log6)
ff_xyzquat = stride_offset * pin.XYZQUATToSE3(position[:7])
position[:7] = pin.SE3ToXYZQUAT(ff_xyzquat)

return State(t=t_orig, **data)

# #####################################################################
Expand Down

0 comments on commit e7257d0

Please sign in to comment.