From 5a36616d39473253fdb4b7c2528c2ca09ec08da8 Mon Sep 17 00:00:00 2001 From: Felix Geyer Date: Thu, 23 Feb 2023 15:49:33 +0100 Subject: [PATCH] Outsource repeating preprocessing steps in `train_inspection` (#148) * Outsource some functionalities * Add docstrings * Rename function * Add changelog --- docs/changes/148.maintenance.rst | 1 + radionets/evaluation/train_inspection.py | 295 +++++------------------ radionets/evaluation/utils.py | 94 ++++++++ 3 files changed, 158 insertions(+), 232 deletions(-) create mode 100644 docs/changes/148.maintenance.rst diff --git a/docs/changes/148.maintenance.rst b/docs/changes/148.maintenance.rst new file mode 100644 index 00000000..09316f9b --- /dev/null +++ b/docs/changes/148.maintenance.rst @@ -0,0 +1 @@ +Outsource preprocessing steps in `train_inspection.py` diff --git a/radionets/evaluation/train_inspection.py b/radionets/evaluation/train_inspection.py index dc198baa..9172c2ed 100644 --- a/radionets/evaluation/train_inspection.py +++ b/radionets/evaluation/train_inspection.py @@ -1,51 +1,55 @@ +from pathlib import Path + import click -import torch import numpy as np -from pathlib import Path +import torch +import torch.nn.functional as F +from pytorch_msssim import ms_ssim +from tqdm import tqdm + from radionets.dl_framework.data import load_data +from radionets.evaluation.blob_detection import calc_blobs, crop_first_component +from radionets.evaluation.contour import area_of_contour +from radionets.evaluation.dynamic_range import calc_dr +from radionets.evaluation.jet_angle import calc_jet_angle from radionets.evaluation.plotting import ( - visualize_with_fourier_diff, - visualize_with_fourier, - plot_results, - visualize_source_reconstruction, - histogram_jet_angles, - histogram_dynamic_ranges, - histogram_ms_ssim, - histogram_mean_diff, + hist_point, histogram_area, + histogram_dynamic_ranges, histogram_gan_sources, + histogram_jet_angles, + histogram_mean_diff, + histogram_ms_ssim, plot_contour, - hist_point, plot_length_point, - visualize_uncertainty, + plot_results, visualize_sampled_unc, + visualize_source_reconstruction, + visualize_uncertainty, + visualize_with_fourier, + visualize_with_fourier_diff, ) +from radionets.evaluation.pointsources import flux_comparison from radionets.evaluation.utils import ( + apply_normalization, + apply_symmetry, create_databunch, create_sampled_databunch, - reshape_2d, - load_pretrained_model, - get_images, eval_model, get_ifft, - save_pred, + get_images, + load_pretrained_model, + mergeDictionary, + preprocessing, + process_prediction, read_pred, - apply_symmetry, + rescale_normalization, + reshape_2d, sample_images, - mergeDictionary, sampled_dataset, - apply_normalization, - rescale_normalization, + save_pred, + sym_new, ) -from radionets.evaluation.jet_angle import calc_jet_angle -from radionets.evaluation.dynamic_range import calc_dr -from radionets.evaluation.blob_detection import calc_blobs, crop_first_component -from radionets.evaluation.contour import area_of_contour -from radionets.evaluation.pointsources import flux_comparison -from pytorch_msssim import ms_ssim -from tqdm import tqdm -import torch.nn.functional as F -from radionets.evaluation.utils import sym_new def create_predictions(conf): @@ -378,38 +382,15 @@ def create_uncertainty_plots(conf, num_images=3, rand=False): def evaluate_viewing_angle(conf): - # create DataLoader - loader = create_databunch( - conf["data_path"], conf["fourier"], conf["source_list"], conf["batch_size"] - ) - model_path = conf["model_path"] - out_path = Path(model_path).parent / "evaluation" - out_path.mkdir(parents=True, exist_ok=True) - - img_size = loader.dataset[0][0][0].shape[-1] - model, norm_dict = load_pretrained_model( - conf["arch_name"], conf["model_path"], img_size - ) - if conf["model_path_2"] != "none": - model_2, norm_dict = load_pretrained_model( - conf["arch_name_2"], conf["model_path_2"], img_size - ) - + model, model_2, loader, norm_dict, out_path = preprocessing(conf) alpha_truths = [] alpha_preds = [] # iterate trough DataLoader for i, (img_test, img_true) in enumerate(tqdm(loader)): - img_test, norm_dict = apply_normalization(img_test, norm_dict) - pred = eval_model(img_test, model) - pred = rescale_normalization(pred, norm_dict) - if conf["model_path_2"] != "none": - pred_2 = eval_model(img_test, model_2) - pred_2 = rescale_normalization(pred_2, norm_dict) - pred = torch.cat((pred, pred_2), dim=1) - - ifft_truth = get_ifft(img_true, amp_phase=conf["amp_phase"]) - ifft_pred = get_ifft(pred, amp_phase=conf["amp_phase"]) + ifft_pred, ifft_truth = process_prediction( + conf, img_test, img_true, norm_dict, model, model_2 + ) m_truth, n_truth, alpha_truth = calc_jet_angle(torch.tensor(ifft_truth)) m_pred, n_pred, alpha_pred = calc_jet_angle(torch.tensor(ifft_pred)) @@ -431,38 +412,15 @@ def evaluate_viewing_angle(conf): def evaluate_dynamic_range(conf): - # create Dataloader - loader = create_databunch( - conf["data_path"], conf["fourier"], conf["source_list"], conf["batch_size"] - ) - model_path = conf["model_path"] - out_path = Path(model_path).parent / "evaluation" - out_path.mkdir(parents=True, exist_ok=True) - - img_size = loader.dataset[0][0][0].shape[-1] - model, norm_dict = load_pretrained_model( - conf["arch_name"], conf["model_path"], img_size - ) - if conf["model_path_2"] != "none": - model_2, norm_dict = load_pretrained_model( - conf["arch_name_2"], conf["model_path_2"], img_size - ) - + model, model_2, loader, norm_dict, out_path = preprocessing(conf) dr_truths = np.array([]) dr_preds = np.array([]) # iterate trough DataLoader for i, (img_test, img_true) in enumerate(tqdm(loader)): - img_test, norm_dict = apply_normalization(img_test, norm_dict) - pred = eval_model(img_test, model) - pred = rescale_normalization(pred, norm_dict) - if conf["model_path_2"] != "none": - pred_2 = eval_model(img_test, model_2) - pred_2 = rescale_normalization(pred_2, norm_dict) - pred = torch.cat((pred, pred_2), dim=1) - - ifft_truth = get_ifft(img_true, amp_phase=conf["amp_phase"]) - ifft_pred = get_ifft(pred, amp_phase=conf["amp_phase"]) + ifft_pred, ifft_truth = process_prediction( + conf, img_test, img_true, norm_dict, model, model_2 + ) dr_truth, dr_pred, _, _ = calc_dr(ifft_truth, ifft_pred) dr_truths = np.append(dr_truths, dr_truth) @@ -482,48 +440,18 @@ def evaluate_dynamic_range(conf): def evaluate_ms_ssim(conf): - # create DataLoader - loader = create_databunch( - conf["data_path"], conf["fourier"], conf["source_list"], conf["batch_size"] - ) - model_path = conf["model_path"] - out_path = Path(model_path).parent / "evaluation" - out_path.mkdir(parents=True, exist_ok=True) - - img_size = loader.dataset[0][0][0].shape[-1] - model, norm_dict = load_pretrained_model( - conf["arch_name"], conf["model_path"], img_size - ) - if conf["model_path_2"] != "none": - model_2, norm_dict = load_pretrained_model( - conf["arch_name_2"], conf["model_path_2"], img_size - ) - + model, model_2, loader, norm_dict, out_path = preprocessing(conf) vals = [] # iterate trough DataLoader for i, (img_test, img_true) in enumerate(tqdm(loader)): - img_test, norm_dict = apply_normalization(img_test, norm_dict) - pred = eval_model(img_test, model) - pred = rescale_normalization(pred, norm_dict) - if conf["model_path_2"] != "none": - pred_2 = eval_model(img_test, model_2) - pred_2 = rescale_normalization(pred_2, norm_dict) - pred = torch.cat((pred, pred_2), dim=1) - - # apply symmetry - if pred.shape[-1] == 128: - img_dict = {"truth": img_true, "pred": pred} - img_dict = apply_symmetry(img_dict) - img_true = img_dict["truth"] - pred = img_dict["pred"] - - ifft_truth = torch.tensor(get_ifft(img_true, amp_phase=conf["amp_phase"])) - ifft_pred = torch.tensor(get_ifft(pred, amp_phase=conf["amp_phase"])) + ifft_pred, ifft_truth = process_prediction( + conf, img_test, img_true, norm_dict, model, model_2 + ) val = ms_ssim( - ifft_pred.unsqueeze(1), - ifft_truth.unsqueeze(1), + torch.tensor(ifft_pred).unsqueeze(1), + torch.tensor(ifft_truth).unsqueeze(1), data_range=1, win_size=7, size_average=False, @@ -538,37 +466,14 @@ def evaluate_ms_ssim(conf): def evaluate_mean_diff(conf): - # create DataLoader - loader = create_databunch( - conf["data_path"], conf["fourier"], conf["source_list"], conf["batch_size"] - ) - model_path = conf["model_path"] - out_path = Path(model_path).parent / "evaluation" - out_path.mkdir(parents=True, exist_ok=True) - - img_size = loader.dataset[0][0][0].shape[-1] - model, norm_dict = load_pretrained_model( - conf["arch_name"], conf["model_path"], img_size - ) - if conf["model_path_2"] != "none": - model_2, norm_dict = load_pretrained_model( - conf["arch_name_2"], conf["model_path_2"], img_size - ) - + model, model_2, loader, norm_dict, out_path = preprocessing(conf) vals = [] # iterate trough DataLoader for i, (img_test, img_true) in enumerate(tqdm(loader)): - img_test, norm_dict = apply_normalization(img_test, norm_dict) - pred = eval_model(img_test, model) - pred = rescale_normalization(pred, norm_dict) - if conf["model_path_2"] != "none": - pred_2 = eval_model(img_test, model_2) - pred_2 = rescale_normalization(pred_2, norm_dict) - pred = torch.cat((pred, pred_2), dim=1) - - ifft_truth = get_ifft(img_true, amp_phase=conf["amp_phase"]) - ifft_pred = get_ifft(pred, amp_phase=conf["amp_phase"]) + ifft_pred, ifft_truth = process_prediction( + conf, img_test, img_true, norm_dict, model, model_2 + ) for pred, truth in zip(ifft_pred, ifft_truth): blobs_pred, blobs_truth = calc_blobs(pred, truth) @@ -711,44 +616,14 @@ def evaluate_area_sampled(conf): def evaluate_area(conf): - # create DataLoader - loader = create_databunch( - conf["data_path"], conf["fourier"], conf["source_list"], conf["batch_size"] - ) - model_path = conf["model_path"] - out_path = Path(model_path).parent / "evaluation" - out_path.mkdir(parents=True, exist_ok=True) - - img_size = loader.dataset[0][0][0].shape[-1] - model, norm_dict = load_pretrained_model( - conf["arch_name"], conf["model_path"], img_size - ) - if conf["model_path_2"] != "none": - model_2, norm_dict = load_pretrained_model( - conf["arch_name_2"], conf["model_path_2"], img_size - ) - + model, model_2, loader, norm_dict, out_path = preprocessing(conf) vals = [] # iterate trough DataLoader for i, (img_test, img_true) in enumerate(tqdm(loader)): - img_test, norm_dict = apply_normalization(img_test, norm_dict) - pred = eval_model(img_test, model) - pred = rescale_normalization(pred, norm_dict) - if conf["model_path_2"] != "none": - pred_2 = eval_model(img_test, model_2) - pred_2 = rescale_normalization(pred_2, norm_dict) - pred = torch.cat((pred, pred_2), dim=1) - - # apply symmetry - if pred.shape[-1] == 128: - img_dict = {"truth": img_true, "pred": pred} - img_dict = apply_symmetry(img_dict) - img_true = img_dict["truth"] - pred = img_dict["pred"] - - ifft_truth = get_ifft(img_true, amp_phase=conf["amp_phase"]) - ifft_pred = get_ifft(pred, amp_phase=conf["amp_phase"]) + ifft_pred, ifft_truth = process_prediction( + conf, img_test, img_true, norm_dict, model, model_2 + ) for pred, truth in zip(ifft_pred, ifft_truth): val = area_of_contour(pred, truth) @@ -768,37 +643,14 @@ def evaluate_area(conf): def evaluate_point(conf): - # create DataLoader - loader = create_databunch( - conf["data_path"], conf["fourier"], conf["source_list"], conf["batch_size"] - ) - model_path = conf["model_path"] - out_path = Path(model_path).parent / "evaluation" - out_path.mkdir(parents=True, exist_ok=True) - - img_size = loader.dataset[0][0][0].shape[-1] - model, norm_dict = load_pretrained_model( - conf["arch_name"], conf["model_path"], img_size - ) - if conf["model_path_2"] != "none": - model_2, norm_dict = load_pretrained_model( - conf["arch_name_2"], conf["model_path_2"], img_size - ) - + model, model_2, loader, norm_dict, out_path = preprocessing(conf) vals = [] lengths = [] for i, (img_test, img_true, source_list) in enumerate(tqdm(loader)): - img_test, norm_dict = apply_normalization(img_test, norm_dict) - pred = eval_model(img_test, model) - pred = rescale_normalization(pred, norm_dict) - if conf["model_path_2"] != "none": - pred_2 = eval_model(img_test, model_2) - pred_2 = rescale_normalization(pred_2, norm_dict) - pred = torch.cat((pred, pred_2), dim=1) - - ifft_truth = get_ifft(img_true, amp_phase=conf["amp_phase"]) - ifft_pred = get_ifft(pred, amp_phase=conf["amp_phase"]) + ifft_pred, ifft_truth = process_prediction( + conf, img_test, img_true, norm_dict, model, model_2 + ) fluxes_pred, fluxes_truth, length = flux_comparison( ifft_pred, ifft_truth, source_list @@ -819,23 +671,9 @@ def evaluate_point(conf): def evaluate_gan_sources(conf): - # create DataLoader - loader = create_databunch( - conf["data_path"], conf["fourier"], conf["source_list"], conf["batch_size"] - ) - model_path = conf["model_path"] - out_path = Path(model_path).parent / "evaluation" - out_path.mkdir(parents=True, exist_ok=True) + model, model_2, loader, norm_dict, out_path = preprocessing(conf) img_size = loader.dataset[0][0][0].shape[-1] - model, norm_dict = load_pretrained_model( - conf["arch_name"], conf["model_path"], img_size - ) - if conf["model_path_2"] != "none": - model_2, norm_dict = load_pretrained_model( - conf["arch_name_2"], conf["model_path_2"], img_size - ) - ratios = [] num_zeros = [] above_zeros = [] @@ -843,16 +681,9 @@ def evaluate_gan_sources(conf): atols = [1e-4, 1e-3, 1e-2, 1e-1] for i, (img_test, img_true) in enumerate(tqdm(loader)): - img_test, norm_dict = apply_normalization(img_test, norm_dict) - pred = eval_model(img_test, model) - pred = rescale_normalization(pred, norm_dict) - if conf["model_path_2"] != "none": - pred_2 = eval_model(img_test, model_2) - pred_2 = rescale_normalization(pred_2, norm_dict) - pred = torch.cat((pred, pred_2), dim=1) - - ifft_truth = get_ifft(img_true, amp_phase=conf["amp_phase"]) - ifft_pred = get_ifft(pred, amp_phase=conf["amp_phase"]) + ifft_pred, ifft_truth = process_prediction( + conf, img_test, img_true, norm_dict, model, model_2 + ) diff = ifft_pred - ifft_truth diff --git a/radionets/evaluation/utils.py b/radionets/evaluation/utils.py index 16874ad1..59ff5d65 100644 --- a/radionets/evaluation/utils.py +++ b/radionets/evaluation/utils.py @@ -767,3 +767,97 @@ def rescale_normalization(pred, norm_dict): pred[:, 1] *= norm_dict["max_factors_imag"] return pred + + +def preprocessing(conf): + """ + Makes the necessary preprocessing for the evaluation methods analyzing the whole + test dataset + + Parameters + ---------- + conf : dictionary + config file containing the settings + + Returns + ------- + model : architecture + model initialized with save file + model_2 : architecture + model initialized with save file + loader : torch.Dataloader + feeds the data batch-wise + norm_dict : dictionary + dict containing the normalization factors + out_path : Path object + path to the evaluation folder + """ + # create DataLoader + loader = create_databunch( + conf["data_path"], conf["fourier"], conf["source_list"], conf["batch_size"] + ) + model_path = conf["model_path"] + out_path = Path(model_path).parent / "evaluation" + out_path.mkdir(parents=True, exist_ok=True) + + img_size = loader.dataset[0][0][0].shape[-1] + model, norm_dict = load_pretrained_model( + conf["arch_name"], conf["model_path"], img_size + ) + + # Loads second model if the two channels were trainined separately + model_2 = None + if conf["model_path_2"] != "none": + model_2, norm_dict = load_pretrained_model( + conf["arch_name_2"], conf["model_path_2"], img_size + ) + + return model, model_2, loader, norm_dict, out_path + + +def process_prediction(conf, img_test, img_true, norm_dict, model, model_2): + """ + Applies the normalization, gets and rescales a prediction and performs + the inverse Fourier transformation. + + Parameters + ---------- + conf : dictionary + config files containing the settings + img_test : torch.Tensor + input file for the network + img_true : torch.tensor + true image + norm_dict : dictionary + dict containing the normalization factors + model : architecture + model initialized with save file + model_2 : + model initialized with save file + + Returns + ------- + ifft_pred : ndarray + predicted source in image space + ifft_truth : ndarray + true source in image space + """ + img_test, norm_dict = apply_normalization(img_test, norm_dict) + pred = eval_model(img_test, model) + pred = rescale_normalization(pred, norm_dict) + if model_2 is not None: + pred_2 = eval_model(img_test, model_2) + pred_2 = rescale_normalization(pred_2, norm_dict) + pred = torch.cat((pred, pred_2), dim=1) + + # apply symmetry + if pred.shape[-1] == 128: + img_dict = {"truth": img_true, "pred": pred} + img_dict = apply_symmetry(img_dict) + img_true = img_dict["truth"] + pred = img_dict["pred"] + + ifft_truth = get_ifft(img_true, amp_phase=conf["amp_phase"]) + ifft_pred = get_ifft(pred, amp_phase=conf["amp_phase"]) + + return ifft_pred, ifft_truth