Skip to content

Commit

Permalink
[gym/common] Fix filtering and normalization wrappers.
Browse files Browse the repository at this point in the history
  • Loading branch information
duburcqa committed Apr 28, 2024
1 parent 242f925 commit 28cc703
Show file tree
Hide file tree
Showing 3 changed files with 60 additions and 15 deletions.
4 changes: 4 additions & 0 deletions python/gym_jiminy/common/gym_jiminy/common/utils/spaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,10 @@ def zeros(space: gym.Space[DataNestedT],
if enforce_bounds:
value = clip(value, space)
return value
if not isinstance(space, gym.Space):
raise ValueError(
"All spaces must derived from `gym.Space`, including tuple and "
"dict containers.")
raise NotImplementedError(
f"Space of type {type(space)} is not supported.")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,9 +57,9 @@ def _copy_filtered(data: SpaceOrDataT,

# Convert parent container to mutable dictionary
parent_type = type(parent)
if parent_type in (dict, OrderedDict):
continue
type_filtered_nodes.append(parent_type)
if parent_type in (list, dict, OrderedDict):
continue
if issubclass(parent_type, gym.Space):
parent = parent.spaces
if issubclass_mapping(parent_type):
Expand All @@ -85,17 +85,21 @@ def _copy_filtered(data: SpaceOrDataT,
for path in path_filtered_leaves):
break
*key_nested_parent, key_leaf = key_nested[:(i + 1)]
parent = reduce(getitem, key_nested_parent, out)
del parent[key_leaf]
try:
parent = reduce(getitem, key_nested_parent, out)
del parent[key_leaf]
except KeyError:
# Some nested keys may have been deleted previously
pass

# Restore original parent container types
parent_type_it = iter(type_filtered_nodes)
for key_nested, _ in out_flat:
parent_type_it = iter(type_filtered_nodes[::-1])
for key_nested, _ in out_flat[::-1]:
if key_nested not in path_filtered_leaves:
continue
for i in range(1, len(key_nested) + 1):
for i in range(1, len(key_nested) + 1)[::-1]:
# Extract parent container
*key_nested_parent, _ = key_nested
*key_nested_parent, _ = key_nested[:i]
if key_nested_parent:
*key_nested_container, key_parent = key_nested_parent
container = reduce(getitem, key_nested_container, out)
Expand All @@ -104,20 +108,19 @@ def _copy_filtered(data: SpaceOrDataT,
parent = out

# Restore original container type if not already done
parent_type = type(parent)
if parent_type in (dict, OrderedDict):
continue
parent_type = next(parent_type_it)
if isinstance(parent, parent_type):
continue
if issubclass_mapping(parent_type):
parent = parent_type(parent)
parent = parent_type(tuple(parent.items()))
elif issubclass_sequence(parent_type):
parent = parent_type(tuple(parent.values()))

# Re-assign output data structure
if key_nested_parent:
container[key_parent] = parent_type(data)
container[key_parent] = parent
else:
out = data
out = parent
return out


Expand Down Expand Up @@ -169,7 +172,7 @@ def _initialize_observation_space(self) -> None:
It gathers a subset of all the leaves of the original observation space
without any further processing.
"""
self.observation = _copy_filtered(
self.observation_space = _copy_filtered(
self.env.observation_space, self.path_filtered_leaves)

def transform_observation(self) -> None:
Expand Down
38 changes: 38 additions & 0 deletions python/gym_jiminy/unit_py/test_wrappers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
""" TODO: Write documentation
"""
import unittest
from functools import reduce

from gym_jiminy.envs import AtlasPDControlJiminyEnv
from gym_jiminy.common.wrappers import (
FilterObservation,
NormalizeAction,
NormalizeObservation,
FlattenAction,
FlattenObservation,
)


class Wrappers(unittest.TestCase):
""" TODO: Write documentation
"""
def test_filter_normalize_flatten_wrappers(self):

Check notice on line 19 in python/gym_jiminy/unit_py/test_wrappers.py

View check run for this annotation

Codacy Production / Codacy Static Code Analysis

python/gym_jiminy/unit_py/test_wrappers.py#L19

Missing function or method docstring
env = reduce(
lambda env, wrapper: wrapper(env),
(
NormalizeObservation,
NormalizeAction,
FlattenObservation,
FlattenAction
),
FilterObservation(
AtlasPDControlJiminyEnv(debug=False),
nested_filter_keys=(
("states", "pd_controller"),
("measurements", "EncoderSensor"),
("features", "mahony_filter"),
),
),
)
env.reset()
env.step(env.action)

0 comments on commit 28cc703

Please sign in to comment.