-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathtrain.py
63 lines (54 loc) · 2.02 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
"""Train script for UNIT model."""
from tqdm import tqdm
from img2img.cfg import unit as cfg
from img2img.data import get_loader
from img2img.models.unit.trainer import UNIT_Trainer
from img2img.utils.unit import save_some_examples
from img2img.utils import prepare_sub_directories
def trainfn(trainer, train_loader):
loop = tqdm(train_loader, leave=True)
for idx, (images_a, images_b) in enumerate(loop):
images_a, images_b = images_a.to(cfg.DEVICE), images_b.to(cfg.DEVICE)
# trainer.update_learning_rate()
trainer.dis_update(images_a, images_b)
trainer.gen_update(images_a, images_b)
trainer.update_learning_rate()
def main() -> int:
"""Entry point."""
trainer = UNIT_Trainer()
trainer.to(cfg.DEVICE)
train_loader = get_loader(
root_dir=cfg.TRAIN_DATASET_PATH,
dataset_type=cfg.CHOSEN_DATASET,
batch_size=cfg.BATCH_SIZE,
shuffle=True,
num_workers=cfg.NUM_WORKERS,
)
val_loader = get_loader(
root_dir=cfg.VAL_DATASET_PATH,
dataset_type=cfg.CHOSEN_DATASET,
batch_size=cfg.VAL_BATCH_SIZE,
shuffle=False,
num_workers=cfg.NUM_WORKERS,
)
path = cfg.OUT_PATH / f"unit_{cfg.CHOSEN_DATASET.value.stem}"
weights_dir, val_dir = prepare_sub_directories(path)
# TODO: copy the config yaml file to the out directory
# shutil.copy(opts.config, os.path.join(output_directory, 'config.yaml')) # copy config file to output folder
# TODO: apply mixed precision (torch.cuda.amp.autocast)
if cfg.LOAD_MODEL:
trainer.resume(weights_dir)
for epoch in range(cfg.NUM_EPOCHS):
print(f"Epoch: {epoch}")
trainfn(trainer=trainer, train_loader=train_loader)
save_some_examples(
trainer=trainer,
val_loader=val_loader,
epoch=epoch,
dir_path=val_dir,
)
if cfg.SAVE_MODEL and epoch % 5 == 0:
trainer.save(weights_dir, epoch)
return 0
if __name__ == "__main__":
raise SystemExit(main())