Skip to content

Commit

Permalink
[gym/common] Enable specifying reference trajectories in pipeline con…
Browse files Browse the repository at this point in the history
…fig.
  • Loading branch information
duburcqa committed May 13, 2024
1 parent 83c52b1 commit 7cad32f
Show file tree
Hide file tree
Showing 11 changed files with 250 additions and 49 deletions.
6 changes: 4 additions & 2 deletions core/src/io/json_writer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,11 @@ namespace jiminy
std::unique_ptr<Json::StreamWriter> writer(builder.newStreamWriter());
std::ostream output(&buffer);
writer->write(input, &output);
device_->resize(static_cast<int64_t>(buffer.str().size()));

device_->write(buffer.str());
// FIXME: Use `view` to get a string_view rather than a string copy when moving to C++20
const std::string data = buffer.str();
device_->resize(static_cast<int64_t>(data.size()));
device_->write(data);

device_->close();
}
Expand Down
17 changes: 16 additions & 1 deletion python/gym_jiminy/common/gym_jiminy/common/bases/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,11 @@
Callable, cast)

import numpy as np

import gymnasium as gym
from gymnasium.core import RenderFrame
from gymnasium.envs.registration import EnvSpec
from jiminy_py.dynamics import Trajectory

from .interfaces import (DT_EPS,
ObsT,
Expand Down Expand Up @@ -319,7 +321,8 @@ class ComposedJiminyEnv(
def __init__(self,
env: InterfaceJiminyEnv[ObsT, ActT],
*,
reward: AbstractReward) -> None:
reward: Optional[AbstractReward] = None,
trajectories: Optional[Dict[str, Trajectory]] = None) -> None:
"""
:param env: Environment to extend, eventually already wrapped.
:param reward: Reward object deriving from `AbstractReward`. It will be
Expand All @@ -329,6 +332,11 @@ def __init__(self,
the provided environment. `None` for not considering any
reward.
Optional: `None` by default.
:param trajectories: Set of named trajectories as a dictionary whose
(key, value) pairs are respectively the name of
each trajectory and the trajectory itself. `None`
for not considering any trajectory.
Optional: `None` by default.
"""
# Make sure that the unwrapped environment matches the reward one
assert reward is None or env.unwrapped is reward.env.unwrapped
Expand All @@ -339,6 +347,11 @@ def __init__(self,
# Initialize base class
super().__init__(env)

# Add reference trajectories to all managed quantities if requested
if trajectories is not None:
for name, trajectory in trajectories.items():
self.env.quantities.add_trajectory(name, trajectory)

# Bind observation and action of the base environment
assert self.observation_space.contains(self.env.observation)
assert self.action_space.contains(self.env.action)
Expand Down Expand Up @@ -396,6 +409,8 @@ def compute_command(self, action: ActT, command: np.ndarray) -> None:
self.env.compute_command(action, command)

def compute_reward(self, terminated: bool, info: InfoType) -> float:
if self.reward is None:
return 0.0
return self.reward(terminated, info)


Expand Down
19 changes: 12 additions & 7 deletions python/gym_jiminy/common/gym_jiminy/common/bases/quantities.py
Original file line number Diff line number Diff line change
Expand Up @@ -762,6 +762,12 @@ def discard(self, name: str) -> None:
:param name: Name of the trajectory to discard.
"""
# Un-select trajectory if it corresponds to the discarded one
if self._name == name:
self._trajectory = None
self._name = ""

# Delete trajectory for global registry
del self.registry[name]

@sync
Expand Down Expand Up @@ -799,12 +805,10 @@ def name(self) -> str:
@InterfaceQuantity.cache.setter # type: ignore[attr-defined]
def cache(self, cache: Optional[SharedCache[ValueT]]) -> None:
# Get existing registry if any and making sure not already out-of-sync
registry: Optional[OrderedDict[str, Trajectory]] = None
owner: Optional[InterfaceQuantity] = None
if cache is not None and cache.owners:
owner: InterfaceQuantity = next(iter(cache.owners))
owner = next(iter(cache.owners))
assert isinstance(owner, DatasetTrajectoryQuantity)
registry = owner.registry
name, mode = owner._name, owner._mode
if self._trajectory:
raise RuntimeError(
"Trajectory dataset not empty. Impossible to add a shared "
Expand All @@ -814,9 +818,10 @@ def cache(self, cache: Optional[SharedCache[ValueT]]) -> None:
InterfaceQuantity.cache.fset(self, cache) # type: ignore[attr-defined]

# Catch-up synchronization
if registry is not None:
self.registry = registry
self.select(name, mode)
if owner:
self.registry = owner.registry
if owner._trajectory is not None:
self.select(owner._name, owner._mode)

def refresh(self) -> State:
"""Compute state of selected trajectory at current simulation time.
Expand Down
7 changes: 6 additions & 1 deletion python/gym_jiminy/common/gym_jiminy/common/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,10 @@
get_fieldnames,
register_variables,
sample)
from .pipeline import build_pipeline, load_pipeline
from .pipeline import (save_trajectory_to_hdf5,
load_trajectory_to_hdf5,
build_pipeline,
load_pipeline)


__all__ = [
Expand Down Expand Up @@ -86,6 +89,8 @@
'is_nan',
'get_fieldnames',
'register_variables',
'save_trajectory_to_hdf5',
'load_trajectory_to_hdf5',
'build_pipeline',
'load_pipeline'
]
191 changes: 175 additions & 16 deletions python/gym_jiminy/common/gym_jiminy/common/utils/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,20 @@
import json
import pathlib
from pydoc import locate
from dataclasses import asdict
from functools import partial
from typing import (
Dict, Any, Optional, Union, Type, Sequence, Callable, TypedDict)
Dict, Any, Optional, Union, Type, Sequence, Callable, TypedDict, Literal,
cast)

