forked from neccam/slt
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathextract_gcn.py
60 lines (53 loc) · 1.77 KB
/
extract_gcn.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
import lzma
import pickle
import time
import random
import os
import yaml
from signjoey.GCN.processor import load_model
from signjoey.augmentations import load_augment
import argparse
def get_files(video_folder):
# Glob directory
videos = [
{
"file": y,
"file_name": y.split(".")[0],
"input_file": os.path.join(p, y),
}
for p, _, x in os.walk(video_folder)
for y in x
if y.endswith(".pkl")
]
return videos
def main(params):
# Load
files = get_files(params.input_folder)
model = load_model()
print(len(files))
for index, file in enumerate(files):
#name = params.input_file.split("/")[-1].split(".")[0]
with open(f"{file['input_file']}", "rb") as f:
label = pickle.loads(f.read())
# Run
try:
keypoints = load_augment("holistic", pickle.loads(lzma.decompress(label['sign'])))
except TypeError:
print(label["name"])
keypoints = load_augment("holistic", pickle.loads(lzma.decompress(label['sgn'])).numpy())
features = model(keypoints.unsqueeze(0).cuda()).squeeze(0).cpu().detach().numpy()
label["sign"] = lzma.compress(pickle.dumps(features))
# Make file
with open(f"{params.output_folder}/{file['file_name']}.pkl", "wb") as f:
f.write(pickle.dumps(label))
if __name__ == '__main__':
# Get params
parser = argparse.ArgumentParser()
parser.add_argument("--input_folder", type=str,
default="", help="")
parser.add_argument("--output_folder", type=str,
default="", help="")
params, _ = parser.parse_known_args()
_time = time.time()
main(params)
print(f"Done in {time.time() - _time}s")