Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[gym/common] Add several termination conditions. #816

Merged
merged 1 commit into from
Jun 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -581,7 +581,7 @@ class AbstractQuantity(InterfaceQuantity, Generic[ValueT]):
"""

mode: QuantityEvalMode
"""Specify on which state to evaluate this quantity. See `Mode`
"""Specify on which state to evaluate this quantity. See `QuantityEvalMode`
documentation for details about each mode.

.. warning::
Expand Down Expand Up @@ -912,7 +912,7 @@ class StateQuantity(InterfaceQuantity[State]):
"""

mode: QuantityEvalMode
"""Specify on which state to evaluate this quantity. See `Mode`
"""Specify on which state to evaluate this quantity. See `QuantityEvalMode`
documentation for details about each mode.

.. warning::
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@
SurviveReward,
DriftTrackingQuantityTermination,
ShiftTrackingQuantityTermination,
MechanicalSafetyTermination)
MechanicalSafetyTermination,
PowerConsumptionTermination)
from .locomotion import (TrackingBaseHeightReward,
TrackingBaseOdometryVelocityReward,
TrackingCapturePointReward,
Expand All @@ -21,7 +22,8 @@
BaseRollPitchTermination,
BaseHeightTermination,
FootCollisionTermination,
FlyingTermination)
FlyingTermination,
ImpactForceTermination)

__all__ = [
"CUTOFF_ESP",
Expand All @@ -42,8 +44,10 @@
"DriftTrackingQuantityTermination",
"ShiftTrackingQuantityTermination",
"MechanicalSafetyTermination",
"PowerConsumptionTermination",
"FlyingTermination",
"BaseRollPitchTermination",
"BaseHeightTermination",
"FootCollisionTermination"
"FootCollisionTermination",
"ImpactForceTermination"
]
83 changes: 72 additions & 11 deletions python/gym_jiminy/common/gym_jiminy/common/compositions/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@
AbstractTerminationCondition, QuantityTermination)
from ..bases.compositions import ArrayOrScalar, ArrayLikeOrScalar
from ..quantities import (
StackedQuantity, UnaryOpQuantity, BinaryOpQuantity,
MultiActuatedJointKinematic)
EnergyGenerationMode, StackedQuantity, UnaryOpQuantity, BinaryOpQuantity,
MultiActuatedJointKinematic, AveragePowerConsumption)

from .mixin import radial_basis_function

Expand Down Expand Up @@ -139,6 +139,7 @@ def __init__(self,
"reward_actuated_joint_positions",
lambda mode: (MultiActuatedJointKinematic, dict(
kinematic_level=pin.KinematicLevel.POSITION,
is_motor_side=False,
mode=mode)),
cutoff)

Expand All @@ -164,7 +165,7 @@ def __init__(self,
[QuantityEvalMode], QuantityCreator[ArrayOrScalar]],
low: Optional[ArrayLikeOrScalar],
high: Optional[ArrayLikeOrScalar],
max_stack: int,
horizon: float,
grace_period: float = 0.0,
*,
op: Callable[
Expand All @@ -187,9 +188,8 @@ def __init__(self,
'env' and 'parent'.
:param low: Lower bound below which termination is triggered.
:param high: Upper bound above which termination is triggered.
:param max_stack: Horizon over which values of the quantity will be
stacked if desired. 1 to disable.
Optional: 1 by default.
:param horizon: Horizon over which values of the quantity will be
stacked before computing the drift.
:param grace_period: Grace period effective only at the very beginning
of the episode, during which the latter is bound
to continue whatever happens.
Expand All @@ -211,6 +211,9 @@ def __init__(self,
"""
# pylint: disable=unnecessary-lambda-assignment

# Convert horizon in stack length, assuming constant env timestep
max_stack = max(int(np.ceil(horizon / env.step_dt)), 1)

# Backup user argument(s)
self.max_stack = max_stack
self.op = op
Expand Down Expand Up @@ -265,7 +268,7 @@ def __init__(self,
quantity_creator: Callable[
[QuantityEvalMode], QuantityCreator[ArrayOrScalar]],
thr: float,
max_stack: int,
horizon: float,
grace_period: float = 0.0,
*,
op: Callable[[np.ndarray, np.ndarray], np.ndarray] = sub,
Expand All @@ -287,9 +290,8 @@ def __init__(self,
'env' and 'parent'.
:param thr: Termination is triggered if the shift exceeds this
threshold.
:param max_stack: Horizon over which values of the quantity will be
stacked if desired. 1 to disable.
Optional: 1 by default.
:param horizon: Horizon over which values of the quantity will be
stacked before computing the shift.
:param grace_period: Grace period effective only at the very beginning
of the episode, during which the latter is bound
to continue whatever happens.
Expand All @@ -315,6 +317,9 @@ def __init__(self,
"""
# pylint: disable=unnecessary-lambda-assignment

# Convert horizon in stack length, assuming constant env timestep
max_stack = max(int(np.ceil(horizon / env.step_dt)), 1)

# Backup user argument(s)
self.max_stack = max_stack
self.op = op
Expand Down Expand Up @@ -405,6 +410,7 @@ def __init__(self,
requirements=dict(
position=(MultiActuatedJointKinematic, dict(
kinematic_level=pin.KinematicLevel.POSITION,
is_motor_side=False,
mode=QuantityEvalMode.TRUE))),
auto_refresh=False)

Expand Down Expand Up @@ -487,7 +493,8 @@ def __init__(self,
_MultiActuatedJointBoundDistance, {})
self.env.quantities["_".join((self.name, "velocity"))] = (
MultiActuatedJointKinematic, dict(
kinematic_level=pin.KinematicLevel.VELOCITY))
kinematic_level=pin.KinematicLevel.VELOCITY,
is_motor_side=False))

# Keep track of the underlying quantities
registry = self.env.quantities.registry
Expand Down Expand Up @@ -525,3 +532,57 @@ def compute(self, info: InfoType) -> bool:
(position_delta_high < self.position_margin) &
(velocity > self.velocity_max))
return is_done


