Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[gym/common] Add several termination conditions. #817

Merged
merged 1 commit into from
Jun 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -4,19 +4,21 @@
radial_basis_function,
AdditiveMixtureReward,
MultiplicativeMixtureReward)
from .generic import (TrackingQuantityReward,
from .generic import (SurviveReward,
TrackingQuantityReward,
TrackingActuatedJointPositionsReward,
SurviveReward,
DriftTrackingQuantityTermination,
ShiftTrackingQuantityTermination,
MechanicalSafetyTermination,
PowerConsumptionTermination)
MechanicalPowerConsumptionTermination,
ShiftTrackingMotorPositionsTermination)
from .locomotion import (TrackingBaseHeightReward,
TrackingBaseOdometryVelocityReward,
TrackingCapturePointReward,
TrackingFootPositionsReward,
TrackingFootOrientationsReward,
TrackingFootForceDistributionReward,
DriftTrackingBaseOdometryPoseTermination,
MinimizeAngularMomentumReward,
MinimizeFrictionReward,
BaseRollPitchTermination,
Expand All @@ -30,6 +32,9 @@
"radial_basis_function",
"AdditiveMixtureReward",
"MultiplicativeMixtureReward",
"SurviveReward",
"MinimizeFrictionReward",
"MinimizeAngularMomentumReward",
"TrackingQuantityReward",
"TrackingActuatedJointPositionsReward",
"TrackingBaseHeightReward",
Expand All @@ -38,13 +43,12 @@
"TrackingFootPositionsReward",
"TrackingFootOrientationsReward",
"TrackingFootForceDistributionReward",
"MinimizeAngularMomentumReward",
"MinimizeFrictionReward",
"SurviveReward",
"DriftTrackingQuantityTermination",
"DriftTrackingBaseOdometryPoseTermination",
"ShiftTrackingQuantityTermination",
"ShiftTrackingMotorPositionsTermination",
"MechanicalSafetyTermination",
"PowerConsumptionTermination",
"MechanicalPowerConsumptionTermination",
"FlyingTermination",
"BaseRollPitchTermination",
"BaseHeightTermination",
Expand Down
90 changes: 84 additions & 6 deletions python/gym_jiminy/common/gym_jiminy/common/compositions/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from ..bases.compositions import ArrayOrScalar, ArrayLikeOrScalar
from ..quantities import (
EnergyGenerationMode, StackedQuantity, UnaryOpQuantity, BinaryOpQuantity,
MultiActuatedJointKinematic, AveragePowerConsumption)
MultiActuatedJointKinematic, AverageMechanicalPowerConsumption)

from .mixin import radial_basis_function

Expand Down Expand Up @@ -170,6 +170,8 @@ def __init__(self,
*,
op: Callable[
[ArrayOrScalar, ArrayOrScalar], ArrayOrScalar] = sub,
post_fn: Optional[Callable[
[ArrayOrScalar], ArrayOrScalar]] = None,
is_truncation: bool = False,
is_training_only: bool = False) -> None:
"""
Expand Down Expand Up @@ -200,6 +202,11 @@ def __init__(self,
Group. The basic subtraction operator `operator.sub` is
appropriate for Euclidean space.
Optional: `operator.sub` by default.
:apram post_fn: Optional callable taking the true and reference drifts
of the quantity as input argument and returning some
post-processed value to which bound checking will be
applied. None to skip post-processing entirely.
Optional: None by default.
:param is_truncation: Whether the episode should be considered
terminated or truncated whenever the termination
condition is triggered.
Expand All @@ -217,6 +224,7 @@ def __init__(self,
# Backup user argument(s)
self.max_stack = max_stack
self.op = op
self.post_fn = post_fn

# Define drift of quantity
stack_creator = lambda mode: (StackedQuantity, dict( # noqa: E731
Expand All @@ -233,9 +241,9 @@ def __init__(self,

# Add drift quantity to the set of quantities managed by environment
drift_tracking_quantity = (BinaryOpQuantity, dict(
quantity_left=delta_creator(QuantityEvalMode.TRUE),
quantity_right=delta_creator(QuantityEvalMode.REFERENCE),
op=sub))
quantity_left=delta_creator(QuantityEvalMode.TRUE),
quantity_right=delta_creator(QuantityEvalMode.REFERENCE),
op=self._compute_drift_error))

# Call base implementation
super().__init__(env,
Expand All @@ -247,6 +255,20 @@ def __init__(self,
is_truncation=is_truncation,
is_training_only=is_training_only)

def _compute_drift_error(self,
left: np.ndarray,
right: np.ndarray) -> ArrayOrScalar:
"""Compute the difference between the true and reference drift over
a given horizon, then apply some post-processing on it if requested.

:param left: True value of the drift as a N-dimensional array.
:param right: Reference value of the drift as a N-dimensional array.
"""
diff = left - right
if self.post_fn is not None:
return self.post_fn(diff)
return diff


class ShiftTrackingQuantityTermination(QuantityTermination[np.ndarray]):
"""Base class to derive termination condition from the shift between the
Expand Down Expand Up @@ -346,6 +368,7 @@ def min_norm(values: np.ndarray) -> float:
stack_creator = lambda mode: (StackedQuantity, dict( # noqa: E731
quantity=quantity_creator(mode),
max_stack=max_stack,
mode='slice',
as_array=True))

