Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update framework for the use of real/imag #146

Merged
merged 29 commits into from
Feb 21, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
5516084
Modify functions to work for real_imag
FeGeyer Feb 6, 2023
26ef7a4
Dont use hardtanh anymore
FeGeyer Feb 6, 2023
bdf888f
Fix some tests
FeGeyer Feb 7, 2023
c6c96ef
Add area sampled to script
FeGeyer Feb 8, 2023
afeb848
Change plotting for real imag
FeGeyer Feb 15, 2023
1a56897
Added creation and saving of normalization factors
FeGeyer Feb 15, 2023
b8d5791
Execute normalize only for training
FeGeyer Feb 15, 2023
7fe50b6
Add model_name to prediction file
FeGeyer Feb 16, 2023
fcb54c8
Revert use of sampling functions
FeGeyer Feb 16, 2023
a2cc2ce
Update plot titles
FeGeyer Feb 16, 2023
4c91914
Fix tests
FeGeyer Feb 16, 2023
d66a760
Update evaluation functions for half images
FeGeyer Feb 16, 2023
03909b0
Add docstring
FeGeyer Feb 16, 2023
b695f24
Fix ms ssim plotting
FeGeyer Feb 17, 2023
cecc575
Add normalization for area and ms ssim
FeGeyer Feb 17, 2023
d6af92b
Implement another normalize method
FeGeyer Feb 17, 2023
9ba4a82
Add keyword
FeGeyer Feb 17, 2023
31ba0e6
Revert changes
FeGeyer Feb 17, 2023
3694da1
Add correct keyword
FeGeyer Feb 17, 2023
a27c2e3
Fix saving
FeGeyer Feb 17, 2023
b88de69
Add img_size to tests
FeGeyer Feb 17, 2023
aec2fc3
Fix deprecation warning
FeGeyer Feb 17, 2023
6df9175
Fix for the tests to run
FeGeyer Feb 17, 2023
a1df787
Revert hardcoded inverse normalisation
FeGeyer Feb 17, 2023
edd9900
Add docu
FeGeyer Feb 17, 2023
3aaa208
Fix saving
FeGeyer Feb 20, 2023
115a7c8
Delete remaining check_vmin_vmax
FeGeyer Feb 20, 2023
51e09aa
Add loading of normalization factors
FeGeyer Feb 20, 2023
ef46a80
Add rescaling
FeGeyer Feb 21, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions docs/changes/146.feature.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
- Add normalization callback with two different techniques
- Update plotting routines for real/imag images
- Update evaluate_area and evaluate_ms_ssim for half images
- Add evaluate_ms_ssim for sampled images
5 changes: 5 additions & 0 deletions docs/changes/146.maintenance.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
- Add the model name to predictions and sampling file
- Delete unnecessary pad_unsqueeze function
- Add amp_phase keyword to sample_images
- Fix deprecation warning in sampling.py
- Add image size to test_evaluation.py routines
1 change: 1 addition & 0 deletions examples/default_train_config.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ norm_path = "none"
[general]
fourier = true
amp_phase = true
normalize = false
source_list = false
arch_name = "filter_deep"
loss_func = "splitted_L1"
Expand Down
4 changes: 2 additions & 2 deletions radionets/dl_framework/architectures/res_exp.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,8 +166,8 @@ def forward(self, x):
x = self.final(x)

