Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft: forward kinematics w.r.t. differentiable kinematic parameters #32

Draft
wants to merge 13 commits into
base: master
Choose a base branch
from
157 changes: 129 additions & 28 deletions src/pytorch_kinematics/chain.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
from functools import lru_cache
from typing import Optional, Sequence
from typing import Collection, Optional, Sequence, Tuple, Union

import numpy as np
import torch

import pytorch_kinematics.transforms as tf
from pytorch_kinematics import jacobian
from pytorch_kinematics.frame import Frame, Link, Joint
from pytorch_kinematics.transforms.parameterized_transform import ParameterizedTransform
from pytorch_kinematics.transforms.rotation_conversions import axis_and_angle_to_matrix_44, axis_and_d_to_pris_matrix


Expand Down Expand Up @@ -189,7 +190,7 @@ def _find_link_recursive(name, frame) -> Optional[Link]:
@staticmethod
def _get_joints(frame, exclude_fixed=True):
joints = []
if exclude_fixed and frame.joint.joint_type != "fixed":
if not exclude_fixed or frame.joint.joint_type != "fixed":
joints.append(frame.joint)
for child in frame.children:
joints.extend(Chain._get_joints(child))
Expand Down Expand Up @@ -293,49 +294,91 @@ def get_link_names(self):
def get_frame_indices(self, *frame_names):
return torch.tensor([self.frame_to_idx[n] for n in frame_names], dtype=torch.long, device=self.device)

def forward_kinematics(self, th, frame_indices: Optional = None):
def _get_jnt_transform(self, th) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Compute forward kinematics for the given joint values.
compute all joint transforms at once first in order to handle multiple joint types without branching,
we create all possible transforms for all joint types and then select the appropriate one for each joint.
Args:
th: The joint configuration to use

Returns: A tuple of revolute, prismatic joint transforms
"""
axes_expanded = self.axes.unsqueeze(0).repeat(th.shape[0], 1, 1).to(th)
return axis_and_angle_to_matrix_44(axes_expanded, th), axis_and_d_to_pris_matrix(axes_expanded, th)

def forward_kinematics(self,
th: Optional[torch.Tensor] = None,
joint_offsets: Optional[Union[torch.Tensor, tf.Transform3d]] = None,
link_offsets: Optional[Union[torch.Tensor, tf.Transform3d]] = None,
frame_indices: Optional = None):
"""
Compute forward kinematics for the given any combination of joint values, joint offsets, and link offsets.

Args:
th: A dict, list, numpy array, or torch tensor of joints values. Possibly batched.
joint_offsets: A Transform3d object or a tensor of shape (N, 4, 4) where N is the number of joints.
If provided, overrides the joint offsets in the chain.
link_offsets: A Transform3d object or a tensor of shape (N, 4, 4) where N is the number of joints.
If provided, overrides the link offsets in the chain.
frame_indices: A list of frame indices to compute transforms for. If None, all frames are computed.
Use `get_frame_indices` to convert from frame names to frame indices.

Returns:
A dict of Transform3d objects for each frame.