# Add drift quantity to the set of quantities managed by environment
Expand Down Expand Up @@ -534,7 +557,7 @@ def compute(self, info: InfoType) -> bool:
return is_done


class PowerConsumptionTermination(QuantityTermination):
class MechanicalPowerConsumptionTermination(QuantityTermination):
"""Terminate the episode immediately if the average mechanical power
consumption is too high.

Expand Down Expand Up @@ -578,11 +601,66 @@ def __init__(
super().__init__(
env,
"termination_power_consumption",
(AveragePowerConsumption, dict( # type: ignore[arg-type]
(AverageMechanicalPowerConsumption, dict( # type: ignore[arg-type]
horizon=self.horizon,
generator_mode=self.generator_mode)),
None,
self.max_power,
grace_period,
is_truncation=False,
is_training_only=is_training_only)


class ShiftTrackingMotorPositionsTermination(ShiftTrackingQuantityTermination):
"""Terminate the episode if the robot is not tracking the actuated joint
positions wrt some reference trajectory with expected accuracy, whatever
the timestep being considered over some fixed-size sliding time window.

The robot must track the reference if there is no hazard, only applying
minor corrections to keep balance. Rewarding the agent for doing so is
not effective as favoring robustness remains more profitable. Indeed, it
would anticipate disturbances, lowering its current reward to maximize the
future return, primarily averting termination. Limiting the shift over a
given horizon allows for large deviations to handle strong pushes.
Moreover, assuming that the agent is not able to keep track of the time
flow, which means that only the observation at the current step is provided
to the agent and o stateful network architecture such as LSTM is being
used, restricting the shift also urges to do what it takes to get back to
normal as soon as possible for fear of triggering termination, as it may
happen any time the deviation is above the maximum acceptable shift,
irrespective of its scale.
"""
def __init__(self,
env: InterfaceJiminyEnv,
thr: float,
horizon: float,
grace_period: float = 0.0,
*,
is_training_only: bool = False) -> None:
"""
:param env: Base or wrapped jiminy environment.
:param thr: Maximum shift above which termination is triggered.
:param horizon: Horizon over which values of the quantity will be
stacked before computing the shift.
:param grace_period: Grace period effective only at the very beginning
of the episode, during which the latter is bound
to continue whatever happens.
Optional: 0.0 by default.
:param is_training_only: Whether the termination condition should be
completely by-passed if the environment is in
evaluation mode.
Optional: False by default.
"""
# Call base implementation
super().__init__(
env,
"termination_tracking_motor_positions",
lambda mode: (MultiActuatedJointKinematic, dict(
kinematic_level=pin.KinematicLevel.POSITION,
is_motor_side=False,
mode=mode)),
thr,
horizon,
grace_period,
is_truncation=False,
is_training_only=is_training_only)
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Rewards mainly relevant for locomotion tasks on floating-base robots.
"""
import math
from functools import partial
from dataclasses import dataclass
from typing import Optional, Union, Sequence, Literal, Callable, cast
Expand All @@ -15,15 +16,16 @@
QuantityReward)
from ..quantities import (
OrientationType, MaskedQuantity, UnaryOpQuantity, FrameOrientation,
BaseRelativeHeight, BaseOdometryAverageVelocity, CapturePoint,
MultiFramePosition, MultiFootRelativeXYZQuat,
BaseRelativeHeight, BaseOdometryPose, BaseOdometryAverageVelocity,
CapturePoint, MultiFramePosition, MultiFootRelativeXYZQuat,
MultiContactNormalizedSpatialForce, MultiFootNormalizedForceVertical,
MultiFootCollisionDetection, AverageBaseMomentum)
from ..quantities.locomotion import sanitize_foot_frame_names
from ..utils import quat_difference

from .generic import (
ArrayLikeOrScalar, TrackingQuantityReward, QuantityTermination)
ArrayLikeOrScalar, TrackingQuantityReward, QuantityTermination,
DriftTrackingQuantityTermination)
from .mixin import radial_basis_function


Expand Down Expand Up @@ -604,3 +606,86 @@ def __init__(self,
grace_period,
is_truncation=False,
is_training_only=is_training_only)


