Skip to content

Commit

Permalink
Fix mean normalization method (#149)
Browse files Browse the repository at this point in the history
* Just normalize known pixels

* Add mask to evaluation

* Fix typo and add newlines

* Add changelog
  • Loading branch information
FeGeyer authored Feb 23, 2023
1 parent 3965057 commit 8bdea11
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 27 deletions.
2 changes: 2 additions & 0 deletions docs/changes/149.bugfix.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
- calculate nomalization only on non-zero pixels
- fix typo in rescaling operation
36 changes: 22 additions & 14 deletions radionets/dl_framework/callbacks.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,23 @@
import torch
import numpy as np
import kornia as K
from radionets.dl_framework.model import save_model
from radionets.dl_framework.utils import _maybe_item
from fastai.callback.core import Callback, CancelBackwardException
from pathlib import Path

import kornia as K
import matplotlib.pyplot as plt
import numpy as np
import torch
from fastai.callback.core import Callback, CancelBackwardException

from radionets.dl_framework.model import save_model
from radionets.dl_framework.utils import _maybe_item, get_ifft_torch
from radionets.evaluation.plotting import create_OrBu
from radionets.evaluation.utils import (
load_data,
get_images,
eval_model,
make_axes_nice,
check_vmin_vmax,
eval_model,
get_ifft,
get_images,
load_data,
load_pretrained_model,
make_axes_nice,
)
from radionets.evaluation.plotting import create_OrBu
from radionets.dl_framework.utils import get_ifft_torch

OrBu = create_OrBu()

Expand Down Expand Up @@ -217,9 +218,16 @@ def before_batch(self):
x[:, 1] *= 1 / torch.amax(torch.abs(x[:, 1]), dim=(-2, -1), keepdim=True)
y[:, 0] *= 1 / torch.amax(x[:, 0], dim=(-2, -1), keepdim=True)
y[:, 1] *= 1 / torch.amax(torch.abs(x[:, 1]), dim=(-2, -1), keepdim=True)

elif self.mode == "mean":
x[:, 0] = self.normalize(x[:, 0], self.mean_real, self.std_real)
x[:, 1] = self.normalize(x[:, 1], self.mean_imag, self.std_imag)
x[:, 0][x[:, 0] != 0] = self.normalize(
x[:, 0][x[:, 0] != 0], self.mean_real, self.std_real
)

x[:, 1][x[:, 1] != 0] = self.normalize(
x[:, 1][x[:, 1] != 0], self.mean_imag, self.std_imag
)

y[:, 0] = self.normalize(y[:, 0], self.mean_real, self.std_real)
y[:, 1] = self.normalize(y[:, 1], self.mean_imag, self.std_imag)

Expand Down
29 changes: 16 additions & 13 deletions radionets/evaluation/utils.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
from pathlib import Path

import h5py
import numpy as np
from radionets.dl_framework.model import load_pre_model
from radionets.dl_framework.data import load_data
import radionets.dl_framework.architecture as architecture
import torch
import torch.nn.functional as F
from numba import set_num_threads, vectorize
from torch.utils.data import DataLoader
import h5py
from numba import vectorize, set_num_threads
from pathlib import Path

import radionets.dl_framework.architecture as architecture
from radionets.dl_framework.data import load_data
from radionets.dl_framework.model import load_pre_model


def source_list_collate(batch):
Expand Down Expand Up @@ -717,12 +719,13 @@ def apply_normalization(img_test, norm_dict):
updated dictionary
"""
if norm_dict and "mean_real" in norm_dict:
img_test[:, 0] = (img_test[:, 0] - norm_dict["mean_real"]) / norm_dict[
"std_real"
]
img_test[:, 1] = (img_test[:, 1] - norm_dict["mean_imag"]) / norm_dict[
"std_imag"
]
img_test[:, 0][img_test[:, 0] != 0] = (
img_test[:, 0][img_test[:, 0] != 0] - norm_dict["mean_real"]
) / norm_dict["std_real"]

img_test[:, 1][img_test[:, 1] != 0] = (
img_test[:, 1][img_test[:, 1] != 0] - norm_dict["mean_imag"]
) / norm_dict["std_imag"]

elif norm_dict and "max_scaling" in norm_dict:
max_factors_real = torch.amax(img_test[:, 0], dim=(-2, -1), keepdim=True)
Expand Down Expand Up @@ -757,7 +760,7 @@ def rescale_normalization(pred, norm_dict):
"""
if norm_dict and "mean_real" in norm_dict:
pred[:, 0] = pred[:, 0] * norm_dict["std_real"] + norm_dict["mean_real"]
pred[:, 0] = pred[:, 0] * norm_dict["std_imag"] + norm_dict["mean_imag"]
pred[:, 1] = pred[:, 1] * norm_dict["std_imag"] + norm_dict["mean_imag"]

elif norm_dict and "max_scaling" in norm_dict:
pred[:, 0] *= norm_dict["max_factors_real"]
Expand Down

0 comments on commit 8bdea11

Please sign in to comment.