Skip to content

Commit

Permalink
[gym/common] Add relative foot odom pose shift tracking termination c…
Browse files Browse the repository at this point in the history
…onditions. (#820)

* [gym/common] Add relative foot odometry pose shift tracking termination conditions.
* [gym/common] Add unit test checking that observation wrappers preserve key ordering.
  • Loading branch information
duburcqa authored Jun 29, 2024
1 parent 80bfd36 commit 0443f50
Show file tree
Hide file tree
Showing 17 changed files with 338 additions and 162 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -303,7 +303,7 @@ def __init__(self,
name: str,
components: Sequence[AbstractReward],
reduce_fn: Callable[
[Sequence[Optional[float]]], Optional[float]],
[Tuple[Optional[float], ...]], Optional[float]],
is_normalized: bool) -> None:
"""
:param env: Base or wrapped jiminy environment.
Expand Down Expand Up @@ -378,7 +378,7 @@ def compute(self, terminated: bool, info: InfoType) -> Optional[float]:
values.append(value)

# Aggregate all reward components in one
reward_total = self._reduce_fn(values)
reward_total = self._reduce_fn(tuple(values))

return reward_total

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,14 @@
TrackingFootPositionsReward,
TrackingFootOrientationsReward,
TrackingFootForceDistributionReward,
DriftTrackingBaseOdometryPoseTermination,
DriftTrackingBaseOdometryPositionTermination,
DriftTrackingBaseOdometryOrientationTermination,
ShiftTrackingFootOdometryPositionsTermination,
ShiftTrackingFootOdometryOrientationsTermination,
MinimizeAngularMomentumReward,
MinimizeFrictionReward,
BaseRollPitchTermination,
BaseHeightTermination,
FallingTermination,
FootCollisionTermination,
FlyingTermination,
ImpactForceTermination)
Expand All @@ -44,14 +47,17 @@
"TrackingFootOrientationsReward",
"TrackingFootForceDistributionReward",
"DriftTrackingQuantityTermination",
"DriftTrackingBaseOdometryPoseTermination",
"DriftTrackingBaseOdometryPositionTermination",
"DriftTrackingBaseOdometryOrientationTermination",
"ShiftTrackingQuantityTermination",
"ShiftTrackingMotorPositionsTermination",
"ShiftTrackingFootOdometryPositionsTermination",
"ShiftTrackingFootOdometryOrientationsTermination",
"MechanicalSafetyTermination",
"MechanicalPowerConsumptionTermination",
"FlyingTermination",
"BaseRollPitchTermination",
"BaseHeightTermination",
"FallingTermination",
"FootCollisionTermination",
"ImpactForceTermination"
]
Original file line number Diff line number Diff line change
Expand Up @@ -346,7 +346,7 @@ def __init__(self,
self.max_stack = max_stack
self.op = op

# Define jit-able minimum distance between two time series
# Jit-able method computing minimum distance between two time series
@nb.jit(nopython=True, cache=True)
def min_norm(values: np.ndarray) -> float:
"""Compute the minimum Euclidean norm over all timestamps of a
Expand Down Expand Up @@ -612,9 +612,9 @@ def __init__(


class ShiftTrackingMotorPositionsTermination(ShiftTrackingQuantityTermination):
"""Terminate the episode if the robot is not tracking the actuated joint
positions wrt some reference trajectory with expected accuracy, whatever
the timestep being considered over some fixed-size sliding time window.
"""Terminate the episode if the selected reference trajectory is not
tracked with expected accuracy regarding the actuated joint positions,
whatever the timestep being considered over some fixed-size sliding window.
The robot must track the reference if there is no hazard, only applying
minor corrections to keep balance. Rewarding the agent for doing so is
Expand Down
Loading

0 comments on commit 0443f50

Please sign in to comment.