From d7a2928bea6f7c691d995cbaa5d3b84a6fd9b3d7 Mon Sep 17 00:00:00 2001 From: Alexis Duburcq Date: Sat, 6 Jul 2024 19:54:51 +0200 Subject: [PATCH] [gym/common] Fix quantity hash collision issue in quantity manager. --- .../gym_jiminy/common/bases/pipeline.py | 24 ++- .../gym_jiminy/common/bases/quantities.py | 190 ++++++++++-------- .../gym_jiminy/common/compositions/mixin.py | 19 +- .../gym_jiminy/common/quantities/generic.py | 25 ++- .../gym_jiminy/common/quantities/manager.py | 50 +++-- .../gym_jiminy/common/quantities/transform.py | 7 +- python/gym_jiminy/unit_py/test_quantities.py | 7 +- 7 files changed, 195 insertions(+), 127 deletions(-) diff --git a/python/gym_jiminy/common/gym_jiminy/common/bases/pipeline.py b/python/gym_jiminy/common/gym_jiminy/common/bases/pipeline.py index 6aaab7262..dfaebb6dc 100644 --- a/python/gym_jiminy/common/gym_jiminy/common/bases/pipeline.py +++ b/python/gym_jiminy/common/gym_jiminy/common/bases/pipeline.py @@ -448,9 +448,10 @@ def refresh_observation(self, measurement: EngineObsType) -> None: self.env.refresh_observation(measurement) def has_terminated(self, info: InfoType) -> Tuple[bool, bool]: - """Determine whether the episode is over, because a terminal state of - the underlying MDP has been reached or an aborting condition outside - the scope of the MDP has been triggered. + """Determine whether the practitioner is instructed to stop the ongoing + episode on the spot because a termination condition has been triggered, + either coming from the based environment or from the ad-hoc termination + conditions that has been plugged on top of it. At each step of the wrapped environment, all its termination conditions will be evaluated sequentially until one of them eventually gets @@ -465,6 +466,9 @@ def has_terminated(self, info: InfoType) -> Tuple[bool, bool]: This method is called after `refresh_observation`, so that the internal buffer 'observation' is up-to-date. + .. seealso:: + See `InterfaceJiminyEnv.has_terminated` documentation for details. + :param info: Dictionary of extra information for monitoring. :returns: terminated and truncated flags. @@ -492,7 +496,19 @@ def compute_command(self, action: ActT, command: np.ndarray) -> None: self.env.compute_command(action, command) def compute_reward(self, terminated: bool, info: InfoType) -> float: - """ TODO: Write documentation. + """Compute the total reward, ie the sum of the original reward from the + wrapped environment with the ad-hoc reward components that has been + plugged on top of it. + + .. seealso:: + See `InterfaceController.compute_reward` documentation for details. + + :param terminated: Whether the episode has reached the terminal state + of the MDP at the current step. This flag can be + used to compute a specific terminal reward. + :param info: Dictionary of extra information for monitoring. + + :returns: Aggregated reward for the current step. """ # Compute base reward reward = self.env.compute_reward(terminated, info) diff --git a/python/gym_jiminy/common/gym_jiminy/common/bases/quantities.py b/python/gym_jiminy/common/gym_jiminy/common/bases/quantities.py index a2d832acc..4bbf936f8 100644 --- a/python/gym_jiminy/common/gym_jiminy/common/bases/quantities.py +++ b/python/gym_jiminy/common/gym_jiminy/common/bases/quantities.py @@ -20,10 +20,10 @@ from collections import OrderedDict from collections.abc import MutableSet from dataclasses import dataclass, replace -from functools import partial, wraps +from functools import wraps from typing import ( Any, Dict, List, Optional, Tuple, Generic, TypeVar, Type, Iterator, - Iterable, Callable, Literal, ClassVar, cast, TYPE_CHECKING) + Collection, Callable, Literal, ClassVar, TYPE_CHECKING) import numpy as np @@ -51,7 +51,7 @@ class WeakMutableCollection(MutableSet, Generic[ValueT]): __slots__ = ("_callback", "_weakrefs") def __init__(self, callback: Optional[Callable[[ - "WeakMutableCollection[ValueT]", ReferenceType[ValueT] + "WeakMutableCollection[ValueT]", ReferenceType ], None]] = None) -> None: """ :param callback: Callback that will be triggered every time an element @@ -59,9 +59,9 @@ def __init__(self, callback: Optional[Callable[[ Optional: None by default. """ self._callback = callback - self._weakrefs: List[ReferenceType[ValueT]] = [] + self._weakrefs: List[ReferenceType] = [] - def __callback__(self, ref: ReferenceType[ValueT]) -> None: + def __callback__(self, ref: ReferenceType) -> None: """Internal method that will be called every time an element must be discarded from the containers, either because it was requested by the user or because no strong reference to the value exists anymore. @@ -128,21 +128,31 @@ def discard(self, value: ValueT) -> None: class QuantityStateMachine(IntEnum): + """Specify the current state of a given (unique) quantity, which determines + the steps to perform for retrieving its current value. + """ + IS_RESET = 0 - """ TODO: Write documentation. + """The quantity at hands has just been reset. The quantity must first be + initialized, then refreshed and finally stored in cached before to retrieve + its value. """ IS_INITIALIZED = 1 - """ TODO: Write documentation. + """The quantity at hands has been initialized but never evaluated for the + current robot state. Its value must still be refreshed and stored in cache + before to retrieve it. """ IS_CACHED = 2 - """ TODO: Write documentation. + """The quantity at hands has been evaluated and its value stored in cache. + As such, its value can be retrieve from cache directly. """ # Define proxies for fast lookup -_IS_RESET, _IS_INITIALIZED, _IS_CACHED = QuantityStateMachine +_IS_RESET, _IS_INITIALIZED, _IS_CACHED = ( # pylint: disable=invalid-name + QuantityStateMachine) class SharedCache(Generic[ValueT]): @@ -159,7 +169,7 @@ class SharedCache(Generic[ValueT]): __slots__ = ( "_value", "_weakrefs", "_owner", "_auto_refresh", "sm_state", "owners") - owners: WeakMutableCollection["InterfaceQuantity[ValueT]"] + owners: Collection["InterfaceQuantity[ValueT]"] """Owners of the shared buffer, ie quantities relying on it to store the result of their evaluation. This information may be useful for determining the most efficient computation path overall. @@ -191,7 +201,7 @@ def __init__(self) -> None: # Define callback to reset part of the computation graph whenever a # quantity owning the cache gets garbage collected, namely all # quantities that may assume at some point the existence of this - # deleted owner to find the adjust their computation path. + # deleted owner to adjust their computation path. def _callback( self: WeakMutableCollection["InterfaceQuantity[ValueT]"], ref: ReferenceType[ # pylint: disable=unused-argument @@ -200,61 +210,87 @@ def _callback( for owner in self: # Stop going up in parent chain if dynamic computation graph # update is disable for efficiency. - while owner.allow_update_graph and owner.parent is not None: + while (owner.allow_update_graph and + owner.parent is not None and owner.parent.has_cache): owner = owner.parent - owner.reset(reset_tracking=True, - ignore_auto_refresh=True, - update_graph=True) + owner.reset(reset_tracking=True) # Initialize weak reference to owning quantities self._weakrefs = WeakMutableCollection(_callback) # Maintain alive owning quantities upon reset - self.owners: Iterable["InterfaceQuantity[ValueT]"] = self._weakrefs + self.owners = self._weakrefs self._owner: Optional["InterfaceQuantity[ValueT]"] = None def add(self, owner: "InterfaceQuantity[ValueT]") -> None: - """ TODO: Write documentation. + """Add a given quantity instance to the set of co-owners associated + with the shared cache at hands. + + .. warning:: + All shared cache co-owners must be instances of the same unique + quantity. An exception will be thrown if an attempt is made to add + a quantity instance that does not satisfy this condition. + + :param owner: Quantity instance to be added to the set of co-owners. """ + # Make sure that the quantity is not already part of the co-owners + if owner in self.owners: + raise ValueError( + "The specified quantity instance is already an owner of this " + "shared cache.") + + # Make sure that the new owner is consistent with the others if any + if any(owner != _owner for _owner in self._weakrefs): + raise ValueError( + "Quantity instance inconsistent with already existing shared " + "cache owners.") + # Add quantity instance to shared cache owners self._weakrefs.add(owner) # Refresh owners - if self.sm_state is QuantityStateMachine.IS_RESET: - self.owners = self._weakrefs - else: - self.owners.append(owner) + if self.sm_state is not QuantityStateMachine.IS_RESET: + self.owners = tuple(self._weakrefs) def discard(self, owner: "InterfaceQuantity[ValueT]") -> None: - """ TODO: Write documentation. + """Remove a given quantity instance from the set of co-owners associated + with the shared cache at hands. + + :param owner: Quantity instance to be removed from the set of co-owners. """ + # Make sure that the quantity is part of the co-owners + if not owner in self.owners: + raise ValueError( + "The specified quantity instance is not an owner of this " + "shared cache.") + + # Restore "dynamic" owner list as it may be involved in quantity reset + self.owners = self._weakrefs + # Remove quantity instance from shared cache owners self._weakrefs.discard(owner) - # Refresh owners - if self.sm_state is QuantityStateMachine.IS_RESET: - self.owners = self._weakrefs - else: - # Keep tracking the quantity instance being used in computations, - # aka 'self._owner', even if it is no longer an actual shared cache - # owner. This is necessary because updating it would require - # resetting the state machine, which is not an option as it would - # mess up with quantities storing history since initialization. - for i, _owner in enumerate(self.owners): - if owner is _owner: - del self.owners[i] - break + # Refresh owners. + # Note that one must keep tracking the quantity instance being used in + # computations, aka 'self._owner', even if it is no longer an actual + # shared cache owner. This is necessary because updating it would + # require resetting the state machine, which is not an option as it + # would mess up with quantities storing history since initialization. + if self.sm_state is not QuantityStateMachine.IS_RESET: + self.owners = tuple(self._weakrefs) - def reset(self, *, - ignore_auto_refresh: bool = False, + def reset(self, + *, ignore_auto_refresh: bool = False, reset_state_machine: bool = False) -> None: """Clear value stored in cache if any. :param ignore_auto_refresh: Whether to skip automatic refresh of all co-owner quantities of this shared cache. Optional: False by default. - - # TODO: Write documentation. + :param reset_state_machine: Whether to reset completely the state + machine of the underlying quantity, ie not + considering it initialized anymore. + Optional: False by default. """ # Clear cache if self.sm_state is _IS_CACHED: @@ -287,13 +323,13 @@ def get(self) -> ValueT: # Get value already stored if self.sm_state is _IS_CACHED: # return cast(ValueT, self._value) - return self._value + return self._value # type: ignore[return-value] # Evaluate quantity try: if self.sm_state is _IS_RESET: # Cache the list of owning quantities - self.owners = list(self._weakrefs) + self.owners = tuple(self._weakrefs) # Stick to the first owning quantity systematically owner = self.owners[0] @@ -306,7 +342,7 @@ def get(self) -> ValueT: # Get first owning quantity systematically # assert self._owner is not None - owner = self._owner + owner = self._owner # type: ignore[assignment] # Make sure that the state has been refreshed if owner._force_update_state: @@ -358,8 +394,8 @@ class InterfaceQuantity(ABC, Generic[ValueT]): the quantity can be reset at any point in time to re-compute the optimal computation path, typically after deletion or addition of some other node to its dependent sub-graph. When this happens, the quantity gets reset on - the spot, which is not always acceptable, hence the capability to disable - this feature. + the spot, even if a simulation is already running. This is not always + acceptable, hence the capability to disable this feature at class-level. """ def __init__(self, @@ -494,7 +530,7 @@ def is_active(self, any_cache_owner: bool = False) -> bool: same cache) is considered sufficient. Optional: False by default. """ - if not any_cache_owner or not self.has_cache: + if not any_cache_owner or self._cache is None: return self._is_active return any(owner._is_active for owner in self._cache.owners) @@ -510,7 +546,7 @@ def get(self) -> ValueT: This method is not meant to be overloaded. """ # Delegate getting value to shared cache if available - if self.has_cache: + if self._cache is not None: # Get value value = self._cache.get() @@ -537,8 +573,7 @@ def get(self) -> ValueT: def reset(self, reset_tracking: bool = False, - ignore_auto_refresh: bool = False, - update_graph: bool = False) -> None: + *, ignore_other_instances: bool = False) -> None: """Consider that the quantity must be re-initialized before being evaluated once again. @@ -556,62 +591,47 @@ def reset(self, :param reset_tracking: Do not consider this quantity as active anymore until the `get` method gets called once again. Optional: False by default. - :param ignore_auto_refresh: Whether to skip automatic refresh of all - co-owner quantities of this shared cache. - Optional: False by default. - :param update_graph: If true, then the quantity will be reset if and - only if dynamic computation graph update is - allowed as prescribed by class attribute - `allow_update_graph`. If false, then it will be - reset no matter what. + :param ignore_other_instances: + Whether to skip reset of intermediary quantities as well as any + shared cache co-owner quantity instances. + Optional: False by default. """ # Make sure that auto-refresh can be honored - if (not ignore_auto_refresh and self.auto_refresh and - not self.has_cache): + if self.auto_refresh and not self.has_cache: raise RuntimeError( "Automatic refresh enabled but no shared cache is available. " "Please add one before calling this method.") # Reset all requirements first - for quantity in self.requirements.values(): - quantity.reset(reset_tracking, ignore_auto_refresh, update_graph) + if not ignore_other_instances: + for quantity in self.requirements.values(): + quantity.reset(reset_tracking, ignore_other_instances=False) - # Skip reset if dynamic computation graph update if appropriate - if update_graph and not self.allow_update_graph: + # Skip reset if dynamic computation graph update is not allowed + if self.env.is_simulation_running and not self.allow_update_graph: return - # No longer consider this exact instance as active if requested - # FIXME: Should be moved before ? + # No longer consider this exact instance as active if reset_tracking: self._is_active = False # No longer consider this exact instance as initialized self._is_initialized = False - # More work must to be done if shared cache is available + # More work must to be done if shared cache if appropriate if self.has_cache: - # Early return if already reset to avoid infinite loop. - # FIXME: You still want to reset the owners ! - if ignore_auto_refresh and ( - self.cache.sm_state is QuantityStateMachine.IS_RESET): - self.cache.owners = self.cache._weakrefs - return - - # Invalidate cache before looping over all identical properties. - # Note that auto-refresh must be ignored to avoid infinite loop. - self.cache.reset(ignore_auto_refresh=True, - reset_state_machine=True) - # Reset all identical quantities. # Note that auto-refresh will be done afterward if requested. - for owner in self.cache.owners: - owner.reset(reset_tracking=reset_tracking, - ignore_auto_refresh=True, - update_graph=update_graph) - - # Reset shared cache one last time but without ignore auto refresh - if not ignore_auto_refresh: - self.cache.reset(ignore_auto_refresh=False) + if not ignore_other_instances: + for owner in self.cache.owners: + if owner is not self: + owner.reset(reset_tracking=reset_tracking, + ignore_other_instances=True) + + # Reset shared cache + self.cache.reset( + ignore_auto_refresh=not self.env.is_simulation_running, + reset_state_machine=True) def initialize(self) -> None: """Initialize internal buffers. diff --git a/python/gym_jiminy/common/gym_jiminy/common/compositions/mixin.py b/python/gym_jiminy/common/gym_jiminy/common/compositions/mixin.py index bb26c8585..d3e5ea57b 100644 --- a/python/gym_jiminy/common/gym_jiminy/common/compositions/mixin.py +++ b/python/gym_jiminy/common/gym_jiminy/common/compositions/mixin.py @@ -111,16 +111,15 @@ def __init__(self, raise ValueError("'order' must be strictly positive or 'inf'.") # Make sure that the weight sequence is consistent with the components - if weights is None or len(weights) != len(components): + if len(weights) != len(components): raise ValueError( "Exactly one weight per reward component must be specified.") # Filter out components whose weight are zero - if weights is not None: - weights, components = zip(*( - (weight, reward) - for weight, reward in zip(weights, components) - if weight > 0.0)) + weights, components = zip(*( + (weight, reward) + for weight, reward in zip(weights, components) + if weight > 0.0)) # Determine whether the cumulative reward is normalized scale = 0.0 @@ -141,11 +140,11 @@ def __init__(self, # Backup user-arguments self.order = order - self.weights = tuple(weights) if weights is not None else weights + self.weights = tuple(weights) # Jit-able method computing the weighted sum of reward components @nb.jit(nopython=True, cache=True, fastmath=True) - def weighted_norm(weights: Optional[Tuple[float, ...]], + def weighted_norm(weights: Tuple[float, ...], order: Union[int, float, Literal['inf']], values: Tuple[Optional[float], ...] ) -> Optional[float]: @@ -155,8 +154,8 @@ def weighted_norm(weights: Optional[Tuple[float, ...]], This method returns `None` if no reward component has been evaluated. - :param weights: Optional sequence of weights for each reward - component, with same ordering as 'components'. + :param weights: Sequence of weights for each reward component, with + same ordering as 'components'. :param order: Order of the L^p-norm. :param values: Sequence of scalar value for reward components that has been evaluated, `None` otherwise, with the same diff --git a/python/gym_jiminy/common/gym_jiminy/common/quantities/generic.py b/python/gym_jiminy/common/gym_jiminy/common/quantities/generic.py index 1d5d7a3ab..9def585df 100644 --- a/python/gym_jiminy/common/gym_jiminy/common/quantities/generic.py +++ b/python/gym_jiminy/common/gym_jiminy/common/quantities/generic.py @@ -90,9 +90,11 @@ def aggregate_frame_names(quantity: InterfaceQuantity) -> Tuple[ # First, order all multi-frame quantities by decreasing length frame_names_chunks: List[Tuple[str, ...]] = [] for owner in quantities: - if owner.parent.is_active(any_cache_owner=False): - if isinstance(owner.parent, MultiFrameQuantity): - frame_names_chunks.append(owner.parent.frame_names) + parent = owner.parent + assert parent is not None + if parent.is_active(any_cache_owner=False): + if isinstance(parent, MultiFrameQuantity): + frame_names_chunks.append(parent.frame_names) # Next, process ordered multi-frame quantities sequentially. # For each of them, we first check if its set of frames is completely @@ -132,9 +134,11 @@ def aggregate_frame_names(quantity: InterfaceQuantity) -> Tuple[ # Otherwise, we just move to the next quantity. frame_name_chunks: List[str] = [] for owner in quantities: - if owner.parent.is_active(any_cache_owner=False): - if isinstance(owner.parent, FrameQuantity): - frame_name_chunks.append(owner.parent.frame_name) + parent = owner.parent + assert parent is not None + if parent.is_active(any_cache_owner=False): + if isinstance(parent, FrameQuantity): + frame_name_chunks.append(parent.frame_name) frame_name = frame_name_chunks[-1] if frame_name not in frame_names: frame_names.append(frame_name) @@ -285,7 +289,8 @@ class OrientationType(IntEnum): # Define proxies for fast lookup -_MATRIX, _EULER, _QUATERNION, _ANGLE_AXIS = OrientationType +_MATRIX, _EULER, _QUATERNION, _ANGLE_AXIS = ( # pylint: disable=invalid-name + OrientationType) @dataclass(unsafe_hash=True) @@ -495,6 +500,9 @@ def initialize(self) -> None: # Force re-initializing shared data if the active set has changed if not was_active: + # Must reset the tracking for shared computation systematically, + # just in case the optimal computation path has changed to the + # point that relying on batched quantity is no longer relevant. self.data.reset(reset_tracking=True) def refresh(self) -> np.ndarray: @@ -576,9 +584,6 @@ def initialize(self) -> None: # Force re-initializing shared data if the active set has changed if not was_active: - # Must reset the tracking for shared computation systematically, - # just in case the optimal computation path has changed to the - # point that relying on batched quantity is no longer relevant. self.data.reset(reset_tracking=True) def refresh(self) -> np.ndarray: diff --git a/python/gym_jiminy/common/gym_jiminy/common/quantities/manager.py b/python/gym_jiminy/common/gym_jiminy/common/quantities/manager.py index 1dc06a537..4f8d689e1 100644 --- a/python/gym_jiminy/common/gym_jiminy/common/quantities/manager.py +++ b/python/gym_jiminy/common/gym_jiminy/common/quantities/manager.py @@ -53,8 +53,13 @@ def __init__(self, env: InterfaceJiminyEnv) -> None: # This is necessary because using a quantity as key directly would # prevent its garbage collection, hence breaking automatic reset of # computation tracking for all quantities sharing its cache. - self._caches: Dict[ - Tuple[Type[InterfaceQuantity], int], SharedCache] = {} + # In case of dataclasses, their hash is the same as if it was obtained + # using `hash(dataclasses.astuple(quantity))`. This is clearly not + # unique, as all it requires to be the same is being built from the + # same nested ordered arguments. To get around this issue, we need to + # store (key, value) pairs in a list. + self._caches: List[Tuple[ + Tuple[Type[InterfaceQuantity], int], SharedCache]] = [] # Instantiate trajectory database. # Note that this quantity is not added to the global registry to avoid @@ -77,9 +82,7 @@ def reset(self, reset_tracking: bool = False) -> None: """ # Reset all quantities sequentially for quantity in self.registry.values(): - quantity.reset( - reset_tracking, - ignore_auto_refresh=not self.env.is_simulation_running) + quantity.reset(reset_tracking) def clear(self) -> None: """Clear internal cache of quantities to force re-evaluating them the @@ -90,8 +93,8 @@ def clear(self) -> None: environment has changed (ie either the agent or world itself), thereby invalidating the value currently stored in cache if any. """ - for cache in self._caches.values(): - cache.reset() + for _, cache in self._caches: + cache.reset(ignore_auto_refresh=not self.env.is_simulation_running) def add_trajectory(self, name: str, trajectory: Trajectory) -> None: """Add a new reference trajectory to the database synchronized between @@ -182,9 +185,24 @@ def _build_quantity( # Set a shared cache entry for all quantities involved in computations quantities_all = [top_quantity] while quantities_all: + # Deal with the first quantity in the process queue quantity = quantities_all.pop() + + # Get already available cache entry if any, otherwise create it key = (type(quantity), hash(quantity)) - quantity.cache = self._caches.setdefault(key, SharedCache()) + for cache_key, cache in self._caches: + if key == cache_key: + owner, *_ = cache.owners + if quantity == owner: + break + else: + cache = SharedCache() + self._caches.append((key, cache)) + + # Set shared cache of the quantity + quantity.cache = cache + + # Add all the requirements of the new quantity in the process queue quantities_all += quantity.requirements.values() return top_quantity @@ -235,13 +253,21 @@ def __delitem__(self, name: str) -> None: :param name: Name of the managed quantity to be discarded. It will raise an exception if the specified name does not exists. """ - # Remove shared cache entry for all quantities involved in computations + # Remove shared cache entries for the quantity and its requirements. + # Note that done top-down rather than bottom-up, otherwise reset of + # required quantities no longer having shared cache will be triggered + # automatically by parent quantities following computation graph + # tracking reset whenever a shared cache co-owner is removed. quantities_all = [self.registry.pop(name)] while quantities_all: - quantity = quantities_all.pop() - if len(tuple(quantity.cache.owners)) == 1: - del self._caches[(type(quantity), hash(quantity))] + quantity = quantities_all.pop(0) + cache = quantity.cache quantity.cache = None # type: ignore[assignment] + if len(cache.owners) == 0: + for i, (_, _cache) in enumerate(self._caches): + if cache is _cache: + del self._caches[i] + break quantities_all += quantity.requirements.values() def __iter__(self) -> Iterator[str]: diff --git a/python/gym_jiminy/common/gym_jiminy/common/quantities/transform.py b/python/gym_jiminy/common/gym_jiminy/common/quantities/transform.py index 34d238d67..793a850f8 100644 --- a/python/gym_jiminy/common/gym_jiminy/common/quantities/transform.py +++ b/python/gym_jiminy/common/gym_jiminy/common/quantities/transform.py @@ -8,7 +8,7 @@ from dataclasses import dataclass from typing import ( Any, Optional, Sequence, Tuple, TypeVar, Union, Generic, ClassVar, - Callable, Literal, overload, cast) + Callable, Literal, List, overload, cast) from typing_extensions import TypeAlias import numpy as np @@ -171,7 +171,8 @@ def initialize(self) -> None: # Full the queue with zero if necessary if self.mode == 'zeros': for _ in range(self.max_stack): - self._value_list.append(np.zeros_like(value)) + self._value_list.append( + np.zeros_like(value)) # type: ignore[arg-type] # Allocate stack memory if necessary if self.as_array: @@ -198,7 +199,7 @@ def refresh(self) -> OtherValueT: "environment after adding this quantity.") value = self.quantity.get() if isinstance(value, np.ndarray): - value_list.append(value.copy()) + value_list.append(value.copy()) # type: ignore[arg-type] else: value_list.append(deepcopy(value)) if len(value_list) > self.max_stack: diff --git a/python/gym_jiminy/unit_py/test_quantities.py b/python/gym_jiminy/unit_py/test_quantities.py index b1451792d..38beddcf7 100644 --- a/python/gym_jiminy/unit_py/test_quantities.py +++ b/python/gym_jiminy/unit_py/test_quantities.py @@ -1,5 +1,6 @@ """ TODO: Write documentation """ +import sys import math import unittest @@ -194,7 +195,7 @@ def test_discard(self): assert len(quantities['rpy_2'].data.cache.owners) == 1 del quantity_manager['rpy_2'] - for (cls, _), cache in quantity_manager._caches.items(): + for (cls, _), cache in quantity_manager._caches: assert len(cache.owners) == (cls is DatasetTrajectoryQuantity) def test_env(self): @@ -247,7 +248,7 @@ def test_stack_api(self): (3, True, "zeros")): quantity_creator = (StackedQuantity, dict( quantity=(MultiFootRelativeXYZQuat, {}), - max_stack=max_stack, + max_stack=max_stack or sys.maxsize, as_array=as_array, mode=mode)) env.quantities["xyzquat_stack"] = quantity_creator @@ -257,7 +258,7 @@ def test_stack_api(self): if as_array: assert isinstance(value, np.ndarray) else: - assert isinstance(value, tuple) + assert isinstance(value, list) for i in range(1, (max_stack or 5) + 2): num_stack = max_stack or i if mode == "slice":