Skip to content

Commit

Permalink
Dev: minor fix
Browse files Browse the repository at this point in the history
  • Loading branch information
skaliy committed Jul 10, 2023
1 parent 565734c commit 5cb8e17
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 75 deletions.
14 changes: 7 additions & 7 deletions fastMONAI/vision_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,13 +76,13 @@ def _multi_channel(image_paths: list, reorder: bool, resample: list, dtype, only
tensor = torch.stack([img.data[0] for _, img, _ in image_data], dim=0)

if only_tensor:
dtype(tensor)
return dtype(tensor)

input_img.set_data(tensor)
return org_img, input_img, org_size


# %% ../nbs/01_vision_core.ipynb 9
# %% ../nbs/01_vision_core.ipynb 8
def med_img_reader(
file_path: (str, Path),
dtype=torch.Tensor,
Expand Down Expand Up @@ -118,15 +118,15 @@ def med_img_reader(

return org_img, input_img, org_size

# %% ../nbs/01_vision_core.ipynb 12
# %% ../nbs/01_vision_core.ipynb 10
class MetaResolver(type(torch.Tensor), metaclass=BypassNewMeta):
'''A class to bypass metaclass conflict:
https://pytorch-geometric.readthedocs.io/en/latest/_modules/torch_geometric/data/batch.html
'''
pass

# %% ../nbs/01_vision_core.ipynb 13
class MedBase(torch.Tensor, metaclass=MetaResolver):
# %% ../nbs/01_vision_core.ipynb 11
class MedBase(torch.Tensor, metaclass=MetaResolver):
'''A class that represents an image object. Metaclass casts x to this class if it is of type cls._bypass_type.'''

_bypass_type=torch.Tensor
Expand Down Expand Up @@ -180,12 +180,12 @@ def show(self, ctx=None, channel=0, indices=None, anatomical_plane=0, **kwargs):
def __repr__(self):
return f'{self.__class__.__name__} mode={self.mode} size={"x".join([str(d) for d in self.size])}'

# %% ../nbs/01_vision_core.ipynb 14
# %% ../nbs/01_vision_core.ipynb 12
class MedImage(MedBase):
'''Subclass of MedBase that represents an image object.'''
pass

# %% ../nbs/01_vision_core.ipynb 15
# %% ../nbs/01_vision_core.ipynb 13
class MedMask(MedBase):
'''Subclass of MedBase that represents an mask object.'''
_show_args = {'alpha':0.5, 'cmap':'tab20'}
70 changes: 2 additions & 68 deletions nbs/01_vision_core.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -134,36 +134,12 @@
" tensor = torch.stack([img.data[0] for _, img, _ in image_data], dim=0)\n",
" \n",
" if only_tensor: \n",
" dtype(tensor) \n",
" return dtype(tensor) \n",
"\n",
" input_img.set_data(tensor)\n",
" return org_img, input_img, org_size\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# def med_img_reader(file_path:(str, Path), # Image path\n",
"# dtype=torch.Tensor, # Datatype (MedImage, MedMask, torch.Tensor)\n",
"# reorder:bool=False, # Whether to reorder the data to be closest to canonical (RAS+) orientation.\n",
"# resample:list=None, # Whether to resample image to different voxel sizes and image dimensions.\n",
"# only_tensor:bool=True # Whether to return only image tensor\n",
"# ):\n",
"# '''Load and preprocess medical image'''\n",
" \n",
"# if isinstance(file_path, str) and ';' in file_path:\n",
"# return _multi_channel(file_path.split(';'), reorder, resample, dtype, only_tensor)\n",
"\n",
"# org_img, input_img, org_size = _load_and_preprocess(file_path, reorder, resample, dtype)\n",
"\n",
"# if only_tensor: return dtype(input_img.data.type(torch.float)) \n",
" \n",
"# return org_img, input_img, org_size"
]
},
{
"cell_type": "code",
"execution_count": null,
Expand Down Expand Up @@ -207,48 +183,6 @@
" return org_img, input_img, org_size"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def read_medical_image(\n",
" file_path: (str, Path),\n",
" dtype=torch.Tensor,\n",
" reorder: bool = False,\n",
" resample: list = None,\n",
" only_tensor: bool = True\n",
"):\n",
" \"\"\"Loads and preprocesses a medical image.\n",
"\n",
" Args:\n",
" file_path: Path to the image. Can be a string or a Path object.\n",
" dtype: Datatype for the return value. Defaults to torch.Tensor.\n",
" reorder: Whether to reorder the data to be closest to canonical \n",
" (RAS+) orientation. Defaults to False.\n",
" resample: Whether to resample image to different voxel sizes and \n",
" image dimensions. Defaults to None.\n",
" only_tensor: Whether to return only image tensor. Defaults to True.\n",
"\n",
" Returns:\n",
" The preprocessed image. Returns only the image tensor if \n",
" only_tensor is True, otherwise returns original image, \n",
" preprocessed image, and original size.\n",
" \"\"\"\n",
" if isinstance(file_path, str) and ';' in file_path:\n",
" return _multi_channel(\n",
" file_path.split(';'), reorder, resample, dtype, only_tensor)\n",
"\n",
" org_img, input_img, org_size = _load_and_preprocess(\n",
" file_path, reorder, resample, dtype)\n",
"\n",
" if only_tensor:\n",
" return dtype(input_img.data.type(torch.float))\n",
"\n",
" return org_img, input_img, org_size"
]
},
{
"cell_type": "markdown",
"metadata": {},
Expand Down Expand Up @@ -277,7 +211,7 @@
"outputs": [],
"source": [
"#| export\n",
"class MedBase(torch.Tensor, metaclass=MetaResolver):\n",
"class MedBase(torch.Tensor, metaclass=MetaResolver): \n",
" '''A class that represents an image object. Metaclass casts x to this class if it is of type cls._bypass_type.'''\n",
"\n",
" _bypass_type=torch.Tensor\n",
Expand Down

0 comments on commit 5cb8e17

Please sign in to comment.