Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow ApplyTransformToPointsd receive a sequence of refer keys #8063

Merged
merged 14 commits into from
Sep 4, 2024
29 changes: 16 additions & 13 deletions monai/transforms/utility/dictionary.py
Original file line number Diff line number Diff line change
Expand Up @@ -1758,8 +1758,9 @@ class ApplyTransformToPointsd(MapTransform, InvertibleTransform):
Args:
keys: keys of the corresponding items to be transformed.
See also: monai.transforms.MapTransform
refer_key: The key of the reference item used for transformation.
It can directly refer to an affine or an image from which the affine can be derived.
refer_keys: The key of the reference item used for transformation.
It can directly refer to an affine or an image from which the affine can be derived. It can also be a
sequence of keys.
KumoLiu marked this conversation as resolved.
Show resolved Hide resolved
dtype: The desired data type for the output.
affine: A 3x3 or 4x4 affine transformation matrix applied to points. This matrix typically originates
from the image. For 2D points, a 3x3 matrix can be provided, avoiding the need to add an unnecessary
Expand All @@ -1782,31 +1783,33 @@ class ApplyTransformToPointsd(MapTransform, InvertibleTransform):
def __init__(
self,
keys: KeysCollection,
refer_key: str | None = None,
refer_keys: KeysCollection | None = None,
dtype: DtypeLike | torch.dtype = torch.float64,
affine: torch.Tensor | None = None,
invert_affine: bool = True,
affine_lps_to_ras: bool = False,
allow_missing_keys: bool = False,
):
MapTransform.__init__(self, keys, allow_missing_keys)
self.refer_key = refer_key
self.refer_keys = ensure_tuple_rep(None, len(self.keys)) if refer_keys is None else ensure_tuple(refer_keys)
if len(self.keys) != len(self.refer_keys):
raise ValueError("refer_keys should have the same length as keys.")
self.converter = ApplyTransformToPoints(
dtype=dtype, affine=affine, invert_affine=invert_affine, affine_lps_to_ras=affine_lps_to_ras
)

def __call__(self, data: Mapping[Hashable, torch.Tensor]):
d = dict(data)
if self.refer_key is not None:
if self.refer_key in d:
refer_data = d[self.refer_key]
else:
raise KeyError(f"The refer_key '{self.refer_key}' is not found in the data.")
else:
refer_data = None
affine = getattr(refer_data, "affine", refer_data)
for key in self.key_iterator(d):
for key, refer_key in self.key_iterator(d, self.refer_keys):
coords = d[key]
if refer_key is not None:
if refer_key in d:
refer_data = d[refer_key]
else:
raise KeyError(f"The refer_key '{refer_key}' is not found in the data.")
else:
refer_data = None
affine = getattr(refer_data, "affine", refer_data)
KumoLiu marked this conversation as resolved.
Show resolved Hide resolved
d[key] = self.converter(coords, affine)
return d

Expand Down
4 changes: 2 additions & 2 deletions tests/test_apply_transform_to_pointsd.py
KumoLiu marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
Expand Up @@ -107,10 +107,10 @@ def test_transform_coordinates(self, image, points, affine, invert_affine, affin
"point": points,
"affine": torch.tensor([[1, 0, 0, 10], [0, 1, 0, -4], [0, 0, 1, 0], [0, 0, 0, 1]]),
}
refer_key = "image" if (image is not None and image != "affine") else image
refer_keys = "image" if (image is not None and image != "affine") else image
transform = ApplyTransformToPointsd(
keys="point",
refer_key=refer_key,
refer_keys=refer_keys,
dtype=torch.int64,
affine=affine,
invert_affine=invert_affine,
Expand Down
Loading