class PowerConsumptionTermination(QuantityTermination):
"""Terminate the episode immediately if the average mechanical power
consumption is too high.

High power consumption is undesirable as it means that the motion is
suboptimal and probably unnatural and fragile. Moreover, it helps to
accommodate hardware capability to avoid motor overheating while increasing
battery autonomy and lifespan. Finally, it may be necessary to deal with
some hardware limitations on max power drain.
"""
def __init__(
self,
env: InterfaceJiminyEnv,
max_power: float,
horizon: float,
generator_mode: EnergyGenerationMode = EnergyGenerationMode.CHARGE,
grace_period: float = 0.0,
*,
is_training_only: bool = False) -> None:
"""
:param env: Base or wrapped jiminy environment.
:param max_power: Maximum average mechanical power consumption applied
on any of the contact points or collision bodies
above which termination is triggered.
:param grace_period: Grace period effective only at the very beginning
of the episode, during which the latter is bound
to continue whatever happens.
Optional: 0.0 by default.
:param horizon: Horizon over which values of the quantity will be
stacked before computing the average.
:param is_training_only: Whether the termination condition should be
completely by-passed if the environment is in
evaluation mode.
Optional: False by default.
"""
# Backup user argument(s)
self.max_power = max_power
self.horizon = horizon
self.generator_mode = generator_mode

