diff --git a/fastMONAI/vision_core.py b/fastMONAI/vision_core.py index eb9a65b..99fa1a2 100644 --- a/fastMONAI/vision_core.py +++ b/fastMONAI/vision_core.py @@ -76,13 +76,13 @@ def _multi_channel(image_paths: list, reorder: bool, resample: list, dtype, only tensor = torch.stack([img.data[0] for _, img, _ in image_data], dim=0) if only_tensor: - dtype(tensor) + return dtype(tensor) input_img.set_data(tensor) return org_img, input_img, org_size -# %% ../nbs/01_vision_core.ipynb 9 +# %% ../nbs/01_vision_core.ipynb 8 def med_img_reader( file_path: (str, Path), dtype=torch.Tensor, @@ -118,15 +118,15 @@ def med_img_reader( return org_img, input_img, org_size -# %% ../nbs/01_vision_core.ipynb 12 +# %% ../nbs/01_vision_core.ipynb 10 class MetaResolver(type(torch.Tensor), metaclass=BypassNewMeta): '''A class to bypass metaclass conflict: https://pytorch-geometric.readthedocs.io/en/latest/_modules/torch_geometric/data/batch.html ''' pass -# %% ../nbs/01_vision_core.ipynb 13 -class MedBase(torch.Tensor, metaclass=MetaResolver): +# %% ../nbs/01_vision_core.ipynb 11 +class MedBase(torch.Tensor, metaclass=MetaResolver): '''A class that represents an image object. Metaclass casts x to this class if it is of type cls._bypass_type.''' _bypass_type=torch.Tensor @@ -180,12 +180,12 @@ def show(self, ctx=None, channel=0, indices=None, anatomical_plane=0, **kwargs): def __repr__(self): return f'{self.__class__.__name__} mode={self.mode} size={"x".join([str(d) for d in self.size])}' -# %% ../nbs/01_vision_core.ipynb 14 +# %% ../nbs/01_vision_core.ipynb 12 class MedImage(MedBase): '''Subclass of MedBase that represents an image object.''' pass -# %% ../nbs/01_vision_core.ipynb 15 +# %% ../nbs/01_vision_core.ipynb 13 class MedMask(MedBase): '''Subclass of MedBase that represents an mask object.''' _show_args = {'alpha':0.5, 'cmap':'tab20'} diff --git a/nbs/01_vision_core.ipynb b/nbs/01_vision_core.ipynb index fd14e75..9b9ff91 100644 --- a/nbs/01_vision_core.ipynb +++ b/nbs/01_vision_core.ipynb @@ -134,36 +134,12 @@ " tensor = torch.stack([img.data[0] for _, img, _ in image_data], dim=0)\n", " \n", " if only_tensor: \n", - " dtype(tensor) \n", + " return dtype(tensor) \n", "\n", " input_img.set_data(tensor)\n", " return org_img, input_img, org_size\n" ] }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# def med_img_reader(file_path:(str, Path), # Image path\n", - "# dtype=torch.Tensor, # Datatype (MedImage, MedMask, torch.Tensor)\n", - "# reorder:bool=False, # Whether to reorder the data to be closest to canonical (RAS+) orientation.\n", - "# resample:list=None, # Whether to resample image to different voxel sizes and image dimensions.\n", - "# only_tensor:bool=True # Whether to return only image tensor\n", - "# ):\n", - "# '''Load and preprocess medical image'''\n", - " \n", - "# if isinstance(file_path, str) and ';' in file_path:\n", - "# return _multi_channel(file_path.split(';'), reorder, resample, dtype, only_tensor)\n", - "\n", - "# org_img, input_img, org_size = _load_and_preprocess(file_path, reorder, resample, dtype)\n", - "\n", - "# if only_tensor: return dtype(input_img.data.type(torch.float)) \n", - " \n", - "# return org_img, input_img, org_size" - ] - }, { "cell_type": "code", "execution_count": null, @@ -207,48 +183,6 @@ " return org_img, input_img, org_size" ] }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "def read_medical_image(\n", - " file_path: (str, Path),\n", - " dtype=torch.Tensor,\n", - " reorder: bool = False,\n", - " resample: list = None,\n", - " only_tensor: bool = True\n", - "):\n", - " \"\"\"Loads and preprocesses a medical image.\n", - "\n", - " Args:\n", - " file_path: Path to the image. Can be a string or a Path object.\n", - " dtype: Datatype for the return value. Defaults to torch.Tensor.\n", - " reorder: Whether to reorder the data to be closest to canonical \n", - " (RAS+) orientation. Defaults to False.\n", - " resample: Whether to resample image to different voxel sizes and \n", - " image dimensions. Defaults to None.\n", - " only_tensor: Whether to return only image tensor. Defaults to True.\n", - "\n", - " Returns:\n", - " The preprocessed image. Returns only the image tensor if \n", - " only_tensor is True, otherwise returns original image, \n", - " preprocessed image, and original size.\n", - " \"\"\"\n", - " if isinstance(file_path, str) and ';' in file_path:\n", - " return _multi_channel(\n", - " file_path.split(';'), reorder, resample, dtype, only_tensor)\n", - "\n", - " org_img, input_img, org_size = _load_and_preprocess(\n", - " file_path, reorder, resample, dtype)\n", - "\n", - " if only_tensor:\n", - " return dtype(input_img.data.type(torch.float))\n", - "\n", - " return org_img, input_img, org_size" - ] - }, { "cell_type": "markdown", "metadata": {}, @@ -277,7 +211,7 @@ "outputs": [], "source": [ "#| export\n", - "class MedBase(torch.Tensor, metaclass=MetaResolver):\n", + "class MedBase(torch.Tensor, metaclass=MetaResolver): \n", " '''A class that represents an image object. Metaclass casts x to this class if it is of type cls._bypass_type.'''\n", "\n", " _bypass_type=torch.Tensor\n",