Skip to content

Commit

Permalink
Make ModelDataWithVelocityRepresentation compatible with batches
Browse files Browse the repository at this point in the history
  • Loading branch information
flferretti committed Feb 10, 2025
1 parent 9d0c0ab commit 8d7cf2a
Showing 1 changed file with 41 additions and 43 deletions.
84 changes: 41 additions & 43 deletions src/jaxsim/api/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,45 +121,45 @@ def inertial_to_other_representation(
The 6D quantity in the other representation.
"""

W_array = array.squeeze()
W_H_O = transform.squeeze()

if W_array.size != 6:
raise ValueError(W_array.size, 6)

if W_H_O.shape != (4, 4):
raise ValueError(W_H_O.shape, (4, 4))
W_array = array.reshape(-1, 6)
W_H_O = transform.reshape(-1, 4, 4)

match other_representation:

case VelRepr.Inertial:
return W_array
return W_array.reshape(array.shape[:-1] + (6,))

case VelRepr.Body:

if not is_force:
O_Xv_W = Adjoint.from_transform(transform=W_H_O, inverse=True)
O_array = O_Xv_W @ W_array
O_array = jnp.einsum("bij,bj->bi", O_Xv_W, W_array)

else:
O_Xf_W = Adjoint.from_transform(transform=W_H_O).T
O_array = O_Xf_W @ W_array
O_Xf_W = Adjoint.from_transform(transform=W_H_O)
O_array = jnp.einsum(
"bij,bj->bi", O_Xf_W.transpose(0, 2, 1), W_array
)

return O_array
return O_array.reshape(array.shape[:-1] + (6,))

case VelRepr.Mixed:
W_p_O = W_H_O[0:3, 3]
W_H_OW = jnp.eye(4).at[0:3, 3].set(W_p_O)
W_p_O = W_H_O[:, 0:3, 3]
W_H_OW = (
jnp.array([jnp.eye(4)] * W_H_O.shape[0]).at[:, 0:3, 3].set(W_p_O)
)

if not is_force:
OW_Xv_W = Adjoint.from_transform(transform=W_H_OW, inverse=True)
OW_array = OW_Xv_W @ W_array
OW_array = jnp.einsum("bij,bj->bi", OW_Xv_W, W_array)

else:
OW_Xf_W = Adjoint.from_transform(transform=W_H_OW).T
OW_array = OW_Xf_W @ W_array
OW_Xf_W = Adjoint.from_transform(transform=W_H_OW)
OW_array = jnp.einsum(
"bij,bj->bi", OW_Xf_W.transpose(0, 2, 1), W_array
)

return OW_array
return OW_array.reshape(array.shape[:-1] + (6,))

case _:
raise ValueError(other_representation)
Expand Down Expand Up @@ -188,47 +188,45 @@ def other_representation_to_inertial(
The 6D quantity in the inertial-fixed representation.
"""

W_array = array.squeeze()
W_H_O = transform.squeeze()

if W_array.size != 6:
raise ValueError(W_array.size, 6)

if W_H_O.shape != (4, 4):
raise ValueError(W_H_O.shape, (4, 4))
O_array = array.reshape(-1, 6)
W_H_O = transform.reshape(-1, 4, 4)

match other_representation:
case VelRepr.Inertial:
W_array = array
return W_array
return O_array.reshape(array.shape[:-1] + (6,))

case VelRepr.Body:
O_array = array

if not is_force:
W_Xv_O: jtp.Array = Adjoint.from_transform(W_H_O)
W_array = W_Xv_O @ O_array
W_Xv_O = Adjoint.from_transform(W_H_O)
W_array = jnp.einsum("bij,bj->bi", W_Xv_O, O_array)

else:
W_Xf_O = Adjoint.from_transform(transform=W_H_O, inverse=True).T
W_array = W_Xf_O @ O_array
W_Xf_O = Adjoint.from_transform(transform=W_H_O, inverse=True)
W_array = jnp.einsum(
"bij,bj->bi", W_Xf_O.transpose(0, 2, 1), O_array
)

return W_array
return W_array.reshape(array.shape[:-1] + (6,))

case VelRepr.Mixed:
BW_array = array
W_p_O = W_H_O[0:3, 3]
W_H_OW = jnp.eye(4).at[0:3, 3].set(W_p_O)

W_p_O = W_H_O[:, 0:3, 3]
W_H_OW = (
jnp.array([jnp.eye(4)] * W_H_O.shape[0]).at[:, 0:3, 3].set(W_p_O)
)

if not is_force:
W_Xv_BW: jtp.Array = Adjoint.from_transform(W_H_OW)
W_array = W_Xv_BW @ BW_array
W_Xv_BW = Adjoint.from_transform(W_H_OW)
W_array = jnp.einsum("bij,bj->bi", W_Xv_BW, O_array)

else:
W_Xf_BW = Adjoint.from_transform(transform=W_H_OW, inverse=True).T
W_array = W_Xf_BW @ BW_array
W_Xf_BW = Adjoint.from_transform(transform=W_H_OW, inverse=True)
W_array = jnp.einsum(
"bij,bj->bi", W_Xf_BW.transpose(0, 2, 1), O_array
)

return W_array
return W_array.reshape(array.shape[:-1] + (6,))

case _:
raise ValueError(other_representation)

0 comments on commit 8d7cf2a

Please sign in to comment.