Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CRF Post Processing as a MONAI Transform #2196

Closed
masadcv opened this issue May 13, 2021 · 8 comments
Closed

CRF Post Processing as a MONAI Transform #2196

masadcv opened this issue May 13, 2021 · 8 comments
Labels
question Further information is requested

Comments

@masadcv
Copy link
Contributor

masadcv commented May 13, 2021

Is your feature request related to a problem? Please describe.
In a number of deep learning based segmentation models, conditional random fields (CRFs) are used as a post processing step to process the output and produce segmentation maps that are more consistent with the underlying regions within an image.

Describe the solution you'd like
At the moment, MONAI provides CRF layers that can enable this. It may be beneficial to have dictionary/array Transforms that utilise CRF and do post processing - such that these can be used to compose a post processing transform that can perform the CRF in the post processing step.

Describe alternatives you've considered
Using CRF layer from MONAI as a separate model layer, that is attached to a model or the outputs from the model. It may be more convenient to separate this out into a Transform such that it can be quickly utilised in post processing steps.

Additional context
As an example, the following is an initial prototype to give an idea of how this may be approached:

from monai.networks.blocks import CRF
from monai.transforms import Transform

class ApplyCRFPostProcd(Transform):
    def __init__(
        self,
        unary: str,
        pairwise: str,
        post_proc_label: str = 'postproc',
        iterations: int = 5, 
        bilateral_weight: float = 3.0,
        gaussian_weight: float = 1.0,
        bilateral_spatial_sigma: float = 5.0,
        bilateral_color_sigma: float = 0.5,
        gaussian_spatial_sigma: float = 5.0,
        compatibility_kernel_range: float = 1,
        device = torch.device('cpu'),
    ):
        self.unary = unary
        self.pairwise = pairwise
        self.post_proc_label = post_proc_label
        self.device = device

        self.crf_layer = CRF(
                iterations, 
                bilateral_weight,
                gaussian_weight,
                bilateral_spatial_sigma,
                bilateral_color_sigma,
                gaussian_spatial_sigma,
                compatibility_kernel_range
                )

    def __call__(self, data):
        d = dict(data)
        unary_term = d[self.unary].float().to(self.device)
        pairwise_term = d[self.pairwise].float().to(self.device)
        d[self.post_proc_label] = self.crf_layer(unary_term, pairwise_term)
        return d

Example usage of above as post processing would be:

