Skip to content

Commit

Permalink
split uniform sample
Browse files Browse the repository at this point in the history
  • Loading branch information
cir7 committed Dec 1, 2022
1 parent 8f123b4 commit 6f78bac
Show file tree
Hide file tree
Showing 8 changed files with 254 additions and 251 deletions.
17 changes: 3 additions & 14 deletions configs/recognition/mvit/mvit-base-p244_u32_sthv2-rgb.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,7 @@
file_client_args = dict(io_backend='disk')
train_pipeline = [
dict(type='DecordInit', **file_client_args),
dict(
type='UniformSampleFrames',
clip_len=32,
out_of_bound_opt='repeat_frame'),
dict(type='UniformSample', clip_len=32),
dict(type='DecordDecode'),
dict(type='Resize', scale=(-1, 256)),
dict(type='RandomResizedCrop'),
Expand All @@ -51,11 +48,7 @@
]
val_pipeline = [
dict(type='DecordInit', **file_client_args),
dict(
type='UniformSampleFrames',
clip_len=32,
out_of_bound_opt='repeat_frame',
test_mode=True),
dict(type='UniformSample', clip_len=32, test_mode=True),
dict(type='DecordDecode'),
dict(type='Resize', scale=(-1, 256)),
dict(type='CenterCrop', crop_size=224),
Expand All @@ -64,11 +57,7 @@
]
test_pipeline = [
dict(type='DecordInit', **file_client_args),
dict(
type='UniformSampleFrames',
clip_len=32,
out_of_bound_opt='repeat_frame',
test_mode=True),
dict(type='UniformSample', clip_len=32, test_mode=True),
dict(type='DecordDecode'),
dict(type='Resize', scale=(-1, 224)),
dict(type='ThreeCrop', crop_size=224),
Expand Down
17 changes: 3 additions & 14 deletions configs/recognition/mvit/mvit-large-p244_u40_sthv2-rgb.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,7 @@
file_client_args = dict(io_backend='disk')
train_pipeline = [
dict(type='DecordInit', **file_client_args),
dict(
type='UniformSampleFrames',
clip_len=40,
out_of_bound_opt='repeat_frame'),
dict(type='UniformSample', clip_len=40),
dict(type='DecordDecode'),
dict(type='Resize', scale=(-1, 256)),
dict(type='RandomResizedCrop'),
Expand All @@ -53,11 +50,7 @@
]
val_pipeline = [
dict(type='DecordInit', **file_client_args),
dict(
type='UniformSampleFrames',
clip_len=40,
out_of_bound_opt='repeat_frame',
test_mode=True),
dict(type='UniformSample', clip_len=40, test_mode=True),
dict(type='DecordDecode'),
dict(type='Resize', scale=(-1, 256)),
dict(type='CenterCrop', crop_size=224),
Expand All @@ -66,11 +59,7 @@
]
test_pipeline = [
dict(type='DecordInit', **file_client_args),
dict(
type='UniformSampleFrames',
clip_len=40,
out_of_bound_opt='repeat_frame',
test_mode=True),
dict(type='UniformSample', clip_len=40, test_mode=True),
dict(type='DecordDecode'),
dict(type='Resize', scale=(-1, 224)),
dict(type='ThreeCrop', crop_size=224),
Expand Down
17 changes: 3 additions & 14 deletions configs/recognition/mvit/mvit-small-p244_u16_sthv2-rgb.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,7 @@
file_client_args = dict(io_backend='disk')
train_pipeline = [
dict(type='DecordInit', **file_client_args),
dict(
type='UniformSampleFrames',
clip_len=16,
out_of_bound_opt='repeat_frame'),
dict(type='UniformSample', clip_len=16),
dict(type='DecordDecode'),
dict(type='Resize', scale=(-1, 256)),
dict(type='RandomResizedCrop'),
Expand All @@ -34,11 +31,7 @@
]
val_pipeline = [
dict(type='DecordInit', **file_client_args),
dict(
type='UniformSampleFrames',
clip_len=16,
out_of_bound_opt='repeat_frame',
test_mode=True),
dict(type='UniformSample', clip_len=16, test_mode=True),
dict(type='DecordDecode'),
dict(type='Resize', scale=(-1, 256)),
dict(type='CenterCrop', crop_size=224),
Expand All @@ -47,11 +40,7 @@
]
test_pipeline = [
dict(type='DecordInit', **file_client_args),
dict(
type='UniformSampleFrames',
clip_len=16,
out_of_bound_opt='repeat_frame',
test_mode=True),
dict(type='UniformSample', clip_len=16, test_mode=True),
dict(type='DecordDecode'),
dict(type='Resize', scale=(-1, 224)),
dict(type='ThreeCrop', crop_size=224),
Expand Down
14 changes: 7 additions & 7 deletions mmaction/datasets/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@
LoadProposals, OpenCVDecode, OpenCVInit, PIMSDecode,
PIMSInit, PyAVDecode, PyAVDecodeMotionVector, PyAVInit,
RawFrameDecode, SampleAVAFrames, SampleFrames,
UniformSampleFrames, UntrimmedSampleFrames)
UniformSample, UntrimmedSampleFrames)
from .pose_loading import (GeneratePoseTarget, LoadKineticsPose,
PaddingWithLoop, PoseDecode)
PaddingWithLoop, PoseDecode, UniformSampleFrames)
from .processing import (AudioAmplify, CenterCrop, ColorJitter, Flip, Fuse,
MelSpectrogram, MultiScaleCrop, PoseCompact,
RandomCrop, RandomRescale, RandomResizedCrop, Resize,
Expand All @@ -30,9 +30,9 @@
'AudioAmplify', 'MelSpectrogram', 'AudioDecode', 'FormatAudioShape',
'LoadAudioFeature', 'AudioFeatureSelector', 'AudioDecodeInit',
'ImageDecode', 'BuildPseudoClip', 'RandomRescale', 'PIMSDecode',
'PyAVDecodeMotionVector', 'UniformSampleFrames', 'PoseDecode',
'LoadKineticsPose', 'GeneratePoseTarget', 'PIMSInit', 'FormatGCNInput',
'PaddingWithLoop', 'ArrayDecode', 'JointToBone', 'PackActionInputs',
'PackLocalizationInputs', 'ImgAug', 'TorchVisionWrapper',
'PytorchVideoWrapper', 'PoseCompact'
'PyAVDecodeMotionVector', 'UniformSample', 'UniformSampleFrames',
'PoseDecode', 'LoadKineticsPose', 'GeneratePoseTarget', 'PIMSInit',
'FormatGCNInput', 'PaddingWithLoop', 'ArrayDecode', 'JointToBone',
'PackActionInputs', 'PackLocalizationInputs', 'ImgAug',
'TorchVisionWrapper', 'PytorchVideoWrapper', 'PoseCompact'
]
130 changes: 15 additions & 115 deletions mmaction/datasets/transforms/loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,14 +266,15 @@ def __repr__(self):


