-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathutils.py
45 lines (33 loc) · 1.58 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
import os
import torch
import wandb
import pytorch_lightning as pl
class ImageVisCallback(pl.Callback):
def __init__(self, val_Dataloader, max_samples=10):
super().__init__()
self.valLoader = val_Dataloader
self.max_samples = max_samples
def on_validation_end(self, trainer: "pl.Trainer", model: "pl.LightningModule") -> None:
'''imgsA = self.val_imgs.to(device=model.device).unsqueeze(0)
imgsA = imgsA[:, model.output_channels, :, :]
imgsY = self.val_y.to(device=model.device).unsqueeze(0)
imgsY = imgsY[:, model.output_channels, :, :]
'''
val_dl = self.valLoader
dataiter = iter(val_dl)
for i in range(self.max_samples):
test = dataiter.next()
imgs = test['x'][0].to(device=model.device).unsqueeze(0)
imgs = imgs[:, model.output_channels, :, :]
imgsY = test['y'][0].to(device=model.device).unsqueeze(0)
imgsY = imgsY[:, model.output_channels, :, :]
upresed = model(imgs)
mosaics = torch.cat([imgs, upresed, imgsY], dim=-2)
caption = "Image {}: Top: Low Res, Middle: High Res Prediction, Bottom: High Res Truth".format(i)
logname = "val/examples{}".format(i) if os.name != "nt" else "val\examples{}".format(i)
trainer.logger.experiment.log({
logname: [wandb.Image(mosaic, caption) for mosaic in mosaics],
})
trainer.logger.experiment.log({
"global_step": trainer.global_step # This will make sure wandb gets the epoch/step right
})