-
Notifications
You must be signed in to change notification settings - Fork 26
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[gym/common] Add average odometry velocity quantity.
- Loading branch information
Showing
10 changed files
with
229 additions
and
41 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
46 changes: 36 additions & 10 deletions
46
python/gym_jiminy/common/gym_jiminy/common/rewards/locomotion.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,39 +1,65 @@ | ||
|
||
"""Rewards mainly relevant for locomotion tasks on floating-base robots. | ||
""" | ||
from typing import Sequence | ||
|
||
import numpy as np | ||
|
||
from ..bases import InterfaceJiminyEnv, BaseQuantityReward | ||
from ..quantities import MaskedQuantity, AverageFrameSpatialVelocity | ||
from ..quantities import AverageOdometryVelocity | ||
|
||
from .generic import radial_basis_function | ||
|
||
|
||
class OdometryVelocityReward(BaseQuantityReward): | ||
""" TODO: Write documentation. | ||
"""Reward the agent for tracking a non-stationary target odometry velocity. | ||
The error transform in a normalized reward to maximize by applying RBF | ||
kernel on the error. The reward will be 0.0 if the error cancels out | ||
completely and less than 0.01 above the user-specified cutoff threshold. | ||
""" | ||
def __init__(self, | ||
env: InterfaceJiminyEnv, | ||
target: Sequence[float], | ||
cutoff: float) -> None: | ||
""" TODO: Write documentation. | ||
""" | ||
:param target: Initial target average odometry velocity (vX, vY, vYaw). | ||
The target can be updated in necessary by calling | ||
`set_target`. | ||
:param cutoff: Cutoff threshold for the RBF kernel transform. | ||
""" | ||
# Backup some user argument(s) | ||
self.target = target | ||
self._target = np.asarray(target) | ||
self.cutoff = cutoff | ||
|
||
# Call base implementation | ||
super().__init__( | ||
env, | ||
"reward_odometry_velocity", | ||
(MaskedQuantity, dict( | ||
quantity=(AverageFrameSpatialVelocity, dict(frame_name="root_joint")), | ||
key=(0, 1, 5))), | ||
(AverageOdometryVelocity, {}), | ||
self._transform, | ||
is_normalized=True, | ||
is_terminal=False) | ||
|
||
def _transform(self, value: np.ndarray) -> np.ndarray: | ||
""" TODO: Write documentation. | ||
@property | ||
def target(self) -> np.ndarray: | ||
"""Get current target odometry velocity. | ||
""" | ||
return self._target | ||
|
||
@target.setter | ||
def target(self, target: Sequence[float]) -> None: | ||
"""Set current target odometry velocity. | ||
""" | ||
self._target = np.asarray(target) | ||
|
||
def _transform(self, value: np.ndarray) -> float: | ||
"""Apply Radial Base Function transform to the residual error between | ||
the current and target average odometry velocity. | ||
.. note:: | ||
The user must call `set_target` method before `compute_reward` to | ||
update the target odometry velocity if non-stationary. | ||
:param value: Current average odometry velocity. | ||
""" | ||
return radial_basis_function(value - self.target, self.cutoff) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.