Skip to content

Commit

Permalink
[python|gym] Add contraint lambda multipliers to state and extracted …
Browse files Browse the repository at this point in the history
…trajectory.
  • Loading branch information
duburcqa committed Jun 16, 2024
1 parent cc4bc75 commit 819dbfa
Show file tree
Hide file tree
Showing 5 changed files with 84 additions and 27 deletions.
61 changes: 53 additions & 8 deletions python/gym_jiminy/common/gym_jiminy/common/bases/quantities.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@

import jiminy_py.core as jiminy
from jiminy_py.core import ( # pylint: disable=no-name-in-module
multi_array_copyto)
array_copyto, multi_array_copyto)
from jiminy_py.dynamics import State, Trajectory, update_quantities
import pinocchio as pin

Expand Down Expand Up @@ -936,10 +936,19 @@ def __init__(self,
self.pinocchio_data = env.robot.pinocchio_data

# State for which the quantity must be evaluated
self._f_external_slices: Tuple[np.ndarray, ...] = ()
self._f_external_list: Tuple[np.ndarray, ...] = ()
self._f_external_slices: List[np.ndarray] = []
self._f_external_list: List[np.ndarray] = []
self.state = State(t=np.nan, q=np.array([]))

# Persistent buffer storing all lambda multipliers for efficiency
self._constraint_lambda_all = np.array({})

# Slices in stacked lambda multiplier flat vector
self._constraint_lambda_slices: List[np.ndarray] = []

# Lambda multipliers of all the constraints individually
self._constraint_lambda_list: List[np.ndarray] = []

def initialize(self) -> None:
# Refresh robot and pinocchio proxies for co-owners of shared cache.
# Note that automatic refresh is not sufficient to guarantee that
Expand Down Expand Up @@ -972,13 +981,14 @@ def initialize(self) -> None:
# The quantity will be considered initialized and active at this point.
super().initialize()

# State for which the quantity must be evaluated
# Define the state for which the quantity must be evaluated
if self.mode == QuantityEvalMode.TRUE:
# Refresh mapping from robot state to quantity buffer
if not self.env.is_simulation_running:
raise RuntimeError("No simulation running. Impossible to "
"initialize this quantity.")
self._f_external_list = tuple(
f_ext.vector for f_ext in self.env.robot_state.f_external)
self._f_external_list = [
f_ext.vector for f_ext in self.env.robot_state.f_external]
if self._f_external_list:
f_external_batch = np.stack(self._f_external_list, axis=0)
else:
Expand All @@ -991,7 +1001,37 @@ def initialize(self) -> None:
self.env.robot_state.u,
self.env.robot_state.command,
f_external_batch)
self._f_external_slices = tuple(f_external_batch)
self._f_external_slices = list(f_external_batch)
else:
# Allocate memory for lambda vector
self._constraint_lambda_all = np.zeros(
(len(self.robot.log_constraint_fieldnames),))

# Refresh mapping from lambda multipliers to corresponding slice
self._constraint_lambda_list.clear()
self._constraint_lambda_slices.clear()
constraint_lookup_pairs = tuple(
(f"Constraint{registry_type}", registry)
for registry_type, registry in (
("BoundJoints", self.robot.constraints.bounds_joints),
("ContactFrames", self.robot.constraints.contact_frames),
("CollisionBodies", {
name: constraint for constraints in (
self.robot.constraints.collision_bodies)
for name, constraint in constraints.items()}),
("User", self.robot.constraints.user)))
i = 0
while i < len(self.robot.log_constraint_fieldnames):
fieldname = self.robot.log_constraint_fieldnames[i]
for registry_type, registry in constraint_lookup_pairs:
if fieldname.startswith(registry_type):
break
constraint_name = fieldname[len(registry_type):-1]
constraint = registry[constraint_name]
self._constraint_lambda_list.append(constraint.lambda_c)
self._constraint_lambda_slices.append(
self._constraint_lambda_all[i:(i + constraint.size)])
i += constraint.size

