Skip to content

Commit

Permalink
[gym/common] Add robot flying termination condition.
Browse files Browse the repository at this point in the history
  • Loading branch information
duburcqa committed Jun 24, 2024
1 parent 262cf3e commit 7dffeb4
Show file tree
Hide file tree
Showing 13 changed files with 265 additions and 57 deletions.
2 changes: 1 addition & 1 deletion core/include/jiminy/core/engine/engine.h
Original file line number Diff line number Diff line change
Expand Up @@ -292,7 +292,7 @@ namespace jiminy
config["groundProfile"] = HeightmapFunction(
[](const Eigen::Vector2d & /* xy */,
double & height,
Eigen::Vector3d & normal) -> void
Eigen::Ref<Eigen::Vector3d> normal) -> void
{
height = 0.0;
normal = Eigen::Vector3d::UnitZ();
Expand Down
5 changes: 3 additions & 2 deletions core/include/jiminy/core/fwd.h
Original file line number Diff line number Diff line change
Expand Up @@ -190,8 +190,9 @@ namespace jiminy
};

// Ground profile functors
using HeightmapFunction = std::function<void(
const Eigen::Vector2d & /* xy */, double & /* height */, Eigen::Vector3d & /* normal */)>;
using HeightmapFunction = std::function<void(const Eigen::Vector2d & /* xy */,
double & /* height */,
Eigen::Ref<Eigen::Vector3d> /* normal */)>;

