-
Notifications
You must be signed in to change notification settings - Fork 53
/
test.py
38 lines (33 loc) · 1.64 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
import os
import sys
sys.path.insert(0, os.path.abspath(os.path.dirname(__file__)))
import shutil
import torch
import torch.nn.functional as F
import PIL
import torchvision.transforms.functional as transform
from vfi_utils import load_file_from_github_release
from vfi_models import gmfss_fortuna, ifrnet, ifunet, m2m, rife, sepconv, amt, xvfi, cain, flavr
import numpy as np
frame_0 = torch.from_numpy(np.array(PIL.Image.open("demo_frames/anime0.png").convert("RGB")).astype(np.float32) / 255.0).unsqueeze(0)
frame_1 = torch.from_numpy(np.array(PIL.Image.open("demo_frames/anime1.png").convert("RGB")).astype(np.float32) / 255.0).unsqueeze(0)
if os.path.exists("test_result"):
shutil.rmtree("test_result")
vfi_node_class = gmfss_fortuna.GMFSS_Fortuna_VFI()
for i, ckpt_name in enumerate(vfi_node_class.INPUT_TYPES()["required"]["ckpt_name"][0][:2]):
result = vfi_node_class.vfi(ckpt_name, torch.cat([
frame_0,
frame_1,
frame_0,
frame_1
], dim=0).cuda(), multipler=4, batch_size=2)[0]
print(result.shape)
print(f"Generated {result.size(0)} frames")
frames = [PIL.Image.fromarray(np.clip((frame * 255).numpy(), 0, 255).astype(np.uint8)) for frame in result]
print(result[0].shape)
os.makedirs(f"test_result/video{i}", exist_ok=True)
for j, frame in enumerate(frames):
frame.save(f"test_result/video{i}/{j}.jpg")
frames[0].save(f"test_result/video{i}.gif", save_all=True, append_images=frames[1:], optimize=True, duration=1/3, loop=0)
os.startfile(f"test_result{os.path.sep}video{i}.gif")
#torchvision.io.video.write_video("test.mp4", einops.rearrange(result, "n c h w -> n h w c").cpu(), fps=1)