x0 = x[:, 0].reshape(-1, 1, s // 2 + 1, s)
x0 = self.relu(x0)
x1 = self.hardtanh(x[:, 1]).reshape(-1, 1, s // 2 + 1, s)
# x0 = self.relu(x0)
x1 = x[:, 1].reshape(-1, 1, s // 2 + 1, s)

return torch.cat([x0, x1], dim=1)

Expand Down
33 changes: 33 additions & 0 deletions radionets/dl_framework/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,39 @@ def before_batch(self):
self.learn.yb = [y]


class Normalize(Callback):
_order = 4

def __init__(self, conf):
self.mode = conf["normalize"]
if self.mode == "mean":
self.mean_real = conf["norm_factors"]["mean_real"]
self.mean_imag = conf["norm_factors"]["mean_imag"]
self.std_real = conf["norm_factors"]["std_real"]
self.std_imag = conf["norm_factors"]["std_imag"]

def normalize(self, x, m, s):
return (x - m) / s

def before_batch(self):
x = self.xb[0].clone()
y = self.yb[0].clone()

if self.mode == "max":
x[:, 0] *= 1 / torch.amax(x[:, 0], dim=(-2, -1), keepdim=True)
x[:, 1] *= 1 / torch.amax(torch.abs(x[:, 1]), dim=(-2, -1), keepdim=True)
y[:, 0] *= 1 / torch.amax(x[:, 0], dim=(-2, -1), keepdim=True)
y[:, 1] *= 1 / torch.amax(torch.abs(x[:, 1]), dim=(-2, -1), keepdim=True)
elif self.mode == "mean":
x[:, 0] = self.normalize(x[:, 0], self.mean_real, self.std_real)
x[:, 1] = self.normalize(x[:, 1], self.mean_imag, self.std_imag)
y[:, 0] = self.normalize(y[:, 0], self.mean_real, self.std_real)
y[:, 1] = self.normalize(y[:, 1], self.mean_imag, self.std_imag)

self.learn.xb = [x]
self.learn.yb = [y]


class SaveTempCallback(Callback):
_order = 95

Expand Down
3 changes: 3 additions & 0 deletions radionets/dl_framework/learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
SwitchLoss,
CudaCallback,
CometCallback,
Normalize,
)
from fastai.optimizer import Adam
from fastai.learner import Learner
Expand Down Expand Up @@ -77,6 +78,8 @@ def define_learner(data, arch, train_conf, lr_find=False, plot_loss=False):
]
)

if not lr_find and not plot_loss and train_conf["normalize"] != "none":
cbfs.extend([Normalize(train_conf)])
# get loss func
if train_conf["loss_func"] == "feature_loss":
loss_func = loss_functions.init_feature_loss()
Expand Down
16 changes: 16 additions & 0 deletions radionets/dl_framework/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,8 @@ def load_pre_model(learn, pre_path, visualize=False, plot_loss=False):

if visualize:
learn.load_state_dict(checkpoint["model"])
if "norm_dict" in checkpoint:
return checkpoint["norm_dict"]
elif plot_loss:
learn.avg_loss.loss_train = checkpoint["train_loss"]
learn.avg_loss.loss_valid = checkpoint["valid_loss"]
Expand All @@ -171,6 +173,19 @@ def load_pre_model(learn, pre_path, visualize=False, plot_loss=False):


def save_model(learn, model_path):
if hasattr(learn, "normalize"):
if hasattr(learn.normalize, "mean_real"):
norm_dict = {
"mean_real": learn.normalize.mean_real,
"mean_imag": learn.normalize.mean_imag,
"std_real": learn.normalize.std_real,
"std_imag": learn.normalize.std_imag,
}
else:
norm_dict = {"max_scaling": 0}
else:
norm_dict = {}

torch.save(
{
"model": learn.model.state_dict(),
Expand All @@ -181,6 +196,7 @@ def save_model(learn, model_path):
"train_loss": learn.avg_loss.loss_train,
"valid_loss": learn.avg_loss.loss_valid,
"lrs": learn.avg_loss.lrs,
"norm_dict": norm_dict,
},
model_path,
)
Expand Down
3 changes: 3 additions & 0 deletions radionets/dl_training/scripts/start_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
define_arch,
pop_interrupt,
end_training,
get_normalisation_factors,
)
from radionets.dl_framework.learner import define_learner
from radionets.dl_framework.model import load_pre_model
Expand Down Expand Up @@ -67,6 +68,8 @@ def main(configuration_path, mode):
)

if mode == "train":
if train_conf["normalize"] == "mean":
train_conf["norm_factors"] = get_normalisation_factors(data)
# check out path and look for existing model files
check_outpath(train_conf["model_path"], train_conf)

