diff --git a/.spelling/.spelling/expect.txt b/.spelling/.spelling/expect.txt index fe6a792b1..0055fe5c9 100644 --- a/.spelling/.spelling/expect.txt +++ b/.spelling/.spelling/expect.txt @@ -487,6 +487,7 @@ rgbatorgb rgbtorgba rigourous Ritesh +rmse rmsprop rocm rocmdocs @@ -561,7 +562,6 @@ thresholded thresholding Thu tiatoolbox -tiffslide timepoints timm tio @@ -597,8 +597,6 @@ unittests unitwise unsqueeze upenn -Uploaing -Uploded upsample upsampled upsampling @@ -725,7 +723,6 @@ zsuokb zwezggl zzokqk thirdparty -adopy Shohei crcrpar lrs diff --git a/GANDLF/cli/generate_metrics.py b/GANDLF/cli/generate_metrics.py index d484b63a0..7e7fdda49 100644 --- a/GANDLF/cli/generate_metrics.py +++ b/GANDLF/cli/generate_metrics.py @@ -15,13 +15,11 @@ overall_stats, structural_similarity_index, mean_squared_error, + root_mean_squared_error, peak_signal_noise_ratio, mean_squared_log_error, mean_absolute_error, - ncc_mean, - ncc_std, - ncc_max, - ncc_min, + ncc_metrics, ) from GANDLF.losses.segmentation import dice from GANDLF.metrics.segmentation import ( @@ -302,21 +300,22 @@ def __percentile_clip( reference_tensor = ( input_tensor if reference_tensor is None else reference_tensor ) - v_min, v_max = np.percentile( - reference_tensor, [p_min, p_max] - ) # get p_min percentile and p_max percentile - + # get p_min percentile and p_max percentile + v_min, v_max = np.percentile(reference_tensor, [p_min, p_max]) # set lower bound to be 0 if strictlyPositive is enabled v_min = max(v_min, 0.0) if strictlyPositive else v_min - output_tensor = np.clip( - input_tensor, v_min, v_max - ) # clip values to percentiles from reference_tensor - output_tensor = (output_tensor - v_min) / ( - v_max - v_min - ) # normalizes values to [0;1] + # clip values to percentiles from reference_tensor + output_tensor = np.clip(input_tensor, v_min, v_max) + # normalizes values to [0;1] + output_tensor = (output_tensor - v_min) / (v_max - v_min) return output_tensor - input_df = __update_header_location_case_insensitive(input_df, "Mask", False) + # these are additional columns that could be present for synthesis tasks + for column_to_make_case_insensitive in ["Mask", "VoidImage"]: + input_df = __update_header_location_case_insensitive( + input_df, column_to_make_case_insensitive, False + ) + for _, row in tqdm(input_df.iterrows(), total=input_df.shape[0]): current_subject_id = row["SubjectID"] overall_stats_dict[current_subject_id] = {} @@ -332,6 +331,15 @@ def __percentile_clip( ) ).byte() + void_image_present = True if "VoidImage" in row else False + void_image = ( + __fix_2d_tensor(torchio.ScalarImage(row["VoidImage"]).data) + if "VoidImage" in row + else torch.from_numpy( + np.ones(target_image.numpy().shape, dtype=np.uint8) + ) + ) + # Get Infill region (we really are only interested in the infill region) output_infill = (pred_image * mask).float() gt_image_infill = (target_image * mask).float() @@ -339,9 +347,10 @@ def __percentile_clip( # Normalize to [0;1] based on GT (otherwise MSE will depend on the image intensity range) normalize = parameters.get("normalize", True) if normalize: + # use all the tissue that is not masked for normalization reference_tensor = ( - target_image * ~mask - ) # use all the tissue that is not masked for normalization + target_image * ~mask if not void_image_present else void_image + ) gt_image_infill = __percentile_clip( gt_image_infill, reference_tensor=reference_tensor, @@ -364,18 +373,10 @@ def __percentile_clip( # ncc metrics compute_ncc = parameters.get("compute_ncc", True) if compute_ncc: - overall_stats_dict[current_subject_id]["ncc_mean"] = ncc_mean( - output_infill, gt_image_infill - ) - overall_stats_dict[current_subject_id]["ncc_std"] = ncc_std( - output_infill, gt_image_infill - ) - overall_stats_dict[current_subject_id]["ncc_max"] = ncc_max( - output_infill, gt_image_infill - ) - overall_stats_dict[current_subject_id]["ncc_min"] = ncc_min( - output_infill, gt_image_infill - ) + calculated_ncc_metrics = ncc_metrics(output_infill, gt_image_infill) + for key, value in calculated_ncc_metrics.items(): + # we don't need the ".item()" here, since the values are already scalars + overall_stats_dict[current_subject_id][key] = value # only voxels that are to be inferred (-> flat array) # these are required for mse, psnr, etc. @@ -386,6 +387,10 @@ def __percentile_clip( output_infill, gt_image_infill ).item() + overall_stats_dict[current_subject_id]["rmse"] = root_mean_squared_error( + output_infill, gt_image_infill + ).item() + overall_stats_dict[current_subject_id]["msle"] = mean_squared_log_error( output_infill, gt_image_infill ).item() diff --git a/GANDLF/entrypoints/hf_hub_integration.py b/GANDLF/entrypoints/hf_hub_integration.py index 353f31dfb..d0f209ad4 100644 --- a/GANDLF/entrypoints/hf_hub_integration.py +++ b/GANDLF/entrypoints/hf_hub_integration.py @@ -96,7 +96,7 @@ @click.option( "--hf-template", "-hft", - help="Adding the template path for the model card it is Required during Uploaing a model", + help="Adding the template path for the model card: it is required during model upload", default=huggingface_file_path, type=click.Path(exists=True, file_okay=True, dir_okay=False), ) diff --git a/GANDLF/metrics/__init__.py b/GANDLF/metrics/__init__.py index 1fc21b3fb..824d38512 100644 --- a/GANDLF/metrics/__init__.py +++ b/GANDLF/metrics/__init__.py @@ -32,13 +32,11 @@ from .synthesis import ( structural_similarity_index, mean_squared_error, + root_mean_squared_error, peak_signal_noise_ratio, mean_squared_log_error, mean_absolute_error, - ncc_mean, - ncc_std, - ncc_max, - ncc_min, + ncc_metrics, ) import GANDLF.metrics.classification as classification import GANDLF.metrics.regression as regression diff --git a/GANDLF/metrics/synthesis.py b/GANDLF/metrics/synthesis.py index ba1b4113e..a7d6a4523 100644 --- a/GANDLF/metrics/synthesis.py +++ b/GANDLF/metrics/synthesis.py @@ -46,8 +46,28 @@ def mean_squared_error(prediction: torch.Tensor, target: torch.Tensor) -> torch. Args: prediction (torch.Tensor): The prediction tensor. target (torch.Tensor): The target tensor. + + Returns: + torch.Tensor: The mean squared error or its square root. """ - mse = MeanSquaredError() + mse = MeanSquaredError(squared=True) + return mse(preds=prediction, target=target) + + +def root_mean_squared_error( + prediction: torch.Tensor, target: torch.Tensor +) -> torch.Tensor: + """ + Computes the mean squared error between the target and prediction. + + Args: + prediction (torch.Tensor): The prediction tensor. + target (torch.Tensor): The target tensor. + + Returns: + torch.Tensor: The mean squared error or its square root. + """ + mse = MeanSquaredError(squared=False) return mse(preds=prediction, target=target) @@ -78,10 +98,9 @@ def peak_signal_noise_ratio( return psnr(preds=prediction, target=target) else: # implementation of PSNR that does not give 'inf'/'nan' when 'mse==0' mse = mean_squared_error(target, prediction) - if data_range == None: # compute data_range like torchmetrics if not given - min_v = ( - 0 if torch.min(target) > 0 else torch.min(target) - ) # look at this line + if data_range is None: # compute data_range like torchmetrics if not given + # put the min value to 0 if all values are positive + min_v = 0 if torch.min(target) > 0 else torch.min(target) max_v = torch.max(target) else: min_v, max_v = data_range @@ -158,69 +177,27 @@ def __convert_to_grayscale(image: sitk.Image) -> sitk.Image: return correlation_filter.Execute(target_image, pred_image) -def ncc_mean(prediction: torch.Tensor, target: torch.Tensor) -> float: - """ - Computes normalized cross correlation mean between target and prediction. - - Args: - prediction (torch.Tensor): The prediction tensor. - target (torch.Tensor): The target tensor. - - Returns: - float: The normalized cross correlation mean. - """ - stats_filter = sitk.StatisticsImageFilter() - corr_image = _get_ncc_image(target, prediction) - stats_filter.Execute(corr_image) - return stats_filter.GetMean() - - -def ncc_std(prediction: torch.Tensor, target: torch.Tensor) -> float: - """ - Computes normalized cross correlation standard deviation between target and prediction. - - Args: - prediction (torch.Tensor): The prediction tensor. - target (torch.Tensor): The target tensor. - - Returns: - float: The normalized cross correlation standard deviation. - """ - stats_filter = sitk.StatisticsImageFilter() - corr_image = _get_ncc_image(target, prediction) - stats_filter.Execute(corr_image) - return stats_filter.GetSigma() - - -def ncc_max(prediction: torch.Tensor, target: torch.Tensor) -> float: - """ - Computes normalized cross correlation maximum between target and prediction. - - Args: - prediction (torch.Tensor): The prediction tensor. - target (torch.Tensor): The target tensor. - - Returns: - float: The normalized cross correlation maximum. - """ - stats_filter = sitk.StatisticsImageFilter() - corr_image = _get_ncc_image(target, prediction) - stats_filter.Execute(corr_image) - return stats_filter.GetMaximum() - - -def ncc_min(prediction: torch.Tensor, target: torch.Tensor) -> float: +def ncc_metrics(prediction: torch.Tensor, target: torch.Tensor) -> dict: """ - Computes normalized cross correlation minimum between target and prediction. + Computes normalized cross correlation metrics between target and prediction. Args: prediction (torch.Tensor): The prediction tensor. target (torch.Tensor): The target tensor. Returns: - float: The normalized cross correlation minimum. + dict: The normalized cross correlation metrics. """ - stats_filter = sitk.StatisticsImageFilter() corr_image = _get_ncc_image(target, prediction) - stats_filter.Execute(corr_image) - return stats_filter.GetMinimum() + stats_filter = sitk.LabelStatisticsImageFilter() + stats_filter.UseHistogramsOn() + # ensure that we are not considering zeros + onesImage = corr_image == corr_image + stats_filter.Execute(corr_image, onesImage) + return { + "ncc_mean": stats_filter.GetMean(1), + "ncc_std": stats_filter.GetSigma(1), + "ncc_max": stats_filter.GetMaximum(1), + "ncc_min": stats_filter.GetMinimum(1), + "ncc_median": stats_filter.GetMedian(1), + }