Skip to content

Commit

Permalink
[gym/common] First draft of reward objects. (#784)
Browse files Browse the repository at this point in the history
* [gym/common] Add masked quantity.
* [gym/common] Add average odometry velocity quantity.
* [gym/common] Introduce reward objects.
* [gym/common] Add locomotion rewards.
  • Loading branch information
duburcqa authored May 3, 2024
1 parent a10a9a8 commit 8f90f1f
Show file tree
Hide file tree
Showing 23 changed files with 863 additions and 23 deletions.
6 changes: 6 additions & 0 deletions docs/api/gym_jiminy/common/bases/quantity.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
Quantity
========

.. automodule:: gym_jiminy.common.bases.quantity
:members:
:show-inheritance:
6 changes: 6 additions & 0 deletions docs/api/gym_jiminy/common/bases/reward.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
Reward
======

.. automodule:: gym_jiminy.common.bases.reward
:members:
:show-inheritance:
2 changes: 2 additions & 0 deletions docs/api/gym_jiminy/common/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -7,5 +7,7 @@ Gym Jiminy API
bases/index
blocks/index
envs/index
quantities/index
rewards/index
wrappers/index
utils/index
9 changes: 9 additions & 0 deletions docs/api/gym_jiminy/common/quantities/generic.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
Generic
=======

.. automodule:: gym_jiminy.common.quantities.generic
:members:
:undoc-members:
:private-members:
:inherited-members:
:show-inheritance:
8 changes: 8 additions & 0 deletions docs/api/gym_jiminy/common/quantities/index.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
Envs
====

.. toctree::
:maxdepth: 1

generic
locomotion
8 changes: 8 additions & 0 deletions docs/api/gym_jiminy/common/quantities/locomotion.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
Locomotion
==========

.. automodule:: gym_jiminy.common.quantities.locomotion
:members:
:undoc-members:
:private-members:
:show-inheritance:
9 changes: 9 additions & 0 deletions docs/api/gym_jiminy/common/rewards/generic.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
Generic
=======

.. automodule:: gym_jiminy.common.rewards.generic
:members:
:undoc-members:
:private-members:
:inherited-members:
:show-inheritance:
8 changes: 8 additions & 0 deletions docs/api/gym_jiminy/common/rewards/index.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
Envs
====

.. toctree::
:maxdepth: 1

generic
locomotion
8 changes: 8 additions & 0 deletions docs/api/gym_jiminy/common/rewards/locomotion.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
Locomotion
==========

.. automodule:: gym_jiminy.common.rewards.locomotion
:members:
:undoc-members:
:private-members:
:show-inheritance:
18 changes: 12 additions & 6 deletions python/gym_jiminy/common/gym_jiminy/common/bases/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,5 @@
# pylint: disable=missing-module-docstring

from .quantity import (QuantityCreator,
SharedCache,
AbstractQuantity)
from .interfaces import (DT_EPS,
ObsT,
ActT,
Expand All @@ -14,6 +11,12 @@
InterfaceObserver,
InterfaceController,
InterfaceJiminyEnv)
from .quantity import (QuantityCreator,
SharedCache,
AbstractQuantity)
from .reward import (AbstractReward,
BaseQuantityReward,
RewardCreator)
from .blocks import (BlockStateT,
InterfaceBlock,
BaseObserverBlock,
Expand All @@ -27,9 +30,6 @@


__all__ = [
'QuantityCreator',
'SharedCache',
'AbstractQuantity',
'DT_EPS',
'ObsT',
'NestedObsT',
Expand All @@ -40,15 +40,21 @@
'InfoType',
'SensorMeasurementStackMap',
'EngineObsType',
'SharedCache',
'InterfaceObserver',
'InterfaceController',
'InterfaceJiminyEnv',
'InterfaceBlock',
'AbstractQuantity',
'AbstractReward',
'BaseQuantityReward',
'BaseObserverBlock',
'BaseControllerBlock',
'BasePipelineWrapper',
'BaseTransformObservation',
'BaseTransformAction',
'ObservedJiminyEnv',
'ControlledJiminyEnv',
'QuantityCreator',
'RewardCreator'
]
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,7 @@ def __init__(self,
env: InterfaceJiminyEnv,
parent: Optional["AbstractQuantity"],
requirements: Dict[str, "QuantityCreator"],
auto_refresh: bool) -> None:
auto_refresh: bool = False) -> None:
"""
:param env: Base or wrapped jiminy environment.
:param parent: Higher-level quantity from which this quantity is a
Expand Down
176 changes: 176 additions & 0 deletions python/gym_jiminy/common/gym_jiminy/common/bases/reward.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,176 @@
"""This module promotes reward components as first-class objects.
Defining rewards this way allows for standardization of usual metrics. Overall,
it greatly reduces code duplication and bugs.
"""
from abc import ABC, abstractmethod
from typing import Dict, Any, TypeVar, Callable, Optional, Tuple, Type

import numpy as np

from ..bases import InterfaceJiminyEnv, QuantityCreator, InfoType


ValueT = TypeVar('ValueT')


class AbstractReward(ABC):
"""Abstract class from which all reward component must derived.
This goal of the agent is to maximize the expectation of the cumulative sum
of discounted reward over complete episodes. This holds true no matter if
its sign is always negative (aka. reward), always positive (aka. cost) or
indefinite (aka. objective).
Defining cost is allowed by not recommended. Although it encourages the
agent to achieve the task at hands as quickly as possible if success if the
only termination condition, it has the side-effect to give the opportunity
to the agent to maximize the return by killing itself whenever this is an
option, which is rarely the desired behavior. No restriction is enforced as
it may be limiting in some relevant cases, so it is up to the user to make
sure that its design makes sense overall.
"""

def __init__(self, env: InterfaceJiminyEnv) -> None:
"""
:param env: Base or wrapped jiminy environment.
"""
self.env = env

@property
@abstractmethod
def name(self) -> str:
"""Name uniquely identifying a given reward component.
"""

@property
@abstractmethod
def is_terminal(self) -> bool:
"""Whether the reward is terminal, ie only evaluated at the end of an
episode if a termination condition has been triggered.
.. note::
Truncation is not consider the same as termination. The reward to
not be evaluated in such a case, which means that it will never be
for such episodes.
"""

@property
@abstractmethod
def is_normalized(self) -> bool:
"""Whether the reward is guaranteed to be normalized, ie it is in range
[0.0, 1.0].
"""

@abstractmethod
def __call__(self, terminated: bool, info: InfoType) -> float:
"""Evaluate the reward for the current state of the environment.
"""


class BaseQuantityReward(AbstractReward):
"""Base class that makes easy easy to derive reward components from generic
quantities.
All this class does is applying some user-specified post-processing to the
value of a given multi-variate quantity to return a floating-point scalar
value, eventually normalized between 0.0 and 1.0 if desired.
"""

def __init__(self,
env: InterfaceJiminyEnv,
name: str,
quantity: QuantityCreator[ValueT],
transform_fun: Optional[Callable[[ValueT], float]],
is_normalized: bool,
is_terminal: bool) -> None:
"""
:param env: Base or wrapped jiminy environment.
:param name: Desired name of the reward. This name will be used as key
for storing current value of the reward in 'info', and to
add the underlying quantity to the set of already managed
quantities by the environment. As a result, it must be
unique otherwise an exception will be raised.
:param quantity: Tuple gathering the class of the underlying quantity
to use as reward after some post-processing, plus all
its constructor keyword-arguments except environment
'env' and parent 'parent.
:param transform_fun: Transform function responsible for aggregating a
multi-variate quantity as floating-point scalar
value to maximize. Typical examples are `np.min`,
`np.max`, `lambda x: np.linalg.norm(x, order=N)`.
This function is also responsible for rescaling
the transformed quantity in range [0.0, 1.0] if
the reward is advertised as normalized. The
Radial Basis Function (RBF) kernel is the most
common choice to derive a reward to maximize from
errors based on distance metrics (See
`radial_basis_function` for details.). `None` to
skip transform entirely if not necessary.
:param is_terminal: Whether the reward is terminal. A terminal reward
will only be evaluated at most once, at the end of
each episode for which a termination condition has
been triggered. On the contrary, a non-terminal
reward will be evaluated systematically except at
the end of the episode. The value 0.0 is returned
and 'info' is not filled when reward evaluation is
skipped.
"""
# Backup user argument(s)
self._name = name
self._transform_fun = transform_fun
self._is_normalized = is_normalized
self._is_terminal = is_terminal

# Call base implementation
super().__init__(env)

# Add quantity to the set of quantities managed by the environment
self.env.quantities[self.name] = quantity

# Keep track of the underlying quantity
self.quantity = self.env.quantities.registry[self.name]

@property
def name(self) -> str:
"""Name uniquely identifying every reward. It will be used to add the
underlying quantity to the ones already managed by the environment.
.. warning::
It must be prefixed by "reward_" as a risk mitigation for name
collision with some other user-defined quantity.
"""
return self._name

@property
def is_terminal(self) -> bool:
return self._is_terminal

@property
def is_normalized(self) -> bool:
return self._is_normalized

def __call__(self, terminated: bool, info: InfoType) -> float:
# Early return depending on whether the reward and state are terminal
if terminated ^ self.is_terminal:
return 0.0

# Evaluate raw quantity
value = self.env.quantities[self.name]

# Apply some post-processing if requested
if self._transform_fun is not None:
value = self._transform_fun(value)
assert np.ndim(value) == 0
if self._is_normalized and (value < 0.0 or value > 1.0):
raise ValueError(
"Reward not normalized in range [0.0, 1.0] as it ought to be.")

# Store its value as info
info[self.name] = value

# Returning the reward
return value


RewardCreator = Tuple[Type[AbstractReward], Dict[str, Any]]
4 changes: 3 additions & 1 deletion python/gym_jiminy/common/gym_jiminy/common/envs/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1484,7 +1484,9 @@ def compute_command(self, action: ActT, command: np.ndarray) -> None:
assert isinstance(action, np.ndarray)
array_copyto(command, action)

def has_terminated(self, info: InfoType) -> Tuple[bool, bool]:
def has_terminated(self,
info: InfoType # pylint: disable=unused-argument
) -> Tuple[bool, bool]:
"""Determine whether the episode is over, because a terminal state of
the underlying MDP has been reached or an aborting condition outside
the scope of the MDP has been triggered.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,22 @@
from .manager import QuantityManager
from .generic import (FrameEulerAngles,
FrameXYZQuat,
AverageFrameSpatialVelocity)
from .locomotion import CenterOfMass, ZeroMomentPoint
StackedQuantity,
AverageFrameSpatialVelocity,
MaskedQuantity)
from .locomotion import (AverageOdometryVelocity,
CenterOfMass,
ZeroMomentPoint)


__all__ = [
'QuantityManager',
'FrameEulerAngles',
'FrameXYZQuat',
'StackedQuantity',
'AverageFrameSpatialVelocity',
'MaskedQuantity',
'AverageOdometryVelocity',
'CenterOfMass',
'ZeroMomentPoint',
]
Loading

0 comments on commit 8f90f1f

Please sign in to comment.