@TRANSFORMS.register_module()
class UniformSampleFrames(BaseTransform):
"""Uniformly sample frames from the video.
class UniformSample(BaseTransform):
"""Uniformly sample frames from the video. Currently used for Something-
Something V2 dataset. Modified from
https://github.com/facebookresearch/SlowFast/blob/64a
bcc90ccfdcbb11cf91d6e525bed60e92a8796/slowfast/datasets/ssv2.py#L159.
To sample an n-frame clip from the video. UniformSampleFrames basically
divides the video into n segments of equal length and randomly samples one
frame from each segment. To make the testing results reproducible, a
random seed is set during testing, to make the sampling results
deterministic.
frame from each segment.
Required keys:
Expand All @@ -292,113 +293,23 @@ class UniformSampleFrames(BaseTransform):
num_clips (int): Number of clips to be sampled. Default: 1.
test_mode (bool): Store True when building test or validation dataset.
Default: False.
seed (int): The random seed used during test time. Default: 255.
out_of_bound_opt (str): The way to deal with out of bounds frame
indexes. Available options are 'loop', 'repeat_frame'.
Default: 'loop'.
"""

def __init__(self,
clip_len: int,
num_clips: int = 1,
test_mode: bool = False,
seed: int = 255,
out_of_bound_opt: str = 'loop') -> None:
test_mode: bool = False) -> None:

self.clip_len = clip_len
self.num_clips = num_clips
self.test_mode = test_mode
self.seed = seed
self.out_of_bound_opt = out_of_bound_opt
assert self.out_of_bound_opt in ['loop', 'repeat_frame']

def _get_train_clips(self, num_frames: int):
"""Uniformly sample indices for training clips.
Args:
num_frames (int): The number of frames.
"""

assert self.num_clips == 1
if num_frames < self.clip_len:
start = np.random.randint(0, num_frames)
inds = np.arange(start, start + self.clip_len)
elif self.clip_len <= num_frames < 2 * self.clip_len:
basic = np.arange(self.clip_len)
inds = np.random.choice(
self.clip_len + 1, num_frames - self.clip_len, replace=False)
offset = np.zeros(self.clip_len + 1, dtype=np.int32)
offset[inds] = 1
offset = np.cumsum(offset)
inds = basic + offset[:-1]
else:
bids = np.array([
i * num_frames // self.clip_len
for i in range(self.clip_len + 1)
])
bsize = np.diff(bids)
bst = bids[:self.clip_len]
offset = np.random.randint(bsize)
inds = bst + offset
return inds

def _get_test_clips(self, num_frames: int):
"""Uniformly sample indices for testing clips.