post_transforms = [
            ApplyCRFPostProcd(unary='logits', pairwise='image', post_proc_label='pred'),
            SqueezeDimd(keys='pred', dim=0),
            ToNumpyd(keys='pred])
]

Please let me know your thoughts about this, whether it makes sense to have as a Transform? If so, I am happy to work on this.

@tvercaut
Copy link
Member

Should this issue be merged with #315?

@Nic-Ma Nic-Ma added the question Further information is requested label May 14, 2021
@wyli
Copy link
Contributor

wyli commented May 17, 2021

issue moved into #315

@wyli wyli closed this as completed May 17, 2021
@MasalaKimchi
Copy link

I'm reaching out because I've been trying to implement the same CRF postprocessing class (as you wrote above) in my post-processing pipeline and have been encountering some errors. I was wondering if you might be able to help me troubleshoot the issue or provide some guidance on how to properly apply this transform class.

test_org_transforms = Compose(
    [
        LoadImaged(keys=["image", "label"]),
        EnsureChannelFirstd(keys=["image", "label"]),
        ScaleIntensityRanged(keys=["image"], a_min=-250, a_max=600, b_min=0.0, b_max=1.0, clip=True,),
        CropForegroundd(keys=["image","label"], source_key="image"),
    ]
)

test_org_ds = CacheDataset(data=test_files[:5], transform=test_org_transforms)
test_org_loader = DataLoader(test_org_ds, batch_size=1, num_workers=8)

post_transforms = Compose(
    [
        ApplyCRFPostProcd(unary="pred", pairwise="image", post_proc_label="post_pred"),
        AsDiscreted(keys="post_pred", argmax=True, to_onehot=2),
    ]
)

dice_metric_post = DiceMetric(include_background=False, reduction="mean", get_not_nans=False)
dice_metric = DiceMetric(include_background=False, reduction="mean", get_not_nans=False)

post_label = AsDiscrete(to_onehot=2)
post_pred = AsDiscrete(argmax=True, to_onehot=2)

model.eval()

with torch.no_grad():
    for test_data in test_org_loader:
        test_inputs, test_labels = (test_data["image"].cuda(), test_data["label"].cuda())
        roi_size = (roi, roi, roi)
        sw_batch_size = sw_batch_size
        with torch.cuda.amp.autocast():
            test_outputs = sliding_window_inference(test_inputs, (roi,roi,roi), sw_batch_size, model, overlap=0.25)
        
        test_labels_convert = [post_label(i) for i in decollate_batch(test_labels)]
        test_output_convert = [post_pred(i) for i in decollate_batch(test_outputs)]

        # test_output_post= [post_transforms(i) for i in decollate_batch(test_outputs)]

        dice_metric(y_pred=test_output_convert, y=test_labels_convert)
        # dice_metric_post(y_pred=test_output_post, y=test_labels_convert)

    mean_dice_val = dice_metric.aggregate().item()
    # mean_dice_val_post = dice_metric_post.aggregate().item()
    dice_metric.reset()
    # dice_metric_post.reset()

print(f"Mean Dice: {mean_dice_val:.4f}")
# print(f"Mean Dice Post: {mean_dice_val_post:.4f}")

Then following errors occur

ValueError                                Traceback (most recent call last)
File [c:\Users\user\Anaconda3\envs\vit\lib\site-packages\monai\transforms\transform.py:102](file:///C:/Users/user/Anaconda3/envs/vit/lib/site-packages/monai/transforms/transform.py:102), in apply_transform(transform, data, map_items, unpack_items, log_stats)
    101         return [_apply_transform(transform, item, unpack_items) for item in data]
--> 102     return _apply_transform(transform, data, unpack_items)
    103 except Exception as e:
    104     # if in debug mode, don't swallow exception so that the breakpoint
    105     # appears where the exception was raised.

File [c:\Users\user\Anaconda3\envs\vit\lib\site-packages\monai\transforms\transform.py:66](file:///C:/Users/user/Anaconda3/envs/vit/lib/site-packages/monai/transforms/transform.py:66), in _apply_transform(transform, parameters, unpack_parameters)
     64     return transform(*parameters)
---> 66 return transform(parameters)

Cell In[49], line 35, in ApplyCRFPostProcd.__call__(self, data)
     34 def __call__(self, data):
---> 35     d = dict(data)
     36     unary_term = d[self.unary].float().to(self.device)

ValueError: dictionary update sequence element #0 has length 274; 2 is required

The above exception was the direct cause of the following exception:

RuntimeError                              Traceback (most recent call last)
Cell In[54], line 18
     15 test_labels_convert = [post_label(i) for i in decollate_batch(test_labels)]
...
    127     else:
    128         _log_stats(data=data)
--> 129 raise RuntimeError(f"applying transform {transform}") from e

RuntimeError: applying transform <__main__.ApplyCRFPostProcd object at 0x0000017E186032E0>

@wyli
Copy link
Contributor

wyli commented Apr 4, 2023

the transform requires a dictionary input, but for i in decollate_batch(test_outputs) generates metatensors, you can create a dictionary then call the transform post_label(dict_i)

@MasalaKimchi
Copy link

Thank you so much for the prompt response.
I forgot to mention above that the following code runs fine. But the error occurs when I un-comment the commented lines for

test_output_post= [post_transforms(i) for i in decollate_batch(test_outputs)]

Using your advice, I tried to make some edits...

test_data["pred"] = test_outputs 
test_output_post = [post_transforms(i) for i in decollate_batch(test_data)]

However, it still leads to other errors.

OptionalImportError                       Traceback (most recent call last)
File [c:\Users\user\Anaconda3\envs\vit\lib\site-packages\monai\transforms\transform.py:102](file:///C:/Users/user/Anaconda3/envs/vit/lib/site-packages/monai/transforms/transform.py:102), in apply_transform(transform, data, map_items, unpack_items, log_stats)
    101         return [_apply_transform(transform, item, unpack_items) for item in data]
--> 102     return _apply_transform(transform, data, unpack_items)
    103 except Exception as e:
    104     # if in debug mode, don't swallow exception so that the breakpoint
    105     # appears where the exception was raised.

File [c:\Users\user\Anaconda3\envs\vit\lib\site-packages\monai\transforms\transform.py:66](file:///C:/Users/user/Anaconda3/envs/vit/lib/site-packages/monai/transforms/transform.py:66), in _apply_transform(transform, parameters, unpack_parameters)
     64     return transform(*parameters)
---> 66 return transform(parameters)

Cell In[12], line 38, in ApplyCRFPostProcd.__call__(self, data)
     37 pairwise_term = d[self.pairwise].float().to(self.device)
---> 38 d[self.post_proc_label] = self.crf_layer(unary_term, pairwise_term)
     39 return d

File [c:\Users\user\Anaconda3\envs\vit\lib\site-packages\torch\nn\modules\module.py:1194](file:///C:/Users/user/Anaconda3/envs/vit/lib/site-packages/torch/nn/modules/module.py:1194), in Module._call_impl(self, *input, **kwargs)
   1192 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1193         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1194     return forward_call(*input, **kwargs)
   1195 # Do not call functions when jit is used

File [c:\Users\user\Anaconda3\envs\vit\lib\site-packages\monai\networks\blocks\crf.py:97](file:///C:/Users/user/Anaconda3/envs/vit/lib/site-packages/monai/networks/blocks/crf.py:97), in CRF.forward(self, input_tensor, reference_tensor)
...
    127     else:
    128         _log_stats(data=data)
--> 129 raise RuntimeError(f"applying transform {transform}") from e

RuntimeError: applying transform <__main__.ApplyCRFPostProcd object at 0x00000147FA1C9D30>

@wyli
Copy link
Contributor

wyli commented Apr 4, 2023

OptionalImportError looks like you haven't compiled the CRF module properly, please follow https://docs.monai.io/en/latest/installation.html#option-2-editable-installation

@MasalaKimchi
Copy link

I am so sorry to bother you. Could you please let me know why this occurs at the terminal on VS code.

(vit) C:\Users\user\Desktop\MMH_Vit\MONAI>BUILD_MONAI=1 python setup.py develop
'BUILD_MONAI' is not recognized as an internal or external command,
operable program or batch file.

@wyli
Copy link
Contributor

wyli commented Apr 4, 2023

the command is only tested with Bash, not sure how to do the same in windows CLI, before you spend time compiling the module, you could also try

the docker image releases https://hub.docker.com/r/projectmonai/monai/tags which has a compiled version,

or use google colab and run BUILD_MONAI=1 pip install git+https://github.com/Project-MONAI/MONAI#egg=monai

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question Further information is requested
Projects
None yet
Development

No branches or pull requests

5 participants