From 7dbb2f3ae382b7d62d899982828b435765f3bd2d Mon Sep 17 00:00:00 2001 From: Alexis Duburcq Date: Wed, 27 Nov 2024 15:59:49 +0100 Subject: [PATCH] [python/dynamics] Take into account stride offset when trajectory time is wrapping. --- .../gym_jiminy/common/bases/pipeline.py | 33 ++++++++++++++----- .../gym_jiminy/common/bases/quantities.py | 2 +- .../common/blocks/quantity_observer.py | 10 ++---- .../common/gym_jiminy/common/envs/generic.py | 3 +- .../gym_jiminy/common/quantities/manager.py | 4 +-- .../gym_jiminy/common/utils/pipeline.py | 2 +- .../envs/gym_jiminy/envs/acrobot.py | 5 +-- python/gym_jiminy/envs/gym_jiminy/envs/ant.py | 5 +-- python/jiminy_py/src/jiminy_py/dynamics.py | 31 +++++++++++++---- 9 files changed, 62 insertions(+), 33 deletions(-) diff --git a/python/gym_jiminy/common/gym_jiminy/common/bases/pipeline.py b/python/gym_jiminy/common/gym_jiminy/common/bases/pipeline.py index bde567392..e5bcd8c55 100644 --- a/python/gym_jiminy/common/gym_jiminy/common/bases/pipeline.py +++ b/python/gym_jiminy/common/gym_jiminy/common/bases/pipeline.py @@ -17,7 +17,7 @@ from abc import abstractmethod from collections import OrderedDict from typing import ( - Dict, Any, List, Sequence, Optional, Tuple, Set, Union, Generic, TypeVar, + Dict, Any, List, Sequence, Optional, Tuple, Union, Generic, TypeVar, Type, Mapping, SupportsFloat, Callable, cast, overload, TYPE_CHECKING) import numpy as np @@ -38,7 +38,6 @@ InfoType, EngineObsType, InterfaceJiminyEnv) -from .quantities import DatasetTrajectoryQuantity from .compositions import AbstractReward, AbstractTerminationCondition from .blocks import BaseControllerBlock, BaseObserverBlock @@ -93,13 +92,29 @@ def _merge_base_env_with_block( block_feature: Optional[NestedSpaceOrData], block_action: Optional[NestedSpaceOrData], ) -> NestedSpaceOrData: - """ TODO: Write documentation - - :param block_name: TODO: Write documentation - :param base_observation: TODO: Write documentation - :param block_state: TODO: Write documentation - :param block_feature: TODO: Write documentation - :param block_action: TODO: Write documentation + """Merge the observation space of a base environment with the state, + feature and action spaces of a given block. + + This method supports specifying both spaces or values for all the input + arguments at once. In both cases, the base observation is shallow copy + first to avoid altering it while sharing memory with the original leaves. + + If the base observation space is a mapping, then the state, feature and + action of the block are added under nested keys ("states", block_name), + ("feature", block_name), and ("action", block_name). Otherwise, the base + observation is first stored under nested key ("measurement",) of a new + mapping, while block spaces are stored under the same hierarchy as before. + + :param block_name: Name of the block. It will be used as parent key of the + state, feature and action spaces. + :param base_observation: Observation space or value of the base + environment. + :param block_state: State space or value of the block. `None` if it does + not exist for the block at hand. + :param block_feature: Feature space or value of the block. `None` if it + does not exist for the block at hand. + :param block_action: Action space or value of the block. `None` if it does + not exist for the block at hand. """ observation: Dict[str, NestedSpaceOrData] = OrderedDict() diff --git a/python/gym_jiminy/common/gym_jiminy/common/bases/quantities.py b/python/gym_jiminy/common/gym_jiminy/common/bases/quantities.py index eb0eb47e6..919cc3154 100644 --- a/python/gym_jiminy/common/gym_jiminy/common/bases/quantities.py +++ b/python/gym_jiminy/common/gym_jiminy/common/bases/quantities.py @@ -22,7 +22,7 @@ from dataclasses import dataclass, replace from functools import wraps from typing import ( - Any, Dict, List, Optional, Tuple, Generic, TypeVar, Type, Iterator, Set, + Any, Dict, List, Optional, Tuple, Generic, TypeVar, Type, Iterator, Collection, Callable, Literal, ClassVar, TYPE_CHECKING) import numpy as np diff --git a/python/gym_jiminy/common/gym_jiminy/common/blocks/quantity_observer.py b/python/gym_jiminy/common/gym_jiminy/common/blocks/quantity_observer.py index 8e66d6d01..26ea4a8b8 100644 --- a/python/gym_jiminy/common/gym_jiminy/common/blocks/quantity_observer.py +++ b/python/gym_jiminy/common/gym_jiminy/common/blocks/quantity_observer.py @@ -1,28 +1,22 @@ """Implementation of Mahony filter block compatible with gym_jiminy reinforcement learning pipeline environment design. """ -import logging from collections import OrderedDict -from typing import Any, List, Union, Dict, Optional, Type, TypeVar, cast +from typing import Any, Type, TypeVar, cast import numpy as np -import numba as nb import gymnasium as gym -from jiminy_py.core import ImuSensor # pylint: disable=no-name-in-module from jiminy_py import tree from ..bases import ( BaseObs, BaseAct, BaseObserverBlock, InterfaceJiminyEnv, InterfaceQuantity) -from ..utils import DataNested, fill, build_copyto +from ..utils import DataNested, build_copyto ValueT = TypeVar('ValueT', bound=DataNested) -LOGGER = logging.getLogger(__name__) - - def get_space(data: DataNested) -> gym.Space[DataNested]: """Infer space from a given value. diff --git a/python/gym_jiminy/common/gym_jiminy/common/envs/generic.py b/python/gym_jiminy/common/gym_jiminy/common/envs/generic.py index 5af110817..242ffd803 100644 --- a/python/gym_jiminy/common/gym_jiminy/common/envs/generic.py +++ b/python/gym_jiminy/common/gym_jiminy/common/envs/generic.py @@ -1274,7 +1274,8 @@ def _initialize_observation_space(self) -> None: order to define a custom observation space. """ observation_spaces: Dict[str, spaces.Space] = OrderedDict() - observation_spaces['t'] = spaces.Box(low=0.0, + observation_spaces['t'] = spaces.Box( + low=0.0, high=self.simulation_duration_max, shape=(), dtype=np.float64) diff --git a/python/gym_jiminy/common/gym_jiminy/common/quantities/manager.py b/python/gym_jiminy/common/gym_jiminy/common/quantities/manager.py index e5c46999d..1495e1b25 100644 --- a/python/gym_jiminy/common/gym_jiminy/common/quantities/manager.py +++ b/python/gym_jiminy/common/gym_jiminy/common/quantities/manager.py @@ -9,9 +9,7 @@ quantities in a large batch to leverage vectorization of math instructions. """ from collections.abc import MutableMapping -from typing import Any, Dict, List, Tuple, Iterator, Literal, Type, cast - -from jiminy_py.dynamics import Trajectory +from typing import Any, Dict, List, Tuple, Iterator, Type, cast from ..bases import ( QuantityCreator, InterfaceJiminyEnv, InterfaceQuantity, SharedCache, diff --git a/python/gym_jiminy/common/gym_jiminy/common/utils/pipeline.py b/python/gym_jiminy/common/gym_jiminy/common/utils/pipeline.py index 748a65204..6e457a52a 100644 --- a/python/gym_jiminy/common/gym_jiminy/common/utils/pipeline.py +++ b/python/gym_jiminy/common/gym_jiminy/common/utils/pipeline.py @@ -15,7 +15,7 @@ from collections.abc import Sequence from typing import ( Dict, Any, Optional, Union, Type, Sequence as SequenceT, Callable, - TypedDict, List, Literal, overload, cast) + TypedDict, Literal, overload, cast) import h5py import tomlkit diff --git a/python/gym_jiminy/envs/gym_jiminy/envs/acrobot.py b/python/gym_jiminy/envs/gym_jiminy/envs/acrobot.py index 96a1e7dfe..2f7d55714 100644 --- a/python/gym_jiminy/envs/gym_jiminy/envs/acrobot.py +++ b/python/gym_jiminy/envs/gym_jiminy/envs/acrobot.py @@ -10,7 +10,7 @@ from gym_jiminy.common.bases import InfoType, EngineObsType from gym_jiminy.common.envs import BaseJiminyEnv -from gym_jiminy.common.utils import sample +from gym_jiminy.common.utils import sample, get_robot_state_space # Stepper update period @@ -150,7 +150,8 @@ def _initialize_observation_space(self) -> None: Only the state is observable, while by default, the current time, state, and sensors data are available. """ - state_space = self._get_agent_state_space(use_theoretical_model=True) + state_space = get_robot_state_space( + self.robot, use_theoretical_model=True) position_space, velocity_space = state_space['q'], state_space['v'] assert isinstance(position_space, gym.spaces.Box) assert isinstance(velocity_space, gym.spaces.Box) diff --git a/python/gym_jiminy/envs/gym_jiminy/envs/ant.py b/python/gym_jiminy/envs/gym_jiminy/envs/ant.py index 18f087a91..b7e52202d 100644 --- a/python/gym_jiminy/envs/gym_jiminy/envs/ant.py +++ b/python/gym_jiminy/envs/gym_jiminy/envs/ant.py @@ -20,7 +20,7 @@ from jiminy_py.simulator import Simulator from gym_jiminy.common.bases import InfoType, EngineObsType from gym_jiminy.common.envs import BaseJiminyEnv -from gym_jiminy.common.utils import sample +from gym_jiminy.common.utils import sample, get_robot_state_space # Stepper update period @@ -143,7 +143,8 @@ def _initialize_observation_space(self) -> None: """ # http://www.mujoco.org/book/APIreference.html#mjData - position_space, velocity_space = self._get_agent_state_space().values() + state_space = get_robot_state_space(self.robot) + position_space, velocity_space = state_space["q"], state_space["v"] assert isinstance(position_space, gym.spaces.Box) assert isinstance(velocity_space, gym.spaces.Box) diff --git a/python/jiminy_py/src/jiminy_py/dynamics.py b/python/jiminy_py/src/jiminy_py/dynamics.py index b208788c3..0d31562ad 100644 --- a/python/jiminy_py/src/jiminy_py/dynamics.py +++ b/python/jiminy_py/src/jiminy_py/dynamics.py @@ -96,7 +96,7 @@ def velocityXYZQuatToXYZRPY(xyzquat: np.ndarray, # #################### State and Trajectory ########################### # ##################################################################### -@dataclass(unsafe_hash=True) +@dataclass class State: """Basic data structure storing kinematics and dynamics information at a given time. @@ -143,7 +143,7 @@ class State: """ -@dataclass(unsafe_hash=True) +@dataclass class Trajectory: """Trajectory of a robot. @@ -200,6 +200,14 @@ def __init__(self, else: self._pinocchio_model = robot.pinocchio_model + # Compute the trajectory stride. + # Ensure continuity of the freeflyer when time is wrapping. + self._stride_offset_log6: Optional[np.ndarray] = None + if self.robot.has_freeflyer and self.has_data: + M_start = pin.XYZQUATToSE3(self.states[0].q[:7]) + M_end = pin.XYZQUATToSE3(self.states[-1].q[:7]) + self._stride_offset_log6 = pin.log6(M_end * M_start.inverse()) + # Keep track of last request to speed up nearest neighbors search self._t_prev = 0.0 self._index_prev = 1 @@ -309,13 +317,15 @@ def get(self, t_orig = t # Handling of the desired mode + n_steps = 0.0 t_start, t_end = self.time_interval if mode == "raise": if t - t_end > TRAJ_INTERP_TOL or t_start - t > TRAJ_INTERP_TOL: raise RuntimeError("Time is out-of-range.") elif mode == "wrap": if t_end > t_start: - t = ((t - t_start) % (t_end - t_start)) + t_start + n_steps, t_rel = divmod(t - t_start, t_end - t_start) + t = t_rel + t_start else: t = t_start else: @@ -334,7 +344,7 @@ def get(self, self._times, t, self._index_prev, len(self._times) - 1) self._t_prev = t - # Skip interpolation if not necessary. + # Skip interpolation if not necessary index_left, index_right = self._index_prev - 1, self._index_prev t_left, s_left = self._times[index_left], self.states[index_left] if t - t_left < TRAJ_INTERP_TOL: @@ -345,14 +355,23 @@ def get(self, alpha = (t - t_left) / (t_right - t_left) # Interpolate state - data = {"q": pin.interpolate( - self._pinocchio_model, s_left.q, s_right.q, alpha)} + position = pin.interpolate( + self._pinocchio_model, s_left.q, s_right.q, alpha) + data = {"q": position} for field in self._fields: value_left = getattr(s_left, field) value_right = getattr(s_right, field) data[field] = value_left + alpha * (value_right - value_left) + + # Perform odometry if the time is wrapping + if self._stride_offset_log6 is not None and n_steps: + stride_offset = pin.exp6(n_steps * self._stride_offset_log6) + ff_xyzquat = stride_offset * pin.XYZQUATToSE3(position[:7]) + position[:7] = pin.SE3ToXYZQUAT(ff_xyzquat) + return State(t=t_orig, **data) + # ##################################################################### # ################### Kinematic and dynamics ########################## # #####################################################################