Skip to content

Commit

Permalink
[gym/common] Add composition pipeline wrapper. Support specifying rew…
Browse files Browse the repository at this point in the history
…ard in pipeline config.
  • Loading branch information
duburcqa committed May 6, 2024
1 parent 182894f commit d165464
Show file tree
Hide file tree
Showing 11 changed files with 284 additions and 55 deletions.
7 changes: 4 additions & 3 deletions python/gym_jiminy/common/gym_jiminy/common/bases/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,10 @@
from .quantities import (QuantityCreator,
SharedCache,
AbstractQuantity)
from .reward import (AbstractReward,
BaseQuantityReward,
BaseMixtureReward)
from .compositions import (AbstractReward,

Check notice on line 17 in python/gym_jiminy/common/gym_jiminy/common/bases/__init__.py

View check run for this annotation

Codacy Production / Codacy Static Code Analysis

python/gym_jiminy/common/gym_jiminy/common/bases/__init__.py#L17

'.compositions.ComposedJiminyEnv' imported but unused (F401)
BaseQuantityReward,
BaseMixtureReward,
ComposedJiminyEnv)
from .blocks import (BlockStateT,
InterfaceBlock,
BaseObserverBlock,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,13 @@
it greatly reduces code duplication and bugs.
"""
from abc import ABC, abstractmethod
from typing import Sequence, Callable, Optional, Tuple, TypeVar
from typing import Sequence, Callable, Optional, Tuple, TypeVar, Generic

import numpy as np

from ..bases import InterfaceJiminyEnv, QuantityCreator, InfoType
from .interfaces import ObsT, ActT, InfoType, EngineObsType, InterfaceJiminyEnv
from .quantities import QuantityCreator
from .pipeline import BasePipelineWrapper


ValueT = TypeVar('ValueT')
Expand Down Expand Up @@ -271,48 +273,48 @@ class BaseMixtureReward(AbstractReward):
single one.
"""

rewards: Tuple[AbstractReward, ...]
components: Tuple[AbstractReward, ...]
"""List of all the reward components that must be aggregated together.
"""

def __init__(self,
env: InterfaceJiminyEnv,
name: str,
rewards: Sequence[AbstractReward],
components: Sequence[AbstractReward],
reduce_fn: Callable[
[Sequence[Optional[float]]], Optional[float]],
is_normalized: bool) -> None:
"""
:param env: Base or wrapped jiminy environment.
:param name: Desired name of the total reward.
:param rewards: Sequence of reward components to aggregate.
:param components: Sequence of reward components to aggregate.
:param reduce_fn: Transform function responsible for aggregating all
the reward components that were evaluated. Typical
examples are cumulative product and weighted sum.
:param is_normalized: Whether the reward is guaranteed to be normalized
after applying reduction function `reduce_fn`.
"""
# Make sure that at least one reward component has been specified
if not rewards:
if not components:
raise ValueError(
"At least one reward component must be specified.")

# Make sure that all reward components share the same environment
env = rewards[0].env
for reward in rewards[1:]:
for reward in components:
if env is not reward.env:
raise ValueError(
"All reward components must share the same environment.")

# Backup some user argument(s)
self.rewards = tuple(rewards)
self.components = tuple(components)
self._reduce_fn = reduce_fn
self._is_normalized = is_normalized

# Call base implementation
super().__init__(env, name)

# Determine whether the reward mixture is terminal
is_terminal = {reward.is_terminal for reward in self.rewards}
is_terminal = {reward.is_terminal for reward in self.components}
self._is_terminal: Optional[bool] = None
if len(is_terminal) == 1:
self._is_terminal = next(iter(is_terminal))
Expand All @@ -335,9 +337,13 @@ def compute(self, terminated: bool, info: InfoType) -> Optional[float]:
"""Evaluate each individual reward component for the current state of
the environment, then aggregate them in one.
"""
# Early return depending on whether the reward and state are terminal
if self.is_terminal is not None and self.is_terminal ^ terminated:
return None

# Compute all reward components
values = []
for reward in self.rewards:
for reward in self.components:
# Evaluate reward
reward_info: InfoType = {}
value: Optional[float] = reward(terminated, reward_info)
Expand All @@ -354,3 +360,75 @@ def compute(self, terminated: bool, info: InfoType) -> Optional[float]:
reward_total = self._reduce_fn(values)

return reward_total


class ComposedJiminyEnv(
BasePipelineWrapper[ObsT, ActT, ObsT, ActT],
Generic[ObsT, ActT]):
"""Plug ad-hoc reward components and termination conditions to the
wrapped environment.
.. note::
This wrapper derives from `BasePipelineWrapper`, and such as, it is
considered as internal unlike `gym.Wrapper`. This means that it will be
taken into account when calling `evaluate` or `play_interactive` on the
wrapped environment.
"""
def __init__(self,
env: InterfaceJiminyEnv[ObsT, ActT],
*,
reward: AbstractReward) -> None:
# Make sure that the reward is linked to this environment
assert env is reward.env

# Backup user argument(s)
self.reward = reward

# Initialize base class
super().__init__(env)

# Bind observation and action of the base environment
assert self.observation_space.contains(self.env.observation)
assert self.action_space.contains(self.env.action)
self.observation = self.env.observation
self.action = self.env.action

def _initialize_action_space(self) -> None:
"""Configure the action space.
It simply copy the action space of the wrapped environment.
"""
self.action_space = self.env.action_space

def _initialize_observation_space(self) -> None:
"""Configure the observation space.
It simply copy the observation space of the wrapped environment.
"""
self.observation_space = self.env.observation_space

def refresh_observation(self, measurement: EngineObsType) -> None:
"""Compute high-level features based on the current wrapped
environment's observation.
It simply forwards the observation computed by the wrapped environment
without any processing.
:param measurement: Low-level measure from the environment to process
to get higher-level observation.
"""
self.env.refresh_observation(measurement)

def compute_command(self, action: ActT, command: np.ndarray) -> None:
"""Compute the motors efforts to apply on the robot.
It simply forwards the command computed by the wrapped environment
without any processing.
:param action: High-level target to achieve by means of the command.
:param command: Lower-level command to updated in-place.
"""
self.env.compute_command(action, command)

def compute_reward(self, terminated: bool, info: InfoType) -> float:
return self.reward(terminated, info)
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ def compute_reward(self,
By default, it returns 0.0 without extra information no matter what.
The user is expected to provide an appropriate reward on its own,
either by overloading this method or by wrapping the environment with
`ComposeReward` for modular environment pipeline design.
`ComposedJiminyEnv` for modular environment pipeline design.
:param terminated: Whether the episode has reached the terminal state
of the MDP at the current step. This flag can be
Expand Down
10 changes: 6 additions & 4 deletions python/gym_jiminy/common/gym_jiminy/common/bases/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,8 +364,9 @@ def __init__(self,

# Make sure that the environment is either some `ObservedJiminyEnv` or
# `ControlledJiminyEnv` block, or the base environment directly.
if isinstance(env, BasePipelineWrapper) and not isinstance(
env, (ObservedJiminyEnv, ControlledJiminyEnv)):
from gym_jiminy.common.bases.compositions import ComposedJiminyEnv

Check notice on line 367 in python/gym_jiminy/common/gym_jiminy/common/bases/pipeline.py

View check run for this annotation

Codacy Production / Codacy Static Code Analysis

python/gym_jiminy/common/gym_jiminy/common/bases/pipeline.py#L367

Import outside toplevel (gym_jiminy.common.bases.compositions.ComposedJiminyEnv)
if isinstance(env, BasePipelineWrapper) and not isinstance(env, (
ObservedJiminyEnv, ControlledJiminyEnv, ComposedJiminyEnv)):
raise TypeError(
"Observers can only be added on top of another observer, "
"controller, or a base environment itself.")
Expand Down Expand Up @@ -586,8 +587,9 @@ def __init__(self,

# Make sure that the environment is either some `ObservedJiminyEnv` or
# `ControlledJiminyEnv` block, or the base environment directly.
if isinstance(env, BasePipelineWrapper) and not isinstance(
env, (ObservedJiminyEnv, ControlledJiminyEnv)):
from gym_jiminy.common.bases.compositions import ComposedJiminyEnv
if isinstance(env, BasePipelineWrapper) and not isinstance(env, (
ObservedJiminyEnv, ControlledJiminyEnv, ComposedJiminyEnv)):
raise TypeError(
"Controllers can only be added on top of another observer, "
"controller, or a base environment itself.")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,28 +62,30 @@ class AdditiveMixtureReward(BaseMixtureReward):
"""

def __init__(self,
env: InterfaceJiminyEnv,
name: str,
rewards: Sequence[AbstractReward],
components: Sequence[AbstractReward],
weights: Optional[Sequence[float]] = None) -> None:
"""
:param env: Base or wrapped jiminy environment.
:param name: Desired name of the total reward.
:param rewards: Sequence of rewards to aggregate.
:param components: Sequence of reward components to aggregate.
:param weights: Sequence of weights associated with each reward
components, with the same ordering as 'rewards'.
components, with the same ordering as 'components'.
Optional: 1.0 for all reward components by default.
"""
# Handling of default arguments
if weights is None:
weights = (1.0,) * len(rewards)
weights = (1.0,) * len(components)

# Make sure that the weight sequence is consistent with the rewards
if len(weights) != len(rewards):
# Make sure that the weight sequence is consistent with the components
if len(weights) != len(components):
raise ValueError(
"Exactly one weight per reward component must be specified.")

# Determine whether the cumulative reward is normalized
weight_total = 0.0
for weight, reward in zip(weights, rewards):
for weight, reward in zip(weights, components):
if not reward.is_normalized:
LOGGER.warning(
"Reward '%s' is not normalized. Aggregating rewards that "
Expand All @@ -99,7 +101,7 @@ def __init__(self,
self.weights = weights

# Call base implementation
super().__init__(name, rewards, self._reduce, is_normalized)
super().__init__(env, name, components, self._reduce, is_normalized)

def _reduce(self, values: Sequence[Optional[float]]) -> Optional[float]:
"""Compute the weighted sum of all the reward components that has been
Expand All @@ -109,7 +111,7 @@ def _reduce(self, values: Sequence[Optional[float]]) -> Optional[float]:
:param values: Sequence of scalar value for reward components that has
been evaluated, `None` otherwise, with the same ordering
as 'rewards'.
as 'components'.
:returns: Scalar value if at least one of the reward component has been
evaluated, `None` otherwise.
Expand Down Expand Up @@ -144,18 +146,20 @@ class MultiplicativeMixtureReward(BaseMixtureReward):
"""

def __init__(self,
env: InterfaceJiminyEnv,
name: str,
rewards: Sequence[AbstractReward]
components: Sequence[AbstractReward]
) -> None:
"""
:param env: Base or wrapped jiminy environment.
:param name: Desired name of the reward.
:param rewards: Sequence of rewards to aggregate.
:param components: Sequence of reward components to aggregate.
"""
# Determine whether the cumulative reward is normalized
is_normalized = all(reward.is_normalized for reward in rewards)
is_normalized = all(reward.is_normalized for reward in components)

# Call base implementation
super().__init__(name, rewards, self._reduce, is_normalized)
super().__init__(env, name, components, self._reduce, is_normalized)

def _reduce(self, values: Sequence[Optional[float]]) -> Optional[float]:
"""Compute the product of all the reward components that has been
Expand All @@ -165,7 +169,7 @@ def _reduce(self, values: Sequence[Optional[float]]) -> Optional[float]:
:param values: Sequence of scalar value for reward components that has
been evaluated, `None` otherwise, with the same ordering
as 'rewards'.
as 'components'.
:returns: Scalar value if at least one of the reward component has been
evaluated, `None` otherwise.
Expand Down
Loading

0 comments on commit d165464

Please sign in to comment.