Skip to content

Commit

Permalink
update meta tensor api (#4131)
Browse files Browse the repository at this point in the history
* update meta tensor api

Signed-off-by: Wenqi Li <wenqil@nvidia.com>

* update based on comments

Signed-off-by: Wenqi Li <wenqil@nvidia.com>
  • Loading branch information
wyli authored Apr 14, 2022
1 parent ccac5ff commit 137164d
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 19 deletions.
2 changes: 1 addition & 1 deletion monai/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@
resolve_writer,
)
from .iterable_dataset import CSVIterableDataset, IterableDataset, ShuffleBuffer
from .meta_obj import get_track_meta, get_track_transforms, set_track_meta, set_track_transforms
from .meta_obj import MetaObj, get_track_meta, get_track_transforms, set_track_meta, set_track_transforms
from .meta_tensor import MetaTensor
from .nifti_saver import NiftiSaver
from .nifti_writer import write_nifti
Expand Down
3 changes: 2 additions & 1 deletion monai/data/meta_obj.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,8 @@ class MetaObj:
"""

_meta: dict
def __init__(self):
self._meta: dict = self.get_default_meta()

@staticmethod
def flatten_meta_objs(args: Sequence[Any]) -> list[MetaObj]:
Expand Down
31 changes: 15 additions & 16 deletions monai/data/meta_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,34 +62,33 @@ class MetaTensor(MetaObj, torch.Tensor):

@staticmethod
def __new__(cls, x, affine: torch.Tensor | None = None, meta: dict | None = None, *args, **kwargs) -> MetaTensor:
return torch.as_tensor(x, *args, **kwargs).as_subclass(cls) # type: ignore

def __init__(self, x, affine: torch.Tensor | None = None, meta: dict | None = None) -> None:
"""
If `meta` is given, use it. Else, if `meta` exists in the input tensor, use it.
Else, use the default value. Similar for the affine, except this could come from
four places.
Priority: `affine`, `meta["affine"]`, `x.affine`, `get_default_affine`.
"""
out: MetaTensor = torch.as_tensor(x, *args, **kwargs).as_subclass(cls) # type: ignore
super().__init__()
# set meta
if meta is not None:
out.meta = meta
self.meta = meta
elif isinstance(x, MetaObj):
out.meta = x.meta
else:
out.meta = out.get_default_meta()
self.meta = x.meta
# set the affine
if affine is not None:
if "affine" in out.meta:
warnings.warn("Setting affine, but the applied meta contains an affine. " "This will be overwritten.")
out.affine = affine
elif "affine" in out.meta:
if "affine" in self.meta:
warnings.warn("Setting affine, but the applied meta contains an affine. This will be overwritten.")
self.affine = affine
elif "affine" in self.meta:
pass # nothing to do
elif isinstance(x, MetaTensor):
out.affine = x.affine
self.affine = x.affine
else:
out.affine = out.get_default_affine()
out.affine = out.affine.to(out.device)

return out
self.affine = self.get_default_affine()
self.affine = self.affine.to(self.device)

def _copy_attr(self, attribute: str, input_objs: list[MetaObj], default_fn: Callable, deep_copy: bool) -> None:
super()._copy_attr(attribute, input_objs, default_fn, deep_copy)
Expand All @@ -113,8 +112,8 @@ def __torch_function__(cls, func, types, args=(), kwargs=None) -> torch.Tensor:
ret.affine = ret.affine.to(ret.device)
return ret

def get_default_affine(self) -> torch.Tensor:
return torch.eye(4, device=self.device)
def get_default_affine(self, dtype=torch.float64) -> torch.Tensor:
return torch.eye(4, device=self.device, dtype=dtype)

def as_tensor(self) -> torch.Tensor:
"""
Expand Down
2 changes: 1 addition & 1 deletion tests/test_meta_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ class TestMetaTensor(unittest.TestCase):
@staticmethod
def get_im(shape=None, dtype=None, device=None):
if shape is None:
shape = shape = (1, 10, 8)
shape = (1, 10, 8)
affine = torch.randint(0, 10, (4, 4))
meta = {"fname": rand_string()}
t = torch.rand(shape)
Expand Down

0 comments on commit 137164d

Please sign in to comment.