diff --git a/python/gym_jiminy/common/gym_jiminy/common/quantities/__init__.py b/python/gym_jiminy/common/gym_jiminy/common/quantities/__init__.py index 2852bfc13..624dc2c76 100644 --- a/python/gym_jiminy/common/gym_jiminy/common/quantities/__init__.py +++ b/python/gym_jiminy/common/gym_jiminy/common/quantities/__init__.py @@ -6,10 +6,14 @@ UnaryOpQuantity, BinaryOpQuantity) from .generic import (FrameEulerAngles, + MultiFrameEulerAngles, FrameXYZQuat, + MultiFrameXYZQuat, + MultiFrameMeanXYZQuat, AverageFrameSpatialVelocity, ActuatedJointPositions) -from .locomotion import (OdometryPose, +from .locomotion import (BaseOdometryPose, + FootOdometryPose, AverageOdometryVelocity, CenterOfMass, CapturePoint, @@ -24,8 +28,12 @@ 'BinaryOpQuantity', 'ActuatedJointPositions', 'FrameEulerAngles', + 'MultiFrameEulerAngles', 'FrameXYZQuat', - 'OdometryPose', + 'MultiFrameXYZQuat', + 'MultiFrameMeanXYZQuat', + 'BaseOdometryPose', + 'FootOdometryPose', 'AverageFrameSpatialVelocity', 'AverageOdometryVelocity', 'CenterOfMass', diff --git a/python/gym_jiminy/common/gym_jiminy/common/quantities/generic.py b/python/gym_jiminy/common/gym_jiminy/common/quantities/generic.py index 324bbc2ad..24e348677 100644 --- a/python/gym_jiminy/common/gym_jiminy/common/quantities/generic.py +++ b/python/gym_jiminy/common/gym_jiminy/common/quantities/generic.py @@ -6,9 +6,10 @@ from functools import partial from dataclasses import dataclass from typing import ( - List, Dict, Set, Optional, Protocol, Sequence, runtime_checkable) + List, Dict, Optional, Protocol, Sequence, Tuple, Union, runtime_checkable) import numpy as np +import numba as nb import jiminy_py.core as jiminy from jiminy_py.core import ( # pylint: disable=no-name-in-module @@ -47,11 +48,122 @@ class MultiFrameQuantity(Protocol): This protocol is involved in automatic computation vectorization. See `FrameQuantity` documentation for details. """ - frame_names: Sequence[str] + frame_names: Tuple[str, ...] + + +def aggregate_frame_names(quantity: InterfaceQuantity) -> Tuple[ + Tuple[str, ...], + Dict[Union[str, Tuple[str, ...]], Union[int, Tuple[()], slice]]]: + """Generate a sequence of frame names that contains all the sub-sequences + specified by the parents of all the cache owners of a given quantity. + + Ideally, the generated sequence should be the shortest possible. Since + finding the optimal sequence is a complex problem, a heuristic is used + instead. It consists in aggregating first all multi-frame quantities + sequentially after ordering them by decreasing length, followed by all + single-frame quantities. + + .. note:: + Only active quantities are considered for best performance, which may + change dynamically. Delegating this responsibility to cache owners may + be possible but difficult to implement because `frame_names` must be + cleared first before re-registering themselves, just in case of optimal + computation graph has changed, not only once to avoid getting rid of + quantities that just registered themselves. Nevertheless, whenever + re-initializing this quantity to take into account changes of the + active set must be decided by cache owners. + + :param quantity: Quantity whose parent implements either `FrameQuantity` or + `MultiFrameQuantity` protocol. All the parents of all its + cache owners must also implement one of these protocol. + """ + # Make sure that parent quantity implement multi- or single-frame protocol + assert isinstance(quantity.parent, (FrameQuantity, MultiFrameQuantity)) + quantities = (quantity.cache.owners if quantity.has_cache else (quantity,)) + + # First, order all multi-frame quantities by decreasing length + frame_names_chunks: List[Tuple[str, ...]] = [] + for owner in quantities: + if owner.parent.is_active(any_cache_owner=False): + if isinstance(owner.parent, MultiFrameQuantity): + frame_names_chunks.append(owner.parent.frame_names) + + # Next, process ordered multi-frame quantities sequentially. + # For each of them, we first check if its set of frames is completely + # included in the current full set. If so, then there is nothing do not and + # we can move to the next quantity. If not, then we check if a part of its + # tail or head is contained at the beginning or end of the full set + # respectively. If so, only the missing part is prepended or appended + # respectively. If not, then the while set of frames is appended to the + # current full set before moving to the next quantity. + frame_names: List[str] = [] + frame_names_chunks = sorted(frame_names_chunks, key=len)[::-1] + for frame_names_ in map(list, frame_names_chunks): + nframes, nframes_ = map(len, (frame_names, frame_names_)) + for i in range(nframes - nframes_ + 1): + # Check if the sub-chain is completely included in the + # current full set. + if frame_names_ == frame_names[i:(i + nframes_)]: + break + else: + for i in range(1, nframes_ + 1): + # Check if part of the frame names matches with the + # tail of the current full set. If so, append the + # disjoint head only. + if (frame_names[(nframes - nframes_ + i):] == + frame_names_[:(nframes_ - i)]): + frame_names += frame_names_[(nframes_ - i):] + break + # Check if part of the frame names matches with the + # head of the current full set. If so, prepend the + # disjoint tail only. + if frame_names[:(nframes_ - i)] == frame_names_[i:]: + frame_names = frame_names_[:i] + frame_names + break + + # Finally, loop over all single-frame quantities. + # If a frame name is missing, then it is appended to the current full set. + # Otherwise, we just move to the next quantity. + frame_name_chunks: List[str] = [] + for owner in quantities: + if owner.parent.is_active(any_cache_owner=False): + if isinstance(owner.parent, FrameQuantity): + frame_name_chunks.append(owner.parent.frame_name) + frame_name = frame_name_chunks[-1] + if frame_name not in frame_names: + frame_names.append(frame_name) + frame_names = tuple(frame_names) + + # Compute mapping from frame names to their corresponding indices in the + # generated sequence of frame names. + # The indices are stored as a slice for non-empty multi-frame quantities, + # as an empty tuple for empty multi-frame quantities, or as an integer for + # single-frame quantities. + frame_slices: Dict[ + Union[str, Tuple[str, ...]], Union[int, Tuple[()], slice]] = {} + nframes = len(frame_names) + for frame_names_ in frame_names_chunks: + if frame_names_ in frame_slices: + continue + if not frame_names_: + frame_slices[frame_names_] = () + continue + nframes_ = len(frame_names_) + for i in range(nframes - nframes_ + 1): + if frame_names_ == frame_names[i:(i + nframes_)]: + break + frame_slices[frame_names_] = slice(i, i + nframes_) + for frame_name in frame_name_chunks: + if frame_name in frame_slices: + continue + frame_slices[frame_name] = frame_names.index(frame_name) + + return frame_names, frame_slices @dataclass(unsafe_hash=True) -class _MultiFramesRotationMatrix(AbstractQuantity[np.ndarray]): +class _BatchedMultiFrameRotationMatrix( + AbstractQuantity[Dict[Union[str, Tuple[str, ...]], np.ndarray]]): """3D rotation matrix of the orientation of all frames involved in quantities relying on it and are active since last reset of computation tracking if shared cache is available, its parent otherwise. @@ -71,15 +183,6 @@ class _MultiFramesRotationMatrix(AbstractQuantity[np.ndarray]): no way to get the orientation of multiple frames at once for now. """ - identifier: int - """Uniquely identify its parent type. - - This implies that quantities specifying `_MultiFramesRotationMatrix` as a - requirement will shared a unique batch with all the other ones having the - same type but not the others. This is essential to provide data access as a - batched ND contiguous array. - """ - def __init__(self, env: InterfaceJiminyEnv, parent: InterfaceQuantity, @@ -93,15 +196,12 @@ def __init__(self, # Make sure that a parent has been specified assert isinstance(parent, (FrameQuantity, MultiFrameQuantity)) - # Set unique identifier based on parent type - self.identifier = hash(type(parent)) - # Call base implementation super().__init__( env, parent, requirements={}, mode=mode, auto_refresh=False) # Initialize the ordered list of frame names - self.frame_names: Set[str] = set() + self.frame_names: Tuple[str, ...] = () # Store all rotation matrices at once self._rot_mat_batch: np.ndarray = np.array([]) @@ -112,38 +212,19 @@ def __init__(self, # Define proxy for the rotation matrices of all frames self._rot_mat_list: List[np.ndarray] = [] + # Mapping from frame names to slices of batched rotation matrices + self._rot_mat_map: Dict[Union[str, Tuple[str, ...]], np.ndarray] = {} + def initialize(self) -> None: + # Clear all cache owners first since only is tracking frames at once + for quantity in (self.cache.owners if self.has_cache else (self,)): + quantity.reset(reset_tracking=True) + # Call base implementation super().initialize() - # Update the frame names based on the cache owners of this quantity. - # Note that only active quantities are considered for efficiency, which - # may change dynamically. Delegating this responsibility to cache - # owners may be possible but difficult to implement because - # `frame_names` must be cleared first before re-registering themselves, - # just in case of optimal computation graph has changed, not only once - # to avoid getting rid of quantities that just registered themselves. - # Nevertheless, whenever re-initializing this quantity to take into - # account changes of the active set must be decided by cache owners. - assert isinstance(self.parent, (FrameQuantity, MultiFrameQuantity)) - if isinstance(self.parent, FrameQuantity): - self.frame_names = {self.parent.frame_name} - else: - self.frame_names = set(self.parent.frame_names) - if self.has_cache: - for owner in self.cache.owners: - # We only consider active `_MultiFramesEulerAngles` instances - # instead of their parents. This is necessary because a derived - # quantity may feature `_MultiFramesEulerAngles` as requirement - # without actually relying on it depending on whether it is - # part of the optimal computation path at that time. - if owner.is_active(any_cache_owner=False): - assert isinstance( - owner.parent, (FrameQuantity, MultiFrameQuantity)) - if isinstance(owner.parent, FrameQuantity): - self.frame_names.add(owner.parent.frame_name) - else: - self.frame_names.union(owner.parent.frame_names) + # Update the frame names based on the cache owners of this quantity + self.frame_names, frame_slices = aggregate_frame_names(self) # Re-allocate memory as the number of frames is not known in advance. # Note that Fortran memory layout (column-major) is used for speed up @@ -160,22 +241,28 @@ def initialize(self) -> None: self._rot_mat_slices.append(self._rot_mat_batch[..., i]) self._rot_mat_list.append(rot_matrix) - def refresh(self) -> np.ndarray: + # Re-assign mapping from frame names to their corresponding data + self._rot_mat_map = { + key: self._rot_mat_batch[:, :, frame_slice] + for key, frame_slice in frame_slices.items()} + + def refresh(self) -> Dict[Union[str, Tuple[str, ...]], np.ndarray]: # Copy all rotation matrices in contiguous buffer multi_array_copyto(self._rot_mat_slices, self._rot_mat_list) # Return proxy directly without copy - return self._rot_mat_batch + return self._rot_mat_map @dataclass(unsafe_hash=True) -class _MultiFramesEulerAngles(InterfaceQuantity[Dict[str, np.ndarray]]): +class _BatchedMultiFrameEulerAngles( + InterfaceQuantity[Dict[Union[str, Tuple[str, ...]], np.ndarray]]): """Euler angles (Roll-Pitch-Yaw) of the orientation of all frames involved in quantities relying on it and are active since last reset of computation tracking if shared cache is available, its parent otherwise. It is not supposed to be instantiated manually but use internally by - `FrameEulerAngles`. See `_MultiFramesRotationMatrix` documentation. + `FrameEulerAngles`. See `_BatchedMultiFrameRotationMatrix` documentation. The orientation of all frames is exposed to the user as a dictionary whose keys are the individual frame names. Internally, data are stored in batched @@ -199,7 +286,7 @@ class _MultiFramesEulerAngles(InterfaceQuantity[Dict[str, np.ndarray]]): def __init__(self, env: InterfaceJiminyEnv, - parent: "FrameEulerAngles", + parent: Union["FrameEulerAngles", "MultiFrameEulerAngles"], mode: QuantityEvalMode) -> None: """ :param env: Base or wrapped jiminy environment. @@ -208,7 +295,7 @@ def __init__(self, :param mode: Desired mode of evaluation for this quantity. """ # Make sure that a suitable parent has been provided - assert isinstance(parent, FrameEulerAngles) + assert isinstance(parent, (FrameEulerAngles, MultiFrameEulerAngles)) # Backup some user argument(s) self.mode = mode @@ -216,14 +303,14 @@ def __init__(self, # Initialize the ordered list of frame names. # Note that this must be done BEFORE calling base `__init__`, otherwise # `isinstance(..., (FrameQuantity, MultiFrameQuantity))` will fail. - self.frame_names: Set[str] = set() + self.frame_names: Tuple[str, ...] = () # Call base implementation super().__init__( env, parent, requirements=dict( - rot_mat_batch=(_MultiFramesRotationMatrix, dict( + rot_mat_batch=(_BatchedMultiFrameRotationMatrix, dict( mode=mode))), auto_refresh=False) @@ -231,31 +318,34 @@ def __init__(self, self._rpy_batch: np.ndarray = np.array([]) # Mapping from frame name to individual Roll-Pitch-Yaw slices - self._rpy_map: Dict[str, np.ndarray] = {} + self._rpy_map: Dict[Union[str, Tuple[str, ...]], np.ndarray] = {} def initialize(self) -> None: + # Clear all cache owners first since only is tracking frames at once + for quantity in (self.cache.owners if self.has_cache else (self,)): + quantity.reset(reset_tracking=True) + # Call base implementation super().initialize() # Update the frame names based on the cache owners of this quantity - assert isinstance(self.parent, FrameEulerAngles) - self.frame_names = {self.parent.frame_name} - if self.has_cache: - for owner in self.cache.owners: - if owner.is_active(any_cache_owner=False): - assert isinstance(owner.parent, FrameEulerAngles) - self.frame_names.add(owner.parent.frame_name) + self.frame_names, frame_slices = aggregate_frame_names(self) # Re-allocate memory as the number of frames is not known in advance nframes = len(self.frame_names) self._rpy_batch = np.zeros((3, nframes), order='F') # Re-assign mapping from frame name to their corresponding data - self._rpy_map = dict(zip(self.frame_names, self._rpy_batch.T)) + self._rpy_map = { + key: self._rpy_batch[:, frame_slice] + for key, frame_slice in frame_slices.items()} + + def refresh(self) -> Dict[Union[str, Tuple[str, ...]], np.ndarray]: + # Get batch of rotation matrices + rot_mat_batch = self.rot_mat_batch[self.frame_names] - def refresh(self) -> Dict[str, np.ndarray]: # Convert all rotation matrices at once to Roll-Pitch-Yaw - matrix_to_rpy(self.rot_mat_batch, self._rpy_batch) + matrix_to_rpy(rot_mat_batch, self._rpy_batch) # Return proxy directly without copy return self._rpy_map @@ -301,7 +391,71 @@ def __init__(self, super().__init__( env, parent, - requirements=dict(data=(_MultiFramesEulerAngles, dict(mode=mode))), + requirements=dict( + data=(_BatchedMultiFrameEulerAngles, dict(mode=mode))), + auto_refresh=False) + + def initialize(self) -> None: + # Check if the quantity is already active + was_active = self._is_active + + # Call base implementation. + # The quantity is now considered active at this point. + super().initialize() + + # Force re-initializing shared data if the active set has changed + if not was_active: + # Must reset the tracking for shared computation systematically, + # just in case the optimal computation path has changed to the + # point that relying on batched quantity is no longer relevant. + self.requirements["data"].reset(reset_tracking=True) + + def refresh(self) -> np.ndarray: + return self.data[self.frame_name] + + +@dataclass(unsafe_hash=True) +class MultiFrameEulerAngles(InterfaceQuantity[np.ndarray]): + """Euler angles (Roll-Pitch-Yaw) of the orientation of a given set of + frames in world reference frame at the end of the agent step. + """ + + frame_names: Tuple[str, ...] + """Name of the frames on which to operate. + """ + + mode: QuantityEvalMode + """Specify on which state to evaluate this quantity. See `Mode` + documentation for details about each mode. + + .. warning:: + Mode `REFERENCE` requires a reference trajectory to be selected + manually prior to evaluating this quantity for the first time. + """ + + def __init__(self, + env: InterfaceJiminyEnv, + parent: Optional[InterfaceQuantity], + frame_names: Sequence[str], + *, + mode: QuantityEvalMode = QuantityEvalMode.TRUE) -> None: + """ + :param env: Base or wrapped jiminy environment. + :param parent: Higher-level quantity from which this quantity is a + requirement if any, `None` otherwise. + :param frame_names: Name of the frames on which to operate. + :param mode: Desired mode of evaluation for this quantity. + """ + # Backup some user argument(s) + self.frame_names = tuple(frame_names) + self.mode = mode + + # Call base implementation + super().__init__( + env, + parent, + requirements=dict( + data=(_BatchedMultiFrameEulerAngles, dict(mode=mode))), auto_refresh=False) def initialize(self) -> None: @@ -321,31 +475,33 @@ def initialize(self) -> None: def refresh(self) -> np.ndarray: # Return a slice of batched data. - # Note that mapping from frame name to frame index in batched data + # Note that mapping from frame names to frame index in batched data # cannot be pre-computed as it may changed dynamically. - return self.data[self.frame_name] + return self.data[self.frame_names] @dataclass(unsafe_hash=True) -class _MultiFramesXYZQuat(AbstractQuantity[Dict[str, np.ndarray]]): +class _BatchedMultiFrameXYZQuat( + AbstractQuantity[Dict[Union[str, Tuple[str, ...]], np.ndarray]]): """Vector representation (X, Y, Z, QuatX, QuatY, QuatZ, QuatW) of the transform of all frames involved in quantities relying on it and are active since last reset of computation tracking if shared cache is available, its parent otherwise. It is not supposed to be instantiated manually but use internally by - `FrameXYZQuat`. See `_MultiFramesRotationMatrix` documentation. + `FrameXYZQuat`. See `_BatchedMultiFrameRotationMatrix` documentation. The transform of all frames is exposed to the user as a dictionary whose - keys are the individual frame names. Internally, data are stored in batched - 2D contiguous array for efficiency. The first dimension gathers the 6 - components (X, Y, Z, QuatX, QuatY, QuatZ, QuatW), while the second one are - individual frames with the same ordering as 'self.frame_names'. + keys are the individual frame names and/or set of frame names as a tuple. + Internally, data are stored in batched 2D contiguous array for efficiency. + The first dimension gathers the 6 components (X, Y, Z, QuatX, QuatY, QuatZ, + QuatW), while the second one are individual frames with the same ordering + as 'self.frame_names'. """ def __init__(self, env: InterfaceJiminyEnv, - parent: "FrameXYZQuat", + parent: Union["FrameXYZQuat", "MultiFrameXYZQuat"], mode: QuantityEvalMode) -> None: """ :param env: Base or wrapped jiminy environment. @@ -354,17 +510,17 @@ def __init__(self, :param mode: Desired mode of evaluation for this quantity. """ # Make sure that a suitable parent has been provided - assert isinstance(parent, FrameXYZQuat) + assert isinstance(parent, (FrameXYZQuat, MultiFrameXYZQuat)) # Initialize the ordered list of frame names - self.frame_names: Set[str] = set() + self.frame_names: Tuple[str, ...] = () # Call base implementation super().__init__( env, parent, requirements=dict( - rot_mat_batch=(_MultiFramesRotationMatrix, dict( + rot_mat_batch=(_BatchedMultiFrameRotationMatrix, dict( mode=mode))), mode=mode, auto_refresh=False) @@ -379,20 +535,18 @@ def __init__(self, self._xyzquat_batch: np.ndarray = np.array([]) # Mapping from frame name to individual XYZQuat slices - self._xyzquat_map: Dict[str, np.ndarray] = {} + self._xyzquat_map: Dict[Union[str, Tuple[str, ...]], np.ndarray] = {} def initialize(self) -> None: + # Clear all cache owners first since only is tracking frames at once + for quantity in (self.cache.owners if self.has_cache else (self,)): + quantity.reset(reset_tracking=True) + # Call base implementation super().initialize() # Update the frame names based on the cache owners of this quantity - assert isinstance(self.parent, FrameXYZQuat) - self.frame_names = {self.parent.frame_name} - if self.has_cache: - for owner in self.cache.owners: - if owner.is_active(any_cache_owner=False): - assert isinstance(owner.parent, FrameXYZQuat) - self.frame_names.add(owner.parent.frame_name) + self.frame_names, frame_slices = aggregate_frame_names(self) # Re-allocate memory as the number of frames is not known in advance nframes = len(self.frame_names) @@ -407,15 +561,20 @@ def initialize(self) -> None: self._translation_slices.append(self._xyzquat_batch[:3, i]) self._translation_list.append(translation) - # Re-assign mapping from frame name to their corresponding data - self._xyzquat_map = dict(zip(self.frame_names, self._xyzquat_batch.T)) + # Re-assign mapping from frame names to their corresponding data + self._xyzquat_map = { + key: self._xyzquat_batch[:, frame_slice] + for key, frame_slice in frame_slices.items()} - def refresh(self) -> Dict[str, np.ndarray]: + def refresh(self) -> Dict[Union[str, Tuple[str, ...]], np.ndarray]: # Copy all translations in contiguous buffer multi_array_copyto(self._translation_slices, self._translation_list) + # Get batch of rotation matrices + rot_mat_batch = self.rot_mat_batch[self.frame_names] + # Convert all rotation matrices at once to XYZQuat representation - matrix_to_quat(self.rot_mat_batch, self._xyzquat_batch[-4:]) + matrix_to_quat(rot_mat_batch, self._xyzquat_batch[-4:]) # Return proxy directly without copy return self._xyzquat_map @@ -462,7 +621,8 @@ def __init__(self, super().__init__( env, parent, - requirements=dict(data=(_MultiFramesXYZQuat, dict(mode=mode))), + requirements=dict( + data=(_BatchedMultiFrameXYZQuat, dict(mode=mode))), auto_refresh=False) def initialize(self) -> None: @@ -477,10 +637,163 @@ def initialize(self) -> None: self.requirements["data"].reset(reset_tracking=True) def refresh(self) -> np.ndarray: - # Return a slice of batched data return self.data[self.frame_name] +@dataclass(unsafe_hash=True) +class MultiFrameXYZQuat(InterfaceQuantity[np.ndarray]): + """Vector representation (X, Y, Z, QuatX, QuatY, QuatZ, QuatW) of the + transform of a given set of frames in world reference frame at the end of + the agent step. + """ + + frame_names: Tuple[str, ...] + """Name of the frames on which to operate. + """ + + mode: QuantityEvalMode + """Specify on which state to evaluate this quantity. See `Mode` + documentation for details about each mode. + + .. warning:: + Mode `REFERENCE` requires a reference trajectory to be selected + manually prior to evaluating this quantity for the first time. + """ + + def __init__(self, + env: InterfaceJiminyEnv, + parent: Optional[InterfaceQuantity], + frame_names: Sequence[str], + *, + mode: QuantityEvalMode = QuantityEvalMode.TRUE) -> None: + """ + :param env: Base or wrapped jiminy environment. + :param parent: Higher-level quantity from which this quantity is a + requirement if any, `None` otherwise. + :param frame_name: Name of the frames on which to operate. + :param mode: Desired mode of evaluation for this quantity. + """ + # Backup some user argument(s) + self.frame_names = tuple(frame_names) + self.mode = mode + + # Call base implementation + super().__init__( + env, + parent, + requirements=dict( + data=(_BatchedMultiFrameXYZQuat, dict(mode=mode))), + auto_refresh=False) + + def initialize(self) -> None: + # Check if the quantity is already active + was_active = self._is_active + + # Call base implementation + super().initialize() + + # Force re-initializing shared data if the active set has changed + if not was_active: + self.requirements["data"].reset(reset_tracking=True) + + def refresh(self) -> np.ndarray: + return self.data[self.frame_names] + + +@dataclass(unsafe_hash=True) +class MultiFrameMeanXYZQuat(InterfaceQuantity[np.ndarray]): + """Vector representation (X, Y, Z, QuatX, QuatY, QuatZ, QuatW) of the + average transform of a given set of frames in world reference frame at the + end of the agent step. + + The average position (X, Y, Z) and orientation as a quaternion vector + (QuatX, QuatY, QuatZ, QuatW) are computed separately. The average is + defined as the value minimizing the mean error wrt every individual + elements, considering some distance metric. See `quaternion_average` for + details about the distance metric being used. + """ + + frame_names: Tuple[str, ...] + """Name of the frames on which to operate. + """ + + mode: QuantityEvalMode + """Specify on which state to evaluate this quantity. See `Mode` + documentation for details about each mode. + + .. warning:: + Mode `REFERENCE` requires a reference trajectory to be selected + manually prior to evaluating this quantity for the first time. + """ + + def __init__(self, + env: InterfaceJiminyEnv, + parent: Optional[InterfaceQuantity], + frame_names: Sequence[str], + *, + mode: QuantityEvalMode = QuantityEvalMode.TRUE) -> None: + """ + :param env: Base or wrapped jiminy environment. + :param parent: Higher-level quantity from which this quantity is a + requirement if any, `None` otherwise. + :param frame_name: Name of the frames on which to operate. + :param mode: Desired mode of evaluation for this quantity. + """ + # Backup some user argument(s) + self.frame_names = tuple(frame_names) + self.mode = mode + + # Call base implementation + super().__init__( + env, + parent, + requirements=dict(data=(MultiFrameXYZQuat, dict( + frame_names=frame_names, + mode=mode))), + auto_refresh=False) + + # Define jit-able specialization of `quat_average` for 2D matrices + @nb.jit(nopython=True, cache=True, fastmath=True) + def quat_average_2d(quat: np.ndarray, + out: np.ndarray) -> np.ndarray: + """Compute the average of a batch of quaternions [qx, qy, qz, qw]. + + .. note:: + Jit-able specialization of `quat_average` for 2D matrices, with + further optimization for the special case where there is only 2 + quaternions. + + :param quat: N-dimensional (N >= 2) array whose first dimension + gathers the 4 quaternion coordinates [qx, qy, qz, qw]. + :param out: Pre-allocated array into which the result is stored. + """ + if quat.shape[1] == 2: + return quat_interpolate_middle(quat[:, 0], quat[:, 1], out) + + quat = np.ascontiguousarray(quat) + _, eigvec = np.linalg.eigh(quat @ quat.T) + out[:] = eigvec[..., -1] + return out + + self._quat_average = quat_average_2d + + # Pre-allocate memory for the mean for mean pose vector XYZQuat + self._xyzquat_mean = np.zeros((7,)) + + # Define position and orientation proxies for fast access + self._xyz_view = self._xyzquat_mean[:3] + self._quat_view = self._xyzquat_mean[3:] + + def refresh(self) -> np.ndarray: + # Compute the mean translation + np.mean(self.data[:3], axis=-1, out=self._xyz_view) + + # Compute the mean quaternion + self._quat_average(self.data[3:], self._quat_view) + + return self._xyzquat_mean + + @dataclass(unsafe_hash=True) class AverageFrameSpatialVelocity(InterfaceQuantity[np.ndarray]): """Average spatial velocity of a given frame at the end of the agent step. diff --git a/python/gym_jiminy/common/gym_jiminy/common/quantities/locomotion.py b/python/gym_jiminy/common/gym_jiminy/common/quantities/locomotion.py index 3ce30b29d..9485aa611 100644 --- a/python/gym_jiminy/common/gym_jiminy/common/quantities/locomotion.py +++ b/python/gym_jiminy/common/gym_jiminy/common/quantities/locomotion.py @@ -1,28 +1,31 @@ """Quantities mainly relevant for locomotion tasks on floating-base robots. """ import math -from typing import Optional, Tuple +from typing import Optional, Tuple, Sequence, Literal, Union from dataclasses import dataclass import numpy as np +import jiminy_py.core as jiminy from jiminy_py.core import array_copyto # pylint: disable=no-name-in-module from jiminy_py.dynamics import update_quantities import pinocchio as pin from ..bases import ( InterfaceJiminyEnv, InterfaceQuantity, AbstractQuantity, QuantityEvalMode) -from ..utils import fill, matrix_to_yaw +from ..utils import fill, matrix_to_yaw, quat_to_yaw -from ..quantities import MaskedQuantity, AverageFrameSpatialVelocity +from ..quantities import ( + MaskedQuantity, AverageFrameSpatialVelocity, MultiFrameMeanXYZQuat) @dataclass(unsafe_hash=True) -class OdometryPose(AbstractQuantity[np.ndarray]): - """Odometry pose agent step. +class BaseOdometryPose(AbstractQuantity[np.ndarray]): + """Odometry pose of the floating base of the robot at the end of the agent + step. - The odometry pose fully specifies the position and orientation of the robot - in 2D world plane. As such, it comprises the linear translation (X, Y) and + The odometry pose fully specifies the position and heading of the robot in + 2D world plane. As such, it comprises the linear translation (X, Y) and the rotation around Z axis (namely rate of change of Yaw Euler angle). """ @@ -75,13 +78,121 @@ def refresh(self) -> np.ndarray: return self.data +@dataclass(unsafe_hash=True) +class FootOdometryPose(InterfaceQuantity[np.ndarray]): + """Odometry pose of the average position and orientation of the feet of a + legged robot at the end of the agent step. + + Using the average foot pose to characterize the position and heading of + the robot in world plane may be more appropriate than using the floating + pose, especially when it comes to assessing the tracking error of the foot + trajectories. It has the advantage to make foot tracking independent from + floating base tracking, giving the opportunity to the robot to locally + recover stability by moving its upper body without impeding foot tracking. + + The odometry pose fully specifies the position and orientation of the robot + in 2D world plane. See `BaseOdometryPose` documentation for details. + """ + + frame_names: Tuple[str, ...] + """Name of the frames corresponding to the feet of the robot. + + These frames must be part of the end-effectors, ie being associated with a + leaf joint in the kinematic tree of the robot. + """ + + mode: QuantityEvalMode + """Specify on which state to evaluate this quantity. See `Mode` + documentation for details about each mode. + + .. warning:: + Mode `REFERENCE` requires a reference trajectory to be selected + manually prior to evaluating this quantity for the first time. + """ + + def __init__(self, + env: InterfaceJiminyEnv, + parent: Optional[InterfaceQuantity], + frame_names: Union[Sequence[str], Literal['auto']] = 'auto', + *, + mode: QuantityEvalMode = QuantityEvalMode.TRUE) -> None: + """ + :param env: Base or wrapped jiminy environment. + :param parent: Higher-level quantity from which this quantity is a + requirement if any, `None` otherwise. + :param frame_names: Name of the frames corresponding to the feet of the + robot. 'auto' to automatically detect them from the + set of contact and force sensors of the robot. + :param mode: Desired mode of evaluation for this quantity. + """ + # Determine the leaf joints of the kinematic tree + pinocchio_model = env.robot.pinocchio_model_th + parents = pinocchio_model.parents + leaf_joint_indices = set(range(len(parents))) - set(parents) + leaf_frame_names = tuple( + frame.name for frame in pinocchio_model.frames + if frame.parent in leaf_joint_indices) + + if frame_names == 'auto': + # Determine the most likely set of frames corresponding to the feet + foot_frame_names = set() + for sensor_class in (jiminy.ContactSensor, jiminy.ForceSensor): + for sensor in env.robot.sensors.get(sensor_class.type, ()): + assert isinstance(sensor, (( + jiminy.ContactSensor, jiminy.ForceSensor))) + # Skip sensors not attached to a leaf joint + if sensor.frame_name in leaf_frame_names: + # The joint name is used as frame name. This avoids + # considering multiple fixed frames wrt to the same + # joint. They would be completely redundant, slowing + # down computations for no reason. + frame = pinocchio_model.frames[sensor.frame_index] + joint_name = pinocchio_model.names[frame.parent] + foot_frame_names.add(joint_name) + frame_names = tuple(foot_frame_names) + + # Make sure that the frame names are end-effectors + if any(name not in leaf_frame_names for name in frame_names): + raise ValueError("All frames must correspond to end-effectors.") + + # Backup some user argument(s) + self.frame_names = tuple(frame_names) + self.mode = mode + + # Call base implementation + super().__init__( + env, + parent, + requirements=dict( + data=(MultiFrameMeanXYZQuat, dict( + frame_names=frame_names, + mode=mode))), + auto_refresh=False) + + # Pre-allocate memory for the odometry pose (X, Y, Yaw) + self._odom_pose = np.zeros((3,)) + + # Split odometry pose in translation (X, Y) and yaw angle + self._xy_view = self._odom_pose[:2] + self._yaw_view = self._odom_pose[-1:].reshape(()) + + def refresh(self) -> np.ndarray: + # Copy translation part + array_copyto(self._xy_view, self.data[:2]) + + # Compute Yaw angle + quat_to_yaw(self.data[-4:], self._yaw_view) + + return self._odom_pose + + @dataclass(unsafe_hash=True) class AverageOdometryVelocity(InterfaceQuantity[np.ndarray]): - """Average odometry velocity in local-world-aligned frame at the end of the - agent step. + """Average odometry velocity of the floating base of the robot in + local-world-aligned frame at the end of the agent step. The odometry pose fully specifies the position and orientation of the robot - in 2D world plane. See `OdometryPose` documentation for details. + in 2D world plane. See `BaseOdometryPose` documentation for details. The average spatial velocity is obtained by finite difference. See `AverageFrameSpatialVelocity` documentation for details. @@ -207,8 +318,9 @@ class ZeroMomentPoint(AbstractQuantity[np.ndarray]): """ reference_frame: pin.ReferenceFrame - """Whether the ZMP must be computed in local odometry frame or aligned with - world axes. + """Whether to compute the ZMP in local frame specified by the odometry pose + of floating base of the robot or the frame located on the position of the + floating base with axes kept aligned with world frame. """ def __init__(self, @@ -242,7 +354,7 @@ def __init__(self, com=(CenterOfMass, dict( kinematic_level=pin.POSITION, mode=mode)), - odom_pose=(OdometryPose, dict(mode=mode)) + odom_pose=(BaseOdometryPose, dict(mode=mode)) ), mode=mode, auto_refresh=False) @@ -316,8 +428,9 @@ class CapturePoint(AbstractQuantity[np.ndarray]): """ reference_frame: pin.ReferenceFrame - """Whether the DCM must be computed in local odometry frame or aligned with - world axes. + """Whether to compute the DCM in local frame specified by the odometry pose + of floating base of the robot or the frame located on the position of the + floating base with axes kept aligned with world frame. """ def __init__(self, @@ -354,7 +467,7 @@ def __init__(self, com_velocity=(CenterOfMass, dict( kinematic_level=pin.VELOCITY, mode=mode)), - odom_pose=(OdometryPose, dict(mode=mode)) + odom_pose=(BaseOdometryPose, dict(mode=mode)) ), mode=mode, auto_refresh=False) diff --git a/python/gym_jiminy/common/gym_jiminy/common/quantities/transform.py b/python/gym_jiminy/common/gym_jiminy/common/quantities/transform.py index a11226e16..0ccefb5ab 100644 --- a/python/gym_jiminy/common/gym_jiminy/common/quantities/transform.py +++ b/python/gym_jiminy/common/gym_jiminy/common/quantities/transform.py @@ -181,7 +181,7 @@ def __init__(self, slice_ = slice(self.indices[0], self.indices[-1] + 1, stride) if axis is None: self._slices = (slice_,) - elif axis > 0: + elif axis >= 0: self._slices = (*((slice(None),) * axis), slice_) else: self._slices = ( diff --git a/python/gym_jiminy/common/gym_jiminy/common/utils/math.py b/python/gym_jiminy/common/gym_jiminy/common/utils/math.py index 40d9a803c..afe3b2e33 100644 --- a/python/gym_jiminy/common/gym_jiminy/common/utils/math.py +++ b/python/gym_jiminy/common/gym_jiminy/common/utils/math.py @@ -25,8 +25,7 @@ def squared_norm_2(array: np.ndarray) -> float: @nb.jit(nopython=True, cache=True) def matrix_to_yaw(mat: np.ndarray, - out: Optional[np.ndarray] = None - ) -> Union[float, np.ndarray]: + out: Optional[np.ndarray] = None) -> np.ndarray: """Compute the yaw from Yaw-Pitch-Roll Euler angles representation of a rotation matrix in 3D Euclidean space. @@ -37,12 +36,13 @@ def matrix_to_yaw(mat: np.ndarray, # Allocate memory for the output array if out is None: - out_ = np.atleast_1d(np.empty(mat.shape[2:])) + out_ = np.empty(mat.shape[2:]) else: assert out.shape == mat.shape[2:] - out_ = np.atleast_1d(out) + out_ = out - out_[:] = np.arctan2(mat[1, 0], mat[0, 0]) + out__ = np.atleast_1d(out_) + out__[:] = np.arctan2(mat[1, 0], mat[0, 0]) return out_ @@ -62,16 +62,30 @@ def quat_to_yaw_cos_sin(quat: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: @nb.jit(nopython=True, cache=True) -def quat_to_yaw(quat: np.ndarray) -> Union[float, np.ndarray]: +def quat_to_yaw(quat: np.ndarray, + out: Optional[np.ndarray] = None) -> np.ndarray: """Compute the yaw from Yaw-Pitch-Roll Euler angles representation of a single or a batch of quaternions. :param quat: N-dimensional array whose first dimension gathers the 4 quaternion coordinates [qx, qy, qz, qw]. + :param out: A pre-allocated array into which the result is stored. If not + provided, a new array is freshly-allocated, which is slower. """ assert quat.ndim >= 1 + + # Allocate memory for the output array + if out is None: + out_ = np.empty(quat.shape[1:]) + else: + assert out.shape == quat.shape[1:] + out_ = out + cos_yaw, sin_yaw = quat_to_yaw_cos_sin(quat) - return np.arctan2(sin_yaw, cos_yaw) + out__ = np.atleast_1d(out_) + out__[:] = np.arctan2(sin_yaw, cos_yaw) + + return out_ @nb.jit(nopython=True, cache=True) @@ -491,7 +505,7 @@ def remove_twist_from_quat(q: np.ndarray) -> None: def quat_average(quat: np.ndarray, - axis: Optional[Union[Tuple[int, ...], int]] = None + axes: Optional[Union[Tuple[int, ...], int]] = None ) -> np.ndarray: """Compute the average of a batch of quaternions [qx, qy, qz, qw] over some or all axes. @@ -505,21 +519,20 @@ def quat_average(quat: np.ndarray, :param quat: N-dimensional (N >= 2) array whose first dimension gathers the 4 quaternion coordinates [qx, qy, qz, qw]. - :param out: A pre-allocated array into which the result is stored. If not - provided, a new array is freshly-allocated, which is slower. + :param axes: Batch dimensions to preserve without computing the average. """ # TODO: This function cannot be jitted because numba does not support # batched matrix multiplication for now. See official issue for details: # https://github.com/numba/numba/issues/3804 assert quat.ndim >= 2 - if axis is None: - axis = tuple(range(1, quat.ndim)) - if isinstance(axis, int): - axis = (axis,) - assert len(axis) > 0 and 0 not in axis + if axes is None: + axes = tuple(range(1, quat.ndim)) + elif isinstance(axes, int): + axes = (axes,) + assert len(axes) > 0 and 0 not in axes q_perm = quat.transpose(( - *(i for i in range(1, quat.ndim) if i not in axis), 0, *axis)) - q_flat = q_perm.reshape((*q_perm.shape[:-len(axis)], -1)) + *(i for i in range(1, quat.ndim) if i not in axes), 0, *axes)) + q_flat = q_perm.reshape((*q_perm.shape[:-len(axes)], -1)) _, eigvec = np.linalg.eigh(q_flat @ np.swapaxes(q_flat, -1, -2)) return np.moveaxis(eigvec[..., -1], -1, 0) diff --git a/python/gym_jiminy/unit_py/test_quantities.py b/python/gym_jiminy/unit_py/test_quantities.py index 88d85cff3..8873addc5 100644 --- a/python/gym_jiminy/unit_py/test_quantities.py +++ b/python/gym_jiminy/unit_py/test_quantities.py @@ -10,12 +10,17 @@ from jiminy_py.log import extract_trajectory_from_log import pinocchio as pin +from gym_jiminy.common.utils import ( + matrix_to_quat, quat_average, quat_to_matrix, quat_to_yaw) from gym_jiminy.common.bases import QuantityEvalMode, DatasetTrajectoryQuantity from gym_jiminy.common.quantities import ( QuantityManager, FrameEulerAngles, + MultiFrameEulerAngles, FrameXYZQuat, + MultiFrameMeanXYZQuat, MaskedQuantity, + FootOdometryPose, AverageFrameSpatialVelocity, AverageOdometryVelocity, ActuatedJointPositions, @@ -73,26 +78,46 @@ def test_dynamic_batching(self): env.reset(seed=0) env.step(env.action_space.sample()) + frame_names = [ + frame.name for frame in env.robot.pinocchio_model.frames] + quantity_manager = QuantityManager(env) for name, cls, kwargs in ( ("xyzquat_0", FrameXYZQuat, dict( - frame_name=env.robot.pinocchio_model.frames[2].name)), + frame_name=frame_names[2])), ("rpy_0", FrameEulerAngles, dict( - frame_name=env.robot.pinocchio_model.frames[1].name)), + frame_name=frame_names[1])), ("rpy_1", FrameEulerAngles, dict( - frame_name=env.robot.pinocchio_model.frames[1].name)), + frame_name=frame_names[1])), ("rpy_2", FrameEulerAngles, dict( - frame_name=env.robot.pinocchio_model.frames[-1].name))): + frame_name=frame_names[-1])), + ("rpy_batch_0", MultiFrameEulerAngles, dict( # Intersection + frame_names=(frame_names[-3], frame_names[1]))), + ("rpy_batch_1", MultiFrameEulerAngles, dict( # Inclusion + frame_names=(frame_names[1], frame_names[-1]))), + ("rpy_batch_2", MultiFrameEulerAngles, dict( # Disjoint + frame_names=(frame_names[1], frame_names[-4])))): quantity_manager[name] = (cls, kwargs) quantities = quantity_manager.registry - xyzquat_0 = quantity_manager.xyzquat_0.copy() + xyzquat_0 = quantity_manager.xyzquat_0.copy() rpy_0 = quantity_manager.rpy_0.copy() assert len(quantities['rpy_0'].requirements['data'].frame_names) == 1 assert np.all(rpy_0 == quantity_manager.rpy_1) rpy_2 = quantity_manager.rpy_2.copy() assert np.any(rpy_0 != rpy_2) assert len(quantities['rpy_2'].requirements['data'].frame_names) == 2 + quantity_manager.rpy_batch_0 + assert len(quantities['rpy_batch_0'].requirements['data']. + frame_names) == 3 + quantity_manager.rpy_batch_1 + assert len(quantities['rpy_batch_1'].requirements['data']. + frame_names) == 3 + quantity_manager.rpy_batch_2 + assert len(quantities['rpy_batch_2'].requirements['data']. + frame_names) == 5 + assert len(quantities['rpy_batch_2'].requirements['data']. + requirements['rot_mat_batch'].frame_names) == 6 env.step(env.action_space.sample()) quantity_manager.reset() @@ -102,9 +127,9 @@ def test_dynamic_batching(self): assert np.any(xyzquat_0 != xyzquat_0_next) assert len(quantities['rpy_2'].requirements['data'].frame_names) == 2 - assert len(quantities['rpy_1'].requirements['data'].cache.owners) == 3 + assert len(quantities['rpy_1'].requirements['data'].cache.owners) == 6 del quantity_manager['rpy_2'] - assert len(quantities['rpy_1'].requirements['data'].cache.owners) == 2 + assert len(quantities['rpy_1'].requirements['data'].cache.owners) == 5 quantity_manager.rpy_1 assert len(quantities['rpy_1'].requirements['data'].frame_names) == 1 @@ -263,7 +288,7 @@ def test_average_odometry_velocity(self): base_velocity_mean_local = base_pose_diff / env.step_dt base_pose_mean = pin.LieGroup.integrate( se3, base_pose_prev, 0.5 * base_pose_diff) - rot_mat = pin.Quaternion(base_pose_mean[-4:]).matrix() + rot_mat = quat_to_matrix(base_pose_mean[-4:]) base_velocity_mean_world = np.concatenate(( rot_mat @ base_velocity_mean_local[:3], rot_mat @ base_velocity_mean_local[3:])) @@ -329,3 +354,59 @@ def test_capture_point(self): np.testing.assert_allclose( env.quantities["dcm"], com_position[:2] + com_velocity[:2] / omega) + + def test_mean_pose(self): + """ TODO: Write documentation + """ + env = gym.make("gym_jiminy.envs:atlas") + + frame_names = [ + frame.name for frame in env.robot.pinocchio_model.frames] + + env.quantities["mean_pose"] = ( + MultiFrameMeanXYZQuat, dict( + frame_names=frame_names[:5], + mode=QuantityEvalMode.TRUE)) + + env.reset(seed=0) + env.step(env.action_space.sample()) + + pos = np.mean(np.stack([ + oMf.translation for oMf in env.robot.pinocchio_data.oMf + ][:5], axis=-1), axis=-1) + quat = quat_average(np.stack([ + matrix_to_quat(oMf.rotation) + for oMf in env.robot.pinocchio_data.oMf][:5], axis=-1)) + if quat[-1] < 0.0: + quat *= -1 + + value = env.quantities["mean_pose"] + if value[-1] < 0.0: + value[-4:] *= -1 + + np.testing.assert_allclose(value, np.concatenate((pos, quat))) + + def test_foot_odometry_pose(self): + """ TODO: Write documentation + """ + env = gym.make("gym_jiminy.envs:atlas") + + env.quantities["foot_odom_pose"] = (FootOdometryPose, {}) + + env.reset(seed=0) + env.step(env.action_space.sample()) + + foot_left_index, foot_right_index = ( + env.robot.pinocchio_model.getFrameId(name) + for name in ("l_foot", "r_foot")) + foot_left_pose, foot_right_pose = ( + env.robot.pinocchio_data.oMf[frame_index] + for frame_index in (foot_left_index, foot_right_index)) + + mean_pos = (foot_left_pose.translation[:2] + + foot_right_pose.translation[:2]) / 2.0 + mean_yaw = quat_to_yaw(quat_average(np.stack(tuple(map(matrix_to_quat, + (foot_left_pose.rotation, foot_right_pose.rotation))), axis=-1))) + value = env.quantities["foot_odom_pose"] + + np.testing.assert_allclose(value, np.array((*mean_pos, mean_yaw)))