import h5py
import toml
import numpy as np
import gymnasium as gym

import jiminy_py.core as jiminy
from jiminy_py.dynamics import State, Trajectory

from ..bases import (InterfaceJiminyEnv,
InterfaceBlock,
BaseControllerBlock,
Expand Down Expand Up @@ -53,6 +60,36 @@ class RewardConfig(TypedDict, total=False):
"""


class TrajectoriesConfig(TypedDict, total=False):
"""Store information required for adding a database of reference
trajectories to the environment.
Specifically, it is a dictionary comprising a set of named trajectories as
a dictionary whose keys are the name of the trajectories and values are
either the trajectory itself or the path of a file storing its dump in HDF5
format, the name of the selected trajectory, and its interpolation mode.
"""

dataset: Dict[str, Union[str, Trajectory]]
"""Set of named trajectories as a dictionary.
.. note::
Both `Trajectory` objects or path (absolute or relative) are supported.
"""

name: str
"""Name of the selected trajectory if any.
This attribute can be omitted.
"""

mode: Literal['raise', 'wrap', 'clip']
"""Interpolation mode of the selected trajectory if any.
This attribute can be omitted.
"""


class EnvConfig(TypedDict, total=False):
"""Store information required for instantiating a given base environment
and compose it with some additional reward components and termination
Expand Down Expand Up @@ -84,6 +121,12 @@ class EnvConfig(TypedDict, total=False):
This attribute can be omitted.
"""

trajectories: TrajectoriesConfig
"""Reference trajectories configuration.
This attribute can be omitted.
"""


class BlockConfig(TypedDict, total=False):
"""Store information required for instantiating a given observation or
Expand Down Expand Up @@ -163,7 +206,9 @@ class LayerConfig(TypedDict, total=False):


def build_pipeline(env_config: EnvConfig,
layers_config: Sequence[LayerConfig]
layers_config: Sequence[LayerConfig],
*,
root_path: Optional[Union[str, pathlib.Path]] = None
) -> Callable[..., InterfaceJiminyEnv]:
"""Wrap together an environment inheriting from `BaseJiminyEnv` with any
number of layers, as a unified pipeline environment class inheriting from
Expand Down Expand Up @@ -232,15 +277,18 @@ def build_reward(env: InterfaceJiminyEnv,

# Define helper to build reward
def build_composition(env_creator: Callable[..., InterfaceJiminyEnv],
reward_config: RewardConfig,
**env_kwargs: Any) -> BasePipelineWrapper:
reward_config: Optional[RewardConfig],
trajectories_config: Optional[TrajectoriesConfig],
**env_kwargs: Any) -> InterfaceJiminyEnv:
"""Helper adding reward on top of a base environment or a pipeline
using `ComposedJiminyEnv` wrapper.
:param env_creator: Callable that takes optional keyword arguments as
input and returns an pipeline or base environment.
:param reward_config: Configuration of the reward, as a dict of type
`RewardConfig`.
:param trajectories: Set of named trajectories as a dictionary. See
`ComposedJiminyEnv` documentation for details.
:param env_kwargs: Keyword arguments to forward to the constructor of
the wrapped environment. Note that it will only
overwrite the default value, so it will still be
Expand All @@ -252,10 +300,29 @@ def build_composition(env_creator: Callable[..., InterfaceJiminyEnv],
env = env_creator(**env_kwargs)

# Instantiate the reward
reward = build_reward(env, reward_config)

# Instantiate the wrapper
return ComposedJiminyEnv(env, reward=reward)
reward = None
if reward_config is not None:
reward = build_reward(env, reward_config)

# Get trajectory dataset
trajectories: Dict[str, Trajectory] = {}
if trajectories_config is not None:
trajectories = cast(
Dict[str, Trajectory], trajectories_config["dataset"])

# Instantiate the composition wrapper if necessary
if reward or trajectories:
env = ComposedJiminyEnv(
env, reward=reward, trajectories=trajectories)

# Select the reference trajectory if specified
if trajectories_config is not None:
name = trajectories_config.get("name")
if name is not None:
mode = trajectories_config.get("mode", "raise")
env.quantities.select_trajectory(name, mode)

return env

# Define helper to wrap a single layer
def build_layer(env_creator: Callable[..., InterfaceJiminyEnv],
Expand Down Expand Up @@ -329,13 +396,33 @@ def build_layer(env_creator: Callable[..., InterfaceJiminyEnv],
pipeline_creator: Callable[..., InterfaceJiminyEnv] = partial(
env_cls, **env_config.get("kwargs", {}))

# Compose base environment with an extra user-specified reward if any
# Parse reward configuration
reward_config = env_config.get("reward")
if reward_config is not None:
sanitize_reward_config(reward_config)
pipeline_creator = partial(build_composition,
pipeline_creator,
reward_config)

# Parse trajectory configuration
trajectories_config = env_config.get("trajectories")
if trajectories_config is not None:
trajectories = trajectories_config['dataset']
assert isinstance(trajectories, dict)
for name, path_or_traj in trajectories.items():
if isinstance(path_or_traj, Trajectory):
continue
path = pathlib.Path(path_or_traj)
if not path.is_absolute():
if root_path is None:
raise RuntimeError(
"The argument 'root_path' must be provided when "
"specifying relative trajectory paths.")
path = pathlib.Path(root_path) / path
trajectories[name] = load_trajectory_to_hdf5(path)

# Compose base environment with an extra user-specified reward
pipeline_creator = partial(build_composition,
pipeline_creator,
reward_config,
trajectories_config)

# Generate pipeline recursively
for layer_config in layers_config:
Expand Down Expand Up @@ -397,15 +484,87 @@ def build_layer(env_creator: Callable[..., InterfaceJiminyEnv],
return pipeline_creator


def load_pipeline(fullpath: str) -> Callable[..., InterfaceJiminyEnv]:
def load_pipeline(fullpath: Union[str, pathlib.Path]
) -> Callable[..., InterfaceJiminyEnv]:
"""Load pipeline from JSON or TOML configuration file.
:param: Fullpath of the configuration file.
"""
file_ext = pathlib.Path(fullpath).suffix
fullpath = pathlib.Path(fullpath)
root_path, file_ext = fullpath.parent, fullpath.suffix
with open(fullpath, 'r') as f:
if file_ext == '.json':
return build_pipeline(**json.load(f))
return build_pipeline(**json.load(f), root_path=root_path)
if file_ext == '.toml':
return build_pipeline(**toml.load(f))
return build_pipeline(**toml.load(f), root_path=root_path)
raise ValueError("Only json and toml formats are supported.")


def save_trajectory_to_hdf5(trajectory: Trajectory,
fullpath: Union[str, pathlib.Path]) -> None:
"""Export a trajectory object to HDF5 format.
:param trajectory: Trajectory object to save.
:param fullpath: Fullpath of the generated HDF5 file.
"""
# Create HDF5 file
hdf_obj = h5py.File(fullpath, "w")

# Dump each state attribute that are specified for all states at once
if trajectory.states:
state_dict = asdict(trajectory.states[0])
state_fields = tuple(
key for key, value in state_dict.items() if value is not None)
for key in state_fields:
data = np.stack([
getattr(state, key) for state in trajectory.states], axis=0)
hdf_obj.create_dataset(name=f"states/{key}", data=data)

# Dump serialized robot
robot_data = jiminy.save_to_binary(trajectory.robot)
dataset = hdf_obj.create_dataset(name="robot", data=np.array(robot_data))

# Dump whether to use the theoretical model of the robot
dataset.attrs["use_theoretical_model"] = trajectory.use_theoretical_model

# Close the HDF5 file
hdf_obj.close()


def load_trajectory_to_hdf5(fullpath: Union[str, pathlib.Path]) -> Trajectory:
"""Import a trajectory object from file in HDF5 format.
:param fullpath: Fullpath of the HDF5 file to import.
:returns: Loaded trajectory object.
"""
# Open HDF5 file
hdf_obj = h5py.File(fullpath, "r")

# Get all state attributes that are specified
states_dict = {}
if 'states' in hdf_obj.keys():
for key, value in hdf_obj['states'].items():
states_dict[key] = value[...]

# Re-construct state sequence
states = []
for args in zip(*states_dict.values()):
states.append(State(**dict(zip(states_dict.keys(), args))))

# Build trajectory from data.
# Null char '\0' must be added at the end to match original string length.
dataset = hdf_obj['robot']
robot_data = dataset[()]
robot_data += b'\0' * (
dataset.nbytes - len(robot_data)) # pylint: disable=no-member
robot = jiminy.load_from_binary(robot_data)

# Load whether to use the theoretical model of the robot
use_theoretical_model = dataset.attrs["use_theoretical_model"]

# Close the HDF5 file
hdf_obj.close()

# Re-construct the whole trajectory
return Trajectory(states, robot, use_theoretical_model)
Loading

0 comments on commit 7cad32f

Please sign in to comment.