"""
def get_ith_transform(offset, i):
if isinstance(offset, torch.Tensor):
return offset[:, i, ...]
return offset[i]

if frame_indices is None:
frame_indices = self.get_all_frame_indices()

th = self.ensure_tensor(th)
th = torch.atleast_2d(th)
if isinstance(joint_offsets, tf.Transform3d):
joint_offsets = joint_offsets.get_matrix().view(-1, len(self.joint_offsets), 4, 4)
if isinstance(link_offsets, tf.Transform3d):
link_offsets = link_offsets.get_matrix().view(-1, len(self.link_offsets), 4, 4)

if th is joint_offsets is link_offsets is None:
raise ValueError("Must provide at least one of th, joint_offsets, or link_offsets.")
if th is not None:
th = self.ensure_tensor(th)
th = torch.atleast_2d(th)
b = th.shape[0]
to_this = th
elif joint_offsets is not None:
b = joint_offsets.shape[0]
to_this = joint_offsets
else:
b = link_offsets.shape[0]
to_this = link_offsets

b = th.shape[0]
axes_expanded = self.axes.unsqueeze(0).repeat(b, 1, 1)
# initialize default values
if th is None:
th = torch.zeros([b, self.n_joints]).to(to_this)
if joint_offsets is None:
joint_offsets = self.joint_offsets
if link_offsets is None:
link_offsets = self.link_offsets

# compute all joint transforms at once first
# in order to handle multiple joint types without branching, we create all possible transforms
# for all joint types and then select the appropriate one for each joint.
rev_jnt_transform = axis_and_angle_to_matrix_44(axes_expanded, th)
pris_jnt_transform = axis_and_d_to_pris_matrix(axes_expanded, th)
rev_jnt_transform, pris_jnt_transform = self._get_jnt_transform(th)

frame_transforms = {}
b = th.shape[0]
for frame_idx in frame_indices:
frame_transform = torch.eye(4).to(th).unsqueeze(0).repeat(b, 1, 1)
frame_transform = torch.eye(4).to(to_this).unsqueeze(0).repeat(b, 1, 1)

# iterate down the list and compose the transform
for chain_idx in self.parents_indices[frame_idx.item()]:
if chain_idx.item() in frame_transforms:
frame_transform = frame_transforms[chain_idx.item()]
else:
link_offset_i = self.link_offsets[chain_idx]
link_offset_i = get_ith_transform(link_offsets, chain_idx)
if link_offset_i is not None:
frame_transform = frame_transform @ link_offset_i

joint_offset_i = self.joint_offsets[chain_idx]
joint_offset_i = get_ith_transform(joint_offsets, chain_idx)

if joint_offset_i is not None:
frame_transform = frame_transform @ joint_offset_i

Expand Down Expand Up @@ -458,35 +501,93 @@ def _generate_serial_chain_recurse(root_frame, end_frame_name):
return [child] + frames
return None

@classmethod
def from_joint_transforms(cls,
transforms: tf.Transform3d,
link_offsets: Optional[tf.Transform3d] = None,
joint_names: Optional[Collection[str]] = None,
link_names: Optional[Collection[str]] = None,
joint_types: Optional[Collection[str]] = None,
**kwargs
):
"""
Create a serial chain with zero link offsets and joint offsets according to the input.

Assumes that frame 0 is the root frame that contains an empty link and a fixed "world" joint. Accordingly, frame
i is aligned with joint i, which moves. All joint axes are aligned with the z-axis of the joint frame.
Args:
transforms: A transform that represents a matrix of shape (N, 4, 4) where N is the number of joints.
link_offsets: A transform that represents a matrix of shape (N, 4, 4) where N is the number of joints.
Optional. If None, all link offsets are assumed to be zero.
joint_names: The names of the joints. If None, the joints are named "joint_0", "joint_1", etc.
link_names: The names of the links. If None, the links are named "link_0", "link_1", etc.
joint_types: The types of the joints. If None, the joints are assumed to be revolute.
"""
device = kwargs.get('device', transforms.device)
dtype = kwargs.get('dtype', transforms.dtype)

transforms = transforms.to(device=device, dtype=dtype)
joint_offsets = transforms.get_matrix()
assert len(joint_offsets.shape) == 3, "Expected a 3D matrix of shape (N, 4, 4)."
n = joint_offsets.shape[0]
if link_offsets is None:
link_offsets = [None] * n
else:
link_offsets = link_offsets.to(device=device, dtype=dtype)
if joint_names is None:
joint_names = [f"joint_{i+1}" for i in range(n)]
if link_names is None:
link_names = [f"link_{i+1}" for i in range(n)]
if joint_types is None:
joint_types = ['revolute' for _ in range(n)]
root_frame = Frame(name="world")
root_frame.link = Link(name="world")
root_frame.joint = Joint(name="world", joint_type="fixed")
children = []
for (i, link, joint, joint_type) in reversed(list(zip(range(n), link_names, joint_names, joint_types))):
frame = Frame(name=f"{link}")
frame.link = Link(name=link, offset=link_offsets[i])
frame.joint = Joint(name=joint, offset=transforms[i], joint_type=joint_type)
frame.children = children
children = [frame]
root_frame.children = children
return cls(Chain(root_frame, **kwargs), link_names[-1], root_frame_name="world")

def jacobian(self, th, locations=None):
if locations is not None:
locations = tf.Transform3d(pos=locations)
return jacobian.calc_jacobian(self, th, tool=locations)

def forward_kinematics(self, th, end_only: bool = True):
def forward_kinematics(self,
th: Optional[torch.Tensor] = None,
joint_offsets: Optional[torch.Tensor] = None,
link_offsets: Optional[torch.Tensor] = None,
end_only: bool = True):
""" Like the base class, except `th` only needs to contain the joints in the SerialChain, not all joints. """
frame_indices, th = self.convert_serial_inputs_to_chain_inputs(th, end_only)
if end_only:
frame_indices = self.get_frame_indices(self._serial_frames[-1].name)
else:
# pass through default behavior for frame indices being None, which is currently
# to return all frames.
frame_indices = None
if th is not None:
th = self.convert_serial_inputs_to_chain_inputs(th)

mat = super().forward_kinematics(th, frame_indices)
mat = super().forward_kinematics(th, joint_offsets=joint_offsets, link_offsets=link_offsets,
frame_indices=frame_indices)

if end_only:
return mat[self._serial_frames[-1].name]
else:
return mat

def convert_serial_inputs_to_chain_inputs(self, th, end_only: bool):
def convert_serial_inputs_to_chain_inputs(self, th: torch.Tensor):
# th = self.ensure_tensor(th)
th_b = get_batch_size(th)
th_n_joints = get_n_joints(th)
if isinstance(th, list):
th = torch.tensor(th, device=self.device, dtype=self.dtype)

if end_only:
frame_indices = self.get_frame_indices(self._serial_frames[-1].name)
else:
# pass through default behavior for frame indices being None, which is currently
# to return all frames.
frame_indices = None
if th_n_joints < self.n_joints:
# if th is only a partial list of joints, assume it's a list of joints for only the serial chain.
partial_th = th
Expand All @@ -504,4 +605,4 @@ def convert_serial_inputs_to_chain_inputs(self, th, end_only: bool):
jnt_idx = self.joint_indices[k]
if frame.joint.joint_type != 'fixed':
th[..., jnt_idx] = partial_th_i
return frame_indices, th
return th
55 changes: 55 additions & 0 deletions src/pytorch_kinematics/transforms/parameter_conversions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
#!/usr/bin/env python3
# Author: Jonathan Külz
# Date: 23.11.23
import torch


def mdh_to_homogeneous(mdh_parameters: torch.Tensor) -> torch.Tensor:
"""
Converts a set of MDH parameters to a homogeneous transformation matrix.