class DriftTrackingBaseOdometryPoseTermination(
DriftTrackingQuantityTermination):
"""Terminate the episode if the current base odometry base is drifting too
much over wrt some reference trajectory that is being tracked.

It is generally important to make sure that the robot is not deviating too
much from some reference trajectory. It sounds appealing to make sure that
the absolute error between the current and reference trajectory is bounded
at all time. However, such a condition is very restrictive, especially for
robots dealing with external disturbances or evolving on an uneven terrain.
Moreover, when it comes to infinite-horizon trajectories in particular, eg
periodic motions, avoiding drifting away over time involves being able to
sense the absolute position of the robot in world frame via exteroceptive
navigation sensors such as depth cameras or LIDARs. This kind of advanced
sensor may not be able, thereby making the objective out of reach. Still,
in the case of legged locomotion, what really matters is tracking
accurately a nominal limit cycle as long as doing so does not compromise
local stability. If it does, then the agent expected to make every effort
to recover balance as fast as possible before going back to the nominal
limit cycle, without trying to catch up with the ensuing drift since the
exact absolute odometry pose in world frame is of little interest. See
`DriftTrackingQuantityTermination` documentation for details.
"""
def __init__(self,
env: InterfaceJiminyEnv,
max_position_err: float,
max_orientation_err: float,
horizon: float,
grace_period: float = 0.0,
*,
is_training_only: bool = False) -> None:
"""
:param env: Base or wrapped jiminy environment.
:param max_position_err:
Maximum drift error in translation (X, Y) in world place above
which termination is triggered.
:param max_orientation_err:
Maximum drift error in orientation (yaw,) in world place above
which termination is triggered.
:param horizon: Horizon over which values of the quantity will be
stacked before computing the drift.
:param grace_period: Grace period effective only at the very beginning
of the episode, during which the latter is bound
to continue whatever happens.
Optional: 0.0 by default.
:param is_training_only: Whether the termination condition should be
completely by-passed if the environment is in
evaluation mode.
Optional: False by default.
"""
# Define jit-able method for computing translation and rotation errors
@nb.jit(nopython=True, cache=True)
def compute_se2_double_geoedesic_distance(
diff: np.ndarray) -> np.ndarray:
"""Compute the errors between two odometry poses in the Cartesian
space (R^2, SO(2)), ie considering the translational and rotational
parts independently.

:param diff: Element-wise difference between two odometry poses as
a 1D array gathering the 3 components (X, Y, Yaw).

:returns: Pair (data_err_pos, data_err_rot) gathering the L2-norm
of the difference in translation and rotation as a 1D array.
"""
diff_x, diff_y, diff_yaw = diff
error_position = math.sqrt(diff_x ** 2 + diff_y ** 2)
error_orientation = math.fabs(diff_yaw)
return np.array([error_position, error_orientation])

# Call base implementation
super().__init__(
env,
"termination_tracking_motor_positions",
lambda mode: (BaseOdometryPose, dict(mode=mode)),
None,
np.array([max_position_err, max_orientation_err]),
horizon,
grace_period,
post_fn=compute_se2_double_geoedesic_distance,
is_truncation=False,
is_training_only=is_training_only)
Original file line number Diff line number Diff line change
Expand Up @@ -164,8 +164,7 @@ class MultiplicativeMixtureReward(MixtureReward):
def __init__(self,
env: InterfaceJiminyEnv,
name: str,
components: Sequence[AbstractReward]
) -> None:
components: Sequence[AbstractReward]) -> None:
"""
:param env: Base or wrapped jiminy environment.
:param name: Desired name of the reward.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
MultiFrameCollisionDetection,
AverageFrameXYZQuat,
AverageFrameRollPitch,
AveragePowerConsumption,
AverageMechanicalPowerConsumption,
FrameSpatialAverageVelocity,
MultiActuatedJointKinematic)
from .locomotion import (BaseOdometryPose,
Expand Down Expand Up @@ -63,7 +63,7 @@
'MultiContactNormalizedSpatialForce',
'AverageFrameXYZQuat',
'AverageFrameRollPitch',
'AveragePowerConsumption',
'AverageMechanicalPowerConsumption',
'FrameSpatialAverageVelocity',
'BaseSpatialAverageVelocity',
'BaseOdometryAverageVelocity',
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1696,7 +1696,7 @@ class EnergyGenerationMode(Enum):


@dataclass(unsafe_hash=True)
class AveragePowerConsumption(InterfaceQuantity[float]):
class AverageMechanicalPowerConsumption(InterfaceQuantity[float]):
"""Average mechanical power consumption by all the motors over a sliding
time window.
"""
Expand Down Expand Up @@ -1795,7 +1795,7 @@ def _compute_power(generator_mode: EnergyGenerationMode,
op=partial(_compute_power, self.generator_mode))),
max_stack=self.max_stack,
as_array=True,
mode='zeros'))),
mode='slice'))),
auto_refresh=False)

def refresh(self) -> float:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,11 @@ class BaseOdometryPose(AbstractQuantity[np.ndarray]):
The odometry pose fully specifies the position and heading of the robot in
2D world plane. As such, it comprises the linear translation (X, Y) and
the rotation around Z axis (namely rate of change of Yaw Euler angle).
Mathematically, one is supposed to rely on se2 Lie Algebra for performing
operations on odometry poses such as differentiation. In practice, the
double geodesic metric space is used instead to prevent coupling between
the linear and angular parts by considering them independently. Strictly
speaking, it corresponds to the cartesian space (R^2 x SO(2)).
"""

def __init__(self,
Expand Down
Loading
Loading