Skip to content

Commit

Permalink
Merge 4ba613f into c548788
Browse files Browse the repository at this point in the history
  • Loading branch information
Dai-Wenxun authored Jul 20, 2023
2 parents c548788 + 4ba613f commit b2517b3
Show file tree
Hide file tree
Showing 3 changed files with 98 additions and 2 deletions.
29 changes: 29 additions & 0 deletions demo/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
- [SpatioTemporal Action Detection Video Demo](#spatiotemporal-action-detection-video-demo): A demo script to predict the spatiotemporal action detection result using a single video.
- [SpatioTemporal Action Detection ONNX Video Demo](#spatiotemporal-action-detection-onnx-video-demo): A demo script to predict the SpatioTemporal Action Detection result using the onnx file instead of building the PyTorch models.
- [Inferencer Demo](#inferencer): A demo script to implement fast predict for video analysis tasks based on unified inferencer interface.
- [Audio Demo](#audio-demo): A demo script to predict the recognition result using a single audio file.

## Modify configs through script arguments

Expand Down Expand Up @@ -438,3 +439,31 @@ Assume that you are located at `$MMACTION2`.
--rec tsn \
--label-file tools/data/kinetics/label_map_k400.txt
```

## Audio Demo

Demo script to predict the audio-based action recognition using a single audio feature.

The script [`extract_audio.py`](/tools/data/extract_audio.py) can be used to extract audios from videos and the script [`build_audio_features.py`](/tools/data/build_audio_features.py) can be used to extract the audio features.

```shell
python demo/demo_audio.py ${CONFIG_FILE} ${CHECKPOINT_FILE} ${AUDIO_FILE} {LABEL_FILE} [--device ${DEVICE}]
```

Optional arguments:

- `DEVICE`: Type of device to run the demo. Allowed values are cuda devices like `cuda:0` or `cpu`. If not specified, it will be set to `cuda:0`.

Examples:

Assume that you are located at `$MMACTION2` and have already downloaded the checkpoints to the directory `checkpoints/`,
or use checkpoint url from `configs/` to directly load the corresponding checkpoint, which will be automatically saved in `$HOME/.cache/torch/checkpoints`.

1. Recognize an audio file as input by using a tsn model on cuda by default.

```shell
python demo/demo_audio.py \
configs/recognition_audio/resnet/tsn_r18_8xb320-64x1x1-100e_kinetics400-audio-feature.py \
https://download.openmmlab.com/mmaction/v1.0/recognition_audio/resnet/tsn_r18_8xb320-64x1x1-100e_kinetics400-audio-feature/tsn_r18_8xb320-64x1x1-100e_kinetics400-audio-feature_20230702-e4642fb0.pth \
audio_feature.npy tools/data/kinetics/label_map_k400.txt
```
57 changes: 57 additions & 0 deletions demo/demo_audio.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
# Copyright (c) OpenMMLab. All rights reserved.
import argparse
from operator import itemgetter

import torch
from mmengine import Config, DictAction

from mmaction.apis import inference_recognizer, init_recognizer


def parse_args():
parser = argparse.ArgumentParser(description='MMAction2 demo')
parser.add_argument('config', help='test config file path')
parser.add_argument('checkpoint', help='checkpoint file/url')
parser.add_argument('audio', help='audio file')
parser.add_argument('label', help='label file')
parser.add_argument(
'--cfg-options',
nargs='+',
action=DictAction,
default={},
help='override some settings in the used config, the key-value pair '
'in xxx=yyy format will be merged into config file. For example, '
"'--cfg-options model.backbone.depth=18 model.backbone.with_cp=True'")
parser.add_argument(
'--device', type=str, default='cuda:0', help='CPU/CUDA device option')
args = parser.parse_args()
return args


def main():
args = parse_args()
device = torch.device(args.device)
cfg = Config.fromfile(args.config)
cfg.merge_from_dict(args.cfg_options)
model = init_recognizer(cfg, args.checkpoint, device=device)

if not args.audio.endswith('.npy'):
raise NotImplementedError('Demo works on extracted audio features')
pred_result = inference_recognizer(model, args.audio)

pred_scores = pred_result.pred_scores.item.tolist()
score_tuples = tuple(zip(range(len(pred_scores)), pred_scores))
score_sorted = sorted(score_tuples, key=itemgetter(1), reverse=True)
top5_label = score_sorted[:5]

labels = open(args.label).readlines()
labels = [x.strip() for x in labels]
results = [(labels[k[0]], k[1]) for k in top5_label]

print('The top-5 labels with corresponding scores are:')
for result in results:
print(f'{result[0]}: ', result[1])


if __name__ == '__main__':
main()
14 changes: 12 additions & 2 deletions mmaction/apis/inference.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.
import os.path as osp
from pathlib import Path
from typing import List, Optional, Union

Expand Down Expand Up @@ -80,8 +81,11 @@ def inference_recognizer(model: nn.Module,
input_flag = None
if isinstance(video, dict):
input_flag = 'dict'
elif isinstance(video, str):
input_flag = 'video'
elif isinstance(video, str) and osp.exists(video):
if video.endswith('.npy'):
input_flag = 'audio'
else:
input_flag = 'video'
else:
raise RuntimeError(f'The type of argument `video` is not supported: '
f'{type(video)}')
Expand All @@ -90,6 +94,12 @@ def inference_recognizer(model: nn.Module,
data = video
if input_flag == 'video':
data = dict(filename=video, label=-1, start_index=0, modality='RGB')
if input_flag == 'audio':
data = dict(
audio_path=video,
total_frames=len(np.load(video)),
start_index=0,
label=-1)

data = test_pipeline(data)
data = pseudo_collate([data])
Expand Down

0 comments on commit b2517b3

Please sign in to comment.