Follows Craig, Introduction to Robotics, 2005, p. 75.
:param mdh_parameters: The MDH parameters, ordered as alpha, a, d, theta.
:return: The homogeneous transformation matrix.
"""
alpha = mdh_parameters[..., 0]
a = mdh_parameters[..., 1]
d = mdh_parameters[..., 2]
theta = mdh_parameters[..., 3]

ct = torch.cos(theta)
st = torch.sin(theta)
ca = torch.cos(alpha)
sa = torch.sin(alpha)
zeros = torch.zeros_like(theta)
return torch.stack([
torch.stack([ct, -st, zeros, a], dim=-1),
torch.stack([st * ca, ct * ca, -sa, -d * sa], dim=-1),
torch.stack([st * sa, ct * sa, ca, d * ca], dim=-1),
torch.stack([zeros, zeros, zeros, torch.ones_like(theta)],
dim=-1)
], dim=-2)


def homogeneous_to_mdh(T: torch.Tensor) -> torch.Tensor:
"""
Converts a homogeneous transformation matrix to a set of MDH parameters.

Attention, this method is expensive due to an internal sanity check.
Follows Craig, Introduction to Robotics, 2005, p. 75.
:param T: The homogeneous transformation matrix.
:return: The MDH parameters.
"""
a = T[..., 0, 3]
theta = torch.atan2(-T[..., 0, 1], T[..., 0, 0])
alpha = torch.atan2(-T[..., 1, 2], T[..., 2, 2])
d = torch.empty_like(a)
use_cos = torch.isclose(torch.sin(alpha), torch.zeros_like(alpha))
d[~use_cos] = -T[~use_cos][:, 1, 3] / torch.sin(alpha[~use_cos])
d[use_cos] = T[use_cos][:, 2, 3] / torch.cos(alpha[use_cos])

parameters = torch.stack([alpha, a, d, theta], dim=-1)
if not torch.allclose(mdh_to_homogeneous(parameters), T, atol=1e-3):
raise ValueError('The given transformation is not MDH.')

return parameters
Loading