Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow ApplyTransformToPointsd receive a sequence of refer keys #8063

Merged
merged 14 commits into from
Sep 4, 2024
64 changes: 37 additions & 27 deletions monai/transforms/utility/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -1764,6 +1764,30 @@ def __init__(
self.invert_affine = invert_affine
self.affine_lps_to_ras = affine_lps_to_ras

def _compute_final_affine(self, affine: torch.Tensor, applied_affine: torch.Tensor | None = None) -> torch.Tensor:
"""
Compute the final affine transformation matrix to apply to the point data.

Args:
data: Input coordinates assumed to be in the shape (C, N, 2 or 3).
affine: 3x3 or 4x4 affine transformation matrix.

Returns:
Final affine transformation matrix.
"""

affine = convert_data_type(affine, dtype=torch.float64)[0]

if self.affine_lps_to_ras:
affine = orientation_ras_lps(affine)

if self.invert_affine:
affine = linalg_inv(affine)
if applied_affine is not None:
affine = affine @ applied_affine

return affine

def transform_coordinates(
self, data: torch.Tensor, affine: torch.Tensor | None = None
) -> tuple[torch.Tensor, dict]:
Expand All @@ -1780,35 +1804,25 @@ def transform_coordinates(
Transformed coordinates.
"""
data = convert_to_tensor(data, track_meta=get_track_meta())
# applied_affine is the affine transformation matrix that has already been applied to the point data
applied_affine = getattr(data, "affine", None)

if affine is None and self.invert_affine:
raise ValueError("affine must be provided when invert_affine is True.")

# applied_affine is the affine transformation matrix that has already been applied to the point data
applied_affine: torch.Tensor | None = getattr(data, "affine", None)
affine = applied_affine if affine is None else affine
affine = convert_data_type(affine, dtype=torch.float64)[0] # always convert to float64 for affine
original_affine: torch.Tensor = affine
if self.affine_lps_to_ras:
affine = orientation_ras_lps(affine)
if affine is None:
raise ValueError("affine must be provided if data does not have an affine matrix.")

# the final affine transformation matrix that will be applied to the point data
_affine: torch.Tensor = affine
if self.invert_affine:
_affine = linalg_inv(affine)
if applied_affine is not None:
# consider the affine transformation already applied to the data in the world space
# and compute delta affine
_affine = _affine @ linalg_inv(applied_affine)
out = apply_affine_to_points(data, _affine, dtype=self.dtype)
final_affine = self._compute_final_affine(affine, applied_affine)
out = apply_affine_to_points(data, final_affine, dtype=self.dtype)

extra_info = {
"invert_affine": self.invert_affine,
"dtype": get_dtype_string(self.dtype),
"image_affine": original_affine, # record for inverse operation
"image_affine": affine,
"affine_lps_to_ras": self.affine_lps_to_ras,
}
xform: torch.Tensor = original_affine if self.invert_affine else linalg_inv(original_affine)

xform = orientation_ras_lps(linalg_inv(final_affine)) if self.affine_lps_to_ras else linalg_inv(final_affine)
meta_info = TraceableTransform.track_transform_meta(
data, affine=xform, extra_info=extra_info, transform_info=self.get_transform_info()
)
Expand All @@ -1834,16 +1848,12 @@ def __call__(self, data: torch.Tensor, affine: torch.Tensor | None = None):

def inverse(self, data: torch.Tensor) -> torch.Tensor:
transform = self.pop_transform(data)
# Create inverse transform
dtype = transform[TraceKeys.EXTRA_INFO]["dtype"]
invert_affine = not transform[TraceKeys.EXTRA_INFO]["invert_affine"]
affine = transform[TraceKeys.EXTRA_INFO]["image_affine"]
affine_lps_to_ras = transform[TraceKeys.EXTRA_INFO]["affine_lps_to_ras"]
inverse_transform = ApplyTransformToPoints(
dtype=dtype, invert_affine=invert_affine, affine_lps_to_ras=affine_lps_to_ras
dtype=transform[TraceKeys.EXTRA_INFO]["dtype"],
invert_affine=not transform[TraceKeys.EXTRA_INFO]["invert_affine"],
affine_lps_to_ras=transform[TraceKeys.EXTRA_INFO]["affine_lps_to_ras"],
)
# Apply inverse
with inverse_transform.trace_transform(False):
data = inverse_transform(data, affine)
data = inverse_transform(data, transform[TraceKeys.EXTRA_INFO]["image_affine"])

return data
28 changes: 15 additions & 13 deletions monai/transforms/utility/dictionary.py
Original file line number Diff line number Diff line change
Expand Up @@ -1758,8 +1758,9 @@ class ApplyTransformToPointsd(MapTransform, InvertibleTransform):
Args:
keys: keys of the corresponding items to be transformed.
See also: monai.transforms.MapTransform
refer_key: The key of the reference item used for transformation.
It can directly refer to an affine or an image from which the affine can be derived.
refer_keys: The key of the reference item used for transformation.
It can directly refer to an affine or an image from which the affine can be derived. It can also be a
sequence of keys, in which case each refers to the affine applied to the matching points in `keys`.
vikashg marked this conversation as resolved.
Show resolved Hide resolved
dtype: The desired data type for the output.
affine: A 3x3 or 4x4 affine transformation matrix applied to points. This matrix typically originates
from the image. For 2D points, a 3x3 matrix can be provided, avoiding the need to add an unnecessary
Expand All @@ -1782,31 +1783,32 @@ class ApplyTransformToPointsd(MapTransform, InvertibleTransform):
def __init__(
self,
keys: KeysCollection,
refer_key: str | None = None,
refer_keys: KeysCollection | None = None,
dtype: DtypeLike | torch.dtype = torch.float64,
affine: torch.Tensor | None = None,
invert_affine: bool = True,
affine_lps_to_ras: bool = False,
allow_missing_keys: bool = False,
):
MapTransform.__init__(self, keys, allow_missing_keys)
self.refer_key = refer_key
self.refer_keys = ensure_tuple_rep(refer_keys, len(self.keys))
self.converter = ApplyTransformToPoints(
dtype=dtype, affine=affine, invert_affine=invert_affine, affine_lps_to_ras=affine_lps_to_ras
)

def __call__(self, data: Mapping[Hashable, torch.Tensor]):
d = dict(data)
if self.refer_key is not None:
if self.refer_key in d:
refer_data = d[self.refer_key]
else:
raise KeyError(f"The refer_key '{self.refer_key}' is not found in the data.")
else:
refer_data = None
affine = getattr(refer_data, "affine", refer_data)
for key in self.key_iterator(d):
for key, refer_key in self.key_iterator(d, self.refer_keys):
coords = d[key]
affine = None # represents using affine given in constructor
if refer_key is not None:
if refer_key in d:
refer_data = d[refer_key]
else:
raise KeyError(f"The refer_key '{refer_key}' is not found in the data.")

# use the "affine" member of refer_data, or refer_data itself, as the affine matrix
affine = getattr(refer_data, "affine", refer_data)
d[key] = self.converter(coords, affine)
return d

Expand Down
136 changes: 94 additions & 42 deletions tests/test_apply_transform_to_pointsd.py
KumoLiu marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
Expand Up @@ -30,72 +30,90 @@
POINT_3D_WORLD = torch.tensor([[[2, 4, 6], [8, 10, 12]], [[14, 16, 18], [20, 22, 24]]])
POINT_3D_IMAGE = torch.tensor([[[-8, 8, 6], [-2, 14, 12]], [[4, 20, 18], [10, 26, 24]]])
POINT_3D_IMAGE_RAS = torch.tensor([[[-12, 0, 6], [-18, -6, 12]], [[-24, -12, 18], [-30, -18, 24]]])
AFFINE_1 = torch.tensor([[2, 0, 0, 0], [0, 2, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1]])
AFFINE_2 = torch.tensor([[1, 0, 0, 10], [0, 1, 0, -4], [0, 0, 1, 0], [0, 0, 0, 1]])

TEST_CASES = [
[MetaTensor(DATA_2D, affine=AFFINE_1), POINT_2D_WORLD, None, True, False, POINT_2D_IMAGE], # use image affine
[None, MetaTensor(POINT_2D_IMAGE, affine=AFFINE_1), None, False, False, POINT_2D_WORLD], # use point affine
[None, MetaTensor(POINT_2D_IMAGE, affine=AFFINE_1), AFFINE_1, False, False, POINT_2D_WORLD], # use input affine
[None, POINT_2D_WORLD, AFFINE_1, True, False, POINT_2D_IMAGE], # use input affine
[
MetaTensor(DATA_2D, affine=torch.tensor([[2, 0, 0, 0], [0, 2, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1]])),
MetaTensor(DATA_2D, affine=AFFINE_1),
POINT_2D_WORLD,
None,
True,
False,
POINT_2D_IMAGE,
],
True,
POINT_2D_IMAGE_RAS,
], # test affine_lps_to_ras
[MetaTensor(DATA_3D, affine=AFFINE_2), POINT_3D_WORLD, None, True, False, POINT_3D_IMAGE],
["affine", POINT_3D_WORLD, None, True, False, POINT_3D_IMAGE], # use refer_data itself
[
None,
MetaTensor(POINT_2D_IMAGE, affine=torch.tensor([[2, 0, 0, 0], [0, 2, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1]])),
MetaTensor(DATA_3D, affine=AFFINE_2),
MetaTensor(POINT_3D_IMAGE, affine=AFFINE_2),
None,
False,
False,
POINT_2D_WORLD,
POINT_3D_WORLD,
],
[MetaTensor(DATA_3D, affine=AFFINE_2), POINT_3D_WORLD, None, True, True, POINT_3D_IMAGE_RAS],
[MetaTensor(DATA_3D, affine=AFFINE_2), POINT_3D_WORLD, None, True, True, POINT_3D_IMAGE_RAS],
]
TEST_CASES_SEQUENCE = [
ericspod marked this conversation as resolved.
Show resolved Hide resolved
[
(MetaTensor(DATA_2D, affine=AFFINE_1), MetaTensor(DATA_3D, affine=AFFINE_2)),
[POINT_2D_WORLD, POINT_3D_WORLD],
None,
MetaTensor(POINT_2D_IMAGE, affine=torch.tensor([[2, 0, 0, 0], [0, 2, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1]])),
torch.tensor([[2, 0, 0, 0], [0, 2, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1]]),
False,
True,
False,
POINT_2D_WORLD,
],
["image_1", "image_2"],
[POINT_2D_IMAGE, POINT_3D_IMAGE],
], # use image affine
[
MetaTensor(DATA_2D, affine=torch.tensor([[2, 0, 0, 0], [0, 2, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1]])),
POINT_2D_WORLD,
(MetaTensor(DATA_2D, affine=AFFINE_1), MetaTensor(DATA_3D, affine=AFFINE_2)),
[POINT_2D_WORLD, POINT_3D_WORLD],
None,
True,
True,
POINT_2D_IMAGE_RAS,
],
["image_1", "image_2"],
[POINT_2D_IMAGE_RAS, POINT_3D_IMAGE_RAS],
], # test affine_lps_to_ras
[
MetaTensor(DATA_3D, affine=torch.tensor([[1, 0, 0, 10], [0, 1, 0, -4], [0, 0, 1, 0], [0, 0, 0, 1]])),
POINT_3D_WORLD,
(None, None),
[MetaTensor(POINT_2D_IMAGE, affine=AFFINE_1), MetaTensor(POINT_3D_IMAGE, affine=AFFINE_2)],
None,
False,
False,
None,
[POINT_2D_WORLD, POINT_3D_WORLD],
], # use point affine
[
(None, None),
[POINT_2D_WORLD, POINT_2D_WORLD],
AFFINE_1,
True,
False,
POINT_3D_IMAGE,
],
["affine", POINT_3D_WORLD, None, True, False, POINT_3D_IMAGE],
None,
[POINT_2D_IMAGE, POINT_2D_IMAGE],
], # use input affine
[
MetaTensor(DATA_3D, affine=torch.tensor([[1, 0, 0, 10], [0, 1, 0, -4], [0, 0, 1, 0], [0, 0, 0, 1]])),
MetaTensor(POINT_3D_IMAGE, affine=torch.tensor([[1, 0, 0, 10], [0, 1, 0, -4], [0, 0, 1, 0], [0, 0, 0, 1]])),
(MetaTensor(DATA_2D, affine=AFFINE_1), MetaTensor(DATA_3D, affine=AFFINE_2)),
[MetaTensor(POINT_2D_IMAGE, affine=AFFINE_1), MetaTensor(POINT_3D_IMAGE, affine=AFFINE_2)],
None,
False,
False,
POINT_3D_WORLD,
],
[
MetaTensor(DATA_3D, affine=torch.tensor([[1, 0, 0, 10], [0, 1, 0, -4], [0, 0, 1, 0], [0, 0, 0, 1]])),
POINT_3D_WORLD,
None,
True,
True,
POINT_3D_IMAGE_RAS,
["image_1", "image_2"],
[POINT_2D_WORLD, POINT_3D_WORLD],
],
]

TEST_CASES_WRONG = [
[POINT_2D_WORLD, True, None],
[POINT_2D_WORLD.unsqueeze(0), False, None],
[POINT_3D_WORLD[..., 0:1], False, None],
[POINT_3D_WORLD, False, torch.tensor([[[1, 0, 0, 10], [0, 1, 0, -4], [0, 0, 1, 0], [0, 0, 0, 1]]])],
[POINT_2D_WORLD, True, None, None],
[POINT_2D_WORLD.unsqueeze(0), False, None, None],
[POINT_3D_WORLD[..., 0:1], False, None, None],
[POINT_3D_WORLD, False, torch.tensor([[[1, 0, 0, 10], [0, 1, 0, -4], [0, 0, 1, 0], [0, 0, 0, 1]]]), None],
[POINT_3D_WORLD, False, None, "image"],
[POINT_3D_WORLD, False, None, []],
]


Expand All @@ -107,10 +125,10 @@ def test_transform_coordinates(self, image, points, affine, invert_affine, affin
"point": points,
"affine": torch.tensor([[1, 0, 0, 10], [0, 1, 0, -4], [0, 0, 1, 0], [0, 0, 0, 1]]),
}
refer_key = "image" if (image is not None and image != "affine") else image
refer_keys = "image" if (image is not None and image != "affine") else image
transform = ApplyTransformToPointsd(
keys="point",
refer_key=refer_key,
refer_keys=refer_keys,
dtype=torch.int64,
affine=affine,
invert_affine=invert_affine,
Expand All @@ -122,11 +140,45 @@ def test_transform_coordinates(self, image, points, affine, invert_affine, affin
invert_out = transform.inverse(output)
self.assertTrue(torch.allclose(invert_out["point"], points))

@parameterized.expand(TEST_CASES_SEQUENCE)
def test_transform_coordinates_sequences(
self, image, points, affine, invert_affine, affine_lps_to_ras, refer_keys, expected_output
):
data = {"image_1": image[0], "image_2": image[1], "point_1": points[0], "point_2": points[1]}
keys = ["point_1", "point_2"]
transform = ApplyTransformToPointsd(
keys=keys,
refer_keys=refer_keys,
dtype=torch.int64,
affine=affine,
invert_affine=invert_affine,
affine_lps_to_ras=affine_lps_to_ras,
)
output = transform(data)

self.assertTrue(torch.allclose(output["point_1"], expected_output[0]))
self.assertTrue(torch.allclose(output["point_2"], expected_output[1]))
invert_out = transform.inverse(output)
self.assertTrue(torch.allclose(invert_out["point_1"], points[0]))

@parameterized.expand(TEST_CASES_WRONG)
def test_wrong_input(self, input, invert_affine, affine):
transform = ApplyTransformToPointsd(keys="point", dtype=torch.int64, invert_affine=invert_affine, affine=affine)
with self.assertRaises(ValueError):
transform({"point": input})
def test_wrong_input(self, input, invert_affine, affine, refer_keys):
if refer_keys == []:
with self.assertRaises(ValueError):
ApplyTransformToPointsd(
keys="point", dtype=torch.int64, invert_affine=invert_affine, affine=affine, refer_keys=refer_keys
)
else:
transform = ApplyTransformToPointsd(
keys="point", dtype=torch.int64, invert_affine=invert_affine, affine=affine, refer_keys=refer_keys
)
data = {"point": input}
if refer_keys == "image":
with self.assertRaises(KeyError):
transform(data)
else:
with self.assertRaises(ValueError):
transform(data)


if __name__ == "__main__":
Expand Down
Loading