-
Notifications
You must be signed in to change notification settings - Fork 4
/
run.py
148 lines (123 loc) · 4.94 KB
/
run.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
from cli import parse_config
import glob
import benchmark
from benchmark import Task, Degradation
from robust_unsupervised import *
config = parse_config()
benchmark.config.resolution = config.resolution
print(config.name)
timestamp = datetime.datetime.now().isoformat(timespec="seconds").replace(":", "")
G = open_generator(config.pkl_path)
loss_fn = MultiscaleLPIPS()
def run_phase(label: str, variable: Variable, lr: float):
# Run optimization loop
optimizer = NGD(variable.parameters(), lr=lr)
try:
for _ in tqdm.tqdm(range(150), desc=label):
x = variable.to_image()
loss = loss_fn(degradation.degrade_prediction, x, target, degradation.mask).mean()
optimizer.zero_grad()
loss.backward()
optimizer.step()
except KeyboardInterrupt:
pass
# Log results
suffix = "_" + label
pred = resize_for_logging(variable.to_image(), config.resolution)
approx_degraded_pred = degradation.degrade_prediction(pred)
degraded_pred = degradation.degrade_ground_truth(pred)
save_image(pred, f"pred{suffix}.png", padding=0)
save_image(degraded_pred, f"degraded_pred{suffix}.png", padding=0)
save_image(
torch.cat([approx_degraded_pred, degraded_pred]),
f"degradation_approximation{suffix}.jpg",
padding=0,
)
save_image(
torch.cat(
[
ground_truth,
resize_for_logging(target, config.resolution),
resize_for_logging(degraded_pred, config.resolution),
pred,
]
),
f"side_by_side{suffix}.jpg",
padding=0,
)
save_image(
torch.cat([resize_for_logging(target, config.resolution), pred]),
f"result{suffix}.jpg",
padding=0,
)
save_image(
torch.cat([target, degraded_pred, (target - degraded_pred).abs()]),
f"fidelity{suffix}.jpg",
padding=0,
)
save_image(
torch.cat([ground_truth, pred, (ground_truth - pred).abs()]),
f"accuracy{suffix}.jpg",
padding=0,
)
if __name__ == '__main__':
if config.tasks == "single":
tasks = benchmark.single_tasks
elif config.tasks == "composed":
tasks = benchmark.composed_tasks
elif config.tasks == "all":
tasks = benchmark.all_tasks
elif config.tasks == "custom":
# Implement your own degradation here
class YourDegradation:
def degrade_ground_truth(self, x):
"The true degradation you are attempting to invert."
raise NotImplementedError
def degrade_prediction(self, x):
"""
Differentiable approximation to the degradation in question.
Can be identical to the true degradation if it is invertible.
"""
raise NotImplementedError
tasks = [
benchmark.Task(
constructor=YourDegradation,
# These labels are just for the output folder structure
name="your_degradation",
category="single",
level="M",
)
]
else:
raise Exception("Invalid task name")
for task in tasks:
experiment_path = f"out/{config.name}/{timestamp}/{task.category}/{task.name}/{task.level}/"
image_paths = sorted(
[
os.path.abspath(path)
for path in (
glob.glob(config.dataset_path + "/**/*.png", recursive=True)
+ glob.glob(config.dataset_path + "/**/*.jpg", recursive=True)
+ glob.glob(config.dataset_path + "/**/*.jpeg", recursive=True)
+ glob.glob(config.dataset_path + "/**/*.tif", recursive=True)
)
]
)
assert len(image_paths) > 0, "No images found!"
with directory(experiment_path):
print(experiment_path)
print(os.path.abspath(config.dataset_path))
for j, image_path in enumerate(image_paths):
with directory(f"inversions/{j:04d}"):
print(f"- {j:04d}")
ground_truth = open_image(image_path, config.resolution)
degradation = task.init_degradation()
save_image(ground_truth, f"ground_truth.png")
target = degradation.degrade_ground_truth(ground_truth)
save_image(target, f"target.png")
W_variable = WVariable.sample_from(G)
run_phase("W", W_variable, config.global_lr_scale * 0.08)
Wp_variable = WpVariable.from_W(W_variable)
run_phase("W+", Wp_variable, config.global_lr_scale * 0.02)
Wpp_variable = WppVariable.from_Wp(Wp_variable)
run_phase("W++", Wpp_variable, config.global_lr_scale * 0.005)