def refresh(self) -> State:
"""Compute the current state depending on the mode of evaluation, and
Expand All @@ -1003,8 +1043,8 @@ def refresh(self) -> State:
else:
self.state = self.trajectory.get()

# Update all the dynamical quantities that can be given available data
if self.mode == QuantityEvalMode.REFERENCE:
# Update all dynamical quantities that can be given available data
update_quantities(
self.robot,
self.state.q,
Expand All @@ -1017,5 +1057,10 @@ def refresh(self) -> State:
update_collisions=True,
use_theoretical_model=self.trajectory.use_theoretical_model)

# Restore lagrangian multipliers of the constraints
array_copyto(self._constraint_lambda_all, self.state.lambda_c)
multi_array_copyto(
self._constraint_lambda_list, self._constraint_lambda_slices)

# Return state
return self.state
2 changes: 1 addition & 1 deletion python/gym_jiminy/unit_py/test_quantities.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,7 @@ def test_masked(self):
env.quantities["v_masked"], quantity.data[[0, 2, 4]])

def test_true_vs_reference(self):
env = gym.make("gym_jiminy.envs:atlas")
env = gym.make("gym_jiminy.envs:atlas", debug=True)

frame_names = [
frame.name for frame in env.robot.pinocchio_model.frames]
Expand Down
14 changes: 7 additions & 7 deletions python/gym_jiminy/unit_py/test_rewards.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,15 +28,15 @@ class Rewards(unittest.TestCase):
""" TODO: Write documentation
"""
def setUp(self):
self.env = gym.make("gym_jiminy.envs:atlas")

self.env.reset(seed=1)
action = self.env.action_space.sample()
env = gym.make("gym_jiminy.envs:atlas", debug=True)
env.reset(seed=1)
action = env.action_space.sample()
for _ in range(10):
self.env.step(action)
self.env.stop()
env.step(action)
env.stop()
trajectory = extract_trajectory_from_log(env.log_data)

trajectory = extract_trajectory_from_log(self.env.log_data)
self.env = gym.make("gym_jiminy.envs:atlas")
self.env.quantities.add_trajectory("reference", trajectory)
self.env.quantities.select_trajectory("reference")

Expand Down
21 changes: 18 additions & 3 deletions python/jiminy_py/src/jiminy_py/dynamics.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import logging
from bisect import bisect_right
from dataclasses import dataclass
from typing import Optional, Tuple, Sequence, Callable, Literal
from typing import Dict, Optional, Tuple, Sequence, Callable, Literal

import numpy as np

Expand Down Expand Up @@ -138,6 +138,13 @@ class State:
coordinates (Fx, Fy, Fz, Mx, My, Mz).
"""

lambda_c: Optional[np.ndarray] = None
"""Lambda multipliers associated with all the constraints as a 2D array.
The first dimension corresponds to the N individual constraints applied on
the robot, while the second gathers the lambda multipliers.
"""


@dataclass(unsafe_hash=True)
class Trajectory:
Expand Down Expand Up @@ -201,21 +208,23 @@ def __init__(self,
self._index_prev = 0

# List of optional state fields that are provided
self._fields: Tuple[str, ...] = ()
self._has_velocity = False
self._has_acceleration = False
self._has_effort = False
self._has_command = False
self._has_external_forces = False
self._has_constraints = False
if states:
state = states[0]
self._has_velocity = state.v is not None
self._has_acceleration = state.a is not None
self._has_effort = state.u is not None
self._has_command = state.command is not None
self._has_external_forces = state.f_external is not None
self._has_constraints = state.lambda_c is not None
self._fields = tuple(
field for field in ("v", "a", "u", "command", "f_external")
field for field in (
"v", "a", "u", "command", "f_external", "lambda_c")
if getattr(state, field) is not None)

@property
Expand Down Expand Up @@ -254,6 +263,12 @@ def has_external_forces(self) -> bool:
"""
return self._has_external_forces

@property
def has_constraints(self) -> bool:
"""Whether the trajectory contains lambda multipliers of constraints.
"""
return self._has_constraints

@property
def time_interval(self) -> Tuple[float, float]:
"""Time interval of the trajectory.
Expand Down
13 changes: 5 additions & 8 deletions python/jiminy_py/src/jiminy_py/log.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,13 +54,9 @@ def extract_variables_from_log(log_vars: Dict[str, np.ndarray],
:param fieldnames: Structured fieldnames.
:param namespace: Namespace of the fieldnames. Empty string to disable.
Optional: Empty by default.
:param keep_structure: Whether to return a dictionary mapping flattened
fieldnames to values.
Optional: True by default.
:returns:
`np.ndarray` or None for each fieldname individually depending if it is
found or not.
:param as_dict: Whether to return a dictionary mapping flattened fieldnames
to values.
Optional: True by default.
"""
# Extract values from log if it exists
if as_dict:
Expand Down Expand Up @@ -238,7 +234,8 @@ def extract_trajectory_from_log(log_data: Dict[str, Any],
"acceleration",
"effort",
"command",
"f_external"):
"f_external",
"constraint"):
fieldnames = getattr(robot, f"log_{name}_fieldnames")
try:
data[name] = extract_variables_from_log(
Expand Down

0 comments on commit 819dbfa

Please sign in to comment.