Skip to content

Commit

Permalink
[gym/common] Remove error-prone, confusing and slow '__getattr__' fal…
Browse files Browse the repository at this point in the history
…lback in pipelines.
  • Loading branch information
duburcqa committed Jun 4, 2024
1 parent 2f5adea commit 1cc15aa
Show file tree
Hide file tree
Showing 4 changed files with 33 additions and 40 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

from ..utils import DataNested
if TYPE_CHECKING:
from ..envs.generic import BaseJiminyEnv
from ..quantities import QuantityManager


Expand Down Expand Up @@ -220,6 +221,9 @@ def __init__(self, *args: Any, **kwargs: Any) -> None:
# Call super to allow mixing interfaces through multiple inheritance
super().__init__(*args, **kwargs)

# Define convenience proxy for quantity manager
self.quantities = self.unwrapped.quantities

def _setup(self) -> None:
"""Configure the observer-controller.
Expand Down Expand Up @@ -336,8 +340,9 @@ def _controller_handle(self,
self.__is_observation_refreshed = False

@property
def unwrapped(self) -> "InterfaceJiminyEnv":
"""Base environment of the pipeline.
def unwrapped(self) -> "BaseJiminyEnv":
"""The "underlying environment at the basis of the pipeline from which
this environment is part of.
"""
return self

Expand Down
47 changes: 15 additions & 32 deletions python/gym_jiminy/common/gym_jiminy/common/bases/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from collections import OrderedDict
from typing import (
Dict, Any, List, Optional, Tuple, Union, Generic, TypeVar, SupportsFloat,
Callable, cast)
Callable, cast, TYPE_CHECKING)

import numpy as np

Expand All @@ -37,6 +37,8 @@
from .blocks import BaseControllerBlock, BaseObserverBlock

from ..utils import DataNested, is_breakpoint, zeros, build_copyto, copy
if TYPE_CHECKING:
from ..envs.generic import BaseJiminyEnv


OtherObsT = TypeVar('OtherObsT', bound=DataNested)
Expand Down Expand Up @@ -96,22 +98,6 @@ def __init__(self,
# may be overwritten by derived classes afterward.
self._copyto_action: Callable[[ActT], None] = lambda action: None

def __getattr__(self, name: str) -> Any:
"""Convenient fallback attribute getter.
It enables to get access to the attribute and methods of the wrapped
environment directly without having to do it through `env`.
"""
return getattr(self.__getattribute__('env'), name)

def __dir__(self) -> List[str]:
"""Attribute lookup.
It is mainly used by autocomplete feature of Ipython. It is overloaded
to get consistent autocompletion wrt `getattr`.
"""
return [*super().__dir__(), *dir(self.env)]

@property
def render_mode(self) -> Optional[str]:
"""Rendering mode of the base environment.
Expand Down Expand Up @@ -143,9 +129,7 @@ def np_random(self, value: np.random.Generator) -> None:
self.env.np_random = value

@property
def unwrapped(self) -> InterfaceJiminyEnv:
"""Base environment of the pipeline.
"""
def unwrapped(self) -> "BaseJiminyEnv":
return self.env.unwrapped

@property
Expand Down Expand Up @@ -236,8 +220,7 @@ def step(self, # type: ignore[override]
self._copyto_action(action)

# Make sure that the pipeline has not change since last reset
env_derived = (
self.unwrapped.derived) # type: ignore[attr-defined]
env_derived = self.unwrapped.derived
if env_derived is not self:
raise RuntimeError(
"Pipeline environment has changed. Please call 'reset' "
Expand Down Expand Up @@ -532,14 +515,14 @@ def __init__(self,
# Register the observer's internal state and feature to the telemetry
if state is not None:
try:
self.env.register_variable( # type: ignore[attr-defined]
self.unwrapped.register_variable(
'state', state, None, self.observer.name)
except ValueError:
pass
self.env.register_variable('feature', # type: ignore[attr-defined]
self.observer.observation,
self.observer.fieldnames,
self.observer.name)
self.unwrapped.register_variable('feature',
self.observer.observation,
self.observer.fieldnames,
self.observer.name)

def _setup(self) -> None:
"""Configure the wrapper.
Expand Down Expand Up @@ -750,14 +733,14 @@ def __init__(self,
# Register the controller's internal state and target to the telemetry
if state is not None:
try:
self.env.register_variable( # type: ignore[attr-defined]
self.unwrapped.register_variable(
'state', state, None, self.controller.name)
except ValueError:
pass
self.env.register_variable('action', # type: ignore[attr-defined]
self.action,
self.controller.fieldnames,
self.controller.name)
self.unwrapped.register_variable('action',
self.action,
self.controller.fieldnames,
self.controller.name)

def _setup(self) -> None:
"""Configure the wrapper.
Expand Down
16 changes: 11 additions & 5 deletions python/gym_jiminy/common/gym_jiminy/common/envs/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,12 @@ class BaseJiminyEnv(InterfaceJiminyEnv[ObsT, ActT],
to implement one. It has been designed to be highly flexible and easy to
customize by overloading it to fit the vast majority of users' needs.
"""

derived: "InterfaceJiminyEnv"
"""Top-most block from which this environment is part of when leveraging
modular pipeline design capability.
"""

def __init__(self,
simulator: Simulator,
step_dt: float,
Expand Down Expand Up @@ -186,8 +192,8 @@ def __init__(self,
self.sensor_measurements: SensorMeasurementStackMap = OrderedDict(
self.robot.sensor_measurements)

# Top-most block of the pipeline to which the environment is part of
self.derived: InterfaceJiminyEnv = self
# Top-most block of the pipeline is the environment itself by default
self.derived = self

# Store references to the variables to register to the telemetry
self._registered_variables: MutableMappingT[
Expand Down Expand Up @@ -215,6 +221,9 @@ def __init__(self,
self.num_steps = np.array(-1, dtype=np.int64)
self._num_steps_beyond_terminate: Optional[int] = None

# Initialize a quantity manager for later use
self.quantities = QuantityManager(self)

# Initialize the interfaces through multiple inheritance
super().__init__() # Do not forward extra arguments, if any

Expand All @@ -233,9 +242,6 @@ def __init__(self,
"`BaseJiminyEnv.compute_command` must be overloaded in case "
"of custom action spaces.")

# Initialize a quantity manager for later use
self.quantities = QuantityManager(self)

# Define specialized operators for efficiency.
# Note that a partial view of observation corresponding to measurement
# must be extracted since only this one must be updated during refresh.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@ def __init__(self, # pylint: disable=unused-argument
self.human_only = human_only

# Extract proxies for convenience
assert isinstance(env.unwrapped, BaseJiminyEnv)
self._step_dt_rel = env.unwrapped.step_dt / speed_ratio

# Buffer to keep track of the last time `step` method was called
Expand Down

0 comments on commit 1cc15aa

Please sign in to comment.