Skip to content

Commit

Permalink
[gym/common] Fix quantity hash collision issue in quantity manager.
Browse files Browse the repository at this point in the history
  • Loading branch information
duburcqa committed Jul 7, 2024
1 parent 8e47617 commit 74b016b
Show file tree
Hide file tree
Showing 7 changed files with 195 additions and 127 deletions.
24 changes: 20 additions & 4 deletions python/gym_jiminy/common/gym_jiminy/common/bases/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -448,9 +448,10 @@ def refresh_observation(self, measurement: EngineObsType) -> None:
self.env.refresh_observation(measurement)

def has_terminated(self, info: InfoType) -> Tuple[bool, bool]:
"""Determine whether the episode is over, because a terminal state of
the underlying MDP has been reached or an aborting condition outside
the scope of the MDP has been triggered.
"""Determine whether the practitioner is instructed to stop the ongoing
episode on the spot because a termination condition has been triggered,
either coming from the based environment or from the ad-hoc termination
conditions that has been plugged on top of it.
At each step of the wrapped environment, all its termination conditions
will be evaluated sequentially until one of them eventually gets
Expand All @@ -465,6 +466,9 @@ def has_terminated(self, info: InfoType) -> Tuple[bool, bool]:
This method is called after `refresh_observation`, so that the
internal buffer 'observation' is up-to-date.
.. seealso::
See `InterfaceJiminyEnv.has_terminated` documentation for details.
:param info: Dictionary of extra information for monitoring.
:returns: terminated and truncated flags.
Expand Down Expand Up @@ -492,7 +496,19 @@ def compute_command(self, action: ActT, command: np.ndarray) -> None:
self.env.compute_command(action, command)

def compute_reward(self, terminated: bool, info: InfoType) -> float:
""" TODO: Write documentation.
"""Compute the total reward, ie the sum of the original reward from the
wrapped environment with the ad-hoc reward components that has been
plugged on top of it.
.. seealso::
See `InterfaceController.compute_reward` documentation for details.
:param terminated: Whether the episode has reached the terminal state
of the MDP at the current step. This flag can be
used to compute a specific terminal reward.
:param info: Dictionary of extra information for monitoring.
:returns: Aggregated reward for the current step.
"""
# Compute base reward
reward = self.env.compute_reward(terminated, info)
Expand Down
190 changes: 105 additions & 85 deletions python/gym_jiminy/common/gym_jiminy/common/bases/quantities.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,10 @@
from collections import OrderedDict
from collections.abc import MutableSet
from dataclasses import dataclass, replace
from functools import partial, wraps
from functools import wraps
from typing import (
Any, Dict, List, Optional, Tuple, Generic, TypeVar, Type, Iterator,
Iterable, Callable, Literal, ClassVar, cast, TYPE_CHECKING)
Collection, Callable, Literal, ClassVar, TYPE_CHECKING)

import numpy as np

Expand Down Expand Up @@ -51,17 +51,17 @@ class WeakMutableCollection(MutableSet, Generic[ValueT]):
__slots__ = ("_callback", "_weakrefs")

def __init__(self, callback: Optional[Callable[[
"WeakMutableCollection[ValueT]", ReferenceType[ValueT]
"WeakMutableCollection[ValueT]", ReferenceType
], None]] = None) -> None:
"""
:param callback: Callback that will be triggered every time an element
is discarded from the container.
Optional: None by default.
"""
self._callback = callback
self._weakrefs: List[ReferenceType[ValueT]] = []
self._weakrefs: List[ReferenceType] = []

def __callback__(self, ref: ReferenceType[ValueT]) -> None:
def __callback__(self, ref: ReferenceType) -> None:
"""Internal method that will be called every time an element must be
discarded from the containers, either because it was requested by the
user or because no strong reference to the value exists anymore.
Expand Down Expand Up @@ -128,21 +128,31 @@ def discard(self, value: ValueT) -> None:


class QuantityStateMachine(IntEnum):
"""Specify the current state of a given (unique) quantity, which determines
the steps to perform for retrieving its current value.
"""

IS_RESET = 0
""" TODO: Write documentation.
"""The quantity at hands has just been reset. The quantity must first be
initialized, then refreshed and finally stored in cached before to retrieve
its value.
"""

IS_INITIALIZED = 1
""" TODO: Write documentation.
"""The quantity at hands has been initialized but never evaluated for the
current robot state. Its value must still be refreshed and stored in cache
before to retrieve it.
"""

