Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[gym/common] Add generic stacked quantity wrapper. #783

Merged
merged 5 commits into from
May 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions python/gym_jiminy/common/gym_jiminy/common/bases/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# pylint: disable=missing-module-docstring

from .quantity import (SharedCache,
QuantityCreator,
from .quantity import (QuantityCreator,
SharedCache,
AbstractQuantity)
from .interfaces import (DT_EPS,
ObsT,
Expand All @@ -27,8 +27,8 @@


__all__ = [
'SharedCache',
'QuantityCreator',
'SharedCache',
'AbstractQuantity',
'DT_EPS',
'ObsT',
Expand Down
47 changes: 29 additions & 18 deletions python/gym_jiminy/common/gym_jiminy/common/bases/interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,6 @@ def compute_command(self, action: ActT, command: BaseActT) -> None:

def compute_reward(self,
terminated: bool, # pylint: disable=unused-argument
truncated: bool, # pylint: disable=unused-argument
info: InfoType # pylint: disable=unused-argument
) -> float:
"""Compute the reward related to a specific control block.
Expand All @@ -159,9 +158,6 @@ def compute_reward(self,
:param terminated: Whether the episode has reached the terminal state
of the MDP at the current step. This flag can be
used to compute a specific terminal reward.
:param truncated: Whether a truncation condition outside the scope of
the MDP has been satisfied at the current step. This
flag can be used to adapt the reward.
:param info: Dictionary of extra information for monitoring.

:returns: Aggregated reward for the current step.
Expand All @@ -182,6 +178,7 @@ class InterfaceJiminyEnv(
"""Observer plus controller interface for both generic pipeline blocks,
including environments.
"""

metadata: Dict[str, Any] = {
"render_modes": (
['rgb_array'] + (['human'] if is_display_available() else []))
Expand All @@ -194,6 +191,16 @@ class InterfaceJiminyEnv(
sensor_measurements: SensorMeasurementStackMap
is_simulation_running: npt.NDArray[np.bool_]

num_steps: npt.NDArray[np.int64]
"""Number of simulation steps that has been performed since last reset of
the base environment.

.. note::
The counter is incremented before updating the observation at the end
of the step, and consequently, before evaluating the reward and the
termination conditions.
"""

quantities: "QuantityManager"

action: ActT
Expand Down Expand Up @@ -250,13 +257,30 @@ def _observer_handle(self,
:param v: Current extended velocity vector of the robot.
:param sensor_measurements: Current sensor data.
"""
# Early return if no simulation is running
if not self.is_simulation_running:
return

# Reset the quantity manager.
# In principle, the internal cache of quantities should be cleared each
# time the state of the robot and/or its derivative changes. This is
# hard to do because there is no way to detect this specifically at the
# time being. However, `_observer_handle` is never called twice in the
# exact same state by the engine, so resetting quantities at the
# beginning of the method should cover most cases. Yet, quantities
# cannot be used reliably in the definition of profile forces because
# they are always updated before the controller gets called, no matter
# if either one or the other is time-continuous. Hacking the internal
# dynamics to clear quantities does not address this issue either.
self.quantities.clear()

# Refresh the observation if not already done but only if a simulation
# is already running. It would be pointless to refresh the observation
# at this point since the controller will be called multiple times at
# start. Besides, it would defeat the purpose `_initialize_buffers`,
# that is supposed to be executed before `refresh_observation` is being
# called for the first time of an episode.
if not self.__is_observation_refreshed and self.is_simulation_running:
if not self.__is_observation_refreshed:
measurement = self.__measurement
measurement["t"][()] = t
measurement["states"]["agent"]["q"] = q
Expand Down Expand Up @@ -303,19 +327,6 @@ def _controller_handle(self,

:returns: Motors torques to apply on the robot.
"""
# Reset the quantity manager.
# In principle, the internal cache of quantities should be cleared not
# each time the state of the robot and/or its derivative changes. This
# is hard to do because there is no way to detect this specifically at
# the time being. However, `_controller_handle` is never called twice
# in the exact same state by the engine, so resetting quantities at the
# beginning of the method should cover most cases. Yet, quantities
# cannot be used reliably in the definition of profile forces because
# they are always updated before the controller gets called, no matter
# if either one or the other is time-continuous. Hacking the internal
# dynamics to clear quantities does not address this issue either.
self.quantities.clear()

# Refresh the observation
self._observer_handle(t, q, v, sensor_measurements)

Expand Down
10 changes: 4 additions & 6 deletions python/gym_jiminy/common/gym_jiminy/common/bases/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ def __init__(self,
self.robot_state = env.robot_state
self.sensor_measurements = env.sensor_measurements
self.is_simulation_running = env.is_simulation_running
self.num_steps = env.num_steps

# Backup the parent environment
self.env = env
Expand Down Expand Up @@ -254,7 +255,7 @@ def step(self, # type: ignore[override]
reward = float(reward)
if not math.isnan(reward):
try:
reward += self.compute_reward(terminated, truncated, info)
reward += self.compute_reward(terminated, info)
except NotImplementedError:
pass

Expand Down Expand Up @@ -736,11 +737,8 @@ def compute_command(self, action: ActT, command: np.ndarray) -> None:
# the right period.
self.env.compute_command(self.env.action, command)

def compute_reward(self,
terminated: bool,
truncated: bool,
info: InfoType) -> float:
return self.controller.compute_reward(terminated, truncated, info)
def compute_reward(self, terminated: bool, info: InfoType) -> float:
return self.controller.compute_reward(terminated, info)


class BaseTransformObservation(
Expand Down
39 changes: 31 additions & 8 deletions python/gym_jiminy/common/gym_jiminy/common/bases/quantity.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
quantities only once per step, and gathering similar quantities in a large
batch to leverage vectorization of math instructions.
"""
import re
import weakref
from weakref import ReferenceType
from abc import ABC, abstractmethod
Expand All @@ -26,8 +27,6 @@

ValueT = TypeVar('ValueT')

QuantityCreator = Tuple[Type["AbstractQuantity"], Dict[str, Any]]


class WeakMutableCollection(MutableSet, Generic[ValueT]):
"""Mutable unordered list container storing weak reference to objects.
Expand Down Expand Up @@ -115,7 +114,7 @@ class SharedCache(Generic[ValueT]):
This implementation is not thread safe.
"""

owners: WeakMutableCollection["AbstractQuantity"]
owners: WeakMutableCollection["AbstractQuantity[ValueT]"]
"""Owners of the shared buffer, ie quantities relying on it to store the
result of their evaluation. This information may be useful for determining
the most efficient computation path overall.
Expand Down Expand Up @@ -164,9 +163,15 @@ def has_value(self) -> bool:
def reset(self) -> None:
"""Clear value stored in cache if any.
"""
# Clear cache
self._value = None
self._has_value = False

# Refresh all owner quantities for which auto refresh has been enabled
for owner in self.owners:
if owner.auto_refresh:
owner.get()

def set(self, value: ValueT) -> None:
"""Set value in cache, silently overriding the existing value if any.

Expand Down Expand Up @@ -218,7 +223,8 @@ class AbstractQuantity(ABC, Generic[ValueT]):
def __init__(self,
env: InterfaceJiminyEnv,
parent: Optional["AbstractQuantity"],
requirements: Dict[str, QuantityCreator]) -> None:
requirements: Dict[str, "QuantityCreator"],
auto_refresh: bool) -> None:
"""
:param env: Base or wrapped jiminy environment.
:param parent: Higher-level quantity from which this quantity is a
Expand All @@ -227,11 +233,21 @@ def __init__(self,
depends for its evaluation, as a dictionary
whose keys are tuple gathering their respective
class and all their constructor keyword-arguments
except the environment 'env'.
except environment 'env' and parent 'parent.
:param auto_refresh: Whether this quantity must be refreshed
automatically as soon as its shared cache has been
cleared if specified, otherwise this does nothing.
"""
# Backup some of user argument(s)
self.env = env
self.parent = parent
self.auto_refresh = auto_refresh

# Make sure that all requirement names would be valid as property
requirement_names = requirements.keys()
if any(re.match('[^A-Za-z0-9_]', name) for name in requirement_names):
raise ValueError("The name of all quantity requirements should be "
"ASCII alphanumeric characters plus underscore.")

# Instantiate intermediary quantities if any
self.requirements: Dict[str, AbstractQuantity] = {
Expand All @@ -252,14 +268,14 @@ class and all their constructor keyword-arguments
# Whether the quantity must be re-initialized
self._is_initialized: bool = False

# Add getter of all intermediary quantities dynamically.
# Add getter for all intermediary quantities dynamically.
# This approach is hacky but much faster than any of other official
# approach, ie implementing custom a `__getattribute__` method or even
# worst a custom `__getattr__` method.
def get_value(name: str, quantity: AbstractQuantity) -> Any:
return quantity.requirements[name].get()

for name in self.requirements.keys():
for name in requirement_names:
setattr(type(self), name, property(partial(get_value, name)))

def __getattr__(self, name: str) -> Any:
Expand All @@ -274,7 +290,11 @@ def __getattr__(self, name: str) -> Any:

:param name: Name of the requested quantity.
"""
return self.__getattribute__('requirements')[name].get()
try:
return self.__getattribute__('requirements')[name].get()
except KeyError as e:
raise AttributeError(
f"'{type(self)}' object has no attribute '{name}'") from e

def __dir__(self) -> List[str]:
"""Attribute lookup.
Expand Down Expand Up @@ -452,3 +472,6 @@ def refresh(self) -> ValueT:
"""Evaluate this quantity based on the agent state at the end of the
current agent step.
"""


QuantityCreator = Tuple[Type[AbstractQuantity[ValueT]], Dict[str, Any]]
Loading
Loading