Skip to content

Commit

Permalink
WIP replace batched
Browse files Browse the repository at this point in the history
  • Loading branch information
flferretti committed Feb 11, 2025
1 parent 8d7cf2a commit 31625af
Show file tree
Hide file tree
Showing 6 changed files with 88 additions and 69 deletions.
22 changes: 11 additions & 11 deletions examples/jaxsim_as_physics_engine_advanced.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,6 @@
},
"outputs": [],
"source": [
"\n",
"# Initialize the simulated time.\n",
"T = jnp.arange(start=0, stop=1.0, step=model.time_step)"
]
Expand Down Expand Up @@ -227,17 +226,18 @@
")(jnp.vstack(subkeys))\n",
"\n",
"# Reset the x and y position to a grid.\n",
"data_batch_t0 = data_batch_t0.reset_base_position(\n",
" jnp.array(\n",
"data_batch_t0 = data_batch_t0.replace(\n",
" model=model,\n",
" base_position=jnp.array(\n",
" [\n",
" jnp.linspace(-1, 1, batch_size),\n",
" jnp.linspace(-1, 1, batch_size),\n",
" data_batch_t0.base_position()[:, 2],\n",
" data_batch_t0.base_position[:, 2],\n",
" ]\n",
" ).T\n",
" ).T,\n",
")\n",
"\n",
"print(\"W_p_B(t0)=\\n\", data_batch_t0.base_position()[0:10])"
"print(\"W_p_B(t0)=\\n\", data_batch_t0.base_position[0:10])"
]
},
{
Expand Down Expand Up @@ -389,9 +389,9 @@
"\n",
" for helper, base_position, base_quaternion, joint_position in zip(\n",
" mj_model_helpers,\n",
" data_t.base_position(),\n",
" data.base_orientation(True),\n",
" data.joint_positions(),\n",
" data_t.base_position,\n",
" data.base_orientation,\n",
" data.joint_positions,\n",
" strict=True,\n",
" ):\n",
" helper.set_base_position(position=base_position)\n",
Expand Down Expand Up @@ -436,7 +436,7 @@
"toc_visible": true
},
"kernelspec": {
"display_name": "jaxsimsprint",
"display_name": "comodoGPU",
"language": "python",
"name": "python3"
},
Expand All @@ -450,7 +450,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.13.1"
"version": "3.12.8"
}
},
"nbformat": 4,
Expand Down
76 changes: 54 additions & 22 deletions src/jaxsim/api/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,7 +265,7 @@ def base_orientation(self) -> jtp.Matrix:
"""

# Extract the base quaternion.
W_Q_B = self.base_quaternion.squeeze()
W_Q_B = self.base_quaternion

# Always normalize the quaternion to avoid numerical issues.
# If the active scheme does not integrate the quaternion on its manifold,
Expand All @@ -285,11 +285,8 @@ def base_velocity(self) -> jtp.Vector:
The base 6D velocity in the active representation.
"""

