Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[gym_jiminy/common] Add 'QuantityObserver' block. #838

Merged
merged 9 commits into from
Nov 27, 2024
412 changes: 280 additions & 132 deletions python/gym_jiminy/common/gym_jiminy/common/bases/pipeline.py

Large diffs are not rendered by default.

62 changes: 57 additions & 5 deletions python/gym_jiminy/common/gym_jiminy/common/bases/quantities.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,8 +377,8 @@ class InterfaceQuantity(Generic[ValueT], metaclass=ABCMeta):
.. warning::
The user is responsible for implementing the dunder methods `__eq__`
and `__hash__` that characterize identical quantities. This property is
used internally by `QuantityManager` to synchronize cache between
them. It is advised to use decorator `@dataclass(unsafe_hash=True)` for
used internally by `QuantityManager` to synchronize cache between them.
It is advised to use decorator `@dataclass(unsafe_hash=True)` for
convenience, but it can also be done manually.
"""

Expand Down Expand Up @@ -862,6 +862,9 @@ def __init__(self,
# Ordered set of named reference trajectories as a dictionary
self.registry: OrderedDict[str, Trajectory] = OrderedDict()

# Whether the dataset is locked, ie no traj can be added/discarded
self._lock = False

# Name of the trajectory that is currently selected
self._name = ""

Expand Down Expand Up @@ -920,6 +923,12 @@ def add(self, name: str, trajectory: Trajectory) -> None:
overwriting it by mistake.
:param trajectory: Trajectory instance to register.
"""
# Make sure that the dataset is not locked
if self._lock:
raise RuntimeError(
"Trajectory dataset already locked. Impossible to add any "
"trajectory.")

