From 24d528952e85cc774ea3a6d3a613b832bcf97441 Mon Sep 17 00:00:00 2001 From: Alexis Duburcq Date: Wed, 3 Apr 2024 10:06:57 +0200 Subject: [PATCH 1/2] [gym/common] Improve support of dynamic computation graph. --- .../gym_jiminy/common/bases/quantity.py | 19 +++++++++++-------- .../gym_jiminy/common/quantities/generic.py | 13 +++++++++---- python/jiminy_py/setup.py | 2 +- 3 files changed, 21 insertions(+), 13 deletions(-) diff --git a/python/gym_jiminy/common/gym_jiminy/common/bases/quantity.py b/python/gym_jiminy/common/gym_jiminy/common/bases/quantity.py index 39aa7c677..dc7abc9fa 100644 --- a/python/gym_jiminy/common/gym_jiminy/common/bases/quantity.py +++ b/python/gym_jiminy/common/gym_jiminy/common/bases/quantity.py @@ -214,13 +214,16 @@ def cache(self, cache: SharedCache[ValueT]) -> None: self._cache = cache self._has_cache = True - @property - def is_active(self) -> bool: + def is_active(self, any_cache_owner: bool = False) -> bool: """Whether this quantity is considered active, namely `initialize` has - been called at least once since previous tracking reset, either by this - exact instance or any identical quantity if shared cache is available. + been called at least once since previous tracking reset. + + :param any_owner: False to check only if this exact instance is active, + True if any of the identical quantities (sharing the + same cache) is considered sufficient. + Optional: False by default. """ - if self._cache is None: + if not any_cache_owner or self._cache is None: return self._is_active return any(owner._is_active for owner in self._cache.owners) @@ -229,8 +232,8 @@ def get(self) -> ValueT: evaluate it and store it in cache. This quantity is considered active as soon as this method has been - called at least once since previous tracking reset. The corresponding - property `is_active` will be true even before calling `initialize`. + called at least once since previous tracking reset. The method + `is_active` will be return true even before calling `initialize`. .. warning:: This method is not meant to be overloaded. @@ -289,7 +292,7 @@ def reset(self, reset_tracking: bool = False) -> None: # Reset all requirements first for quantity in self.requirements.values(): - quantity.reset() + quantity.reset(reset_tracking) # More work has to be done if shared cache is available and has value if self._has_cache: diff --git a/python/gym_jiminy/common/gym_jiminy/common/quantities/generic.py b/python/gym_jiminy/common/gym_jiminy/common/quantities/generic.py index 96ac38235..96e4c9f29 100644 --- a/python/gym_jiminy/common/gym_jiminy/common/quantities/generic.py +++ b/python/gym_jiminy/common/gym_jiminy/common/quantities/generic.py @@ -185,10 +185,15 @@ def initialize(self) -> None: self.frame_names = {self.parent.frame_name} if self.cache: for owner in self.cache.owners: - parent = owner.parent - assert isinstance(parent, EulerAnglesFrame) - if parent.is_active: - self.frame_names.add(parent.frame_name) + # We only consider active instances of `_BatchEulerAnglesFrame` + # instead of their corresponding parent `EulerAnglesFrame`. + # This is necessary because a derived quantity may feature + # `_BatchEulerAnglesFrame` as a requirement without actually + # relying on it depending on whether it is part of the optimal + # computation path at the time being or not. + if owner.is_active(any_cache_owner=False): + assert isinstance(owner.parent, EulerAnglesFrame) + self.frame_names.add(owner.parent.frame_name) # Re-allocate memory as the number of frames is not known in advance. # Note that Fortran memory layout (column-major) is used for speed up diff --git a/python/jiminy_py/setup.py b/python/jiminy_py/setup.py index c3e626745..2694ec7e8 100644 --- a/python/jiminy_py/setup.py +++ b/python/jiminy_py/setup.py @@ -108,7 +108,7 @@ def finalize_options(self) -> None: # Panda3d is NOT supported by PyPy even if built from source. # - 1.10.12 fixes numerous bugs # - 1.10.13 crashes when generating wheels on MacOS - "panda3d>=1.10.14", + "panda3d>=1.10.13", # Photo-realistic shader for Panda3d to improve rendering of meshes. # - 0.11.X is not backward compatible. "panda3d-simplepbr==0.11.2", From 08151e36ab44a58d38ae20c4d0b19e470fc86e75 Mon Sep 17 00:00:00 2001 From: Alexis Duburcq Date: Wed, 3 Apr 2024 14:01:58 +0200 Subject: [PATCH 2/2] [misc] Add quantity benchmark example script. --- .../gym_jiminy/examples/quantity_benchmark.py | 59 +++++++++++++++++++ 1 file changed, 59 insertions(+) create mode 100644 python/gym_jiminy/examples/quantity_benchmark.py diff --git a/python/gym_jiminy/examples/quantity_benchmark.py b/python/gym_jiminy/examples/quantity_benchmark.py new file mode 100644 index 000000000..6a029635c --- /dev/null +++ b/python/gym_jiminy/examples/quantity_benchmark.py @@ -0,0 +1,59 @@ +import timeit + +import matplotlib.pyplot as plt +import gymnasium as gym + +import gym_jiminy.common.bases.quantity +from gym_jiminy.common.bases import QuantityManager +from gym_jiminy.common.quantities import EulerAnglesFrame + +# Define number of samples for benchmarking +N_SAMPLES = 20000 + +# Disable caching by forcing `SharedCache.has_value` to always return `False` +setattr(gym_jiminy.common.bases.quantity.SharedCache, + "has_value", + property(lambda self: False)) + +# Instantiate a dummy environment +env = gym.make("gym_jiminy.envs:atlas") +env.reset() +env.step(env.action) + +# Define quantity manager and add quantities to benchmark +nframes = len(env.pinocchio_model.frames) +quantity_manager = QuantityManager( + env.simulator, + { + f"rpy_{i}": (EulerAnglesFrame, dict(frame_name=frame.name)) + for i, frame in enumerate(env.pinocchio_model.frames) + }) + +# Run the benchmark for all batch size +time_per_frame_all = [] +for i in range(1, nframes): + # Reset tracking + quantity_manager.reset(reset_tracking=True) + + # Fetch all quantities once to update dynamic computation graph + for j, quantity in enumerate(quantity_manager.quantities.values()): + quantity.get() + if i == j + 1: + break + + # Extract batched data buffer of `EulerAnglesFrame` quantities + shared_data = quantity.requirements['data'] + + # Benchmark computation of batched data buffer + duration = timeit.timeit( + 'shared_data.get()', number=N_SAMPLES, globals={ + "shared_data": shared_data + }) + time_per_frame_all.append(duration / N_SAMPLES / i * 1e9) + +# Plot the result +plt.figure() +plt.plot(time_per_frame_all) +plt.xlabel("Number of frames") +plt.ylabel("Average computation time per frame (ns)") +plt.show()