Skip to content

Commit

Permalink
[gym/common] Add generic stacked quantity wrapper.
Browse files Browse the repository at this point in the history
  • Loading branch information
duburcqa committed May 1, 2024
1 parent c76b729 commit 320e48e
Show file tree
Hide file tree
Showing 9 changed files with 417 additions and 170 deletions.
43 changes: 29 additions & 14 deletions python/gym_jiminy/common/gym_jiminy/common/bases/interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,7 @@ class InterfaceJiminyEnv(
"""Observer plus controller interface for both generic pipeline blocks,
including environments.
"""

metadata: Dict[str, Any] = {
"render_modes": (
['rgb_array'] + (['human'] if is_display_available() else []))
Expand All @@ -190,6 +191,16 @@ class InterfaceJiminyEnv(
sensor_measurements: SensorMeasurementStackMap
is_simulation_running: npt.NDArray[np.bool_]

num_steps: npt.NDArray[np.int64]
"""Number of simulation steps that has been performed since last reset of
the base environment.
.. note::
The counter is incremented before updating the observation at the end
of the step, and consequently, before evaluating the reward and the
termination conditions.
"""

quantities: "QuantityManager"

action: ActT
Expand Down Expand Up @@ -246,13 +257,30 @@ def _observer_handle(self,
:param v: Current extended velocity vector of the robot.
:param sensor_measurements: Current sensor data.
"""
# Early return if no simulation is running
if not self.is_simulation_running:
return

# Reset the quantity manager.
# In principle, the internal cache of quantities should be cleared each
# time the state of the robot and/or its derivative changes. This is
# hard to do because there is no way to detect this specifically at the
# time being. However, `_observer_handle` is never called twice in the
# exact same state by the engine, so resetting quantities at the
# beginning of the method should cover most cases. Yet, quantities
# cannot be used reliably in the definition of profile forces because
# they are always updated before the controller gets called, no matter
# if either one or the other is time-continuous. Hacking the internal
# dynamics to clear quantities does not address this issue either.
self.quantities.clear()

# Refresh the observation if not already done but only if a simulation
# is already running. It would be pointless to refresh the observation
# at this point since the controller will be called multiple times at
# start. Besides, it would defeat the purpose `_initialize_buffers`,
# that is supposed to be executed before `refresh_observation` is being
# called for the first time of an episode.
if not self.__is_observation_refreshed and self.is_simulation_running:
if not self.__is_observation_refreshed:
measurement = self.__measurement
measurement["t"][()] = t
measurement["states"]["agent"]["q"] = q
Expand Down Expand Up @@ -299,19 +327,6 @@ def _controller_handle(self,
:returns: Motors torques to apply on the robot.
"""
# Reset the quantity manager.
# In principle, the internal cache of quantities should be cleared each
# time the state of the robot and/or its derivative changes. This is
# hard to do because there is no way to detect this specifically at the
# time being. However, `_controller_handle` is never called twice in
# the exact same state by the engine, so resetting quantities at the
# beginning of the method should cover most cases. Yet, quantities
# cannot be used reliably in the definition of profile forces because
# they are always updated before the controller gets called, no matter
# if either one or the other is time-continuous. Hacking the internal
# dynamics to clear quantities does not address this issue either.
self.quantities.clear()

# Refresh the observation
self._observer_handle(t, q, v, sensor_measurements)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ def __init__(self,
self.robot_state = env.robot_state
self.sensor_measurements = env.sensor_measurements
self.is_simulation_running = env.is_simulation_running
self.num_steps = env.num_steps

# Backup the parent environment
self.env = env
Expand Down
9 changes: 5 additions & 4 deletions python/gym_jiminy/common/gym_jiminy/common/bases/quantity.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,6 @@

ValueT = TypeVar('ValueT')

QuantityCreator = Tuple[Type["AbstractQuantity[ValueT]"], Dict[str, Any]]


class WeakMutableCollection(MutableSet, Generic[ValueT]):
"""Mutable unordered list container storing weak reference to objects.
Expand Down Expand Up @@ -224,7 +222,7 @@ class AbstractQuantity(ABC, Generic[ValueT]):
def __init__(self,
env: InterfaceJiminyEnv,
parent: Optional["AbstractQuantity"],
requirements: Dict[str, QuantityCreator],
requirements: Dict[str, "QuantityCreator"],
auto_refresh: bool) -> None:
"""
:param env: Base or wrapped jiminy environment.
Expand All @@ -234,7 +232,7 @@ def __init__(self,
depends for its evaluation, as a dictionary
whose keys are tuple gathering their respective
class and all their constructor keyword-arguments
except the environment 'env'.
except environment 'env' and parent 'parent.
:param auto_refresh: Whether this quantity must be refreshed
automatically as soon as its shared cache has been
cleared if specified, otherwise this does nothing.
Expand Down Expand Up @@ -463,3 +461,6 @@ def refresh(self) -> ValueT:
"""Evaluate this quantity based on the agent state at the end of the
current agent step.
"""


QuantityCreator = Tuple[Type[AbstractQuantity[ValueT]], Dict[str, Any]]
9 changes: 5 additions & 4 deletions python/gym_jiminy/common/gym_jiminy/common/envs/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,8 @@
build_contains,
get_fieldnames,
register_variables)
from ..bases import (ObsT,
from ..bases import (DT_EPS,
ObsT,
ActT,
InfoType,
SensorMeasurementStackMap,
Expand Down Expand Up @@ -235,7 +236,7 @@ def __init__(self,
self.total_reward = 0.0

# Number of simulation steps performed
self.num_steps = -1
self.num_steps = np.array(-1, dtype=np.int64)
self._num_steps_beyond_terminate: Optional[int] = None

# Initialize the interfaces through multiple inheritance
Expand Down Expand Up @@ -715,7 +716,7 @@ def reset(self, # type: ignore[override]
self.simulator.reset(remove_all_forces=False)

# Reset some internal buffers
self.num_steps = 0
self.num_steps[()] = 0
self._num_steps_beyond_terminate = None

# Create a new log file
Expand Down Expand Up @@ -962,7 +963,7 @@ def step(self, # type: ignore[override]
terminated, truncated = self.has_terminated(self._info)
truncated = (
truncated or not self.is_simulation_running or
self.stepper_state.t >= self.simulation_duration_max)
self.stepper_state.t + DT_EPS > self.simulation_duration_max)

# Check if stepping after done and if it is an undefined behavior
if self._num_steps_beyond_terminate is None:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,14 +1,17 @@
# pylint: disable=missing-module-docstring

from .manager import QuantityManager
from .generic import AverageFrameSpatialVelocity, FrameEulerAngles
from .generic import (FrameEulerAngles,
FrameXYZQuat,
AverageFrameSpatialVelocity)
from .locomotion import CenterOfMass, ZeroMomentPoint


__all__ = [
'QuantityManager',
'AverageFrameSpatialVelocity',
'FrameEulerAngles',
'FrameXYZQuat',
'AverageFrameSpatialVelocity',
'CenterOfMass',
'ZeroMomentPoint',
]
Loading

0 comments on commit 320e48e

Please sign in to comment.