Expand Down
34 changes: 34 additions & 0 deletions radionets/dl_training/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import sys
import torch
from tqdm import tqdm
import click
from pathlib import Path
from radionets.dl_framework.data import load_data, DataBunch, get_dls
Expand Down Expand Up @@ -38,6 +40,7 @@ def read_config(config):

train_conf["fourier"] = config["general"]["fourier"]
train_conf["amp_phase"] = config["general"]["amp_phase"]
train_conf["normalize"] = config["general"]["normalize"]
train_conf["arch_name"] = config["general"]["arch_name"]
train_conf["loss_func"] = config["general"]["loss_func"]
train_conf["num_epochs"] = config["general"]["num_epochs"]
Expand Down Expand Up @@ -108,3 +111,34 @@ def end_training(learn, train_conf):

# Plot loss
plot_loss(learn, Path(train_conf["model_path"]))


def get_normalisation_factors(data):
mean_real = []
mean_imag = []
std_real = []
std_imag = []

for inp, true in tqdm(data.train_ds):
mean_batch_imag = inp[1].mean()
mean_batch_real = inp[0].mean()
std_batch_imag = inp[1].std()
std_batch_real = inp[0].std()
mean_real.append(mean_batch_real)
mean_imag.append(mean_batch_imag)
std_real.append(std_batch_real)
std_imag.append(std_batch_imag)

mean_real = torch.tensor(mean_real).mean()
mean_imag = torch.tensor(mean_imag).mean()
std_real = torch.tensor(std_real).std()
std_imag = torch.tensor(std_imag).std()

norm_factors = {
"mean_real": mean_real,
"mean_imag": mean_imag,
"std_real": std_real,
"std_imag": std_imag,
}

