Skip to content

Commit

Permalink
Unify api.contact_model in api.contact
Browse files Browse the repository at this point in the history
  • Loading branch information
flferretti committed Feb 10, 2025
1 parent 62adf79 commit f584892
Show file tree
Hide file tree
Showing 3 changed files with 99 additions and 106 deletions.
98 changes: 98 additions & 0 deletions src/jaxsim/api/contact.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import jaxsim.typing as jtp
from jaxsim import logging
from jaxsim.math import Adjoint, Cross, Transform
from jaxsim.rbda.contacts import SoftContacts

from .common import VelRepr

Expand Down Expand Up @@ -526,3 +527,100 @@ def compute_O_J̇_WC_I(
)

return O_J̇_WC


@jax.jit
@js.common.named_scope
def link_contact_forces(
model: js.model.JaxSimModel,
data: js.data.JaxSimModelData,
*,
link_forces: jtp.MatrixLike | None = None,
joint_torques: jtp.VectorLike | None = None,
) -> tuple[jtp.Matrix, dict[str, jtp.Matrix]]:
"""
Compute the 6D contact forces of all links of the model in inertial representation.
Args:
model: The model to consider.
data: The data of the considered model.
link_forces:
The 6D external forces to apply to the links expressed in inertial representation
joint_torques:
The joint torques acting on the joints.
Returns:
A `(nL, 6)` array containing the stacked 6D contact forces of the links,
expressed in inertial representation.
"""

# Compute the contact forces for each collidable point with the active contact model.
W_f_C, aux_dict = model.contact_model.compute_contact_forces(
model=model,
data=data,
**(
dict(link_forces=link_forces, joint_force_references=joint_torques)
if not isinstance(model.contact_model, SoftContacts)
else {}
),
)

# Compute the 6D forces applied to the links equivalent to the forces applied
# to the frames associated to the collidable points.
W_f_L = link_forces_from_contact_forces(
model=model, data=data, contact_forces=W_f_C
)

return W_f_L, aux_dict


@staticmethod
def link_forces_from_contact_forces(
model: js.model.JaxSimModel,
data: js.data.JaxSimModelData,
*,
contact_forces: jtp.MatrixLike,
) -> jtp.Matrix:
"""
Compute the link forces from the contact forces.
Args:
model: The robot model considered by the contact model.
data: The data of the considered model.
contact_forces: The contact forces computed by the contact model.
Returns:
The 6D contact forces applied to the links and expressed in the frame of
the velocity representation of data.
"""

# Get the object storing the contact parameters of the model.
contact_parameters = model.kin_dyn_parameters.contact_parameters

# Extract the indices corresponding to the enabled collidable points.
indices_of_enabled_collidable_points = (
contact_parameters.indices_of_enabled_collidable_points
)

# Convert the contact forces to a JAX array.
W_f_C = jnp.atleast_2d(jnp.array(contact_forces, dtype=float).squeeze())

# Construct the vector defining the parent link index of each collidable point.
# We use this vector to sum the 6D forces of all collidable points rigidly
# attached to the same link.
parent_link_index_of_collidable_points = jnp.array(
contact_parameters.body, dtype=int
)[indices_of_enabled_collidable_points]

# Create the mask that associate each collidable point to their parent link.
# We use this mask to sum the collidable points to the right link.
mask = parent_link_index_of_collidable_points[:, jnp.newaxis] == jnp.arange(
model.number_of_links()
)

# Sum the forces of all collidable points rigidly attached to a body.
# Since the contact forces W_f_C are expressed in the world frame,
# we don't need any coordinate transformation.
W_f_L = mask.T @ W_f_C

return W_f_L
105 changes: 0 additions & 105 deletions src/jaxsim/api/contact_model.py

This file was deleted.

2 changes: 1 addition & 1 deletion src/jaxsim/api/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2057,7 +2057,7 @@ def step(

# Compute the 6D forces W_f ∈ ℝ^{n_L × 6} applied to links due to contact
# with the terrain.
W_f_L_terrain, aux_dict = js.contact_model.link_contact_forces(
W_f_L_terrain, aux_dict = js.contact.link_contact_forces(
model=model,
data=data,
link_forces=W_f_L_external,
Expand Down

0 comments on commit f584892

Please sign in to comment.