# Make sure that no trajectory with the exact same name already exists
if name in self.registry:
raise KeyError(
Expand Down Expand Up @@ -950,6 +959,12 @@ def discard(self, name: str) -> None:

:param name: Name of the trajectory to discard.
"""
# Make sure that the dataset is not locked
if self._lock:
raise RuntimeError(
"Trajectory dataset already locked. Impossible to discard any "
"trajectory.")

# Un-select trajectory if it corresponds to the discarded one
if self._name == name:
self._trajectory = None
Expand All @@ -958,14 +973,46 @@ def discard(self, name: str) -> None:
# Delete trajectory for global registry
del self.registry[name]

@sync
def clear(self) -> None:
"""Clear the trajectory dataset from the local internal registry of all
instances sharing the same cache as this quantity.
"""
# Make sure that the dataset is not locked
if self._lock:
raise RuntimeError(
"Trajectory dataset already locked. Impossible to clear the "
"dataset.")

# Un-select trajectory
self._trajectory = None
self._name = ""

# Delete the whole registry
self.registry.clear()

def __iter__(self) -> Iterator[Trajectory]:
"""Iterate over all the trajectories in the dataset.
"""
return iter(self.registry.values())

def __bool__(self) -> bool:
"""Whether the dataset of trajectory is currently empty.
"""
return bool(self.registry)

@sync
def select(self,
name: str,
mode: Literal['raise', 'wrap', 'clip'] = 'raise') -> None:
"""Jointly select a trajectory in the internal registry of all
instances sharing the same cache as this quantity.
"""Select an existing trajectory from the database shared synchronized
all managed quantities.

:param name: Name of the trajectory to discard.
.. note::
There is no way to select a different reference trajectory for
individual quantities at the time being.

:param name: Name of the trajectory to select.
:param mode: Specifies how to deal with query time of are out of the
time interval of the trajectory. See `Trajectory.get`
documentation for details.
Expand All @@ -984,6 +1031,11 @@ def select(self,
# Un-initialize quantity when the selected trajectory changes
self.reset(reset_tracking=False)

def lock(self) -> None:
"""Forbid adding/discarding trajectories to the dataset from now on.
"""
self._lock = True

@property
def name(self) -> str:
"""Name of the trajectory that is currently selected.
Expand Down
2 changes: 2 additions & 0 deletions python/gym_jiminy/common/gym_jiminy/common/blocks/__init__.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
# pylint: disable=missing-module-docstring

from .quantity_observer import QuantityObserver
from .mahony_filter import MahonyFilter
from .motor_safety_limit import MotorSafetyLimit
from .proportional_derivative_controller import PDController, PDAdapter
from .deformation_estimator import DeformationEstimator


__all__ = [
'QuantityObserver',
'MahonyFilter',
'MotorSafetyLimit',
'PDController',
Expand Down
115 changes: 115 additions & 0 deletions python/gym_jiminy/common/gym_jiminy/common/blocks/quantity_observer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
"""Implementation of Mahony filter block compatible with gym_jiminy
reinforcement learning pipeline environment design.
"""
from collections import OrderedDict
from typing import Any, Type, TypeVar, cast

import numpy as np
import gymnasium as gym

from jiminy_py import tree

from ..bases import (
BaseObs, BaseAct, BaseObserverBlock, InterfaceJiminyEnv, InterfaceQuantity)
from ..utils import DataNested, build_copyto


ValueT = TypeVar('ValueT', bound=DataNested)


def get_space(data: DataNested) -> gym.Space[DataNested]:
"""Infer space from a given value.

.. warning::
Beware that space inference is lossly. Firstly, one cannot discriminate
between `gym.spaces.Box` and other non-container spaces, e.g.
`gym.spaces.Discrete` or `gym.spaces.MultiBinary`. Because of this
limitation, it is assumed that all `np.ndarray` data has been sampled
by a `gym.spaces.Box` space. Secondly, it is impossible to determine
the bounds of the space, so it is assumed to be unbounded.

:param value: Any value sampled from a given space.
"""
data_type = type(data)
if tree.issubclass_mapping(data_type):
return gym.spaces.Dict(OrderedDict([
(field, get_space(value))
for field, value in data.items()])) # type: ignore[union-attr]
if tree.issubclass_sequence(data_type):
return gym.spaces.Tuple([get_space(value) for value in data])
assert isinstance(data, np.ndarray)
return gym.spaces.Box(
low=float("-inf"),
high=float("inf"),
shape=data.shape,
dtype=data.dtype.type)


class QuantityObserver(BaseObserverBlock[ValueT, None, BaseObs, BaseAct]):
"""Add a given pre-defined quantity to the observation of the environment.

.. warning::
The observation space of a quantity must be invariant. Yet, nothing
prevent the shape of the quantity to change dynamically. As a result,
it is up to user to make sure that does not occur in practice,
otherwise it will raise an exception.
"""
def __init__(self,
name: str,
env: InterfaceJiminyEnv[BaseObs, BaseAct],
quantity: Type[InterfaceQuantity[ValueT]],
*,
update_ratio: int = 1,
**kwargs: Any) -> None:
"""
:param name: Name of the block.
:param env: Environment to connect with.
:param quantity: Type of the quantity.
:param update_ratio: Ratio between the update period of the observer
and the one of the subsequent observer. -1 to
match the simulation timestep of the environment.
Optional: `1` by default.
:param kwargs: Additional arguments that will be forwarded to the
constructor of the quantity.
"""
# Add the quantity to the environment
env.quantities[name] = (quantity, kwargs)

# Define proxy for fast access
self.data = env.quantities.registry[name]

# Initialize the observer
super().__init__(name, env, update_ratio)

# Try to bind the memory of the quantity to the observation.
# Note that there is no guarantee that the quantity will be updated
# in-place without dynamic memory allocation, so it needs to be checked
# at run-time systematically and copy the value if necessary.
self.observation = self.data.get()

# Define specialized copyto operator for efficiency.
# This is necessary because there is no guarantee that the quantity
# will be updated in-place without dynamic memory allocation.
self._copyto_observation = build_copyto(self.observation)

def __del__(self) -> None:
try:
del self.env.quantities[self.name]
except Exception: # pylint: disable=broad-except
# This method must not fail under any circumstances
pass

def _initialize_observation_space(self) -> None:
# Let us infer the observation space from the value of the quantity.
# Note that it is always valid to fetch the value of a quantity, even
# if no simulation is running.
self.observation_space = cast(
gym.Space[ValueT], get_space(self.data.get()))

def refresh_observation(self, measurement: BaseObs) -> None:
# Evaluate the quantity
value = self.data.get()

# Update the observation in-place in case of dynamic memory allocation
if self.observation is not value:
self._copyto_observation(value)
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ def weighted_norm(weights: Tuple[float, ...],
if any_value:
total = max(total, weight * value)
else:
total = value
total = weight * value
else:
total += weight * math.pow(value, order)
any_value = True
Expand Down
Loading