From 9122ac73eebd968ecdaa2a5a456564011b4224e6 Mon Sep 17 00:00:00 2001 From: Alexis Duburcq Date: Wed, 26 Jun 2024 21:42:45 +0200 Subject: [PATCH] [gym/common] Add base odom pose drift tracking and motor positions shift tracking termination conditions. --- .../common/compositions/__init__.py | 18 ++-- .../gym_jiminy/common/compositions/generic.py | 90 ++++++++++++++++-- .../common/compositions/locomotion.py | 91 ++++++++++++++++++- .../gym_jiminy/common/compositions/mixin.py | 3 +- .../gym_jiminy/common/quantities/__init__.py | 4 +- .../gym_jiminy/common/quantities/generic.py | 4 +- .../common/quantities/locomotion.py | 5 + python/gym_jiminy/unit_py/test_quantities.py | 6 +- python/gym_jiminy/unit_py/test_rewards.py | 11 +-- .../gym_jiminy/unit_py/test_terminations.py | 25 ++++- .../src/jiminy_py/viewer/meshcat/recorder.py | 2 +- 11 files changed, 225 insertions(+), 34 deletions(-) diff --git a/python/gym_jiminy/common/gym_jiminy/common/compositions/__init__.py b/python/gym_jiminy/common/gym_jiminy/common/compositions/__init__.py index d0b765fc4..426bbfa6f 100644 --- a/python/gym_jiminy/common/gym_jiminy/common/compositions/__init__.py +++ b/python/gym_jiminy/common/gym_jiminy/common/compositions/__init__.py @@ -4,19 +4,21 @@ radial_basis_function, AdditiveMixtureReward, MultiplicativeMixtureReward) -from .generic import (TrackingQuantityReward, +from .generic import (SurviveReward, + TrackingQuantityReward, TrackingActuatedJointPositionsReward, - SurviveReward, DriftTrackingQuantityTermination, ShiftTrackingQuantityTermination, MechanicalSafetyTermination, - PowerConsumptionTermination) + MechanicalPowerConsumptionTermination, + ShiftTrackingMotorPositionsTermination) from .locomotion import (TrackingBaseHeightReward, TrackingBaseOdometryVelocityReward, TrackingCapturePointReward, TrackingFootPositionsReward, TrackingFootOrientationsReward, TrackingFootForceDistributionReward, + DriftTrackingBaseOdometryPoseTermination, MinimizeAngularMomentumReward, MinimizeFrictionReward, BaseRollPitchTermination, @@ -30,6 +32,9 @@ "radial_basis_function", "AdditiveMixtureReward", "MultiplicativeMixtureReward", + "SurviveReward", + "MinimizeFrictionReward", + "MinimizeAngularMomentumReward", "TrackingQuantityReward", "TrackingActuatedJointPositionsReward", "TrackingBaseHeightReward", @@ -38,13 +43,12 @@ "TrackingFootPositionsReward", "TrackingFootOrientationsReward", "TrackingFootForceDistributionReward", - "MinimizeAngularMomentumReward", - "MinimizeFrictionReward", - "SurviveReward", "DriftTrackingQuantityTermination", + "DriftTrackingBaseOdometryPoseTermination", "ShiftTrackingQuantityTermination", + "ShiftTrackingMotorPositionsTermination", "MechanicalSafetyTermination", - "PowerConsumptionTermination", + "MechanicalPowerConsumptionTermination", "FlyingTermination", "BaseRollPitchTermination", "BaseHeightTermination", diff --git a/python/gym_jiminy/common/gym_jiminy/common/compositions/generic.py b/python/gym_jiminy/common/gym_jiminy/common/compositions/generic.py index 2c64439ea..623eaa1c6 100644 --- a/python/gym_jiminy/common/gym_jiminy/common/compositions/generic.py +++ b/python/gym_jiminy/common/gym_jiminy/common/compositions/generic.py @@ -18,7 +18,7 @@ from ..bases.compositions import ArrayOrScalar, ArrayLikeOrScalar from ..quantities import ( EnergyGenerationMode, StackedQuantity, UnaryOpQuantity, BinaryOpQuantity, - MultiActuatedJointKinematic, AveragePowerConsumption) + MultiActuatedJointKinematic, AverageMechanicalPowerConsumption) from .mixin import radial_basis_function @@ -170,6 +170,8 @@ def __init__(self, *, op: Callable[ [ArrayOrScalar, ArrayOrScalar], ArrayOrScalar] = sub, + post_fn: Optional[Callable[ + [ArrayOrScalar], ArrayOrScalar]] = None, is_truncation: bool = False, is_training_only: bool = False) -> None: """ @@ -200,6 +202,11 @@ def __init__(self, Group. The basic subtraction operator `operator.sub` is appropriate for Euclidean space. Optional: `operator.sub` by default. + :apram post_fn: Optional callable taking the true and reference drifts + of the quantity as input argument and returning some + post-processed value to which bound checking will be + applied. None to skip post-processing entirely. + Optional: None by default. :param is_truncation: Whether the episode should be considered terminated or truncated whenever the termination condition is triggered. @@ -217,6 +224,7 @@ def __init__(self, # Backup user argument(s) self.max_stack = max_stack self.op = op + self.post_fn = post_fn # Define drift of quantity stack_creator = lambda mode: (StackedQuantity, dict( # noqa: E731 @@ -233,9 +241,9 @@ def __init__(self, # Add drift quantity to the set of quantities managed by environment drift_tracking_quantity = (BinaryOpQuantity, dict( - quantity_left=delta_creator(QuantityEvalMode.TRUE), - quantity_right=delta_creator(QuantityEvalMode.REFERENCE), - op=sub)) + quantity_left=delta_creator(QuantityEvalMode.TRUE), + quantity_right=delta_creator(QuantityEvalMode.REFERENCE), + op=self._compute_drift_error)) # Call base implementation super().__init__(env, @@ -247,6 +255,20 @@ def __init__(self, is_truncation=is_truncation, is_training_only=is_training_only) + def _compute_drift_error(self, + left: np.ndarray, + right: np.ndarray) -> ArrayOrScalar: + """Compute the difference between the true and reference drift over + a given horizon, then apply some post-processing on it if requested. + + :param left: True value of the drift as a N-dimensional array. + :param right: Reference value of the drift as a N-dimensional array. + """ + diff = left - right + if self.post_fn is not None: + return self.post_fn(diff) + return diff + class ShiftTrackingQuantityTermination(QuantityTermination[np.ndarray]): """Base class to derive termination condition from the shift between the @@ -346,6 +368,7 @@ def min_norm(values: np.ndarray) -> float: stack_creator = lambda mode: (StackedQuantity, dict( # noqa: E731 quantity=quantity_creator(mode), max_stack=max_stack, + mode='slice', as_array=True)) # Add drift quantity to the set of quantities managed by environment @@ -534,7 +557,7 @@ def compute(self, info: InfoType) -> bool: return is_done -class PowerConsumptionTermination(QuantityTermination): +class MechanicalPowerConsumptionTermination(QuantityTermination): """Terminate the episode immediately if the average mechanical power consumption is too high. @@ -578,7 +601,7 @@ def __init__( super().__init__( env, "termination_power_consumption", - (AveragePowerConsumption, dict( # type: ignore[arg-type] + (AverageMechanicalPowerConsumption, dict( # type: ignore[arg-type] horizon=self.horizon, generator_mode=self.generator_mode)), None, @@ -586,3 +609,58 @@ def __init__( grace_period, is_truncation=False, is_training_only=is_training_only) + + +class ShiftTrackingMotorPositionsTermination(ShiftTrackingQuantityTermination): + """Terminate the episode if the robot is not tracking the actuated joint + positions wrt some reference trajectory with expected accuracy, whatever + the timestep being considered over some fixed-size sliding time window. + + The robot must track the reference if there is no hazard, only applying + minor corrections to keep balance. Rewarding the agent for doing so is + not effective as favoring robustness remains more profitable. Indeed, it + would anticipate disturbances, lowering its current reward to maximize the + future return, primarily averting termination. Limiting the shift over a + given horizon allows for large deviations to handle strong pushes. + Moreover, assuming that the agent is not able to keep track of the time + flow, which means that only the observation at the current step is provided + to the agent and o stateful network architecture such as LSTM is being + used, restricting the shift also urges to do what it takes to get back to + normal as soon as possible for fear of triggering termination, as it may + happen any time the deviation is above the maximum acceptable shift, + irrespective of its scale. + """ + def __init__(self, + env: InterfaceJiminyEnv, + thr: float, + horizon: float, + grace_period: float = 0.0, + *, + is_training_only: bool = False) -> None: + """ + :param env: Base or wrapped jiminy environment. + :param thr: Maximum shift above which termination is triggered. + :param horizon: Horizon over which values of the quantity will be + stacked before computing the shift. + :param grace_period: Grace period effective only at the very beginning + of the episode, during which the latter is bound + to continue whatever happens. + Optional: 0.0 by default. + :param is_training_only: Whether the termination condition should be + completely by-passed if the environment is in + evaluation mode. + Optional: False by default. + """ + # Call base implementation + super().__init__( + env, + "termination_tracking_motor_positions", + lambda mode: (MultiActuatedJointKinematic, dict( + kinematic_level=pin.KinematicLevel.POSITION, + is_motor_side=False, + mode=mode)), + thr, + horizon, + grace_period, + is_truncation=False, + is_training_only=is_training_only) diff --git a/python/gym_jiminy/common/gym_jiminy/common/compositions/locomotion.py b/python/gym_jiminy/common/gym_jiminy/common/compositions/locomotion.py index bc81ed55f..e40da0967 100644 --- a/python/gym_jiminy/common/gym_jiminy/common/compositions/locomotion.py +++ b/python/gym_jiminy/common/gym_jiminy/common/compositions/locomotion.py @@ -1,5 +1,6 @@ """Rewards mainly relevant for locomotion tasks on floating-base robots. """ +import math from functools import partial from dataclasses import dataclass from typing import Optional, Union, Sequence, Literal, Callable, cast @@ -15,15 +16,16 @@ QuantityReward) from ..quantities import ( OrientationType, MaskedQuantity, UnaryOpQuantity, FrameOrientation, - BaseRelativeHeight, BaseOdometryAverageVelocity, CapturePoint, - MultiFramePosition, MultiFootRelativeXYZQuat, + BaseRelativeHeight, BaseOdometryPose, BaseOdometryAverageVelocity, + CapturePoint, MultiFramePosition, MultiFootRelativeXYZQuat, MultiContactNormalizedSpatialForce, MultiFootNormalizedForceVertical, MultiFootCollisionDetection, AverageBaseMomentum) from ..quantities.locomotion import sanitize_foot_frame_names from ..utils import quat_difference from .generic import ( - ArrayLikeOrScalar, TrackingQuantityReward, QuantityTermination) + ArrayLikeOrScalar, TrackingQuantityReward, QuantityTermination, + DriftTrackingQuantityTermination) from .mixin import radial_basis_function @@ -604,3 +606,86 @@ def __init__(self, grace_period, is_truncation=False, is_training_only=is_training_only) + + +class DriftTrackingBaseOdometryPoseTermination( + DriftTrackingQuantityTermination): + """Terminate the episode if the current base odometry base is drifting too + much over wrt some reference trajectory that is being tracked. + + It is generally important to make sure that the robot is not deviating too + much from some reference trajectory. It sounds appealing to make sure that + the absolute error between the current and reference trajectory is bounded + at all time. However, such a condition is very restrictive, especially for + robots dealing with external disturbances or evolving on an uneven terrain. + Moreover, when it comes to infinite-horizon trajectories in particular, eg + periodic motions, avoiding drifting away over time involves being able to + sense the absolute position of the robot in world frame via exteroceptive + navigation sensors such as depth cameras or LIDARs. This kind of advanced + sensor may not be able, thereby making the objective out of reach. Still, + in the case of legged locomotion, what really matters is tracking + accurately a nominal limit cycle as long as doing so does not compromise + local stability. If it does, then the agent expected to make every effort + to recover balance as fast as possible before going back to the nominal + limit cycle, without trying to catch up with the ensuing drift since the + exact absolute odometry pose in world frame is of little interest. See + `DriftTrackingQuantityTermination` documentation for details. + """ + def __init__(self, + env: InterfaceJiminyEnv, + max_position_err: float, + max_orientation_err: float, + horizon: float, + grace_period: float = 0.0, + *, + is_training_only: bool = False) -> None: + """ + :param env: Base or wrapped jiminy environment. + :param max_position_err: + Maximum drift error in translation (X, Y) in world place above + which termination is triggered. + :param max_orientation_err: + Maximum drift error in orientation (yaw,) in world place above + which termination is triggered. + :param horizon: Horizon over which values of the quantity will be + stacked before computing the drift. + :param grace_period: Grace period effective only at the very beginning + of the episode, during which the latter is bound + to continue whatever happens. + Optional: 0.0 by default. + :param is_training_only: Whether the termination condition should be + completely by-passed if the environment is in + evaluation mode. + Optional: False by default. + """ + # Define jit-able method for computing translation and rotation errors + @nb.jit(nopython=True, cache=True) + def compute_se2_double_geoedesic_distance( + diff: np.ndarray) -> np.ndarray: + """Compute the errors between two odometry poses in the Cartesian + space (R^2, SO(2)), ie considering the translational and rotational + parts independently. + + :param diff: Element-wise difference between two odometry poses as + a 1D array gathering the 3 components (X, Y, Yaw). + + :returns: Pair (data_err_pos, data_err_rot) gathering the L2-norm + of the difference in translation and rotation as a 1D array. + """ + diff_x, diff_y, diff_yaw = diff + error_position = math.sqrt(diff_x ** 2 + diff_y ** 2) + error_orientation = math.fabs(diff_yaw) + return np.array([error_position, error_orientation]) + + # Call base implementation + super().__init__( + env, + "termination_tracking_motor_positions", + lambda mode: (BaseOdometryPose, dict(mode=mode)), + None, + np.array([max_position_err, max_orientation_err]), + horizon, + grace_period, + post_fn=compute_se2_double_geoedesic_distance, + is_truncation=False, + is_training_only=is_training_only) diff --git a/python/gym_jiminy/common/gym_jiminy/common/compositions/mixin.py b/python/gym_jiminy/common/gym_jiminy/common/compositions/mixin.py index 76f181766..91deec684 100644 --- a/python/gym_jiminy/common/gym_jiminy/common/compositions/mixin.py +++ b/python/gym_jiminy/common/gym_jiminy/common/compositions/mixin.py @@ -164,8 +164,7 @@ class MultiplicativeMixtureReward(MixtureReward): def __init__(self, env: InterfaceJiminyEnv, name: str, - components: Sequence[AbstractReward] - ) -> None: + components: Sequence[AbstractReward]) -> None: """ :param env: Base or wrapped jiminy environment. :param name: Desired name of the reward. diff --git a/python/gym_jiminy/common/gym_jiminy/common/quantities/__init__.py b/python/gym_jiminy/common/gym_jiminy/common/quantities/__init__.py index ba2f7c653..f0d961aaa 100644 --- a/python/gym_jiminy/common/gym_jiminy/common/quantities/__init__.py +++ b/python/gym_jiminy/common/gym_jiminy/common/quantities/__init__.py @@ -18,7 +18,7 @@ MultiFrameCollisionDetection, AverageFrameXYZQuat, AverageFrameRollPitch, - AveragePowerConsumption, + AverageMechanicalPowerConsumption, FrameSpatialAverageVelocity, MultiActuatedJointKinematic) from .locomotion import (BaseOdometryPose, @@ -63,7 +63,7 @@ 'MultiContactNormalizedSpatialForce', 'AverageFrameXYZQuat', 'AverageFrameRollPitch', - 'AveragePowerConsumption', + 'AverageMechanicalPowerConsumption', 'FrameSpatialAverageVelocity', 'BaseSpatialAverageVelocity', 'BaseOdometryAverageVelocity', diff --git a/python/gym_jiminy/common/gym_jiminy/common/quantities/generic.py b/python/gym_jiminy/common/gym_jiminy/common/quantities/generic.py index 92dfab7c2..4f6fc1b97 100644 --- a/python/gym_jiminy/common/gym_jiminy/common/quantities/generic.py +++ b/python/gym_jiminy/common/gym_jiminy/common/quantities/generic.py @@ -1696,7 +1696,7 @@ class EnergyGenerationMode(Enum): @dataclass(unsafe_hash=True) -class AveragePowerConsumption(InterfaceQuantity[float]): +class AverageMechanicalPowerConsumption(InterfaceQuantity[float]): """Average mechanical power consumption by all the motors over a sliding time window. """ @@ -1795,7 +1795,7 @@ def _compute_power(generator_mode: EnergyGenerationMode, op=partial(_compute_power, self.generator_mode))), max_stack=self.max_stack, as_array=True, - mode='zeros'))), + mode='slice'))), auto_refresh=False) def refresh(self) -> float: diff --git a/python/gym_jiminy/common/gym_jiminy/common/quantities/locomotion.py b/python/gym_jiminy/common/gym_jiminy/common/quantities/locomotion.py index 4fdb35643..92090a46c 100644 --- a/python/gym_jiminy/common/gym_jiminy/common/quantities/locomotion.py +++ b/python/gym_jiminy/common/gym_jiminy/common/quantities/locomotion.py @@ -167,6 +167,11 @@ class BaseOdometryPose(AbstractQuantity[np.ndarray]): The odometry pose fully specifies the position and heading of the robot in 2D world plane. As such, it comprises the linear translation (X, Y) and the rotation around Z axis (namely rate of change of Yaw Euler angle). + Mathematically, one is supposed to rely on se2 Lie Algebra for performing + operations on odometry poses such as differentiation. In practice, the + double geodesic metric space is used instead to prevent coupling between + the linear and angular parts by considering them independently. Strictly + speaking, it corresponds to the cartesian space (R^2 x SO(2)). """ def __init__(self, diff --git a/python/gym_jiminy/unit_py/test_quantities.py b/python/gym_jiminy/unit_py/test_quantities.py index 01317b75d..b33a0ce33 100644 --- a/python/gym_jiminy/unit_py/test_quantities.py +++ b/python/gym_jiminy/unit_py/test_quantities.py @@ -34,7 +34,7 @@ BaseOdometryAverageVelocity, BaseRelativeHeight, AverageBaseMomentum, - AveragePowerConsumption, + AverageMechanicalPowerConsumption, CenterOfMass, CapturePoint, ZeroMomentPoint) @@ -707,13 +707,13 @@ def test_power_consumption(self): EnergyGenerationMode.LOST_GLOBAL, EnergyGenerationMode.PENALIZE): env.quantities["mean_power_consumption"] = ( - AveragePowerConsumption, dict( + AverageMechanicalPowerConsumption, dict( horizon=0.2, generator_mode=mode)) quantity = env.quantities.registry["mean_power_consumption"] env.reset(seed=0) - total_power_stack = [0.0,] * quantity.max_stack + total_power_stack = [0.0,] encoder_data = env.robot.sensor_measurements["EncoderSensor"] _, motor_velocities = encoder_data for _ in range(8): diff --git a/python/gym_jiminy/unit_py/test_rewards.py b/python/gym_jiminy/unit_py/test_rewards.py index d8a382291..805634dd2 100644 --- a/python/gym_jiminy/unit_py/test_rewards.py +++ b/python/gym_jiminy/unit_py/test_rewards.py @@ -66,8 +66,7 @@ def test_tracking(self): (TrackingFootOrientationsReward, 2.0), (TrackingFootForceDistributionReward, 2.0)): reward = reward_class(self.env, cutoff=cutoff) - quantity_true = reward.quantity.requirements['value_left'] - quantity_ref = reward.quantity.requirements['value_right'] + quantity = reward.quantity self.env.reset(seed=0) action = 0.5 * self.env.action_space.sample() @@ -77,15 +76,15 @@ def test_tracking(self): with np.testing.assert_raises(AssertionError): np.testing.assert_allclose( - quantity_true.get(), quantity_ref.get()) + quantity.value_left, quantity.value_right) if isinstance(reward, TrackingBaseHeightReward): np.testing.assert_allclose( - quantity_true.get(), self.env.robot_state.q[2]) + quantity.value_left, self.env.robot_state.q[2]) gamma = - np.log(CUTOFF_ESP) / cutoff ** 2 - value = np.exp(- gamma * np.sum((reward.quantity.op( - quantity_true.get(), quantity_ref.get())) ** 2)) + value = np.exp(- gamma * np.sum((quantity.op( + quantity.value_left, quantity.value_right)) ** 2)) assert value > 0.01 np.testing.assert_allclose(reward(terminated, {}), value) diff --git a/python/gym_jiminy/unit_py/test_terminations.py b/python/gym_jiminy/unit_py/test_terminations.py index 805ce3b2b..67d4be3db 100644 --- a/python/gym_jiminy/unit_py/test_terminations.py +++ b/python/gym_jiminy/unit_py/test_terminations.py @@ -22,7 +22,9 @@ MechanicalSafetyTermination, FlyingTermination, ImpactForceTermination, - PowerConsumptionTermination) + MechanicalPowerConsumptionTermination, + DriftTrackingBaseOdometryPoseTermination, + ShiftTrackingMotorPositionsTermination) class TerminationConditions(unittest.TestCase): @@ -319,13 +321,32 @@ def test_flying(self): break assert terminated ^ is_valid + def test_drift_tracking_base_odom(self): + MAX_POS_ERROR, MAX_ROT_ERROR = 0.1, 0.2 + termination = DriftTrackingBaseOdometryPoseTermination( + self.env, MAX_POS_ERROR,MAX_ROT_ERROR, 1.0) + quantity = termination.quantity + + self.env.reset(seed=0) + action = self.env.action_space.sample() + for _ in range(20): + _, _, terminated, _, _ = self.env.step(action) + if terminated: + break + terminated, truncated = termination({}) + diff = quantity.value_left - quantity.value_right + is_valid = np.linalg.norm(diff[:2]) <= MAX_POS_ERROR + is_valid &= np.abs(diff[2]) <= MAX_ROT_ERROR + assert terminated ^ is_valid + def test_misc(self): """ TODO: Write documentation """ for termination in ( BaseHeightTermination(self.env, 0.6), ImpactForceTermination(self.env, 1.0), - PowerConsumptionTermination(self.env, 400.0, 1.0),): + MechanicalPowerConsumptionTermination(self.env, 400.0, 1.0), + ShiftTrackingMotorPositionsTermination(self.env, 0.4, 0.5),): self.env.reset(seed=0) self.env.eval() action = self.env.action_space.sample() diff --git a/python/jiminy_py/src/jiminy_py/viewer/meshcat/recorder.py b/python/jiminy_py/src/jiminy_py/viewer/meshcat/recorder.py index 55200cb22..ccca76ad5 100644 --- a/python/jiminy_py/src/jiminy_py/viewer/meshcat/recorder.py +++ b/python/jiminy_py/src/jiminy_py/viewer/meshcat/recorder.py @@ -350,7 +350,7 @@ def start_video_recording(self, """ TODO: Write documentation. """ self._send_request( - "start_record", message=f"{fps}|{width}|{height}", timeout=10.0) + "start_record", message=f"{fps}|{width}|{height}", timeout=15.0) self.is_recording = True def add_video_frame(self) -> None: