From 71b887347ed07509762c5c22ec0fb25fdd5b2b83 Mon Sep 17 00:00:00 2001 From: Alexis Duburcq Date: Tue, 4 Jun 2024 16:39:41 +0200 Subject: [PATCH] [gym/common] Restrict usage of error-prone '__getattr__' fallback in pipelines. --- .../gym_jiminy/common/bases/interfaces.py | 23 +++++++++- .../gym_jiminy/common/bases/pipeline.py | 45 ++++++++++++------- .../common/gym_jiminy/common/envs/generic.py | 31 +++++++------ .../common/wrappers/observation_stack.py | 4 ++ .../toolbox/wrappers/frame_rate_limiter.py | 1 - .../unit_py/test_pipeline_control.py | 6 ++- .../unit_py/test_pipeline_design.py | 1 + 7 files changed, 77 insertions(+), 34 deletions(-) diff --git a/python/gym_jiminy/common/gym_jiminy/common/bases/interfaces.py b/python/gym_jiminy/common/gym_jiminy/common/bases/interfaces.py index f42efc7f7..d01559d78 100644 --- a/python/gym_jiminy/common/gym_jiminy/common/bases/interfaces.py +++ b/python/gym_jiminy/common/gym_jiminy/common/bases/interfaces.py @@ -17,6 +17,7 @@ from ..utils import DataNested if TYPE_CHECKING: + from ..envs.generic import BaseJiminyEnv from ..quantities import QuantityManager @@ -220,6 +221,9 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: # Call super to allow mixing interfaces through multiple inheritance super().__init__(*args, **kwargs) + # Define convenience proxy for quantity manager + self.quantities = self.unwrapped.quantities + def _setup(self) -> None: """Configure the observer-controller. @@ -335,9 +339,24 @@ def _controller_handle(self, # '_controller_handle' as it is never called more often than necessary. self.__is_observation_refreshed = False + def stop(self) -> None: + """Stop the episode immediately without waiting for a termination or + truncation condition to be satisfied. + + .. note:: + This method is mainly intended for data analysis and debugging. + Stopping the episode is necessary to log the final state, otherwise + it will be missing from plots and viewer replay. Moreover, sensor + data will not be available during replay using object-oriented + method `replay`. Helper method `play_logs_data` must be preferred + to replay an episode that cannot be stopped at the time being. + """ + self.simulator.stop() + @property - def unwrapped(self) -> "InterfaceJiminyEnv": - """Base environment of the pipeline. + def unwrapped(self) -> "BaseJiminyEnv": + """The "underlying environment at the basis of the pipeline from which + this environment is part of. """ return self diff --git a/python/gym_jiminy/common/gym_jiminy/common/bases/pipeline.py b/python/gym_jiminy/common/gym_jiminy/common/bases/pipeline.py index c1cc53caf..176a6363a 100644 --- a/python/gym_jiminy/common/gym_jiminy/common/bases/pipeline.py +++ b/python/gym_jiminy/common/gym_jiminy/common/bases/pipeline.py @@ -10,13 +10,14 @@ eventually already wrapped, so that it appears as a black-box environment. """ import math +import logging from weakref import ref from copy import deepcopy from abc import abstractmethod from collections import OrderedDict from typing import ( Dict, Any, List, Optional, Tuple, Union, Generic, TypeVar, SupportsFloat, - Callable, cast) + Callable, cast, TYPE_CHECKING) import numpy as np @@ -37,6 +38,8 @@ from .blocks import BaseControllerBlock, BaseObserverBlock from ..utils import DataNested, is_breakpoint, zeros, build_copyto, copy +if TYPE_CHECKING: + from ..envs.generic import BaseJiminyEnv OtherObsT = TypeVar('OtherObsT', bound=DataNested) @@ -46,6 +49,9 @@ TransformedActT = TypeVar('TransformedActT', bound=DataNested) +LOGGER = logging.getLogger(__name__) + + class BasePipelineWrapper( InterfaceJiminyEnv[ObsT, ActT], Generic[ObsT, ActT, BaseObsT, BaseActT]): @@ -101,7 +107,17 @@ def __getattr__(self, name: str) -> Any: It enables to get access to the attribute and methods of the wrapped environment directly without having to do it through `env`. + + .. warning:: + This fallback incurs a significant runtime overhead. As such, it + must only be used for debug and manual analysis between episodes. + Calling this method if a simulation is already running would + trigger a warning to avoid relying on it by mistake. """ + if self.is_simulation_running: + LOGGER.warning( + "Relying on fallback attribute getter is inefficient and " + "strongly discouraged at runtime.") return getattr(self.__getattribute__('env'), name) def __dir__(self) -> List[str]: @@ -143,9 +159,7 @@ def np_random(self, value: np.random.Generator) -> None: self.env.np_random = value @property - def unwrapped(self) -> InterfaceJiminyEnv: - """Base environment of the pipeline. - """ + def unwrapped(self) -> "BaseJiminyEnv": return self.env.unwrapped @property @@ -236,8 +250,7 @@ def step(self, # type: ignore[override] self._copyto_action(action) # Make sure that the pipeline has not change since last reset - env_derived = ( - self.unwrapped.derived) # type: ignore[attr-defined] + env_derived = self.unwrapped.derived if env_derived is not self: raise RuntimeError( "Pipeline environment has changed. Please call 'reset' " @@ -532,14 +545,14 @@ def __init__(self, # Register the observer's internal state and feature to the telemetry if state is not None: try: - self.env.register_variable( # type: ignore[attr-defined] + self.unwrapped.register_variable( 'state', state, None, self.observer.name) except ValueError: pass - self.env.register_variable('feature', # type: ignore[attr-defined] - self.observer.observation, - self.observer.fieldnames, - self.observer.name) + self.unwrapped.register_variable('feature', + self.observer.observation, + self.observer.fieldnames, + self.observer.name) def _setup(self) -> None: """Configure the wrapper. @@ -750,14 +763,14 @@ def __init__(self, # Register the controller's internal state and target to the telemetry if state is not None: try: - self.env.register_variable( # type: ignore[attr-defined] + self.unwrapped.register_variable( 'state', state, None, self.controller.name) except ValueError: pass - self.env.register_variable('action', # type: ignore[attr-defined] - self.action, - self.controller.fieldnames, - self.controller.name) + self.unwrapped.register_variable('action', + self.action, + self.controller.fieldnames, + self.controller.name) def _setup(self) -> None: """Configure the wrapper. diff --git a/python/gym_jiminy/common/gym_jiminy/common/envs/generic.py b/python/gym_jiminy/common/gym_jiminy/common/envs/generic.py index bce0eb64e..397583edb 100644 --- a/python/gym_jiminy/common/gym_jiminy/common/envs/generic.py +++ b/python/gym_jiminy/common/gym_jiminy/common/envs/generic.py @@ -100,6 +100,12 @@ class BaseJiminyEnv(InterfaceJiminyEnv[ObsT, ActT], to implement one. It has been designed to be highly flexible and easy to customize by overloading it to fit the vast majority of users' needs. """ + + derived: "InterfaceJiminyEnv" + """Top-most block from which this environment is part of when leveraging + modular pipeline design capability. + """ + def __init__(self, simulator: Simulator, step_dt: float, @@ -186,8 +192,8 @@ def __init__(self, self.sensor_measurements: SensorMeasurementStackMap = OrderedDict( self.robot.sensor_measurements) - # Top-most block of the pipeline to which the environment is part of - self.derived: InterfaceJiminyEnv = self + # Top-most block of the pipeline is the environment itself by default + self.derived = self # Store references to the variables to register to the telemetry self._registered_variables: MutableMappingT[ @@ -215,6 +221,9 @@ def __init__(self, self.num_steps = np.array(-1, dtype=np.int64) self._num_steps_beyond_terminate: Optional[int] = None + # Initialize a quantity manager for later use + self.quantities = QuantityManager(self) + # Initialize the interfaces through multiple inheritance super().__init__() # Do not forward extra arguments, if any @@ -233,9 +242,6 @@ def __init__(self, "`BaseJiminyEnv.compute_command` must be overloaded in case " "of custom action spaces.") - # Initialize a quantity manager for later use - self.quantities = QuantityManager(self) - # Define specialized operators for efficiency. # Note that a partial view of observation corresponding to measurement # must be extracted since only this one must be updated during refresh. @@ -599,8 +605,8 @@ def reset(self, # type: ignore[override] if seed is not None: self._initialize_seed(seed) - # Stop the simulator - self.simulator.stop() + # Stop the episode if one is still running + self.stop() # Remove external forces, if any self.simulator.remove_all_forces() @@ -854,7 +860,7 @@ def step(self, # type: ignore[override] self.simulator.step(self.step_dt) except Exception: # Stop the simulation before raising the exception - self.simulator.stop() + self.stop() raise # Make sure there is no 'nan' value in observation @@ -1023,8 +1029,8 @@ def replay(self, **kwargs: Any) -> None: kwargs['close_backend'] = not self.simulator.is_viewer_available # Stop any running simulation before replay if `has_terminated` is True - if self.is_simulation_running and any(self.has_terminated({})): - self.simulator.stop() + if any(self.has_terminated({})): + self.stop() with viewer_lock: # Call render before replay in order to take into account custom @@ -1135,8 +1141,7 @@ def _interact(key: Optional[str] = None) -> bool: # Stop the simulation to unlock the robot. # It will enable to display contact forces for replay. - if self.simulator.is_simulation_running: - self.simulator.stop() + self.stop() # Disable play interactive mode flag and restore training flag self._is_interactive = False @@ -1213,7 +1218,7 @@ def evaluate(self, action = policy_fn(obs, reward, terminated or truncated, info) obs, reward, terminated, truncated, info = env.step(action) info_episode.append(info) - self.simulator.stop() + self.stop() except KeyboardInterrupt: pass diff --git a/python/gym_jiminy/common/gym_jiminy/common/wrappers/observation_stack.py b/python/gym_jiminy/common/gym_jiminy/common/wrappers/observation_stack.py index e700a63b1..6499d0c4b 100644 --- a/python/gym_jiminy/common/gym_jiminy/common/wrappers/observation_stack.py +++ b/python/gym_jiminy/common/gym_jiminy/common/wrappers/observation_stack.py @@ -146,6 +146,10 @@ def __init__(self, # Whether the stack has been shifted to the left since last update self._was_stack_shifted = True + # Bind action of the base environment + assert self.action_space.contains(self.env.action) + self.action = self.env.action + def _initialize_action_space(self) -> None: """Configure the action space. """ diff --git a/python/gym_jiminy/toolbox/gym_jiminy/toolbox/wrappers/frame_rate_limiter.py b/python/gym_jiminy/toolbox/gym_jiminy/toolbox/wrappers/frame_rate_limiter.py index 828c16a67..0567f49a8 100644 --- a/python/gym_jiminy/toolbox/gym_jiminy/toolbox/wrappers/frame_rate_limiter.py +++ b/python/gym_jiminy/toolbox/gym_jiminy/toolbox/wrappers/frame_rate_limiter.py @@ -44,7 +44,6 @@ def __init__(self, # pylint: disable=unused-argument self.human_only = human_only # Extract proxies for convenience - assert isinstance(env.unwrapped, BaseJiminyEnv) self._step_dt_rel = env.unwrapped.step_dt / speed_ratio # Buffer to keep track of the last time `step` method was called diff --git a/python/gym_jiminy/unit_py/test_pipeline_control.py b/python/gym_jiminy/unit_py/test_pipeline_control.py index c34202b7d..103d0131f 100644 --- a/python/gym_jiminy/unit_py/test_pipeline_control.py +++ b/python/gym_jiminy/unit_py/test_pipeline_control.py @@ -49,6 +49,7 @@ def _test_pid_standing(self): # Run the simulation while self.env.stepper_state.t < 9.0: self.env.step(action) + self.env.stop() # Export figure fd, pdf_path = mkstemp(prefix="plot_", suffix=".pdf") @@ -263,6 +264,7 @@ def test_pd_controller(self): env.unwrapped._height_neutral = float("-inf") while env.stepper_state.t < 2.0: env.step(0.2 * env.action_space.sample()) + env.stop() # Extract the target position and velocity of a single motor adapter_name, controller_name = adapter.name, controller.name @@ -284,9 +286,9 @@ def test_pd_controller(self): command_vel[(update_ratio-1)::update_ratio], atol=TOLERANCE) np.testing.assert_allclose( - target_accel_diff, target_accel[1:], atol=TOLERANCE) + target_accel_diff[:-1], target_accel[1:-1], atol=TOLERANCE) np.testing.assert_allclose( - target_vel_diff, target_vel[1:], atol=TOLERANCE) + target_vel_diff[:-1], target_vel[1:-1], atol=TOLERANCE) # Make sure that the position and velocity targets are within bounds motor = env.robot.motors[-1] diff --git a/python/gym_jiminy/unit_py/test_pipeline_design.py b/python/gym_jiminy/unit_py/test_pipeline_design.py index 0c39e76aa..973a18112 100644 --- a/python/gym_jiminy/unit_py/test_pipeline_design.py +++ b/python/gym_jiminy/unit_py/test_pipeline_design.py @@ -246,6 +246,7 @@ def configure_telemetry() -> InterfaceJiminyEnv: env.reset(seed=0, options=dict(reset_hook=configure_telemetry)) env.step(env.action) + env.stop() controller = env.env.env.env.controller assert isinstance(controller, PDController)