// Flexible joints
struct FlexibilityJointConfig
Expand Down
4 changes: 3 additions & 1 deletion core/src/io/serialization.cc
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,9 @@ namespace boost::serialization
template<class Archive>
void load(Archive & /* ar */, HeightmapFunction & fun, const unsigned int /* version */)
{
fun = [](const Eigen::Vector2d & /* xy */, double & height, Eigen::Vector3d & normal)
fun = [](const Eigen::Vector2d & /* xy */,
double & height,
Eigen::Ref<Eigen::Vector3d> normal)
{
height = 0.0;
normal = Eigen::Vector3d::UnitZ();
Expand Down
14 changes: 9 additions & 5 deletions core/src/utilities/geometry.cc
Original file line number Diff line number Diff line change
Expand Up @@ -696,8 +696,9 @@ namespace jiminy
{
return heightmaps[0];
}
return [heightmaps](
const Eigen::Vector2d & pos, double & height, Eigen::Vector3d & normal) -> void
return [heightmaps](const Eigen::Vector2d & pos,
double & height,
Eigen::Ref<Eigen::Vector3d> normal) -> void
{
thread_local static double height_i;
thread_local static Eigen::Vector3d normal_i;
Expand All @@ -720,8 +721,9 @@ namespace jiminy
{
return heightmaps[0];
}
return [heightmaps](
const Eigen::Vector2d & pos, double & height, Eigen::Vector3d & normal) -> void
return [heightmaps](const Eigen::Vector2d & pos,
double & height,
Eigen::Ref<Eigen::Vector3d> normal) -> void
{
thread_local static double height_i;
thread_local static Eigen::Vector3d normal_i;
Expand Down Expand Up @@ -757,7 +759,9 @@ namespace jiminy
const Eigen::Rotation2D<double> rot_mat(orientation);

return [stepWidth, stepHeight, stepNumber, rot_mat, interpDelta](
const Eigen::Vector2d & pos, double & height, Eigen::Vector3d & normal) -> void
const Eigen::Vector2d & pos,
double & height,
Eigen::Ref<Eigen::Vector3d> normal) -> void
{
// Compute position in stairs reference frame
Eigen::Vector2d posRel = rot_mat.inverse() * pos;
Expand Down
13 changes: 7 additions & 6 deletions core/src/utilities/json.cc
Original file line number Diff line number Diff line change
Expand Up @@ -145,12 +145,13 @@ namespace jiminy
template<>
HeightmapFunction convertFromJson<HeightmapFunction>(const Json::Value & /* value */)
{
return {
[](const Eigen::Vector2d & /* xy */, double & height, Eigen::Vector3d & normal) -> void
{
height = 0.0;
normal = Eigen::Vector3d::UnitZ();
}};
return {[](const Eigen::Vector2d & /* xy */,
double & height,
Eigen::Ref<Eigen::Vector3d> normal) -> void
{
height = 0.0;
normal = Eigen::Vector3d::UnitZ();
}};
}

template<>
Expand Down
4 changes: 3 additions & 1 deletion core/src/utilities/random.cc
Original file line number Diff line number Diff line change
Expand Up @@ -678,7 +678,9 @@ namespace jiminy
const Eigen::Rotation2D<double> rot_mat(orientation);

return [size, heightMax, interpDelta, rot_mat, sparsity, interpThr, offset, seed](
const Eigen::Vector2d & pos, double & height, Eigen::Vector3d & normal) -> void
const Eigen::Vector2d & pos,
double & height,
Eigen::Ref<Eigen::Vector3d> normal) -> void
{
// Compute the tile index and relative coordinate
Eigen::Vector2d posRel = (rot_mat * (pos + offset)).array() / size.array();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@
MinimizeFrictionReward,
BaseRollPitchTermination,
BaseHeightTermination,
FootCollisionTermination)
FootCollisionTermination,
FlyingTermination)

__all__ = [
"CUTOFF_ESP",
Expand All @@ -41,6 +42,7 @@
"DriftTrackingQuantityTermination",
"ShiftTrackingQuantityTermination",
"MechanicalSafetyTermination",
"FlyingTermination",
"BaseRollPitchTermination",
"BaseHeightTermination",
"FootCollisionTermination"
Expand Down
29 changes: 13 additions & 16 deletions python/gym_jiminy/common/gym_jiminy/common/compositions/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

from ..bases import (
InfoType, QuantityCreator, InterfaceJiminyEnv, InterfaceQuantity,
AbstractQuantity, QuantityEvalMode, AbstractReward, QuantityReward,
QuantityEvalMode, AbstractReward, QuantityReward,
AbstractTerminationCondition, QuantityTermination)
from ..bases.compositions import ArrayOrScalar, ArrayLikeOrScalar
from ..quantities import (
Expand Down Expand Up @@ -384,16 +384,14 @@ def _compute_min_distance(self,

@dataclass(unsafe_hash=True)
class _MultiActuatedJointBoundDistance(
AbstractQuantity[Tuple[np.ndarray, np.ndarray]]):
InterfaceQuantity[Tuple[np.ndarray, np.ndarray]]):
"""Distance of the actuated joints from their respective lower and upper
mechanical stops.
"""

def __init__(self,
env: InterfaceJiminyEnv,
parent: Optional[InterfaceQuantity],
*,
mode: QuantityEvalMode = QuantityEvalMode.TRUE) -> None:
parent: Optional[InterfaceQuantity]) -> None:
"""
:param env: Base or wrapped jiminy environment.
:param parent: Higher-level quantity from which this quantity is a
Expand All @@ -406,32 +404,31 @@ def __init__(self,
parent,
requirements=dict(
position=(MultiActuatedJointKinematic, dict(
kinematic_level=pin.KinematicLevel.POSITION))),
mode=mode,
kinematic_level=pin.KinematicLevel.POSITION,
mode=QuantityEvalMode.TRUE))),
auto_refresh=False)

# Lower and upper bounds of the actuated joints
self.position_lower = np.array([])
self.position_upper = np.array([])
self.position_low, self.position_high = np.array([]), np.array([])

def initialize(self) -> None:
# Call base implementation
super().initialize()

# Initialize the motor position indices
# Initialize the actuated joint position indices
quantity = self.requirements["position"]
quantity.initialize()
position_indices = quantity.kinematic_indices

# Refresh mechanical joint position indices
position_limit_lower = self.robot.pinocchio_model.lowerPositionLimit
self.position_lower = position_limit_lower[position_indices]
position_limit_upper = self.robot.pinocchio_model.upperPositionLimit
self.position_upper = position_limit_upper[position_indices]
position_limit_low = self.env.robot.pinocchio_model.lowerPositionLimit
self.position_low = position_limit_low[position_indices]
position_limit_high = self.env.robot.pinocchio_model.upperPositionLimit
self.position_high = position_limit_high[position_indices]

def refresh(self) -> Tuple[np.ndarray, np.ndarray]:
return (self.position - self.position_lower,
self.position_upper - self.position)
return (self.position - self.position_low,
self.position_high - self.position)


class MechanicalSafetyTermination(AbstractTerminationCondition):
Expand Down
159 changes: 151 additions & 8 deletions python/gym_jiminy/common/gym_jiminy/common/compositions/locomotion.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,24 @@
"""Rewards mainly relevant for locomotion tasks on floating-base robots.
"""
from functools import partial
from dataclasses import dataclass
from typing import Optional, Union, Sequence, Literal, Callable, cast

import numpy as np
import numba as nb

import jiminy_py.core as jiminy
import pinocchio as pin

from ..bases import (
InterfaceJiminyEnv, StateQuantity, QuantityEvalMode, QuantityReward)
InterfaceJiminyEnv, StateQuantity, InterfaceQuantity, QuantityEvalMode,
QuantityReward)
from ..quantities import (
OrientationType, MaskedQuantity, UnaryOpQuantity, FrameOrientation,
BaseRelativeHeight, BaseOdometryAverageVelocity, CapturePoint,
MultiFootRelativeXYZQuat, MultiContactNormalizedForceTangential,
MultiFootNormalizedForceVertical, MultiFootCollisionDetection,
AverageBaseMomentum)
MultiFramePosition, MultiFootRelativeXYZQuat,
MultiContactNormalizedForceTangential, MultiFootNormalizedForceVertical,
MultiFootCollisionDetection, AverageBaseMomentum)
from ..quantities.locomotion import sanitize_foot_frame_names
from ..utils import quat_difference

Expand Down Expand Up @@ -347,13 +352,14 @@ class BaseHeightTermination(QuantityTermination):
"""
def __init__(self,
env: InterfaceJiminyEnv,
thr: float,
min_height: float,
grace_period: float = 0.0,
*,
is_training_only: bool = False) -> None:
"""
:param env: Base or wrapped jiminy environment.
:param low: Lower bound below which termination is triggered.
:param min_height: Minimum height of the floating base of the robot
below which termination is triggered.
:param grace_period: Grace period effective only at the very beginning
of the episode, during which the latter is bound
to continue whatever happens.
Expand All @@ -368,7 +374,7 @@ def __init__(self,
env,
"termination_base_height",
(BaseRelativeHeight, {}), # type: ignore[arg-type]
thr,
min_height,
None,
grace_period,
is_truncation=False,
Expand All @@ -391,7 +397,11 @@ def __init__(self,
is_training_only: bool = False) -> None:
"""
:param env: Base or wrapped jiminy environment.
:param low: Lower bound below which termination is triggered.
:param security_margin:
Minimum signed distance below which termination is triggered. This
can be interpreted as inflating or deflating the geometry objects
by the safety margin depending on whether it is positive or
negative. See `MultiFootCollisionDetection` for details.
:param grace_period: Grace period effective only at the very beginning
of the episode, during which the latter is bound
to continue whatever happens.
Expand All @@ -412,3 +422,136 @@ def __init__(self,
grace_period,
is_truncation=False,
is_training_only=is_training_only)


@dataclass(unsafe_hash=True)
class _MultiContactMinGroundDistance(InterfaceQuantity[float]):
"""Minimum distance from the ground profile among all the contact points.
.. note::
Internally, it does not compute the exact shortest distance from the
ground profile because it would be computionally too demanding for now.
As a surrogate, it relies on a first order approximation assuming zero
local curvature around all the contact points individually.
.. warning::
The set of contact points must not change over episodes. In addition,
collision bodies are not supported for now.
"""

def __init__(self,
env: InterfaceJiminyEnv,
parent: Optional[InterfaceQuantity]) -> None:
"""
:param env: Base or wrapped jiminy environment.
:param parent: Higher-level quantity from which this quantity is a
requirement if any, `None` otherwise.
"""
# Get the name of all the contact points
contact_frame_names = env.robot.contact_frame_names

# Call base implementation
super().__init__(
env,
parent,
requirements=dict(
positions=(MultiFramePosition, dict(
frame_names=contact_frame_names,
mode=QuantityEvalMode.TRUE
))),
auto_refresh=False)

# Define jit-able method for computing minimum first-order depth
@nb.jit(nopython=True, cache=True, fastmath=True)
def min_depth(positions, heights, normals):
"""Approximate minimum distance from the ground profile among a set
of the query points.
Internally, it uses a first order approximation assuming zero local
curvature around each query point.
:param positions: Position of all the query points from which to
compute from the ground profile, as a 2D array
whose first dimension gathers the 3 position
coordinates (X, Y, Z) while the second correponds
to the N individual query points.
:param heights: Vertical height wrt the ground profile of the N
individual query points in world frame as 1D array.
:param normals: Normal of the ground profile for the projection in
world plane of all the query points, as a 2D array
whose first dimension gathers the 3 position
coordinates (X, Y, Z) while the second correponds
to the N individual query points.
"""
return np.min((positions[2] - heights) * normals[2])

self._min_depth = min_depth

# Reference to the heightmap function for the ongoing epsiode
self._heightmap = jiminy.HeightmapFunction(lambda: None)

# Allocate memory for the height and normal of all the contact points
self._heights = np.zeros((len(contact_frame_names),))
self._normals = np.zeros((3, len(contact_frame_names)), order="F")

def initialize(self) -> None:
# Call base implementation
super().initialize()

# Refresh the heighmap function
engine_options = self.env.unwrapped.engine.get_options()
self._heightmap = engine_options["world"]["groundProfile"]

def refresh(self) -> float:
# Query the height and normal to the ground profile for the position in
# world plane of all the contact points.
jiminy.query_heightmap(self._heightmap,
self.positions[:2],
self._heights,
self._normals)

# Make sure the ground normal is normalized
# self._normals /= np.linalg.norm(self._normals, axis=0)

# First-order distance estimation assuming no curvature
return self._min_depth(self.positions, self._heights, self._normals)


class FlyingTermination(QuantityTermination):
"""Discourage the agent of jumping by terminating the episode immediately
if the robot is flying too high above the ground.
This kind of jumping behavior is unsually undesirable because it may be
frightning for people nearby, difficule to predict and hardly repeatable.
Moreover, they tend to transfer poorly to reality as very dynamic motions
worsen the simulation to real gap.
"""
def __init__(self,
env: InterfaceJiminyEnv,
max_height: float,
grace_period: float = 0.0,
*,
is_training_only: bool = False) -> None:
"""
:param env: Base or wrapped jiminy environment.
:param max_height: Maximum height of all the lowest contact points wrt
the groupd above which termination is triggered.
:param grace_period: Grace period effective only at the very beginning
of the episode, during which the latter is bound
to continue whatever happens.
Optional: 0.0 by default.
:param is_training_only: Whether the termination condition should be
completely by-passed if the environment is in
evaluation mode.
Optional: False by default.
"""
# Call base implementation
super().__init__(
env,
"termination_flying",
(_MultiContactMinGroundDistance, {}), # type: ignore[arg-type]
None,
max_height,
grace_period,
is_truncation=False,
is_training_only=is_training_only)
Loading

0 comments on commit 7dffeb4

Please sign in to comment.