From b0247acd287d7033e238c111a044e0888bfbe366 Mon Sep 17 00:00:00 2001 From: scap3yvt <149599669+scap3yvt@users.noreply.github.com> Date: Thu, 18 Apr 2024 10:00:34 -0400 Subject: [PATCH 01/11] added standardized CLI options --- gandlf_anonymizer | 5 +++-- gandlf_configGenerator | 3 ++- gandlf_constructCSV | 5 +++-- gandlf_patchMiner | 3 ++- gandlf_preprocess | 3 ++- gandlf_recoverConfig | 3 ++- 6 files changed, 14 insertions(+), 8 deletions(-) diff --git a/gandlf_anonymizer b/gandlf_anonymizer index 2d6141c89..294ab9aa5 100644 --- a/gandlf_anonymizer +++ b/gandlf_anonymizer @@ -38,6 +38,7 @@ def main(): ) parser.add_argument( "-o", + "--outputfile", "--outputFile", metavar="", type=str, @@ -47,12 +48,12 @@ def main(): args = parser.parse_args() # check for required parameters - this is needed here to keep the cli clean - for param_none_check in [args.inputDir, args.outputFile]: + for param_none_check in [args.inputDir, args.outputfile]: if param_none_check is None: sys.exit("ERROR: Missing required parameter:", param_none_check) inputDir = os.path.normpath(args.inputDir) - outputFile = os.path.normpath(args.outputFile) + outputFile = os.path.normpath(args.outputfile) if os.path.isfile(args.config): config = yaml.safe_load(open(args.config, "r")) else: diff --git a/gandlf_configGenerator b/gandlf_configGenerator index 26ddca70b..1c0ea7db3 100644 --- a/gandlf_configGenerator +++ b/gandlf_configGenerator @@ -28,6 +28,7 @@ if __name__ == "__main__": ) parser.add_argument( "-o", + "--outputdir", "--output", metavar="", type=str, @@ -37,6 +38,6 @@ if __name__ == "__main__": args = parser.parse_args() - config_generator(args.config, args.strategy, args.output) + config_generator(args.config, args.strategy, args.outputdir) print("Finished.") diff --git a/gandlf_constructCSV b/gandlf_constructCSV index a61f322ea..bbad3deb8 100644 --- a/gandlf_constructCSV +++ b/gandlf_constructCSV @@ -40,6 +40,7 @@ def main(): ) parser.add_argument( "-o", + "--outputfile", "--outputFile", metavar="", type=str, @@ -60,13 +61,13 @@ def main(): for param_none_check in [ args.inputDir, args.channelsID, - args.outputFile, + args.outputfile, ]: if param_none_check is None: sys.exit("ERROR: Missing required parameter:", param_none_check) inputDir = os.path.normpath(args.inputDir) - outputFile = os.path.normpath(args.outputFile) + outputFile = os.path.normpath(args.outputfile) channelsID = args.channelsID labelID = args.labelID relativizePathsToOutput = args.relativizePaths diff --git a/gandlf_patchMiner b/gandlf_patchMiner index 1cb63ec2d..daff4c4b8 100644 --- a/gandlf_patchMiner +++ b/gandlf_patchMiner @@ -24,6 +24,7 @@ if __name__ == "__main__": ) parser.add_argument( "-o", + "--outputdir", "--output_path", dest="output_path", default=None, @@ -41,6 +42,6 @@ if __name__ == "__main__": args = parser.parse_args() - patch_extraction(args.input_path, args.output_path, args.config) + patch_extraction(args.input_path, args.outputdir, args.config) print("Finished.") diff --git a/gandlf_preprocess b/gandlf_preprocess index 1522f8607..991d737b2 100644 --- a/gandlf_preprocess +++ b/gandlf_preprocess @@ -31,6 +31,7 @@ if __name__ == "__main__": ) parser.add_argument( "-o", + "--outputdir", "--output", metavar="", type=str, @@ -70,7 +71,7 @@ if __name__ == "__main__": preprocess_and_save( args.inputdata, args.config, - args.output, + args.outputdir, args.labelPad, args.applyaugs, args.cropzero, diff --git a/gandlf_recoverConfig b/gandlf_recoverConfig index 37653fe04..464a97586 100644 --- a/gandlf_recoverConfig +++ b/gandlf_recoverConfig @@ -33,6 +33,7 @@ if __name__ == "__main__": ) parser.add_argument( "-o", + "--outputfile", "--outputFile", metavar="", type=str, @@ -46,5 +47,5 @@ if __name__ == "__main__": else: search_dir = args.modeldir - result = recover_config(search_dir, args.outputFile) + result = recover_config(search_dir, args.outputfile) assert result, "Config file recovery failed." From ca49b4ca71988de850791f3691be3ef0e924e9d7 Mon Sep 17 00:00:00 2001 From: sarthakpati Date: Thu, 19 Dec 2024 09:57:43 -0500 Subject: [PATCH 02/11] putting comments and added parameter to get rmse --- GANDLF/metrics/synthesis.py | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/GANDLF/metrics/synthesis.py b/GANDLF/metrics/synthesis.py index ba1b4113e..f29933da7 100644 --- a/GANDLF/metrics/synthesis.py +++ b/GANDLF/metrics/synthesis.py @@ -39,15 +39,21 @@ def structural_similarity_index( return ssim_idx.mean() -def mean_squared_error(prediction: torch.Tensor, target: torch.Tensor) -> torch.Tensor: +def mean_squared_error( + prediction: torch.Tensor, target: torch.Tensor, squared: bool = True +) -> torch.Tensor: """ Computes the mean squared error between the target and prediction. Args: prediction (torch.Tensor): The prediction tensor. target (torch.Tensor): The target tensor. + squared (bool, optional): Whether to return squared error. Defaults to True. + + Returns: + torch.Tensor: The mean squared error or its square root. """ - mse = MeanSquaredError() + mse = MeanSquaredError(squared=squared) return mse(preds=prediction, target=target) @@ -78,10 +84,9 @@ def peak_signal_noise_ratio( return psnr(preds=prediction, target=target) else: # implementation of PSNR that does not give 'inf'/'nan' when 'mse==0' mse = mean_squared_error(target, prediction) - if data_range == None: # compute data_range like torchmetrics if not given - min_v = ( - 0 if torch.min(target) > 0 else torch.min(target) - ) # look at this line + if data_range is None: # compute data_range like torchmetrics if not given + # put the min value to 0 if all values are positive + min_v = 0 if torch.min(target) > 0 else torch.min(target) max_v = torch.max(target) else: min_v, max_v = data_range From 1c4afd7306a2b766f8d71a5460bd9edc151412eb Mon Sep 17 00:00:00 2001 From: sarthakpati Date: Thu, 19 Dec 2024 09:58:05 -0500 Subject: [PATCH 03/11] ensure that the brain mask and void image are treated differently --- GANDLF/cli/generate_metrics.py | 83 ++++++++++++++++++++-------------- 1 file changed, 49 insertions(+), 34 deletions(-) diff --git a/GANDLF/cli/generate_metrics.py b/GANDLF/cli/generate_metrics.py index d484b63a0..1821aa2dd 100644 --- a/GANDLF/cli/generate_metrics.py +++ b/GANDLF/cli/generate_metrics.py @@ -302,21 +302,22 @@ def __percentile_clip( reference_tensor = ( input_tensor if reference_tensor is None else reference_tensor ) - v_min, v_max = np.percentile( - reference_tensor, [p_min, p_max] - ) # get p_min percentile and p_max percentile - + # get p_min percentile and p_max percentile + v_min, v_max = np.percentile(reference_tensor, [p_min, p_max]) # set lower bound to be 0 if strictlyPositive is enabled v_min = max(v_min, 0.0) if strictlyPositive else v_min - output_tensor = np.clip( - input_tensor, v_min, v_max - ) # clip values to percentiles from reference_tensor - output_tensor = (output_tensor - v_min) / ( - v_max - v_min - ) # normalizes values to [0;1] + # clip values to percentiles from reference_tensor + output_tensor = np.clip(input_tensor, v_min, v_max) + # normalizes values to [0;1] + output_tensor = (output_tensor - v_min) / (v_max - v_min) return output_tensor - input_df = __update_header_location_case_insensitive(input_df, "Mask", False) + # these are additional columns that could be present for synthesis tasks + for column_to_make_case_insensitive in ["Mask", "VoidImage"]: + input_df = __update_header_location_case_insensitive( + input_df, column_to_make_case_insensitive, False + ) + for _, row in tqdm(input_df.iterrows(), total=input_df.shape[0]): current_subject_id = row["SubjectID"] overall_stats_dict[current_subject_id] = {} @@ -332,6 +333,15 @@ def __percentile_clip( ) ).byte() + void_image_present = True if "VoidImage" in row else False + void_image = ( + __fix_2d_tensor(torchio.ScalarImage(row["VoidImage"]).data) + if "VoidImage" in row + else torch.from_numpy( + np.ones(target_image.numpy().shape, dtype=np.uint8) + ) + ) + # Get Infill region (we really are only interested in the infill region) output_infill = (pred_image * mask).float() gt_image_infill = (target_image * mask).float() @@ -339,9 +349,10 @@ def __percentile_clip( # Normalize to [0;1] based on GT (otherwise MSE will depend on the image intensity range) normalize = parameters.get("normalize", True) if normalize: + # use all the tissue that is not masked for normalization reference_tensor = ( - target_image * ~mask - ) # use all the tissue that is not masked for normalization + target_image * ~mask if not void_image_present else void_image + ) gt_image_infill = __percentile_clip( gt_image_infill, reference_tensor=reference_tensor, @@ -357,9 +368,9 @@ def __percentile_clip( strictlyPositive=True, ) - overall_stats_dict[current_subject_id][ - "ssim" - ] = structural_similarity_index(output_infill, gt_image_infill, mask).item() + overall_stats_dict[current_subject_id]["ssim"] = ( + structural_similarity_index(output_infill, gt_image_infill, mask).item() + ) # ncc metrics compute_ncc = parameters.get("compute_ncc", True) @@ -386,6 +397,10 @@ def __percentile_clip( output_infill, gt_image_infill ).item() + overall_stats_dict[current_subject_id]["rmse"] = mean_squared_error( + output_infill, gt_image_infill, squared=False + ).item() + overall_stats_dict[current_subject_id]["msle"] = mean_squared_log_error( output_infill, gt_image_infill ).item() @@ -400,30 +415,30 @@ def __percentile_clip( ).item() # same as above but with epsilon for robustness - overall_stats_dict[current_subject_id][ - "psnr_eps" - ] = peak_signal_noise_ratio( - output_infill, gt_image_infill, epsilon=sys.float_info.epsilon - ).item() + overall_stats_dict[current_subject_id]["psnr_eps"] = ( + peak_signal_noise_ratio( + output_infill, gt_image_infill, epsilon=sys.float_info.epsilon + ).item() + ) # only use fix data range to [0;1] if the data was normalized before if normalize: # torchmetrics PSNR but with fixed data range of 0 to 1 - overall_stats_dict[current_subject_id][ - "psnr_01" - ] = peak_signal_noise_ratio( - output_infill, gt_image_infill, data_range=(0, 1) - ).item() + overall_stats_dict[current_subject_id]["psnr_01"] = ( + peak_signal_noise_ratio( + output_infill, gt_image_infill, data_range=(0, 1) + ).item() + ) # same as above but with epsilon for robustness - overall_stats_dict[current_subject_id][ - "psnr_01_eps" - ] = peak_signal_noise_ratio( - output_infill, - gt_image_infill, - data_range=(0, 1), - epsilon=sys.float_info.epsilon, - ).item() + overall_stats_dict[current_subject_id]["psnr_01_eps"] = ( + peak_signal_noise_ratio( + output_infill, + gt_image_infill, + data_range=(0, 1), + epsilon=sys.float_info.epsilon, + ).item() + ) pprint(overall_stats_dict) if outputfile is not None: From 05f71014e7d614e707203ea9e4984f80b4c3518f Mon Sep 17 00:00:00 2001 From: sarthakpati Date: Thu, 19 Dec 2024 11:09:16 -0500 Subject: [PATCH 04/11] updated dictionary for spell checker --- .spelling/.spelling/expect.txt | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/.spelling/.spelling/expect.txt b/.spelling/.spelling/expect.txt index fe6a792b1..b4f50c2f9 100644 --- a/.spelling/.spelling/expect.txt +++ b/.spelling/.spelling/expect.txt @@ -487,6 +487,7 @@ rgbatorgb rgbtorgba rigourous Ritesh +rmse rmsprop rocm rocmdocs @@ -561,7 +562,6 @@ thresholded thresholding Thu tiatoolbox -tiffslide timepoints timm tio @@ -597,7 +597,6 @@ unittests unitwise unsqueeze upenn -Uploaing Uploded upsample upsampled @@ -725,7 +724,6 @@ zsuokb zwezggl zzokqk thirdparty -adopy Shohei crcrpar lrs From ba3ea8329c89e188f9bd79038a8adf65ab90df20 Mon Sep 17 00:00:00 2001 From: sarthakpati Date: Thu, 19 Dec 2024 11:09:49 -0500 Subject: [PATCH 05/11] lint fix --- GANDLF/cli/generate_metrics.py | 42 +++++++++++++++++----------------- 1 file changed, 21 insertions(+), 21 deletions(-) diff --git a/GANDLF/cli/generate_metrics.py b/GANDLF/cli/generate_metrics.py index 1821aa2dd..fc0b25ff1 100644 --- a/GANDLF/cli/generate_metrics.py +++ b/GANDLF/cli/generate_metrics.py @@ -368,9 +368,9 @@ def __percentile_clip( strictlyPositive=True, ) - overall_stats_dict[current_subject_id]["ssim"] = ( - structural_similarity_index(output_infill, gt_image_infill, mask).item() - ) + overall_stats_dict[current_subject_id][ + "ssim" + ] = structural_similarity_index(output_infill, gt_image_infill, mask).item() # ncc metrics compute_ncc = parameters.get("compute_ncc", True) @@ -415,30 +415,30 @@ def __percentile_clip( ).item() # same as above but with epsilon for robustness - overall_stats_dict[current_subject_id]["psnr_eps"] = ( - peak_signal_noise_ratio( - output_infill, gt_image_infill, epsilon=sys.float_info.epsilon - ).item() - ) + overall_stats_dict[current_subject_id][ + "psnr_eps" + ] = peak_signal_noise_ratio( + output_infill, gt_image_infill, epsilon=sys.float_info.epsilon + ).item() # only use fix data range to [0;1] if the data was normalized before if normalize: # torchmetrics PSNR but with fixed data range of 0 to 1 - overall_stats_dict[current_subject_id]["psnr_01"] = ( - peak_signal_noise_ratio( - output_infill, gt_image_infill, data_range=(0, 1) - ).item() - ) + overall_stats_dict[current_subject_id][ + "psnr_01" + ] = peak_signal_noise_ratio( + output_infill, gt_image_infill, data_range=(0, 1) + ).item() # same as above but with epsilon for robustness - overall_stats_dict[current_subject_id]["psnr_01_eps"] = ( - peak_signal_noise_ratio( - output_infill, - gt_image_infill, - data_range=(0, 1), - epsilon=sys.float_info.epsilon, - ).item() - ) + overall_stats_dict[current_subject_id][ + "psnr_01_eps" + ] = peak_signal_noise_ratio( + output_infill, + gt_image_infill, + data_range=(0, 1), + epsilon=sys.float_info.epsilon, + ).item() pprint(overall_stats_dict) if outputfile is not None: From e4f68509c9a48086eb5aa8b642fde599225cc401 Mon Sep 17 00:00:00 2001 From: sarthakpati Date: Thu, 19 Dec 2024 11:24:21 -0500 Subject: [PATCH 06/11] typo fix and unnecessary word removed --- .spelling/.spelling/expect.txt | 1 - GANDLF/entrypoints/hf_hub_integration.py | 2 +- 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/.spelling/.spelling/expect.txt b/.spelling/.spelling/expect.txt index b4f50c2f9..0055fe5c9 100644 --- a/.spelling/.spelling/expect.txt +++ b/.spelling/.spelling/expect.txt @@ -597,7 +597,6 @@ unittests unitwise unsqueeze upenn -Uploded upsample upsampled upsampling diff --git a/GANDLF/entrypoints/hf_hub_integration.py b/GANDLF/entrypoints/hf_hub_integration.py index 353f31dfb..d0f209ad4 100644 --- a/GANDLF/entrypoints/hf_hub_integration.py +++ b/GANDLF/entrypoints/hf_hub_integration.py @@ -96,7 +96,7 @@ @click.option( "--hf-template", "-hft", - help="Adding the template path for the model card it is Required during Uploaing a model", + help="Adding the template path for the model card: it is required during model upload", default=huggingface_file_path, type=click.Path(exists=True, file_okay=True, dir_okay=False), ) From b1f0b54e9356550de4d65242247008fdf9572e97 Mon Sep 17 00:00:00 2001 From: sarthakpati Date: Thu, 19 Dec 2024 13:47:31 -0500 Subject: [PATCH 07/11] updated ncc metrics --- GANDLF/cli/generate_metrics.py | 20 ++-------- GANDLF/metrics/__init__.py | 5 +-- GANDLF/metrics/synthesis.py | 72 +++++++--------------------------- 3 files changed, 20 insertions(+), 77 deletions(-) diff --git a/GANDLF/cli/generate_metrics.py b/GANDLF/cli/generate_metrics.py index fc0b25ff1..e28ece92b 100644 --- a/GANDLF/cli/generate_metrics.py +++ b/GANDLF/cli/generate_metrics.py @@ -18,10 +18,7 @@ peak_signal_noise_ratio, mean_squared_log_error, mean_absolute_error, - ncc_mean, - ncc_std, - ncc_max, - ncc_min, + ncc_metrics, ) from GANDLF.losses.segmentation import dice from GANDLF.metrics.segmentation import ( @@ -375,18 +372,9 @@ def __percentile_clip( # ncc metrics compute_ncc = parameters.get("compute_ncc", True) if compute_ncc: - overall_stats_dict[current_subject_id]["ncc_mean"] = ncc_mean( - output_infill, gt_image_infill - ) - overall_stats_dict[current_subject_id]["ncc_std"] = ncc_std( - output_infill, gt_image_infill - ) - overall_stats_dict[current_subject_id]["ncc_max"] = ncc_max( - output_infill, gt_image_infill - ) - overall_stats_dict[current_subject_id]["ncc_min"] = ncc_min( - output_infill, gt_image_infill - ) + calculated_ncc_metrics = ncc_metrics(output_infill, gt_image_infill) + for key, value in calculated_ncc_metrics.items(): + overall_stats_dict[current_subject_id][key] = value.item() # only voxels that are to be inferred (-> flat array) # these are required for mse, psnr, etc. diff --git a/GANDLF/metrics/__init__.py b/GANDLF/metrics/__init__.py index 1fc21b3fb..08d1165c1 100644 --- a/GANDLF/metrics/__init__.py +++ b/GANDLF/metrics/__init__.py @@ -35,10 +35,7 @@ peak_signal_noise_ratio, mean_squared_log_error, mean_absolute_error, - ncc_mean, - ncc_std, - ncc_max, - ncc_min, + ncc_metrics, ) import GANDLF.metrics.classification as classification import GANDLF.metrics.regression as regression diff --git a/GANDLF/metrics/synthesis.py b/GANDLF/metrics/synthesis.py index f29933da7..a8db74bf1 100644 --- a/GANDLF/metrics/synthesis.py +++ b/GANDLF/metrics/synthesis.py @@ -163,69 +163,27 @@ def __convert_to_grayscale(image: sitk.Image) -> sitk.Image: return correlation_filter.Execute(target_image, pred_image) -def ncc_mean(prediction: torch.Tensor, target: torch.Tensor) -> float: +def ncc_metrics(prediction: torch.Tensor, target: torch.Tensor) -> dict: """ - Computes normalized cross correlation mean between target and prediction. + Computes normalized cross correlation metrics between target and prediction. Args: prediction (torch.Tensor): The prediction tensor. target (torch.Tensor): The target tensor. Returns: - float: The normalized cross correlation mean. + dict: The normalized cross correlation metrics. """ - stats_filter = sitk.StatisticsImageFilter() corr_image = _get_ncc_image(target, prediction) - stats_filter.Execute(corr_image) - return stats_filter.GetMean() - - -def ncc_std(prediction: torch.Tensor, target: torch.Tensor) -> float: - """ - Computes normalized cross correlation standard deviation between target and prediction. - - Args: - prediction (torch.Tensor): The prediction tensor. - target (torch.Tensor): The target tensor. - - Returns: - float: The normalized cross correlation standard deviation. - """ - stats_filter = sitk.StatisticsImageFilter() - corr_image = _get_ncc_image(target, prediction) - stats_filter.Execute(corr_image) - return stats_filter.GetSigma() - - -def ncc_max(prediction: torch.Tensor, target: torch.Tensor) -> float: - """ - Computes normalized cross correlation maximum between target and prediction. - - Args: - prediction (torch.Tensor): The prediction tensor. - target (torch.Tensor): The target tensor. - - Returns: - float: The normalized cross correlation maximum. - """ - stats_filter = sitk.StatisticsImageFilter() - corr_image = _get_ncc_image(target, prediction) - stats_filter.Execute(corr_image) - return stats_filter.GetMaximum() - - -def ncc_min(prediction: torch.Tensor, target: torch.Tensor) -> float: - """ - Computes normalized cross correlation minimum between target and prediction. - - Args: - prediction (torch.Tensor): The prediction tensor. - target (torch.Tensor): The target tensor. - - Returns: - float: The normalized cross correlation minimum. - """ - stats_filter = sitk.StatisticsImageFilter() - corr_image = _get_ncc_image(target, prediction) - stats_filter.Execute(corr_image) - return stats_filter.GetMinimum() + stats_filter = sitk.LabelStatisticsImageFilter() + stats_filter.UseHistogramsOn() + # ensure that we are not considering zeros + onesImage = corr_image == corr_image + stats_filter.Execute(corr_image, onesImage) + return { + "mean": stats_filter.GetMean(1), + "std": stats_filter.GetSigma(1), + "max": stats_filter.GetMaximum(1), + "min": stats_filter.GetMinimum(1), + "median": stats_filter.GetMedian(1), + } From c12c37d429b585896b9cc10a56798b70480d80cf Mon Sep 17 00:00:00 2001 From: sarthakpati Date: Thu, 19 Dec 2024 13:52:41 -0500 Subject: [PATCH 08/11] addressing comment --- GANDLF/cli/generate_metrics.py | 5 +++-- GANDLF/metrics/__init__.py | 1 + GANDLF/metrics/synthesis.py | 22 ++++++++++++++++++---- 3 files changed, 22 insertions(+), 6 deletions(-) diff --git a/GANDLF/cli/generate_metrics.py b/GANDLF/cli/generate_metrics.py index e28ece92b..72b572e4c 100644 --- a/GANDLF/cli/generate_metrics.py +++ b/GANDLF/cli/generate_metrics.py @@ -15,6 +15,7 @@ overall_stats, structural_similarity_index, mean_squared_error, + root_mean_squared_error, peak_signal_noise_ratio, mean_squared_log_error, mean_absolute_error, @@ -385,8 +386,8 @@ def __percentile_clip( output_infill, gt_image_infill ).item() - overall_stats_dict[current_subject_id]["rmse"] = mean_squared_error( - output_infill, gt_image_infill, squared=False + overall_stats_dict[current_subject_id]["rmse"] = root_mean_squared_error( + output_infill, gt_image_infill ).item() overall_stats_dict[current_subject_id]["msle"] = mean_squared_log_error( diff --git a/GANDLF/metrics/__init__.py b/GANDLF/metrics/__init__.py index 08d1165c1..824d38512 100644 --- a/GANDLF/metrics/__init__.py +++ b/GANDLF/metrics/__init__.py @@ -32,6 +32,7 @@ from .synthesis import ( structural_similarity_index, mean_squared_error, + root_mean_squared_error, peak_signal_noise_ratio, mean_squared_log_error, mean_absolute_error, diff --git a/GANDLF/metrics/synthesis.py b/GANDLF/metrics/synthesis.py index a8db74bf1..973ca6ff8 100644 --- a/GANDLF/metrics/synthesis.py +++ b/GANDLF/metrics/synthesis.py @@ -39,8 +39,23 @@ def structural_similarity_index( return ssim_idx.mean() -def mean_squared_error( - prediction: torch.Tensor, target: torch.Tensor, squared: bool = True +def mean_squared_error(prediction: torch.Tensor, target: torch.Tensor) -> torch.Tensor: + """ + Computes the mean squared error between the target and prediction. + + Args: + prediction (torch.Tensor): The prediction tensor. + target (torch.Tensor): The target tensor. + + Returns: + torch.Tensor: The mean squared error or its square root. + """ + mse = MeanSquaredError(squared=True) + return mse(preds=prediction, target=target) + + +def root_mean_squared_error( + prediction: torch.Tensor, target: torch.Tensor ) -> torch.Tensor: """ Computes the mean squared error between the target and prediction. @@ -48,12 +63,11 @@ def mean_squared_error( Args: prediction (torch.Tensor): The prediction tensor. target (torch.Tensor): The target tensor. - squared (bool, optional): Whether to return squared error. Defaults to True. Returns: torch.Tensor: The mean squared error or its square root. """ - mse = MeanSquaredError(squared=squared) + mse = MeanSquaredError(squared=False) return mse(preds=prediction, target=target) From 08dc42a6068d8117df1b7116c28eace7577b3bd9 Mon Sep 17 00:00:00 2001 From: sarthakpati Date: Thu, 19 Dec 2024 13:53:51 -0500 Subject: [PATCH 09/11] ensure `ncc` gets picked up correctly --- GANDLF/cli/generate_metrics.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/GANDLF/cli/generate_metrics.py b/GANDLF/cli/generate_metrics.py index 72b572e4c..591b1d0ff 100644 --- a/GANDLF/cli/generate_metrics.py +++ b/GANDLF/cli/generate_metrics.py @@ -375,7 +375,7 @@ def __percentile_clip( if compute_ncc: calculated_ncc_metrics = ncc_metrics(output_infill, gt_image_infill) for key, value in calculated_ncc_metrics.items(): - overall_stats_dict[current_subject_id][key] = value.item() + overall_stats_dict[current_subject_id][f"ncc_{key}"] = value.item() # only voxels that are to be inferred (-> flat array) # these are required for mse, psnr, etc. From 8b9a797418e98e73d32599b5e460cb557869edcc Mon Sep 17 00:00:00 2001 From: sarthakpati Date: Thu, 19 Dec 2024 14:01:23 -0500 Subject: [PATCH 10/11] putting the `ncc` in the metric calculation itself for clarity --- GANDLF/cli/generate_metrics.py | 2 +- GANDLF/metrics/synthesis.py | 10 +++++----- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/GANDLF/cli/generate_metrics.py b/GANDLF/cli/generate_metrics.py index 591b1d0ff..72b572e4c 100644 --- a/GANDLF/cli/generate_metrics.py +++ b/GANDLF/cli/generate_metrics.py @@ -375,7 +375,7 @@ def __percentile_clip( if compute_ncc: calculated_ncc_metrics = ncc_metrics(output_infill, gt_image_infill) for key, value in calculated_ncc_metrics.items(): - overall_stats_dict[current_subject_id][f"ncc_{key}"] = value.item() + overall_stats_dict[current_subject_id][key] = value.item() # only voxels that are to be inferred (-> flat array) # these are required for mse, psnr, etc. diff --git a/GANDLF/metrics/synthesis.py b/GANDLF/metrics/synthesis.py index 973ca6ff8..a7d6a4523 100644 --- a/GANDLF/metrics/synthesis.py +++ b/GANDLF/metrics/synthesis.py @@ -195,9 +195,9 @@ def ncc_metrics(prediction: torch.Tensor, target: torch.Tensor) -> dict: onesImage = corr_image == corr_image stats_filter.Execute(corr_image, onesImage) return { - "mean": stats_filter.GetMean(1), - "std": stats_filter.GetSigma(1), - "max": stats_filter.GetMaximum(1), - "min": stats_filter.GetMinimum(1), - "median": stats_filter.GetMedian(1), + "ncc_mean": stats_filter.GetMean(1), + "ncc_std": stats_filter.GetSigma(1), + "ncc_max": stats_filter.GetMaximum(1), + "ncc_min": stats_filter.GetMinimum(1), + "ncc_median": stats_filter.GetMedian(1), } From 1e9447cf10c8aade588fa40cc7a4f21c6f3232a4 Mon Sep 17 00:00:00 2001 From: sarthakpati Date: Thu, 19 Dec 2024 14:42:55 -0500 Subject: [PATCH 11/11] fixing error --- GANDLF/cli/generate_metrics.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/GANDLF/cli/generate_metrics.py b/GANDLF/cli/generate_metrics.py index 72b572e4c..7e7fdda49 100644 --- a/GANDLF/cli/generate_metrics.py +++ b/GANDLF/cli/generate_metrics.py @@ -375,7 +375,8 @@ def __percentile_clip( if compute_ncc: calculated_ncc_metrics = ncc_metrics(output_infill, gt_image_infill) for key, value in calculated_ncc_metrics.items(): - overall_stats_dict[current_subject_id][key] = value.item() + # we don't need the ".item()" here, since the values are already scalars + overall_stats_dict[current_subject_id][key] = value # only voxels that are to be inferred (-> flat array) # these are required for mse, psnr, etc.