-
Notifications
You must be signed in to change notification settings - Fork 7
/
utils.py
68 lines (50 loc) · 2.01 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
import os
import cv2
import random
import torch
from torch.backends import cudnn
def gpu_manage(config):
if config.cuda:
os.environ['CUDA_VISIBLE_DEVICES'] = ','.join(map(str, config.gpu_ids))
config.gpu_ids = list(range(len(config.gpu_ids)))
print(os.environ['CUDA_VISIBLE_DEVICES'])
if config.manualSeed is None:
config.manualSeed = random.randint(1, 10000)
print('Random Seed: ', config.manualSeed)
random.seed(config.manualSeed)
torch.manual_seed(config.manualSeed)
if config.cuda:
torch.cuda.manual_seed_all(config.manualSeed)
cudnn.benchmark = True
if torch.cuda.is_available() and not config.cuda:
print("WARNING: You have a CUDA device, so you should probably run with --cuda")
def save_image(out_dir, x, num, epoch, filename=None):
test_dir = os.path.join(out_dir, 'epoch_{0:04d}'.format(epoch))
if filename is not None:
test_path = os.path.join(test_dir, filename)
else:
test_path = os.path.join(test_dir, 'test_{0:04d}.png'.format(num))
if not os.path.exists(test_dir):
os.makedirs(test_dir)
cv2.imwrite(test_path, x)
def checkpoint(config, epoch, gen, dis):
model_dir = os.path.join(config.out_dir, 'models')
if not os.path.exists(model_dir):
os.makedirs(model_dir)
net_gen_model_out_path = os.path.join(model_dir, 'gen_model_epoch_{}.pth'.format(epoch))
net_dis_model_out_path = os.path.join(model_dir, 'dis_model_epoch_{}.pth'.format(epoch))
torch.save(gen.state_dict(), net_gen_model_out_path)
torch.save(dis.state_dict(), net_dis_model_out_path)
print("Checkpoint saved to {}".format(model_dir))
def make_manager():
if not os.path.exists('.job'):
os.makedirs('.job')
with open('.job/job.txt', 'w') as f:
f.write('0')
def job_increment():
with open('.job/job.txt', 'r') as f:
n_job = f.read()
n_job = int(n_job)
with open('.job/job.txt', 'w') as f:
f.write(str(n_job + 1))
return n_job