return norm_factors
63 changes: 42 additions & 21 deletions radionets/evaluation/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
from radionets.evaluation.utils import (
check_vmin_vmax,
make_axes_nice,
pad_unsqueeze,
reshape_2d,
)
from radionets.simulations.utils import adjust_outpath
Expand Down Expand Up @@ -235,28 +234,25 @@ def visualize_with_fourier(
im6 = ax6.imshow(imag_truth, cmap="RdBu", vmin=-np.pi, vmax=np.pi)
make_axes_nice(fig, ax6, im6, r"Phase Truth", phase=True)
else:
a = check_vmin_vmax(inp_real)
im1 = ax1.imshow(inp_real, cmap="RdBu", vmin=-a, vmax=a)
im1 = ax1.imshow(inp_real, cmap="RdBu")
make_axes_nice(fig, ax1, im1, r"Real Input")

a = check_vmin_vmax(real_truth)
im2 = ax2.imshow(real_pred, cmap="RdBu", vmin=-a, vmax=a)
im2 = ax2.imshow(real_pred, cmap="RdBu")
make_axes_nice(fig, ax2, im2, r"Real Prediction")

a = check_vmin_vmax(real_truth)
im3 = ax3.imshow(real_truth, cmap="RdBu", vmin=-a, vmax=a)
im3 = ax3.imshow(real_truth, cmap="RdBu")
make_axes_nice(fig, ax3, im3, r"Real Truth")

a = check_vmin_vmax(inp_imag)
im4 = ax4.imshow(inp_imag, cmap="RdBu", vmin=-a, vmax=a)
im4 = ax4.imshow(inp_imag, cmap="RdBu")
make_axes_nice(fig, ax4, im4, r"Imaginary Input")

a = check_vmin_vmax(imag_truth)
im5 = ax5.imshow(imag_pred, cmap="RdBu", vmin=-np.pi, vmax=np.pi)
im5 = ax5.imshow(imag_pred, cmap="RdBu")
make_axes_nice(fig, ax5, im5, r"Imaginary Prediction")

a = check_vmin_vmax(imag_truth)
im6 = ax6.imshow(imag_truth, cmap="RdBu", vmin=-np.pi, vmax=np.pi)
im6 = ax6.imshow(imag_truth, cmap="RdBu")
make_axes_nice(fig, ax6, im6, r"Imaginary Truth")

ax1.set_ylabel(r"Pixels")
Expand Down Expand Up @@ -316,6 +312,26 @@ def visualize_with_fourier_diff(
)
make_axes_nice(fig, ax6, im6, r"Phase Difference", phase_diff=True)

else:
im1 = ax1.imshow(real_pred, cmap="inferno")
make_axes_nice(fig, ax1, im1, r"Real Prediction")

im2 = ax2.imshow(real_truth, cmap="inferno")
make_axes_nice(fig, ax2, im2, "Real Truth")

a = check_vmin_vmax(real_pred - real_truth)
im3 = ax3.imshow(real_pred - real_truth, cmap=OrBu, vmin=-a, vmax=a)
make_axes_nice(fig, ax3, im3, r"Real Difference")

im4 = ax4.imshow(imag_pred, cmap=OrBu)
make_axes_nice(fig, ax4, im4, r"Imaginary Prediction")

im5 = ax5.imshow(imag_truth, cmap=OrBu)
make_axes_nice(fig, ax5, im5, r"Imaginary Truth")

im6 = ax6.imshow(imag_pred - imag_truth, cmap=OrBu)
make_axes_nice(fig, ax6, im6, r"Imaginary Difference")

ax1.set_ylabel(r"Pixels")
ax4.set_ylabel(r"Pixels")
ax4.set_xlabel(r"Pixels")
Expand Down Expand Up @@ -369,11 +385,16 @@ def visualize_source_reconstruction(
plot_box(ax2, num_boxes, corners[0])

if msssim:
ifft_truth = pad_unsqueeze(torch.tensor(ifft_truth).unsqueeze(0))
ifft_pred = pad_unsqueeze(torch.tensor(ifft_pred).unsqueeze(0))
val = ms_ssim(ifft_pred, ifft_truth, data_range=ifft_truth.max())

ax1.plot([], [], " ", label=f"ms ssim: {val:.2f}")
val = ms_ssim(
torch.tensor(ifft_pred).unsqueeze(0).unsqueeze(0),
torch.tensor(ifft_truth).unsqueeze(0).unsqueeze(0),
data_range=1,
win_size=7,
size_average=False,
)
val = val.numpy()[0]
ax1.plot([], [], " ", label=f"MS-SSIM: {val:.2f}")
ax1.legend(loc="best")

outpath = str(out_path) + f"/fft_pred_{i}.{plot_format}"

Expand Down Expand Up @@ -423,19 +444,19 @@ def visualize_uncertainty(
2, 2, sharey=True, sharex=True, figsize=(12, 10)
)

im1 = ax1.imshow(true_phase, cmap=OrBu, vmin=-np.pi, vmax=np.pi)
im1 = ax1.imshow(true_phase, cmap=OrBu)

im2 = ax2.imshow(pred_phase, cmap=OrBu, vmin=-np.pi, vmax=np.pi)
im2 = ax2.imshow(pred_phase, cmap=OrBu)

im3 = ax3.imshow(unc_phase)

a = check_vmin_vmax(true_phase - pred_phase)
im4 = ax4.imshow(true_phase - pred_phase, cmap=OrBu, vmin=-a, vmax=a)

make_axes_nice(fig, ax1, im1, r"Simulation", phase=True)
make_axes_nice(fig, ax2, im2, r"Predicted $\mu$", phase=True)
make_axes_nice(fig, ax1, im1, r"Simulation")
make_axes_nice(fig, ax2, im2, r"Predicted $\mu$")
make_axes_nice(fig, ax3, im3, r"Predicted $\sigma^2$", unc=True)
make_axes_nice(fig, ax4, im4, r"Difference", phase_diff=True)
make_axes_nice(fig, ax4, im4, r"Difference")

ax1.set_ylabel(r"pixels")
ax3.set_ylabel(r"pixels")
Expand Down Expand Up @@ -640,7 +661,7 @@ def histogram_ms_ssim(msssim, out_path, plot_format="png"):
fig, (ax1) = plt.subplots(1, figsize=(6, 4))
ax1.hist(
msssim.numpy(),
51,
80,
color="darkorange",
linewidth=3,
histtype="step",
Expand Down
Loading