W_v_WB = jnp.hstack(
[
self._base_linear_velocity,
self._base_angular_velocity,
]
W_v_WB = jnp.concatenate(
[self._base_linear_velocity, self._base_angular_velocity], axis=-1
)

W_H_B = self._base_transform
Expand Down Expand Up @@ -404,6 +401,12 @@ def replace(
"""
Replace the attributes of the `JaxSimModelData` object.
"""

# Extract the batch size.
batch_size = (
self._base_transform.shape[0] if self._base_transform.ndim > 2 else 1
)

if joint_positions is None:
joint_positions = self.joint_positions
if joint_velocities is None:
Expand All @@ -421,44 +424,73 @@ def replace(
base_transform = jaxsim.math.Transform.from_quaternion_and_translation(
translation=base_position, quaternion=base_quaternion
)
joint_transforms = model.kin_dyn_parameters.joint_transforms(
joint_positions=joint_positions, base_transform=base_transform

joint_transforms = jax.vmap(model.kin_dyn_parameters.joint_transforms)(
joint_positions=jnp.broadcast_to(
joint_positions, (batch_size, model.dofs())
),
base_transform=jnp.broadcast_to(base_transform, (batch_size, 4, 4)),
)

if base_linear_velocity is None and base_angular_velocity is None:
base_linear_velocity = self._base_linear_velocity
base_angular_velocity = self._base_angular_velocity
base_linear_velocity_inertial = self._base_linear_velocity
base_angular_velocity_inertial = self._base_angular_velocity
else:
if base_linear_velocity is None:
base_linear_velocity = self.base_velocity[:3]
if base_angular_velocity is None:
base_angular_velocity = self.base_velocity[3:]

base_linear_velocity = jnp.atleast_1d(base_linear_velocity.squeeze())
base_angular_velocity = jnp.atleast_1d(base_angular_velocity.squeeze())

W_v_WB = JaxSimModelData.other_representation_to_inertial(
array=jnp.hstack([base_linear_velocity, base_angular_velocity]),
other_representation=self.velocity_representation,
transform=base_transform,
is_force=False,
).astype(float)
base_linear_velocity, base_angular_velocity = W_v_WB[:3], W_v_WB[3:]

link_transforms, link_velocities = jaxsim.rbda.forward_kinematics_model(
model=model,
base_position=base_position,
base_quaternion=base_quaternion,
joint_positions=joint_positions,
joint_velocities=joint_velocities,
base_linear_velocity_inertial=base_linear_velocity,
base_angular_velocity_inertial=base_angular_velocity,

base_linear_velocity_inertial, base_angular_velocity_inertial = (
W_v_WB[..., :3],
W_v_WB[..., 3:],
)

link_transforms, link_velocities = jax.vmap(
jaxsim.rbda.forward_kinematics_model, in_axes=(None,)
)(
model,
base_position=jnp.broadcast_to(base_position, (batch_size, 3)),
base_quaternion=jnp.broadcast_to(base_quaternion, (batch_size, 4)),
joint_positions=jnp.broadcast_to(
joint_positions, (batch_size, model.dofs())
),
joint_velocities=jnp.broadcast_to(
joint_velocities, (batch_size, model.dofs())
),
base_linear_velocity_inertial=jnp.broadcast_to(
base_linear_velocity_inertial, (batch_size, 3)
),
base_angular_velocity_inertial=jnp.broadcast_to(
base_angular_velocity_inertial, (batch_size, 3)
),
)

# Adjust the output shapes.
if batch_size == 1:
link_transforms = link_transforms.reshape(
(batch_size,) + link_transforms.shape[2:]
)
link_velocities = link_velocities.reshape(
(batch_size,) + link_velocities.shape[2:]
)

return super().replace(
_joint_positions=joint_positions,
_joint_velocities=joint_velocities,
_base_quaternion=base_quaternion,
_base_linear_velocity=base_linear_velocity,
_base_angular_velocity=base_angular_velocity,
_base_linear_velocity=base_linear_velocity_inertial,
_base_angular_velocity=base_angular_velocity_inertial,
_base_position=base_position,
_base_transform=base_transform,
_joint_transforms=joint_transforms,
Expand Down
14 changes: 6 additions & 8 deletions src/jaxsim/api/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2019,14 +2019,12 @@ def step(
)

# Get the external forces in inertial-fixed representation.
W_f_L_external = jax.vmap(
lambda f_L, W_H_L: js.data.JaxSimModelData.other_representation_to_inertial(
f_L,
other_representation=data.velocity_representation,
transform=W_H_L,
is_force=True,
)
)(O_f_L_external, data._link_transforms)
W_f_L_external = js.data.JaxSimModelData.other_representation_to_inertial(
O_f_L_external,
other_representation=data.velocity_representation,
transform=data._link_transforms,
is_force=True,
)

τ_references = jnp.atleast_1d(
jnp.array(joint_force_references, dtype=float).squeeze()
Expand Down
25 changes: 9 additions & 16 deletions src/jaxsim/api/references.py
Original file line number Diff line number Diff line change
Expand Up @@ -434,24 +434,17 @@ def replace(forces: jtp.MatrixLike) -> JaxSimModelReferences:
if not_tracing(forces) and not data.valid(model=model):
raise ValueError("The provided data is not valid for the model")

# Helper function to convert a single 6D force to the inertial representation
# considering as body the link (i.e. L_f_L and LW_f_L).
def convert_using_link_frame(
f_L: jtp.MatrixLike, W_H_L: jtp.ArrayLike
) -> jtp.Matrix:

return jax.vmap(
lambda f_L, W_H_L: JaxSimModelReferences.other_representation_to_inertial(
array=f_L,
other_representation=self.velocity_representation,
transform=W_H_L,
is_force=True,
)
)(f_L, W_H_L)
W_H_L = data._link_transforms

# Convert a single 6D force to the inertial representation
# considering as body the link (i.e. L_f_L and LW_f_L).
# The f_L input is either L_f_L or LW_f_L, depending on the representation.
W_H_L = data._link_transforms
W_f_L = convert_using_link_frame(f_L=f_L, W_H_L=W_H_L[link_idxs, :, :])
W_f_L = JaxSimModelReferences.other_representation_to_inertial(
array=f_L,
other_representation=self.velocity_representation,
transform=W_H_L[link_idxs] if model.number_of_links() > 1 else W_H_L,
is_force=True,
)

return replace(forces=self._link_forces.at[link_idxs, :].set(W_f0_L + W_f_L))

Expand Down
4 changes: 2 additions & 2 deletions src/jaxsim/math/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,8 @@ def from_quaternion_and_translation(
W_Q_B = jnp.array(quaternion).astype(float)
W_p_B = jnp.array(translation).astype(float)

assert W_p_B.size == 3
assert W_Q_B.size == 4
assert W_p_B.shape[-1] == 3
assert W_Q_B.shape[-1] == 4

A_R_B = jaxlie.SO3(wxyz=W_Q_B)
A_R_B = A_R_B if not normalize_quaternion else A_R_B.normalize()
Expand Down
16 changes: 6 additions & 10 deletions src/jaxsim/rbda/contacts/relaxed_rigid.py
Original file line number Diff line number Diff line change
Expand Up @@ -460,16 +460,12 @@ def continuing_criterion(carry: OptimizationCarry) -> jtp.Bool:
CW_fl_C = solution.reshape(-1, 3)

# Convert the contact forces from mixed to inertial-fixed representation.
W_f_C = jax.vmap(
lambda CW_fl_C, W_H_C: (
ModelDataWithVelocityRepresentation.other_representation_to_inertial(
array=jnp.zeros(6).at[0:3].set(CW_fl_C),
transform=W_H_C,
other_representation=VelRepr.Mixed,
is_force=True,
)
),
)(CW_fl_C, W_H_C)
W_f_C = ModelDataWithVelocityRepresentation.other_representation_to_inertial(
array=jnp.zeros((W_H_C.shape[0], 6)).at[:, :3].set(CW_fl_C),
transform=W_H_C,
other_representation=VelRepr.Mixed,
is_force=True,
)

return W_f_C, {}

Expand Down

0 comments on commit 31625af

Please sign in to comment.