Skip to content

Commit

Permalink
[gym/common] More generic 'FilterObservation', 'StackedJiminyEnv' wra…
Browse files Browse the repository at this point in the history
…ppers. Speedup 'StackedJiminyEnv' wrapper.
  • Loading branch information
duburcqa committed Apr 22, 2024
1 parent d8e3f74 commit 294f1e2
Show file tree
Hide file tree
Showing 10 changed files with 367 additions and 281 deletions.
4 changes: 2 additions & 2 deletions python/gym_jiminy/common/gym_jiminy/common/bases/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,7 @@ def step(self, # type: ignore[override]

# Make sure that the pipeline has not change since last reset
env_derived = (
self.unwrapped._env_derived) # type: ignore[attr-defined]
self.unwrapped.derived) # type: ignore[attr-defined]
if env_derived is not self:
raise RuntimeError(
"Pipeline environment has changed. Please call 'reset' "
Expand Down Expand Up @@ -472,7 +472,7 @@ def refresh_observation(self, measurement: EngineObsType) -> None:
:param measurement: Low-level measure from the environment to process
to get higher-level observation.
"""
# Get environment observation
# Refresh environment observation
self.env.refresh_observation(measurement)

# Update observed features if necessary
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -525,7 +525,7 @@ def __init__(self,
"""
:param update_ratio: Ratio between the update period of the controller
and the one of the subsequent controller. -1 to
match the simulation timestep of the environment.
match the environment step `env.step_dt`.
Optional: -1 by default.
:param order: Derivative order of the action. It accepts position or
velocity (respectively 0 or 1).
Expand Down
22 changes: 14 additions & 8 deletions python/gym_jiminy/common/gym_jiminy/common/envs/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,12 @@
import tempfile
from copy import deepcopy
from collections import OrderedDict
from collections.abc import Mapping
from collections.abc import Mapping, Sequence
from functools import partial
from typing import (
Dict, Any, List, cast, no_type_check, Optional, Tuple, Callable, Union,
SupportsFloat, Iterator, Generic, Sequence, Mapping as MappingT,
MutableMapping as MutableMappingT)
SupportsFloat, Iterator, Generic, Sequence as SequenceT,
Mapping as MappingT, MutableMapping as MutableMappingT)

import numpy as np
from gymnasium import spaces
Expand Down Expand Up @@ -85,7 +85,7 @@

