forked from MrThetaIII/Video-Summerization-using-LSTM
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathclassify.py
64 lines (50 loc) · 2.2 KB
/
classify.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import torch
from torch.autograd import Variable
from dataset import Video
from spatial_transforms import (Compose, Normalize, Scale, CenterCrop, ToTensor)
from temporal_transforms import LoopPadding
from Prepare_dataset import Prepare_dataset
data_driver = Prepare_dataset()
def classify_video(video_dir, video_name, class_names, model, opt):
assert opt.mode in ['score', 'feature']
spatial_transform = Compose([Scale(opt.sample_size),
CenterCrop(opt.sample_size),
ToTensor(),
Normalize(opt.mean, [1, 1, 1])])
temporal_transform = LoopPadding(opt.sample_duration)
data = Video(video_dir, spatial_transform=spatial_transform,
temporal_transform=temporal_transform,
sample_duration=opt.sample_duration)
data_loader = torch.utils.data.DataLoader(data, batch_size=opt.batch_size,
shuffle=False, num_workers=opt.n_threads, pin_memory=True)
video_outputs = []
video_segments = []
for i, (inputs, segments) in enumerate(data_loader):
with torch.no_grad():
inputs = Variable(inputs)
outputs = model(inputs)
video_outputs.append(outputs.cpu().data)
video_segments.append(segments)
video_outputs = torch.cat(video_outputs)
video_segments = torch.cat(video_segments)
total_frames = video_segments[-1].tolist()[-1]
annotations = data_driver.get_annotations(video_name)
results = {
'video': video_name,
'clips': [],
'total_frames': total_frames,
'annotations_count': len(annotations)
}
_, max_indices = video_outputs.max(dim=1)
for i in range(video_outputs.size(0)):
clip_results = {
'segment': video_segments[i].tolist(),
'importance': annotations[i]
}
if opt.mode == 'score':
clip_results['label'] = class_names[max_indices[i]]
clip_results['scores'] = video_outputs[i].tolist()
elif opt.mode == 'feature':
clip_results['features'] = video_outputs[i].tolist()
results['clips'].append(clip_results)
return results