Args:
num_frames (int): The number of frames.
"""

np.random.seed(self.seed)
if num_frames < self.clip_len:
# Then we use a simple strategy
if num_frames < self.num_clips:
start_inds = list(range(self.num_clips))
else:
start_inds = [
i * num_frames // self.num_clips
for i in range(self.num_clips)
]
inds = np.concatenate(
[np.arange(i, i + self.clip_len) for i in start_inds])
elif self.clip_len <= num_frames < self.clip_len * 2:
all_inds = []
for i in range(self.num_clips):
basic = np.arange(self.clip_len)
inds = np.random.choice(
self.clip_len + 1,
num_frames - self.clip_len,
replace=False)
offset = np.zeros(self.clip_len + 1, dtype=np.int32)
offset[inds] = 1
offset = np.cumsum(offset)
inds = basic + offset[:-1]
all_inds.append(inds)
inds = np.concatenate(all_inds)
else:
bids = np.array([
i * num_frames // self.clip_len
for i in range(self.clip_len + 1)
])
bsize = np.diff(bids)
bst = bids[:self.clip_len]
all_inds = []
for i in range(self.num_clips):
offset = np.random.randint(bsize)
all_inds.append(bst + offset)
inds = np.concatenate(all_inds)
return inds

def _get_repeat_sample_clips(self, num_frames: int) -> np.array:
"""Repeat sample when video is shorter than clip_len Modified from
https://github.com/facebookresearch/SlowFast/blob/64ab
cc90ccfdcbb11cf91d6e525bed60e92a8796/slowfast/datasets/ssv2.py#L159.
When video frames is shorter than target clip len, this strategy would
repeat sample frame, rather than loop sample in 'loop' mode.
In test mode, this strategy would sample the middle frame of each
segment, rather than set a random seed, and therefore only support
sample 1 clip.
def _get_sample_clips(self, num_frames: int) -> np.array:
"""When video frames is shorter than target clip len, this strategy
would repeat sample frame, rather than loop sample in 'loop' mode. In
test mode, this strategy would sample the middle frame of each segment,
rather than set a random seed, and therefore only support sample 1
clip.
Args:
num_frames (int): Total number of frame in the video.
Expand All @@ -421,17 +332,7 @@ def _get_repeat_sample_clips(self, num_frames: int) -> np.array:
def transform(self, results: dict):
num_frames = results['total_frames']

if self.out_of_bound_opt == 'loop':
if self.test_mode:
inds = self._get_test_clips(num_frames)
else:
inds = self._get_train_clips(num_frames)
inds = np.mod(inds, num_frames)
elif self.out_of_bound_opt == 'repeat_frame':
inds = self._get_repeat_sample_clips(num_frames)
else:
raise ValueError('Illegal out_of_bound option.')

inds = self._get_sample_clips(num_frames)
start_index = results['start_index']
inds = inds + start_index

Expand All @@ -445,8 +346,7 @@ def __repr__(self):
repr_str = (f'{self.__class__.__name__}('
f'clip_len={self.clip_len}, '
f'num_clips={self.num_clips}, '
f'test_mode={self.test_mode}, '
f'seed={self.seed})')
f'test_mode={self.test_mode}')
return repr_str


Expand Down
Loading

0 comments on commit 6f78bac

Please sign in to comment.