class _LazyDictItemFilter(Mapping):
def __init__(self,
dict_packed: MappingT[str, Sequence[Any]],
dict_packed: MappingT[str, SequenceT[Any]],
item_index: int) -> None:
self.dict_packed = dict_packed
self.item_index = item_index
Expand Down Expand Up @@ -175,6 +175,11 @@ def __init__(self,
# Make sure that rendering mode is valid
assert render_mode in self.metadata['render_modes']

# Make sure that the simulator is single-robot
if len(simulator.robots) > 1:
raise NotImplementedError(
"Multi-robot simulation is not supported for now.")

# Backup some user arguments
self.simulator: Simulator = simulator
self._step_dt = step_dt
Expand All @@ -196,7 +201,7 @@ def __init__(self,
self.robot.sensor_measurements)

# Top-most block of the pipeline to which the environment is part of
self._env_derived: InterfaceJiminyEnv = self
self.derived: InterfaceJiminyEnv = self

# Store references to the variables to register to the telemetry
self._registered_variables: MutableMappingT[
Expand Down Expand Up @@ -777,7 +782,7 @@ def reset(self, # type: ignore[override]
env_derived = reset_hook() or env
assert env_derived.unwrapped is self
env = env_derived
self._env_derived = env
self.derived = env

# Instantiate the actual controller.
# Note that a weak reference must be used to avoid circular reference.
Expand Down Expand Up @@ -925,7 +930,7 @@ def step(self, # type: ignore[override]
# Update the observer at the end of the step.
# This is necessary because, internally, it is called at the beginning
# of the every integration steps, during the controller update.
self._env_derived._observer_handle(
self.derived._observer_handle(
self.stepper_state.t,
self._robot_state_q,
self._robot_state_v,
Expand Down Expand Up @@ -1026,6 +1031,7 @@ def plot(self,
"""
# Call base implementation
figure = self.simulator.plot(**kwargs)
assert not isinstance(figure, Sequence)

# Extract log data
log_vars = self.simulator.log_data.get("variables", {})
Expand Down Expand Up @@ -1268,7 +1274,7 @@ def evaluate(self,
# Run the simulation
info_episode = [info]
try:
env = self._env_derived
env = self.derived
while not (terminated or truncated or (
horizon is not None and self.num_steps > horizon)):
action = policy_fn(obs, reward, terminated or truncated, info)
Expand Down
4 changes: 3 additions & 1 deletion python/gym_jiminy/common/gym_jiminy/common/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@
compute_tilt_from_quat,
swing_from_vector,
remove_twist_from_quat)
from .spaces import (DataNested,
from .spaces import (StructNested,
DataNested,
FieldNested,
ArrayOrScalar,
get_bounds,
Expand Down Expand Up @@ -60,6 +61,7 @@
'compute_tilt_from_quat',
'swing_from_vector',
'remove_twist_from_quat',
'StructNested',
'DataNested',
'FieldNested',
'ArrayOrScalar',
Expand Down
42 changes: 24 additions & 18 deletions python/gym_jiminy/common/gym_jiminy/common/utils/spaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,7 @@ def copy(data: DataNestedT) -> DataNestedT:
return cast(DataNestedT, tree.unflatten_as(data, tree.flatten(data)))


@no_type_check
def clip(data: DataNested, space: gym.Space[DataNested]) -> DataNested:
"""Clip data from `gym.Space` to make sure it is within bounds.
Expand All @@ -226,17 +227,19 @@ def clip(data: DataNested, space: gym.Space[DataNested]) -> DataNested:
:param data: Data to clip.
:param space: `gym.Space` on which to operate.
"""
# FIXME: Add support of `gym.spaces.Tuple`
if not isinstance(space, gym.spaces.Dict):
return _array_clip(data, *get_bounds(space))
assert isinstance(data, dict)

out: Dict[str, DataNested] = OrderedDict()
for field, subspace in space.spaces.items():
out[field] = clip(data[field], subspace)
return out
data_type = type(data)
if tree.issubclass_mapping(data_type):
return data_type({
field: clip(data[field], subspace)
for field, subspace in space.spaces.items()})
if tree.issubclass_sequence(data_type):
return data_type(tuple(
clip(data[i], subspace)
for i, subspace in enumerate(space.spaces)))
return _array_clip(data, *get_bounds(space))


@no_type_check
def contains(data: DataNested,
space: gym.Space[DataNested],
tol_abs: float = 0.0,
Expand All @@ -254,19 +257,21 @@ def contains(data: DataNested,
:param tol_abs: Absolute tolerance.
:param tol_rel: Relative tolerance.
"""
if not isinstance(space, gym.spaces.Dict):
return _array_contains(data, *get_bounds(space), tol_abs, tol_rel)
assert isinstance(data, dict)

return all(contains(data[field], subspace, tol_abs, tol_rel)
for field, subspace in space.spaces.items())
data_type = type(data)
if tree.issubclass_mapping(data_type):
return all(contains(data[field], subspace, tol_abs, tol_rel)

Check warning on line 262 in python/gym_jiminy/common/gym_jiminy/common/utils/spaces.py

View check run for this annotation

Codacy Production / Codacy Static Code Analysis

python/gym_jiminy/common/gym_jiminy/common/utils/spaces.py#L262

Bad indentation. Found 7 spaces, expected 8
for field, subspace in space.spaces.items())
if tree.issubclass_sequence(data_type):
return all(contains(data[i], subspace, tol_abs, tol_rel)

Check warning on line 265 in python/gym_jiminy/common/gym_jiminy/common/utils/spaces.py

View check run for this annotation

Codacy Production / Codacy Static Code Analysis

python/gym_jiminy/common/gym_jiminy/common/utils/spaces.py#L265

Bad indentation. Found 7 spaces, expected 8
for i, subspace in enumerate(space.spaces))
return _array_contains(data, *get_bounds(space), tol_abs, tol_rel)


@no_type_check
def build_reduce(fn: Callable[..., ValueInT],
op: Optional[Callable[[ValueOutT, ValueInT], ValueOutT]],
dataset: SequenceT[DataNested],
space: Optional[gym.spaces.Dict],
space: Optional[gym.Space[DataNested]],
arity: Optional[Literal[0, 1]],
*args: Any,
initializer: Optional[Callable[[], ValueOutT]] = None,
Expand Down Expand Up @@ -312,8 +317,9 @@ def build_reduce(fn: Callable[..., ValueInT],
reduction. This is useful when apply in-place transform.
:param data: Pre-allocated nested data structure. Optional if the space is
provided but hardly relevant.
:param space: `gym.spaces.Dict` on which to operate. Optional iif the
nested data structure is provided.
:param space: Container space on which to operate (eg `gym.spaces.Dict` or
`gym.spaces.Tuple`). Optional iif the nested data structure
is provided.
:param arity: Arity of the generated callable. `None` to indicate that it
must be determined at runtime, which is slower.
:param args: Extra arguments to systematically forward as transform input
Expand Down
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
# pylint: disable=missing-module-docstring

from .observation_filter import FilterObservation
from .observation_stack import PartialObservationStack, StackedJiminyEnv
from .observation_stack import StackedJiminyEnv
from .normalize import NormalizeAction, NormalizeObservation
from .flatten import FlattenAction, FlattenObservation


__all__ = [
'FilterObservation',
'PartialObservationStack',
'StackedJiminyEnv',
'NormalizeAction',
'NormalizeObservation',
Expand Down
Loading

0 comments on commit 294f1e2

Please sign in to comment.