IS_CACHED = 2
""" TODO: Write documentation.
"""The quantity at hands has been evaluated and its value stored in cache.
As such, its value can be retrieve from cache directly.
"""


# Define proxies for fast lookup
_IS_RESET, _IS_INITIALIZED, _IS_CACHED = QuantityStateMachine
_IS_RESET, _IS_INITIALIZED, _IS_CACHED = ( # pylint: disable=invalid-name
QuantityStateMachine)


class SharedCache(Generic[ValueT]):
Expand All @@ -159,7 +169,7 @@ class SharedCache(Generic[ValueT]):
__slots__ = (
"_value", "_weakrefs", "_owner", "_auto_refresh", "sm_state", "owners")

owners: WeakMutableCollection["InterfaceQuantity[ValueT]"]
owners: Collection["InterfaceQuantity[ValueT]"]
"""Owners of the shared buffer, ie quantities relying on it to store the
result of their evaluation. This information may be useful for determining
the most efficient computation path overall.
Expand Down Expand Up @@ -191,7 +201,7 @@ def __init__(self) -> None:
# Define callback to reset part of the computation graph whenever a
# quantity owning the cache gets garbage collected, namely all
# quantities that may assume at some point the existence of this
# deleted owner to find the adjust their computation path.
# deleted owner to adjust their computation path.
def _callback(
self: WeakMutableCollection["InterfaceQuantity[ValueT]"],
ref: ReferenceType[ # pylint: disable=unused-argument
Expand All @@ -200,61 +210,87 @@ def _callback(
for owner in self:
# Stop going up in parent chain if dynamic computation graph
# update is disable for efficiency.
while owner.allow_update_graph and owner.parent is not None:
while (owner.allow_update_graph and
owner.parent is not None and owner.parent.has_cache):
owner = owner.parent
owner.reset(reset_tracking=True,
ignore_auto_refresh=True,
update_graph=True)
owner.reset(reset_tracking=True)

# Initialize weak reference to owning quantities
self._weakrefs = WeakMutableCollection(_callback)

# Maintain alive owning quantities upon reset
self.owners: Iterable["InterfaceQuantity[ValueT]"] = self._weakrefs
self.owners = self._weakrefs
self._owner: Optional["InterfaceQuantity[ValueT]"] = None

def add(self, owner: "InterfaceQuantity[ValueT]") -> None:
""" TODO: Write documentation.
"""Add a given quantity instance to the set of co-owners associated
with the shared cache at hands.
.. warning::
All shared cache co-owners must be instances of the same unique
quantity. An exception will be thrown if an attempt is made to add
a quantity instance that does not satisfy this condition.
:param owner: Quantity instance to add to the set of co-owners.
"""
# Make sure that the quantity is not already part of the co-owners
if id(owner) in map(id, self.owners):
raise ValueError(
"The specified quantity instance is already an owner of this "
"shared cache.")

# Make sure that the new owner is consistent with the others if any
if any(owner != _owner for _owner in self._weakrefs):
raise ValueError(
"Quantity instance inconsistent with already existing shared "
"cache owners.")

# Add quantity instance to shared cache owners
self._weakrefs.add(owner)

# Refresh owners
if self.sm_state is QuantityStateMachine.IS_RESET:
self.owners = self._weakrefs
else:
self.owners.append(owner)
if self.sm_state is not QuantityStateMachine.IS_RESET:
self.owners = tuple(self._weakrefs)

def discard(self, owner: "InterfaceQuantity[ValueT]") -> None:
""" TODO: Write documentation.
"""Remove a given quantity instance from the set of co-owners
associated with the shared cache at hands.
:param owner: Quantity instance to remove from the set of co-owners.
"""
# Make sure that the quantity is part of the co-owners
if id(owner) not in map(id, self.owners):
raise ValueError(
"The specified quantity instance is not an owner of this "
"shared cache.")

# Restore "dynamic" owner list as it may be involved in quantity reset
self.owners = self._weakrefs

# Remove quantity instance from shared cache owners
self._weakrefs.discard(owner)

# Refresh owners
if self.sm_state is QuantityStateMachine.IS_RESET:
self.owners = self._weakrefs
else:
# Keep tracking the quantity instance being used in computations,
# aka 'self._owner', even if it is no longer an actual shared cache
# owner. This is necessary because updating it would require
# resetting the state machine, which is not an option as it would
# mess up with quantities storing history since initialization.
for i, _owner in enumerate(self.owners):
if owner is _owner:
del self.owners[i]
break
# Refresh owners.
# Note that one must keep tracking the quantity instance being used in
# computations, aka 'self._owner', even if it is no longer an actual
# shared cache owner. This is necessary because updating it would
# require resetting the state machine, which is not an option as it
# would mess up with quantities storing history since initialization.
if self.sm_state is not QuantityStateMachine.IS_RESET:
self.owners = tuple(self._weakrefs)

def reset(self, *,
ignore_auto_refresh: bool = False,
def reset(self,
*, ignore_auto_refresh: bool = False,
reset_state_machine: bool = False) -> None:
"""Clear value stored in cache if any.
:param ignore_auto_refresh: Whether to skip automatic refresh of all
co-owner quantities of this shared cache.
Optional: False by default.
# TODO: Write documentation.
:param reset_state_machine: Whether to reset completely the state
machine of the underlying quantity, ie not
considering it initialized anymore.
Optional: False by default.
"""
# Clear cache
if self.sm_state is _IS_CACHED:
Expand Down Expand Up @@ -287,13 +323,13 @@ def get(self) -> ValueT:
# Get value already stored
if self.sm_state is _IS_CACHED:
# return cast(ValueT, self._value)
return self._value
return self._value # type: ignore[return-value]

# Evaluate quantity
try:
if self.sm_state is _IS_RESET:
# Cache the list of owning quantities
self.owners = list(self._weakrefs)
self.owners = tuple(self._weakrefs)

# Stick to the first owning quantity systematically
owner = self.owners[0]
Expand All @@ -306,7 +342,7 @@ def get(self) -> ValueT:

# Get first owning quantity systematically
# assert self._owner is not None
owner = self._owner
owner = self._owner # type: ignore[assignment]

# Make sure that the state has been refreshed
if owner._force_update_state:
Expand Down Expand Up @@ -358,8 +394,8 @@ class InterfaceQuantity(ABC, Generic[ValueT]):
the quantity can be reset at any point in time to re-compute the optimal
computation path, typically after deletion or addition of some other node
to its dependent sub-graph. When this happens, the quantity gets reset on
the spot, which is not always acceptable, hence the capability to disable
this feature.
the spot, even if a simulation is already running. This is not always
acceptable, hence the capability to disable this feature at class-level.
"""

def __init__(self,
Expand Down Expand Up @@ -494,7 +530,7 @@ def is_active(self, any_cache_owner: bool = False) -> bool:
same cache) is considered sufficient.
Optional: False by default.
"""
if not any_cache_owner or not self.has_cache:
if not any_cache_owner or self._cache is None:
return self._is_active
return any(owner._is_active for owner in self._cache.owners)

Expand All @@ -510,7 +546,7 @@ def get(self) -> ValueT:
This method is not meant to be overloaded.
"""
# Delegate getting value to shared cache if available
if self.has_cache:
if self._cache is not None:
# Get value
value = self._cache.get()

Expand All @@ -537,8 +573,7 @@ def get(self) -> ValueT:

def reset(self,
reset_tracking: bool = False,
ignore_auto_refresh: bool = False,
update_graph: bool = False) -> None:
*, ignore_other_instances: bool = False) -> None:
"""Consider that the quantity must be re-initialized before being
evaluated once again.
Expand All @@ -556,62 +591,47 @@ def reset(self,
:param reset_tracking: Do not consider this quantity as active anymore
until the `get` method gets called once again.
Optional: False by default.
:param ignore_auto_refresh: Whether to skip automatic refresh of all
co-owner quantities of this shared cache.
Optional: False by default.
:param update_graph: If true, then the quantity will be reset if and
only if dynamic computation graph update is
allowed as prescribed by class attribute
`allow_update_graph`. If false, then it will be
reset no matter what.
:param ignore_other_instances:
Whether to skip reset of intermediary quantities as well as any
shared cache co-owner quantity instances.
Optional: False by default.
"""
# Make sure that auto-refresh can be honored
if (not ignore_auto_refresh and self.auto_refresh and
not self.has_cache):
if self.auto_refresh and not self.has_cache:
raise RuntimeError(
"Automatic refresh enabled but no shared cache is available. "
"Please add one before calling this method.")

# Reset all requirements first
for quantity in self.requirements.values():
quantity.reset(reset_tracking, ignore_auto_refresh, update_graph)
if not ignore_other_instances:
for quantity in self.requirements.values():
quantity.reset(reset_tracking, ignore_other_instances=False)

# Skip reset if dynamic computation graph update if appropriate
if update_graph and not self.allow_update_graph:
# Skip reset if dynamic computation graph update is not allowed
if self.env.is_simulation_running and not self.allow_update_graph:
return

# No longer consider this exact instance as active if requested
# FIXME: Should be moved before ?
# No longer consider this exact instance as active
if reset_tracking:
self._is_active = False

# No longer consider this exact instance as initialized
self._is_initialized = False

# More work must to be done if shared cache is available
# More work must to be done if shared cache if appropriate
if self.has_cache:
# Early return if already reset to avoid infinite loop.
# FIXME: You still want to reset the owners !
if ignore_auto_refresh and (
self.cache.sm_state is QuantityStateMachine.IS_RESET):
self.cache.owners = self.cache._weakrefs
return

# Invalidate cache before looping over all identical properties.
# Note that auto-refresh must be ignored to avoid infinite loop.
self.cache.reset(ignore_auto_refresh=True,
reset_state_machine=True)

# Reset all identical quantities.
# Note that auto-refresh will be done afterward if requested.
for owner in self.cache.owners:
owner.reset(reset_tracking=reset_tracking,
ignore_auto_refresh=True,
update_graph=update_graph)

# Reset shared cache one last time but without ignore auto refresh
if not ignore_auto_refresh:
self.cache.reset(ignore_auto_refresh=False)
if not ignore_other_instances:
for owner in self.cache.owners:
if owner is not self:
owner.reset(reset_tracking=reset_tracking,
ignore_other_instances=True)

# Reset shared cache
self.cache.reset(
ignore_auto_refresh=not self.env.is_simulation_running,
reset_state_machine=True)

def initialize(self) -> None:
"""Initialize internal buffers.
Expand Down
Loading

0 comments on commit 74b016b

Please sign in to comment.