Skip to content

Commit

Permalink
[gym_jiminy/common] Add 'QuantityObserver' block. (#838)
Browse files Browse the repository at this point in the history
* [python/dynamics] Take into account stride offset when trajectory time is wrapping.
* [gym_jiminy/common] Fix trajectory file never closed if loading fails.
* [gym_jiminy/common] Add locking mechanism to trajectory databaset.
* [gym_jiminy/common] More robust pipeline registration mechanism.
* [gym_jiminy/common] Add composition wrapper before observer-controller blocks.
* [gym_jiminy/common] Enable env composition to augment the observation space with trajectory reference.
* [gym_jiminy/common] Add 'QuantityObserver' block.
* [gym_jiminy/common] Fix 'AdditiveMixtureReward' for 'order=inf'.
* [gym_jiminy/common] Support string representation of enums in pipeline config.
  • Loading branch information
duburcqa authored Nov 27, 2024
1 parent 526b909 commit 5f98eba
Show file tree
Hide file tree
Showing 21 changed files with 776 additions and 402 deletions.
412 changes: 280 additions & 132 deletions python/gym_jiminy/common/gym_jiminy/common/bases/pipeline.py

Large diffs are not rendered by default.

62 changes: 57 additions & 5 deletions python/gym_jiminy/common/gym_jiminy/common/bases/quantities.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,8 +377,8 @@ class InterfaceQuantity(Generic[ValueT], metaclass=ABCMeta):
.. warning::
The user is responsible for implementing the dunder methods `__eq__`
and `__hash__` that characterize identical quantities. This property is
used internally by `QuantityManager` to synchronize cache between
them. It is advised to use decorator `@dataclass(unsafe_hash=True)` for
used internally by `QuantityManager` to synchronize cache between them.
It is advised to use decorator `@dataclass(unsafe_hash=True)` for
convenience, but it can also be done manually.
"""

Expand Down Expand Up @@ -862,6 +862,9 @@ def __init__(self,
# Ordered set of named reference trajectories as a dictionary
self.registry: OrderedDict[str, Trajectory] = OrderedDict()

# Whether the dataset is locked, ie no traj can be added/discarded
self._lock = False

# Name of the trajectory that is currently selected
self._name = ""

Expand Down Expand Up @@ -920,6 +923,12 @@ def add(self, name: str, trajectory: Trajectory) -> None:
overwriting it by mistake.
:param trajectory: Trajectory instance to register.
"""
# Make sure that the dataset is not locked
if self._lock:
raise RuntimeError(
"Trajectory dataset already locked. Impossible to add any "
"trajectory.")

# Make sure that no trajectory with the exact same name already exists
if name in self.registry:
raise KeyError(
Expand Down Expand Up @@ -950,6 +959,12 @@ def discard(self, name: str) -> None:
:param name: Name of the trajectory to discard.
"""
# Make sure that the dataset is not locked
if self._lock:
raise RuntimeError(
"Trajectory dataset already locked. Impossible to discard any "
"trajectory.")

# Un-select trajectory if it corresponds to the discarded one
if self._name == name:
self._trajectory = None
Expand All @@ -958,14 +973,46 @@ def discard(self, name: str) -> None:
# Delete trajectory for global registry
del self.registry[name]

@sync
def clear(self) -> None:
"""Clear the trajectory dataset from the local internal registry of all
instances sharing the same cache as this quantity.
"""
# Make sure that the dataset is not locked
if self._lock:
raise RuntimeError(
"Trajectory dataset already locked. Impossible to clear the "
"dataset.")

# Un-select trajectory
self._trajectory = None
self._name = ""

# Delete the whole registry
self.registry.clear()

def __iter__(self) -> Iterator[Trajectory]:
"""Iterate over all the trajectories in the dataset.
"""
return iter(self.registry.values())

def __bool__(self) -> bool:
"""Whether the dataset of trajectory is currently empty.
"""
return bool(self.registry)

@sync
def select(self,
name: str,
mode: Literal['raise', 'wrap', 'clip'] = 'raise') -> None:
"""Jointly select a trajectory in the internal registry of all
instances sharing the same cache as this quantity.
"""Select an existing trajectory from the database shared synchronized
all managed quantities.
:param name: Name of the trajectory to discard.
.. note::
There is no way to select a different reference trajectory for
individual quantities at the time being.
:param name: Name of the trajectory to select.
:param mode: Specifies how to deal with query time of are out of the
time interval of the trajectory. See `Trajectory.get`
documentation for details.
Expand All @@ -984,6 +1031,11 @@ def select(self,
# Un-initialize quantity when the selected trajectory changes
self.reset(reset_tracking=False)

def lock(self) -> None:
"""Forbid adding/discarding trajectories to the dataset from now on.
"""
self._lock = True

@property
def name(self) -> str:
"""Name of the trajectory that is currently selected.
Expand Down
2 changes: 2 additions & 0 deletions python/gym_jiminy/common/gym_jiminy/common/blocks/__init__.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
# pylint: disable=missing-module-docstring

from .quantity_observer import QuantityObserver
from .mahony_filter import MahonyFilter
from .motor_safety_limit import MotorSafetyLimit
from .proportional_derivative_controller import PDController, PDAdapter
from .deformation_estimator import DeformationEstimator


__all__ = [
'QuantityObserver',
'MahonyFilter',
'MotorSafetyLimit',
'PDController',
Expand Down
115 changes: 115 additions & 0 deletions python/gym_jiminy/common/gym_jiminy/common/blocks/quantity_observer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
"""Implementation of Mahony filter block compatible with gym_jiminy
reinforcement learning pipeline environment design.
"""
from collections import OrderedDict
from typing import Any, Type, TypeVar, cast

import numpy as np
import gymnasium as gym

from jiminy_py import tree

from ..bases import (
BaseObs, BaseAct, BaseObserverBlock, InterfaceJiminyEnv, InterfaceQuantity)
from ..utils import DataNested, build_copyto


ValueT = TypeVar('ValueT', bound=DataNested)


def get_space(data: DataNested) -> gym.Space[DataNested]:
"""Infer space from a given value.
.. warning::
Beware that space inference is lossly. Firstly, one cannot discriminate
between `gym.spaces.Box` and other non-container spaces, e.g.
`gym.spaces.Discrete` or `gym.spaces.MultiBinary`. Because of this
limitation, it is assumed that all `np.ndarray` data has been sampled
by a `gym.spaces.Box` space. Secondly, it is impossible to determine
the bounds of the space, so it is assumed to be unbounded.
:param value: Any value sampled from a given space.
"""
data_type = type(data)
if tree.issubclass_mapping(data_type):
return gym.spaces.Dict(OrderedDict([
(field, get_space(value))
for field, value in data.items()])) # type: ignore[union-attr]
if tree.issubclass_sequence(data_type):
return gym.spaces.Tuple([get_space(value) for value in data])
assert isinstance(data, np.ndarray)
return gym.spaces.Box(
low=float("-inf"),
high=float("inf"),
shape=data.shape,
dtype=data.dtype.type)


class QuantityObserver(BaseObserverBlock[ValueT, None, BaseObs, BaseAct]):
"""Add a given pre-defined quantity to the observation of the environment.
.. warning::
The observation space of a quantity must be invariant. Yet, nothing
prevent the shape of the quantity to change dynamically. As a result,
it is up to user to make sure that does not occur in practice,
otherwise it will raise an exception.
"""
def __init__(self,
name: str,
env: InterfaceJiminyEnv[BaseObs, BaseAct],
quantity: Type[InterfaceQuantity[ValueT]],
*,
update_ratio: int = 1,
**kwargs: Any) -> None:
"""
:param name: Name of the block.
:param env: Environment to connect with.
:param quantity: Type of the quantity.
:param update_ratio: Ratio between the update period of the observer
and the one of the subsequent observer. -1 to
match the simulation timestep of the environment.
Optional: `1` by default.
:param kwargs: Additional arguments that will be forwarded to the
constructor of the quantity.
"""
# Add the quantity to the environment
env.quantities[name] = (quantity, kwargs)

# Define proxy for fast access
self.data = env.quantities.registry[name]

# Initialize the observer
super().__init__(name, env, update_ratio)

# Try to bind the memory of the quantity to the observation.
# Note that there is no guarantee that the quantity will be updated
# in-place without dynamic memory allocation, so it needs to be checked
# at run-time systematically and copy the value if necessary.
self.observation = self.data.get()

# Define specialized copyto operator for efficiency.
# This is necessary because there is no guarantee that the quantity
# will be updated in-place without dynamic memory allocation.
self._copyto_observation = build_copyto(self.observation)

def __del__(self) -> None:
try:
del self.env.quantities[self.name]
except Exception: # pylint: disable=broad-except
# This method must not fail under any circumstances
pass

def _initialize_observation_space(self) -> None:
# Let us infer the observation space from the value of the quantity.
# Note that it is always valid to fetch the value of a quantity, even
# if no simulation is running.
self.observation_space = cast(
gym.Space[ValueT], get_space(self.data.get()))

def refresh_observation(self, measurement: BaseObs) -> None:
# Evaluate the quantity
value = self.data.get()

# Update the observation in-place in case of dynamic memory allocation
if self.observation is not value:
self._copyto_observation(value)
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ def weighted_norm(weights: Tuple[float, ...],
if any_value:
total = max(total, weight * value)
else:
total = value
total = weight * value
else:
total += weight * math.pow(value, order)
any_value = True
Expand Down
Loading

0 comments on commit 5f98eba

Please sign in to comment.