Skip to content

Commit

Permalink
[gym/common] Add foot vertical forces quantity. (#810)
Browse files Browse the repository at this point in the history
* [core] Rename 'runge_kutta_dopri5' in 'runge_kutta_dopri' for clarity.
* [gym/common] Add foot vertical forces quantity.
* [gym/common] Add foot force distribution reward.
* [gym/common] Do not pre-allocate memory for tracking op as it makes no difference.
  • Loading branch information
duburcqa committed Jun 18, 2024
1 parent e4ae9ef commit eca32f5
Show file tree
Hide file tree
Showing 21 changed files with 429 additions and 106 deletions.
2 changes: 1 addition & 1 deletion core/examples/double_pendulum/double_pendulum.cc
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ int main(int /* argc */, char * /* argv */[])
GenericConfig & worldOptions = boost::get<GenericConfig>(simuOptions.at("world"));
boost::get<Eigen::VectorXd>(worldOptions.at("gravity"))[2] = -9.81;
GenericConfig & stepperOptions = boost::get<GenericConfig>(simuOptions.at("stepper"));
boost::get<std::string>(stepperOptions.at("odeSolver")) = std::string("runge_kutta_dopri5");
boost::get<std::string>(stepperOptions.at("odeSolver")) = std::string("runge_kutta_dopri");
boost::get<double>(stepperOptions.at("tolRel")) = 1.0e-5;
boost::get<double>(stepperOptions.at("tolAbs")) = 1.0e-4;
boost::get<double>(stepperOptions.at("dtMax")) = 3.0e-3;
Expand Down
6 changes: 3 additions & 3 deletions core/include/jiminy/core/engine/engine.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ namespace jiminy
const std::map<std::string, ConstraintSolverType> CONSTRAINT_SOLVERS_MAP{
{"PGS", ConstraintSolverType::PGS}};

const std::set<std::string> STEPPERS{"euler_explicit", "runge_kutta_4", "runge_kutta_dopri5"};
const std::set<std::string> STEPPERS{"euler_explicit", "runge_kutta_4", "runge_kutta_dopri"};

class Robot;
class AbstractConstraintSolver;
Expand Down Expand Up @@ -306,8 +306,8 @@ namespace jiminy
GenericConfig config;
config["verbose"] = false;
config["randomSeedSeq"] = VectorX<uint32_t>::Zero(1).eval();
/// \details Must be either "runge_kutta_dopri5", "runge_kutta_4" or "euler_explicit".
config["odeSolver"] = std::string{"runge_kutta_dopri5"};
/// \details Must be either "runge_kutta_dopri", "runge_kutta_4" or "euler_explicit".
config["odeSolver"] = std::string{"runge_kutta_dopri"};
config["tolAbs"] = 1.0e-5;
config["tolRel"] = 1.0e-4;
config["dtMax"] = SIMULATION_MAX_TIMESTEP;
Expand Down
2 changes: 1 addition & 1 deletion core/src/engine/engine.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1149,7 +1149,7 @@ namespace jiminy
robots_.end(),
std::back_inserter(robots),
[](const auto & robot) { return robot.get(); });
if (engineOptions_->stepper.odeSolver == "runge_kutta_dopri5")
if (engineOptions_->stepper.odeSolver == "runge_kutta_dopri")
{
stepper_ = std::unique_ptr<AbstractStepper>(new RungeKuttaDOPRIStepper(
robotOde, robots, engineOptions_->stepper.tolAbs, engineOptions_->stepper.tolRel));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1145,7 +1145,8 @@ def initialize(self) -> None:
self.env.robot_state.a,
self.env.robot_state.u,
self.env.robot_state.command,
self._f_external_batch)
self._f_external_batch,
self._constraint_lambda_batch)

def refresh(self) -> State:
"""Compute the current state depending on the mode of evaluation, and
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
TrackingCapturePointReward,
TrackingFootPositionsReward,
TrackingFootOrientationsReward,
TrackingFootForceDistributionReward,
MinimizeAngularMomentumReward,
MinimizeFrictionReward)

Expand All @@ -27,6 +28,7 @@
"TrackingCapturePointReward",
"TrackingFootPositionsReward",
"TrackingFootOrientationsReward",
"TrackingFootForceDistributionReward",
"MinimizeAngularMomentumReward",
"MinimizeFrictionReward",
"SurviveReward"
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Rewards mainly relevant for locomotion tasks on floating-base robots.
"""
from functools import partial
from typing import Union, Sequence, Literal
from typing import Union, Sequence, Literal, Callable, cast

