forked from krasserm/super-resolution
-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathrun_LAPSRN.py
75 lines (54 loc) · 2.03 KB
/
run_LAPSRN.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
import os
import matplotlib.pyplot as plt
from data import DIV2K
from model.lapsrn import lapsrn
from train import LapsrnTrainer
# Number of residual blocks
depth = 16
# Super-resolution factor
scale = 4
# Downgrade operator
downgrade = 'bicubic'
# Location of model weights (needed for demo)
weights_dir = f'weights/lapsrn'
weights_file = os.path.join(weights_dir, 'lapsrn_weights.h5')
os.makedirs(weights_dir, exist_ok=True)
div2k_train = DIV2K(scale=scale, subset='train', downgrade=downgrade)
div2k_valid = DIV2K(scale=scale, subset='valid', downgrade=downgrade)
train_ds = div2k_train.dataset(batch_size=16, random_transform=True)
valid_ds = div2k_valid.dataset(batch_size=1, random_transform=False, repeat_count=1)
attention = True # change to false
trainer = LapsrnTrainer(model=lapsrn(attention=attention),
checkpoint_dir=f'.ckpt/lapsrn')
# Train LAPSRN model for 300,000 steps and evaluate model
# every 1000 steps on the first 10 images of the DIV2K
# validation set. Save a checkpoint only if evaluation
# PSNR has improved.
print("training LAPSRN model...")
trainer.train(train_ds,
valid_ds.take(10),
steps=300000,
evaluate_every=1000,
save_best_only=True,
model_name = "lapsrn_attention")
# Restore from checkpoint with highest PSNR
trainer.restore()
print("evaluating LAPSRN model...")
# Evaluate model on full validation set
psnrv = trainer.evaluate(valid_ds)
ssimv = trainer.evaluate2(valid_ds)
print(f'PSNR = {psnrv.numpy():3f}')
print(f'SSIM = {ssimv.numpy():3f}')
# Save weights to separate location (needed for demo)
trainer.model.save_weights(weights_file)
model = lapsrn(attention=attention)
model.load_weights(weights_file)
from model import resolve_single
from utils import load_image, plot_sample
print("resolving LAPSRN model...")
def resolve_and_plot(lr_image_path):
lr = load_image(lr_image_path)
sr = resolve_single(model, lr)
plot_sample(lr, sr)
plt.savefig("lapsrn_sample")
resolve_and_plot('demo/0869x4-crop.png')