Skip to content

Commit

Permalink
Dev: refactorize
Browse files Browse the repository at this point in the history
  • Loading branch information
skaliy committed Jul 10, 2023
1 parent 5cb8e17 commit 1fe9061
Show file tree
Hide file tree
Showing 9 changed files with 81 additions and 92 deletions.
2 changes: 1 addition & 1 deletion fastMONAI/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "0.3.0"
__version__ = "0.3.1"
6 changes: 2 additions & 4 deletions fastMONAI/_modidx.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,12 +189,10 @@
'fastMONAI/vision_inference.py'),
'fastMONAI.vision_inference._to_original_orientation': ( 'vision_inference.html#_to_original_orientation',
'fastMONAI/vision_inference.py'),
'fastMONAI.vision_inference.find_similar_size_labels': ( 'vision_inference.html#find_similar_size_labels',
'fastMONAI/vision_inference.py'),
'fastMONAI.vision_inference.inference': ( 'vision_inference.html#inference',
'fastMONAI/vision_inference.py'),
'fastMONAI.vision_inference.pred_postprocess': ( 'vision_inference.html#pred_postprocess',
'fastMONAI/vision_inference.py')},
'fastMONAI.vision_inference.refine_binary_pred_mask': ( 'vision_inference.html#refine_binary_pred_mask',
'fastMONAI/vision_inference.py')},
'fastMONAI.vision_loss': { 'fastMONAI.vision_loss.CustomLoss': ( 'vision_loss_functions.html#customloss',
'fastMONAI/vision_loss.py'),
'fastMONAI.vision_loss.CustomLoss.__call__': ( 'vision_loss_functions.html#customloss.__call__',
Expand Down
53 changes: 31 additions & 22 deletions fastMONAI/vision_inference.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/06_vision_inference.ipynb.

# %% auto 0
__all__ = ['inference', 'pred_postprocess', 'find_similar_size_labels']
__all__ = ['inference', 'refine_binary_pred_mask']

# %% ../nbs/06_vision_inference.ipynb 1
import numpy as np
Expand Down Expand Up @@ -56,27 +56,36 @@ def inference(learn_inf, reorder, resample, fn:(Path,str)='', save_path:(str,Pat
return org_img

# %% ../nbs/06_vision_inference.ipynb 7
def pred_postprocess(pred_mask, remove_size=10437, percentage=0.2):
'''Remove small objects from predicted mask. (TODO:refactorize)'''
small_objects = remove_size*percentage
labeled_mask, ncomponents = label(pred_mask)
labeled_mask = remove_small_objects(labeled_mask, min_size=small_objects)
return np.where(labeled_mask>0, 1., 0.)
def refine_binary_pred_mask(
pred_mask,
remove_size: (int, float) = None,
percentage: float = 0.2,
verbose: bool = False
):
"""Removes small objects from the predicted binary mask.
# %% ../nbs/06_vision_inference.ipynb 8
def find_similar_size_labels(labeled_mask, size_threshold=0.8):
"""
Find labels of components in a labeled mask that are of similar size
to the largest component.
Args:
pred_mask: The predicted mask from which small objects are to be removed.
remove_size: The size under which objects are considered 'small'.
percentage: The percentage of the remove_size to be used as threshold.
Defaults to 0.2.
verbose: If True, print the number of components. Defaults to False.
Returns:
The processed mask with small objects removed.
"""

sizes = np.bincount(labeled_mask.ravel())
max_label = sizes[1:].argmax() + 1
threshold_size = size_threshold * sizes[max_label]
similar_size_labels = [
label for label, size in enumerate(sizes[1:], start=1)
if size >= threshold_size
]
labeled_mask, n_components = label(pred_mask)

if verbose:
print(n_components)

if remove_size is None:
sizes = np.bincount(labeled_mask.ravel())
max_label = sizes[1:].argmax() + 1
remove_size = sizes[max_label]

small_objects_threshold = remove_size * percentage
processed_mask = remove_small_objects(
labeled_mask, min_size=small_objects_threshold)

return max_label, similar_size_labels
return np.where(processed_mask > 0, 1., 0.)
63 changes: 32 additions & 31 deletions nbs/06_vision_inference.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -114,43 +114,44 @@
{
"cell_type": "code",
"execution_count": null,
"id": "d6000a96-56fb-4966-b4f4-83dcb51684e9",
"id": "a64a3407-4b97-4b1c-933c-d4a316dbff94",
"metadata": {},
"outputs": [],
"source": [
"#| export\n",
"def pred_postprocess(pred_mask, remove_size=10437, percentage=0.2): \n",
" '''Remove small objects from predicted mask. (TODO:refactorize)'''\n",
" small_objects = remove_size*percentage \n",
" labeled_mask, ncomponents = label(pred_mask)\n",
" labeled_mask = remove_small_objects(labeled_mask, min_size=small_objects)\n",
" \n",
" return np.where(labeled_mask>0, 1., 0.)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "f491d6e0-c8d6-460c-9a13-17561aa64ac4",
"metadata": {},
"outputs": [],
"source": [
"#| export\n",
"def find_similar_size_labels(labeled_mask, size_threshold=0.8):\n",
" \"\"\"\n",
" Find labels of components in a labeled mask that are of similar size \n",
" to the largest component.\n",
"def refine_binary_pred_mask(\n",
" pred_mask,\n",
" remove_size: (int, float) = None,\n",
" percentage: float = 0.2,\n",
" verbose: bool = False\n",
"):\n",
" \"\"\"Removes small objects from the predicted binary mask.\n",
"\n",
" Args:\n",
" pred_mask: The predicted mask from which small objects are to be removed.\n",
" remove_size: The size under which objects are considered 'small'.\n",
" percentage: The percentage of the remove_size to be used as threshold. \n",
" Defaults to 0.2.\n",
" verbose: If True, print the number of components. Defaults to False.\n",
"\n",
" Returns:\n",
" The processed mask with small objects removed.\n",
" \"\"\"\n",
" \n",
" sizes = np.bincount(labeled_mask.ravel())\n",
" max_label = sizes[1:].argmax() + 1\n",
" threshold_size = size_threshold * sizes[max_label]\n",
" similar_size_labels = [\n",
" label for label, size in enumerate(sizes[1:], start=1)\n",
" if size >= threshold_size\n",
" ]\n",
" labeled_mask, n_components = label(pred_mask)\n",
"\n",
" if verbose:\n",
" print(n_components)\n",
"\n",
" if remove_size is None:\n",
" sizes = np.bincount(labeled_mask.ravel())\n",
" max_label = sizes[1:].argmax() + 1\n",
" remove_size = sizes[max_label]\n",
"\n",
" small_objects_threshold = remove_size * percentage\n",
" processed_mask = remove_small_objects(\n",
" labeled_mask, min_size=small_objects_threshold)\n",
"\n",
" return max_label, similar_size_labels"
" return np.where(processed_mask > 0, 1., 0.)"
]
}
],
Expand Down
24 changes: 6 additions & 18 deletions research/endometrical_cancer/02-ec-inference.ipynb

Large diffs are not rendered by default.

10 changes: 1 addition & 9 deletions research/endometrical_cancer/inference_script.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@

from fastMONAI.vision_all import *
from huggingface_hub import snapshot_download
from scipy.ndimage import label

def find_similar_size_labels(labeled_mask, size_threshold=0.8):
"""
Expand Down Expand Up @@ -51,15 +50,8 @@ def find_similar_size_labels(labeled_mask, size_threshold=0.8):
#pred_items
org_img, input_img, org_size = med_img_reader(img_path, reorder=reorder, resample=resample, only_tensor=False)

#Predict with ensemble
mask_data = inference(learner, reorder=reorder, resample=resample, org_img=org_img, input_img=input_img, org_size=org_size).data
labeled_mask, ncomponents = label(mask_data.numpy())
if ncomponents > 1:
max_label, similar_size_labels = find_similar_size_labels(labeled_mask)
if len(similar_size_labels) == 1:
mask_data = torch.Tensor(np.where(labeled_mask == max_label, 1, 0))
else:
print(save_path)
mask_data = refine_binary_pred_mask(mask_data, percentage=0.8)

if "".join(org_img.orientation) == 'LSA':
mask_data = mask_data.permute(0,1,3,2)
Expand Down
10 changes: 6 additions & 4 deletions research/spine/01-spine-segmentation-inference.ipynb

Large diffs are not rendered by default.

3 changes: 1 addition & 2 deletions research/spine/inference_script.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import argparse

from fastMONAI.vision_all import *
from IPython.display import clear_output
from huggingface_hub import snapshot_download

# Parse command line arguments
Expand Down Expand Up @@ -48,7 +47,7 @@
mask_data = torch.where(mask_data > 0.5, 1., 0.)

# Apply postprocessing to remove small objects from the binary mask
mask_data = torch.Tensor(pred_postprocess(mask_data))
mask_data = refine_binary_pred_mask(mask_data, remove_size=10437, percentage=0.2)

# Set the data of the mask object to the processed mask data
mask.set_data(mask_data)
Expand Down
2 changes: 1 addition & 1 deletion settings.ini
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
### Python Library ###
lib_name = fastMONAI
min_python = 3.7
version = 0.3.0
version = 0.3.1
### OPTIONAL ###

requirements = fastai==2.7.12 monai==1.2.0 torchio==0.18.91 xlrd>=1.2.0 scikit-image==0.19.3 huggingface-hub gdown
Expand Down

0 comments on commit 1fe9061

Please sign in to comment.