import numpy as np
import pinocchio as pin
Expand All @@ -11,7 +11,7 @@
from ..quantities import (
MaskedQuantity, UnaryOpQuantity, AverageBaseOdometryVelocity, CapturePoint,
MultiFootRelativeXYZQuat, MultiContactRelativeForceTangential,
AverageBaseMomentum)
MultiFootRelativeForceVertical, AverageBaseMomentum)
from ..quantities.locomotion import sanitize_foot_frame_names
from ..utils import quat_difference

Expand Down Expand Up @@ -169,20 +169,6 @@ def __init__(self,
# Sanitize frame names corresponding to the feet of the robot
frame_names = tuple(sanitize_foot_frame_names(env, frame_names))

# Buffer storing the difference before current and reference poses
# FIXME: Is it worth it to create a temporary ?
self._diff = np.zeros((3, len(frame_names) - 1))

# Define buffered quaternion difference operator for efficiency
def quat_difference_buffered(out: np.ndarray,
q1: np.ndarray,
q2: np.ndarray) -> np.ndarray:
"""Wrapper around `quat_difference` passing buffer in and out
instead of allocating fresh memory for efficiency.
"""
quat_difference(q1, q2, out)
return out

# Call base implementation
super().__init__(
env,
Expand All @@ -194,7 +180,55 @@ def quat_difference_buffered(out: np.ndarray,
axis=0,
keys=(3, 4, 5, 6))),
cutoff,
op=partial(quat_difference_buffered, self._diff))
op=cast(Callable[
[np.ndarray, np.ndarray], np.ndarray], quat_difference))


class TrackingFootForceDistributionReward(BaseTrackingReward):
"""Reward the agent for tracking the relative vertical force in world frame
applied on each foot.
.. note::
The force is normalized by the weight of the robot rather than the
total force applied on all feet. This is important as it not only takes
into account the force distribution between the feet, but also the
overall ground contact interact force. This way, building up momentum
before jumping will be distinguished for standing still. Moreover, it
ensures that the reward is always properly defined, even if the robot
has no contact with the ground at all, which typically arises during
the flying phase of running.
.. seealso::
See `BaseTrackingReward` documentation for technical details.
"""
def __init__(self,
env: InterfaceJiminyEnv,
cutoff: float,
*,
frame_names: Union[Sequence[str], Literal['auto']] = 'auto'
) -> None:
"""
:param env: Base or wrapped jiminy environment.
:param cutoff: Cutoff threshold for the RBF kernel transform.
:param frame_names: Name of the frames corresponding to the feet of the
robot. 'auto' to automatically detect them from the
set of contact and force sensors of the robot.
Optional: 'auto' by default.
"""
# Backup some user argument(s)
self.cutoff = cutoff

# Sanitize frame names corresponding to the feet of the robot
frame_names = tuple(sanitize_foot_frame_names(env, frame_names))

# Call base implementation
super().__init__(
env,
"reward_tracking_foot_force_distribution",
lambda mode: (MultiFootRelativeForceVertical, dict(
frame_names=frame_names,
mode=mode)),
cutoff)


class MinimizeAngularMomentumReward(BaseQuantityReward):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
AverageBaseMomentum,
MultiFootMeanXYZQuat,
MultiFootRelativeXYZQuat,
MultiFootRelativeForceVertical,
MultiContactRelativeForceTangential,
CenterOfMass,
CapturePoint,
Expand All @@ -48,6 +49,7 @@
'MultiFootMeanXYZQuat',
'MultiFootRelativeXYZQuat',
'MultiFootMeanOdometryPose',
'MultiFootRelativeForceVertical',
'MultiContactRelativeForceTangential',
'AverageFrameXYZQuat',
'AverageFrameRollPitch',
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@
import pinocchio as pin

from ..bases import (
InterfaceJiminyEnv, InterfaceQuantity, AbstractQuantity, QuantityEvalMode)
InterfaceJiminyEnv, InterfaceQuantity, AbstractQuantity, StateQuantity,
QuantityEvalMode)
from ..utils import (
matrix_to_rpy, matrix_to_quat, quat_apply, remove_yaw_from_quat,
quat_interpolate_middle)
Expand Down Expand Up @@ -1404,7 +1405,9 @@ def __init__(self,
super().__init__(
env,
parent,
requirements={},
requirements=dict(
state=(StateQuantity, dict(
update_kinematics=False))),
mode=mode,
auto_refresh=False)

Expand Down
Loading

0 comments on commit eca32f5

Please sign in to comment.