# Call base implementation
super().__init__(
env,
"termination_power_consumption",
(AveragePowerConsumption, dict( # type: ignore[arg-type]
horizon=self.horizon,
generator_mode=self.generator_mode)),
None,
self.max_power,
grace_period,
is_truncation=False,
is_training_only=is_training_only)
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
OrientationType, MaskedQuantity, UnaryOpQuantity, FrameOrientation,
BaseRelativeHeight, BaseOdometryAverageVelocity, CapturePoint,
MultiFramePosition, MultiFootRelativeXYZQuat,
MultiContactNormalizedForceTangential, MultiFootNormalizedForceVertical,
MultiContactNormalizedSpatialForce, MultiFootNormalizedForceVertical,
MultiFootCollisionDetection, AverageBaseMomentum)
from ..quantities.locomotion import sanitize_foot_frame_names
from ..utils import quat_difference
Expand Down Expand Up @@ -50,7 +50,9 @@ def __init__(self,
"reward_tracking_base_height",
lambda mode: (MaskedQuantity, dict(
quantity=(UnaryOpQuantity, dict(
quantity=(StateQuantity, dict(mode=mode)),
quantity=(StateQuantity, dict(
update_kinematics=False,
mode=mode)),
op=lambda state: state.q)),
keys=(2,))),
cutoff)
Expand Down Expand Up @@ -291,7 +293,10 @@ def __init__(self,
super().__init__(
env,
"reward_friction",
(MultiContactNormalizedForceTangential, dict()),
(MaskedQuantity, dict(
quantity=(MultiContactNormalizedSpatialForce, dict()),
axis=0,
keys=(0, 1))),
partial(radial_basis_function, cutoff=self.cutoff, order=2),
is_normalized=True,
is_terminal=False)
Expand Down Expand Up @@ -523,10 +528,10 @@ class FlyingTermination(QuantityTermination):
"""Discourage the agent of jumping by terminating the episode immediately
if the robot is flying too high above the ground.

This kind of jumping behavior is unsually undesirable because it may be
frightning for people nearby, difficule to predict and hardly repeatable.
Moreover, they tend to transfer poorly to reality as very dynamic motions
worsen the simulation to real gap.
This kind of behavior is unsually undesirable because it may be frightning
for people nearby, damage the hardware, be difficult to predict and be
hardly repeatable. Moreover, such dynamic motions tend to transfer poorly
to reality because the simulation to real gap is worsening.
"""
def __init__(self,
env: InterfaceJiminyEnv,
Expand All @@ -536,8 +541,8 @@ def __init__(self,
is_training_only: bool = False) -> None:
"""
:param env: Base or wrapped jiminy environment.
:param max_height: Maximum height of all the lowest contact points wrt
the groupd above which termination is triggered.
:param max_height: Maximum height of the lowest contact points wrt the
groupd above which termination is triggered.
:param grace_period: Grace period effective only at the very beginning
of the episode, during which the latter is bound
to continue whatever happens.
Expand All @@ -557,3 +562,45 @@ def __init__(self,
grace_period,
is_truncation=False,
is_training_only=is_training_only)


class ImpactForceTermination(QuantityTermination):
"""Terminate the episode immediately in case of violent impact on the
ground.

Similarly to the jumping behavior, this kind of behavior is usually
undesirable. See `FlyingTermination` documentation for details.
"""
def __init__(self,
env: InterfaceJiminyEnv,
max_force: float,
grace_period: float = 0.0,
*,
is_training_only: bool = False) -> None:
"""
:param env: Base or wrapped jiminy environment.
:param max_force: Maximum vertical force applied on any of the contact
points or collision bodies above which termination is
triggered.
:param grace_period: Grace period effective only at the very beginning
of the episode, during which the latter is bound
to continue whatever happens.
Optional: 0.0 by default.
:param is_training_only: Whether the termination condition should be
completely by-passed if the environment is in
evaluation mode.
Optional: False by default.
"""
# Call base implementation
super().__init__(
env,
"termination_flying",
(MaskedQuantity, dict( # type: ignore[arg-type]
quantity=(MultiContactNormalizedSpatialForce, dict()),
axis=0,
keys=(2,))),
None,
max_force,
grace_period,
is_truncation=False,
is_training_only=is_training_only)
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
UnaryOpQuantity,
BinaryOpQuantity,
MultiAryOpQuantity)
from .generic import (OrientationType,
from .generic import (EnergyGenerationMode,
OrientationType,
FrameOrientation,
FramePosition,
FrameXYZQuat,
Expand All @@ -17,6 +18,7 @@
MultiFrameCollisionDetection,
AverageFrameXYZQuat,
AverageFrameRollPitch,
AveragePowerConsumption,
FrameSpatialAverageVelocity,
MultiActuatedJointKinematic)
from .locomotion import (BaseOdometryPose,
Expand All @@ -28,7 +30,7 @@
MultiFootRelativeXYZQuat,
MultiFootMeanOdometryPose,
MultiFootNormalizedForceVertical,
MultiContactNormalizedForceTangential,
MultiContactNormalizedSpatialForce,
MultiFootCollisionDetection,
CenterOfMass,
CapturePoint,
Expand All @@ -37,6 +39,7 @@


__all__ = [
'EnergyGenerationMode',
'OrientationType',
'QuantityManager',
'StackedQuantity',
Expand All @@ -57,9 +60,10 @@
'MultiFootMeanOdometryPose',
'MultiFootNormalizedForceVertical',
'MultiFootCollisionDetection',
'MultiContactNormalizedForceTangential',
'MultiContactNormalizedSpatialForce',
'AverageFrameXYZQuat',
'AverageFrameRollPitch',
'AveragePowerConsumption',
'FrameSpatialAverageVelocity',
'BaseSpatialAverageVelocity',
'BaseOdometryAverageVelocity',
Expand Down
Loading
Loading