-
Notifications
You must be signed in to change notification settings - Fork 8
/
Copy pathgenerate_videos.py
68 lines (47 loc) · 1.64 KB
/
generate_videos.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
from __future__ import absolute_import
import torch
import torch.nn as nn
from model.networks import Generator
import cfg
import skvideo.io
import numpy as np
import os
from tqdm import tqdm
import argparse
def save_videos(path, vids, epoch, bs):
for i in range(bs):
v = vids[i].permute(0,2,3,1).cpu().numpy()
v *= 255
v = v.astype(np.uint8)
skvideo.io.vwrite(os.path.join(path, "%d.mp4"%(epoch * bs + i)), v, outputdict={"-vcodec":"libx264"})
def main(args):
device = torch.device("cuda:0")
G = Generator().to(device)
G = nn.DataParallel(G)
G.load_state_dict(torch.load(args.model_path))
with torch.no_grad():
G.eval()
batch_size = args.batch_size
n_epoch = args.n // batch_size + 1
for epoch in tqdm(range(n_epoch)):
bs = min(batch_size, args.n - epoch * batch_size)
za = torch.randn(bs, args.d_za, 1, 1, 1).to(device)
zm = torch.randn(bs, args.d_zm, 1, 1, 1).to(device)
vid_fake = G(za, zm)
vid_fake = vid_fake.transpose(2,1) # bs x 16 x 3 x 64 x 64
vid_fake = ((vid_fake - vid_fake.min()) / (vid_fake.max() - vid_fake.min())).data
# save into videos
save_videos(args.gen_path, vid_fake, epoch, bs)
return
if __name__ == '__main__':
# gen_path = '/data/stars/user/yaowang/exp/g3an/'
# training params
parser = argparse.ArgumentParser()
parser.add_argument("--n", type=int, default=5000)
parser.add_argument("--batch_size", type=int, default=64)
parser.add_argument("--d_za", type=int, default=128)
parser.add_argument("--d_zm", type=int, default=10)
parser.add_argument("--model_path", type=str, default='./g3an.pth')
parser.add_argument("--gen_path", type=str)
args = parser.parse_args()
main(args)