diff --git a/gluoncv/data/__init__.py b/gluoncv/data/__init__.py index 956d26efbd..81dea767e5 100644 --- a/gluoncv/data/__init__.py +++ b/gluoncv/data/__init__.py @@ -19,6 +19,7 @@ from .recordio.detection import RecordFileDetection from .lst.detection import LstDetection from .mixup.detection import MixupDetection +from .ucf101.classification import UCF101 datasets = { 'ade20k': ADE20KSegmentation, diff --git a/gluoncv/data/transforms/__init__.py b/gluoncv/data/transforms/__init__.py index fed4478664..5e17b03a3e 100644 --- a/gluoncv/data/transforms/__init__.py +++ b/gluoncv/data/transforms/__init__.py @@ -8,3 +8,4 @@ from . import presets from .block import RandomCrop from . import pose +from . import video diff --git a/gluoncv/data/transforms/video.py b/gluoncv/data/transforms/video.py new file mode 100644 index 0000000000..5028855dc5 --- /dev/null +++ b/gluoncv/data/transforms/video.py @@ -0,0 +1,345 @@ +# pylint: disable= missing-docstring +"""Extended image transformations to video transformations.""" +from __future__ import division +import random +import numbers +import numpy as np +from mxnet import nd +from mxnet.gluon import Block + +__all__ = ['VideoToTensor', 'VideoNormalize', 'VideoRandomHorizontalFlip', 'VideoMultiScaleCrop', + 'VideoCenterCrop', 'VideoTenCrop'] + +class VideoToTensor(Block): + """Converts a video clip NDArray to a tensor NDArray. + + Converts a video clip NDArray of shape (H x W x C) in the range + [0, 255] to a float32 tensor NDArray of shape (C x H x W) in + the range [0, 1). + + Parameters + ---------- + max_intensity : float + The maximum intensity value to be divided. + + Inputs: + - **data**: input tensor with (H x W x C) shape and uint8 type. + + Outputs: + - **out**: output tensor with (C x H x W) shape and float32 type. + """ + def __init__(self, max_intensity=255.0): + super(VideoToTensor, self).__init__() + self.max_intensity = max_intensity + + def forward(self, clips): + return nd.transpose(clips, axes=(2, 0, 1)) / self.max_intensity + +class VideoNormalize(Block): + """Normalize an tensor of shape (C x H x W) with mean and standard deviation. + + Given mean `(m1, ..., mn)` and std `(s1, ..., sn)` for `n` channels, + this transform normalizes each channel of the input tensor with:: + + output[i] = (input[i] - mi) / si + + If mean or std is scalar, the same value will be applied to all channels. + + Parameters + ---------- + mean : float or tuple of floats + The mean values. + std : float or tuple of floats + The standard deviation values. + + + Inputs: + - **data**: input tensor with (C x H x W) shape. + + Outputs: + - **out**: output tensor with the shape as `data`. + """ + + def __init__(self, mean, std): + super(VideoNormalize, self).__init__() + self.mean = mean + self.std = std + + def forward(self, clips): + c, _, _ = clips.shape + num_images = int(c / 3) + clip_mean = self.mean * num_images + clip_std = self.std * num_images + clip_mean = nd.array(np.asarray(clip_mean).reshape((c, 1, 1))) + clip_std = nd.array(np.asarray(clip_std).reshape((c, 1, 1))) + + return (clips - clip_mean) / clip_std + +class VideoRandomHorizontalFlip(Block): + """Randomly flip the input video clip left to right with a probability of 0.5. + + Parameters + ---------- + px : float + The probability value to flip the input tensor. + + Inputs: + - **data**: input tensor with (H x W x C) shape. + + Outputs: + - **out**: output tensor with same shape as `data`. + """ + + def __init__(self, px=0): + super(VideoRandomHorizontalFlip, self).__init__() + self.px = px + + def forward(self, clips): + if random.random() < 0.5: + clips = nd.flip(clips, axis=1) + return clips + +class VideoMultiScaleCrop(Block): + """Corner cropping and multi-scale cropping. + Two data augmentation techniques introduced in: + Towards Good Practices for Very Deep Two-Stream ConvNets, + http://arxiv.org/abs/1507.02159 + Limin Wang, Yuanjun Xiong, Zhe Wang and Yu Qiao + + Parameters: + ---------- + size : int + height and width required by network input, e.g., (224, 224) + scale_ratios : list + efficient scale jittering, e.g., [1.0, 0.875, 0.75, 0.66] + fix_crop : bool + use corner cropping or not. Default: True + more_fix_crop : bool + use more corners or not. Default: True + max_distort : float + maximum distortion. Default: 1 + + Inputs: + - **data**: input tensor with (H x W x C) shape. + + Outputs: + - **out**: output tensor with desired size as 'size' + + """ + + def __init__(self, size, scale_ratios, fix_crop=True, + more_fix_crop=True, max_distort=1): + super(VideoMultiScaleCrop, self).__init__() + self.height = size[0] + self.width = size[1] + self.scale_ratios = scale_ratios + self.fix_crop = fix_crop + self.more_fix_crop = more_fix_crop + self.max_distort = max_distort + + def fillFixOffset(self, datum_height, datum_width): + """Fixed cropping strategy + + Inputs: + - **data**: height and width of input tensor + + Outputs: + - **out**: a list of locations to crop the image + + """ + h_off = int((datum_height - self.height) / 4) + w_off = int((datum_width - self.width) / 4) + + offsets = [] + offsets.append((0, 0)) # upper left + offsets.append((0, 4*w_off)) # upper right + offsets.append((4*h_off, 0)) # lower left + offsets.append((4*h_off, 4*w_off)) # lower right + offsets.append((2*h_off, 2*w_off)) # center + + if self.more_fix_crop: + offsets.append((0, 2*w_off)) # top center + offsets.append((4*h_off, 2*w_off)) # bottom center + offsets.append((2*h_off, 0)) # left center + offsets.append((2*h_off, 4*w_off)) # right center + + offsets.append((1*h_off, 1*w_off)) # upper left quarter + offsets.append((1*h_off, 3*w_off)) # upper right quarter + offsets.append((3*h_off, 1*w_off)) # lower left quarter + offsets.append((3*h_off, 3*w_off)) # lower right quarter + + return offsets + + def fillCropSize(self, input_height, input_width): + """Fixed cropping strategy + + Inputs: + - **data**: height and width of input tensor + + Outputs: + - **out**: a list of crop sizes to crop the image + + """ + crop_sizes = [] + base_size = np.min((input_height, input_width)) + scale_rates = self.scale_ratios + for h, scale_rate_h in enumerate(scale_rates): + crop_h = int(base_size * scale_rate_h) + for w, scale_rate_w in enumerate(scale_rates): + crop_w = int(base_size * scale_rate_w) + # append this cropping size into the list + if (np.absolute(h-w) <= self.max_distort): + crop_sizes.append((crop_h, crop_w)) + + return crop_sizes + + def forward(self, clips): + + from ...utils.filesystem import try_import_cv2 + cv2 = try_import_cv2() + + clips = clips.asnumpy() + h, w, c = clips.shape + is_color = False + if c % 3 == 0: + is_color = True + + crop_size_pairs = self.fillCropSize(h, w) + size_sel = random.randint(0, len(crop_size_pairs)-1) + crop_height = crop_size_pairs[size_sel][0] + crop_width = crop_size_pairs[size_sel][1] + + if self.fix_crop: + offsets = self.fillFixOffset(h, w) + off_sel = random.randint(0, len(offsets)-1) + h_off = offsets[off_sel][0] + w_off = offsets[off_sel][1] + else: + h_off = random.randint(0, h - self.height) + w_off = random.randint(0, w - self.width) + + scaled_clips = np.zeros((self.height, self.width, c)) + if is_color: + num_imgs = int(c / 3) + for frame_id in range(num_imgs): + cur_img = clips[:, :, frame_id*3:frame_id*3+3] + crop_img = cur_img[h_off:h_off+crop_height, w_off:w_off+crop_width, :] + scaled_clips[:, :, frame_id*3:frame_id*3+3] = \ + cv2.resize(crop_img, (self.width, self.height), cv2.INTER_LINEAR) + else: + num_imgs = int(c / 1) + for frame_id in range(num_imgs): + cur_img = clips[:, :, frame_id:frame_id+1] + crop_img = cur_img[h_off:h_off+crop_height, w_off:w_off+crop_width, :] + scaled_clips[:, :, frame_id:frame_id+1] = np.expand_dims(\ + cv2.resize(crop_img, (self.width, self.height), cv2.INTER_LINEAR), axis=2) + + return nd.array(scaled_clips) + + +class VideoCenterCrop(Block): + """Crops the given numpy array at the center to have a region of + the given size. size can be a tuple (target_height, target_width) + or an integer, in which case the target will be of a square shape (size, size) + + Parameters: + ---------- + size : int + height and width required by network input, e.g., (224, 224) + + Inputs: + - **data**: input tensor with (H x W x C) shape. + + Outputs: + - **out**: output tensor with desired size as 'size' + """ + + def __init__(self, size): + super(VideoCenterCrop, self).__init__() + if isinstance(size, numbers.Number): + self.size = (int(size), int(size)) + else: + self.size = size + + def forward(self, clips): + h, w, c = clips.shape + th, tw = self.size + x1 = int(round((w - tw) / 2.)) + y1 = int(round((h - th) / 2.)) + + is_color = False + if c % 3 == 0: + is_color = True + + scaled_clips = nd.zeros((th, tw, c)) + if is_color: + num_imgs = int(c / 3) + for frame_id in range(num_imgs): + cur_img = clips[:, :, frame_id*3:frame_id*3+3] + crop_img = cur_img[y1:y1+th, x1:x1+tw, :] + assert(crop_img.shape == (th, tw, 3)) + scaled_clips[:, :, frame_id*3:frame_id*3+3] = crop_img + else: + num_imgs = int(c / 1) + for frame_id in range(num_imgs): + cur_img = clips[:, :, frame_id:frame_id+1] + crop_img = cur_img[y1:y1+th, x1:x1+tw, :] + assert(crop_img.shape == (th, tw, 1)) + scaled_clips[:, :, frame_id:frame_id+1] = crop_img + return scaled_clips + + +class VideoTenCrop(Block): + """Crop 10 regions from an array. + This is performed same as: + http://chainercv.readthedocs.io/en/stable/reference/transforms.html#ten-crop + + This method crops 10 regions. All regions will be in shape + :obj`size`. These regions consist of 1 center crop and 4 corner + crops and horizontal flips of them. + The crops are ordered in this order. + * center crop + * top-left crop + * bottom-left crop + * top-right crop + * bottom-right crop + * center crop (flipped horizontally) + * top-left crop (flipped horizontally) + * bottom-left crop (flipped horizontally) + * top-right crop (flipped horizontally) + * bottom-right crop (flipped horizontally) + + Parameters: + ---------- + size : int + height and width required by network input, e.g., (224, 224) + + Inputs: + - **data**: input tensor with (H x W x C) shape. + + Outputs: + - **out**: output tensor with (H x W x 10C) shape. + + """ + def __init__(self, size): + super(VideoTenCrop, self).__init__() + if isinstance(size, numbers.Number): + self.size = (int(size), int(size)) + else: + self.size = size + + def forward(self, clips): + h, w, _ = clips.shape + oh, ow = self.size + if h < oh or w < ow: + raise ValueError("Cannot crop area {} from image with size \ + ({}, {})".format(str(self.size), h, w)) + + center = clips[(h - oh) // 2:(h + oh) // 2, (w - ow) // 2:(w + ow) // 2, :] + tl = clips[0:oh, 0:ow, :] + bl = clips[h - oh:h, 0:ow, :] + tr = clips[0:oh, w - ow:w, :] + br = clips[h - oh:h, w - ow:w, :] + crops = nd.concat(*[center, tl, bl, tr, br], dim=2) + crops = nd.concat(*[crops, nd.flip(crops, axis=1)], dim=2) + return crops diff --git a/gluoncv/data/ucf101/__init__.py b/gluoncv/data/ucf101/__init__.py new file mode 100755 index 0000000000..8772d0eaa0 --- /dev/null +++ b/gluoncv/data/ucf101/__init__.py @@ -0,0 +1,6 @@ +# pylint: disable=wildcard-import +"""Video action recognition, UCF101 dataset. +https://www.crcv.ucf.edu/data/UCF101.php +""" +from __future__ import absolute_import +from .classification import * diff --git a/gluoncv/data/ucf101/classification.py b/gluoncv/data/ucf101/classification.py new file mode 100755 index 0000000000..260f0d9c88 --- /dev/null +++ b/gluoncv/data/ucf101/classification.py @@ -0,0 +1,191 @@ +# pylint: disable=line-too-long,too-many-lines,missing-docstring +"""UCF101 action classification dataset.""" +import os +import random +import numpy as np +from mxnet import nd +from mxnet.gluon.data import dataset + +__all__ = ['UCF101'] + +class UCF101(dataset.Dataset): + """Load the UCF101 action recognition dataset. + + Refer to :doc:`../build/examples_datasets/ucf101` for the description of + this dataset and how to prepare it. + + Parameters + ---------- + root : str, default '~/.mxnet/datasets/ucf101' + Path to the folder stored the dataset. + setting : str, required + Config file of the prepared dataset. + train : bool, default True + Whether to load the training or validation set. + test_mode : bool, default False + Whether to perform evaluation on the test set + name_pattern : str, default None + The naming pattern of the decoded video frames. + For example, img_00012.jpg + is_color : bool, default True + Whether the loaded image is color or grayscale + modality : str, default 'rgb' + Input modalities, we support only rgb video frames for now. + Will add support for rgb difference image and optical flow image later. + num_segments : int, default 1 + Number of segments to evenly divide the video into clips. + A useful technique to obtain global video-level information. + Limin Wang, etal, Temporal Segment Networks: Towards Good Practices for Deep Action Recognition, ECCV 2016 + new_length : int, default 1 + The length of input video clip. Default is a single image, but it can be multiple video frames. + For example, new_length=16 means we will extract a video clip of consecutive 16 frames. + new_width : int, default 340 + Scale the width of loaded image to 'new_width' for later multiscale cropping and resizing. + new_height : int, default 256 + Scale the height of loaded image to 'new_height' for later multiscale cropping and resizing. + target_width : int, default 224 + Scale the width of transformed image to the same 'target_width' for batch forwarding. + target_height : int, default 224 + Scale the height of transformed image to the same 'target_height' for batch forwarding. + transform : function, default None + A function that takes data and label and transforms them. + """ + def __init__(self, + setting, + root=os.path.join('~', '.mxnet', 'datasets', 'ucf101'), + train=True, + test_mode=False, + name_pattern=None, + is_color=True, + modality='rgb', + num_segments=1, + new_length=1, + new_width=340, + new_height=256, + target_width=224, + target_height=224, + transform=None): + + super(UCF101, self).__init__() + + self.root = root + self.setting = setting + self.train = train + self.test_mode = test_mode + self.is_color = is_color + self.modality = modality + self.num_segments = num_segments + self.new_height = new_height + self.new_width = new_width + self.target_height = target_height + self.target_width = target_width + self.new_length = new_length + self.transform = transform + + self.classes, self.class_to_idx = self._find_classes(root) + self.clips = self._make_dataset(root, setting) + if len(self.clips) == 0: + raise(RuntimeError("Found 0 video clips in subfolders of: " + root + "\n" + "Check your data directory (opt.data-dir).")) + + if name_pattern: + self.name_pattern = name_pattern + else: + if self.modality == "rgb": + self.name_pattern = "img_%05d.jpg" + elif self.modality == "flow": + self.name_pattern = "flow_%s_%05d.jpg" + + def __getitem__(self, index): + + directory, duration, target = self.clips[index] + average_duration = int(duration / self.num_segments) + offsets = [] + for seg_id in range(self.num_segments): + if self.train and not self.test_mode: + # training + if average_duration >= self.new_length: + offset = random.randint(0, average_duration - self.new_length) + # No +1 because randint(a,b) return a random integer N such that a <= N <= b. + offsets.append(offset + seg_id * average_duration) + else: + offsets.append(0) + elif not self.train and not self.test_mode: + # validation + if average_duration >= self.new_length: + offsets.append(int((average_duration - self.new_length + 1)/2 + seg_id * average_duration)) + else: + offsets.append(0) + else: + # test + if average_duration >= self.new_length: + offsets.append(int((average_duration - self.new_length + 1)/2 + seg_id * average_duration)) + else: + offsets.append(0) + + clip_input = self._TSN_RGB(directory, offsets, self.new_height, self.new_width, self.new_length, self.is_color, self.name_pattern) + + if self.transform is not None: + clip_input = self.transform(clip_input) + + if self.num_segments > 1 and not self.test_mode: + # For TSN training, reshape the input to B x 3 x H x W. Here, B = batch_size * num_segments + clip_input = clip_input.reshape((-1, 3 * self.new_length, self.target_height, self.target_width)) + + return clip_input, target + + def __len__(self): + return len(self.clips) + + def _find_classes(self, directory): + + classes = [d for d in os.listdir(directory) if os.path.isdir(os.path.join(directory, d))] + classes.sort() + class_to_idx = {classes[i]: i for i in range(len(classes))} + return classes, class_to_idx + + def _make_dataset(self, directory, setting): + + if not os.path.exists(setting): + raise(RuntimeError("Setting file %s doesn't exist. Check opt.train-list and opt.val-list. " % (setting))) + clips = [] + with open(setting) as split_f: + data = split_f.readlines() + for line in data: + line_info = line.split() + # line format: video_path, video_duration, video_label + clip_path = os.path.join(directory, line_info[0]) + duration = int(line_info[1]) + target = int(line_info[2]) + item = (clip_path, duration, target) + clips.append(item) + return clips + + def _TSN_RGB(self, directory, offsets, new_height, new_width, new_length, is_color, name_pattern): + + from ...utils.filesystem import try_import_cv2 + cv2 = try_import_cv2() + + if is_color: + cv_read_flag = cv2.IMREAD_COLOR + else: + cv_read_flag = cv2.IMREAD_GRAYSCALE + interpolation = cv2.INTER_LINEAR + + sampled_list = [] + for _, offset in enumerate(offsets): + for length_id in range(1, new_length+1): + frame_name = name_pattern % (length_id + offset) + frame_path = directory + "/" + frame_name + cv_img_origin = cv2.imread(frame_path, cv_read_flag) + if cv_img_origin is None: + raise(RuntimeError("Could not load file %s. Check data path." % (frame_path))) + if new_width > 0 and new_height > 0: + cv_img = cv2.resize(cv_img_origin, (new_width, new_height), interpolation) + else: + cv_img = cv_img_origin + cv_img = cv2.cvtColor(cv_img, cv2.COLOR_BGR2RGB) + sampled_list.append(cv_img) + # the shape of clip_input will be H x W x C, and C = num_segments * new_length * 3 + clip_input = np.concatenate(sampled_list, axis=2) + return nd.array(clip_input) diff --git a/gluoncv/model_zoo/__init__.py b/gluoncv/model_zoo/__init__.py index 56f447fb4c..a263dfdbfe 100644 --- a/gluoncv/model_zoo/__init__.py +++ b/gluoncv/model_zoo/__init__.py @@ -17,6 +17,7 @@ from .se_resnet import * from .nasnet import * from .simple_pose.simple_pose_resnet import * +from .action_recognition import * from .alexnet import * from .densenet import * diff --git a/gluoncv/model_zoo/action_recognition/__init__.py b/gluoncv/model_zoo/action_recognition/__init__.py new file mode 100755 index 0000000000..31ae0ceab6 --- /dev/null +++ b/gluoncv/model_zoo/action_recognition/__init__.py @@ -0,0 +1,5 @@ +# pylint: disable=wildcard-import +"""Video action recognition.""" +from __future__ import absolute_import +from .vgg16_ucf101 import * +from .inceptionv3_ucf101 import * diff --git a/gluoncv/model_zoo/action_recognition/inceptionv3_ucf101.py b/gluoncv/model_zoo/action_recognition/inceptionv3_ucf101.py new file mode 100644 index 0000000000..9e6986e0c0 --- /dev/null +++ b/gluoncv/model_zoo/action_recognition/inceptionv3_ucf101.py @@ -0,0 +1,65 @@ +# pylint: disable=line-too-long,too-many-lines,missing-docstring,arguments-differ,unused-argument +from mxnet import init +from mxnet.gluon import nn +from mxnet.gluon.nn import HybridBlock +from ...nn.block import Consensus +from ..inception import inception_v3 + +__all__ = ['inceptionv3_ucf101', 'ActionRecInceptionV3', 'ActionRecInceptionV3TSN'] + +def inceptionv3_ucf101(nclass=101, pretrained=True, tsn=False, partial_bn=True, num_segments=3, **kwargs): + if tsn: + model = ActionRecInceptionV3TSN(nclass=nclass, pretrained=pretrained, partial_bn=partial_bn, num_segments=num_segments) + else: + model = ActionRecInceptionV3(nclass=nclass, pretrained=pretrained, partial_bn=partial_bn) + return model + +class ActionRecInceptionV3(HybridBlock): + r"""InceptionV3 model for video action recognition + + Parameters + ---------- + nclass : int, number of classes + pretrained : bool, load pre-trained weights or not + + Input: a single image + Output: a single predicted action label + """ + def __init__(self, nclass, pretrained=True, partial_bn=True, **kwargs): + super(ActionRecInceptionV3, self).__init__() + + pretrained_model = inception_v3(pretrained=pretrained, partial_bn=partial_bn, **kwargs) + self.features = pretrained_model.features + def update_dropout_ratio(block): + if isinstance(block, nn.basic_layers.Dropout): + block._rate = 0.8 + self.apply(update_dropout_ratio) + self.output = nn.Dense(units=nclass, in_units=2048, weight_initializer=init.Normal(sigma=0.001)) + self.output.initialize() + + def hybrid_forward(self, F, x): + x = self.features(x) + x = self.output(x) + return x + +class ActionRecInceptionV3TSN(HybridBlock): + r"""InceptionV3 model with temporal segments for video action recognition + + Parameters + ---------- + nclass : int, number of classes + pretrained : bool, load pre-trained weights or not + + Input: N images from N segments in a single video + Output: a single predicted action label + """ + def __init__(self, nclass, pretrained=True, partial_bn=True, num_segments=3, **kwargs): + super(ActionRecInceptionV3TSN, self).__init__() + + self.basenet = ActionRecInceptionV3(nclass=nclass, pretrained=pretrained, partial_bn=partial_bn) + self.tsn_consensus = Consensus(nclass=nclass, num_segments=num_segments) + + def hybrid_forward(self, F, x): + pred = self.basenet(x) + consensus_out = self.tsn_consensus(pred) + return consensus_out diff --git a/gluoncv/model_zoo/action_recognition/vgg16_ucf101.py b/gluoncv/model_zoo/action_recognition/vgg16_ucf101.py new file mode 100644 index 0000000000..f279e48dd7 --- /dev/null +++ b/gluoncv/model_zoo/action_recognition/vgg16_ucf101.py @@ -0,0 +1,69 @@ +# pylint: disable=line-too-long,too-many-lines,missing-docstring,arguments-differ,unused-argument +from mxnet import init +from mxnet.gluon import nn +from mxnet.gluon.nn import HybridBlock +from ...nn.block import Consensus +from ..vgg import vgg16 + +__all__ = ['vgg16_ucf101', 'ActionRecVGG16', 'ActionRecVGG16TSN'] + +def vgg16_ucf101(nclass=101, pretrained=True, tsn=False, num_segments=3, **kwargs): + if tsn: + model = ActionRecVGG16TSN(nclass=nclass, pretrained=pretrained, num_segments=num_segments) + else: + model = ActionRecVGG16(nclass=nclass, pretrained=pretrained) + return model + +class ActionRecVGG16(HybridBlock): + r"""VGG16 model for video action recognition + Limin Wang, etal, Towards Good Practices for Very Deep Two-Stream ConvNets, arXiv 2015 + https://arxiv.org/abs/1507.02159 + + Parameters + ---------- + nclass : int, number of classes + pretrained : bool, load pre-trained weights or not + + Input: a single video frame + Output: a single predicted action label + """ + def __init__(self, nclass, pretrained=True, **kwargs): + super(ActionRecVGG16, self).__init__() + + pretrained_model = vgg16(pretrained=pretrained, **kwargs) + self.features = pretrained_model.features + def update_dropout_ratio(block): + if isinstance(block, nn.basic_layers.Dropout): + block._rate = 0.9 + self.apply(update_dropout_ratio) + self.output = nn.Dense(units=nclass, in_units=4096, weight_initializer=init.Normal(sigma=0.001)) + self.output.initialize() + + def hybrid_forward(self, F, x): + x = self.features(x) + x = self.output(x) + return x + +class ActionRecVGG16TSN(HybridBlock): + r"""VGG16 model with temporal segments for video action recognition + Limin Wang, etal, Temporal Segment Networks: Towards Good Practices for Deep Action Recognition, ECCV 2016 + https://arxiv.org/abs/1608.00859 + + Parameters + ---------- + nclass : int, number of classes + pretrained : bool, load pre-trained weights or not + + Input: N images from N segments in a single video + Output: a single predicted action label + """ + def __init__(self, nclass, pretrained=True, num_segments=3, **kwargs): + super(ActionRecVGG16TSN, self).__init__() + + self.basenet = ActionRecVGG16(nclass=nclass, pretrained=pretrained) + self.tsn_consensus = Consensus(nclass=nclass, num_segments=num_segments) + + def hybrid_forward(self, F, x): + pred = self.basenet(x) + consensus_out = self.tsn_consensus(pred) + return consensus_out diff --git a/gluoncv/model_zoo/inception.py b/gluoncv/model_zoo/inception.py index 344bbda2b6..e48fbbe1ec 100644 --- a/gluoncv/model_zoo/inception.py +++ b/gluoncv/model_zoo/inception.py @@ -171,13 +171,21 @@ class Inception3(HybridBlock): Additional `norm_layer` arguments, for example `num_devices=4` for :class:`mxnet.gluon.contrib.nn.SyncBatchNorm`. """ - def __init__(self, classes=1000, norm_layer=BatchNorm, norm_kwargs=None, **kwargs): + def __init__(self, classes=1000, norm_layer=BatchNorm, + norm_kwargs=None, partial_bn=False, **kwargs): super(Inception3, self).__init__(**kwargs) # self.use_aux_logits = use_aux_logits with self.name_scope(): self.features = nn.HybridSequential(prefix='') self.features.add(_make_basic_conv(channels=32, kernel_size=3, strides=2, norm_layer=norm_layer, norm_kwargs=norm_kwargs)) + if partial_bn: + if norm_kwargs is not None: + norm_kwargs['use_global_stats'] = True + else: + norm_kwargs = {} + norm_kwargs['use_global_stats'] = True + self.features.add(_make_basic_conv(channels=32, kernel_size=3, norm_layer=norm_layer, norm_kwargs=norm_kwargs)) self.features.add(_make_basic_conv(channels=64, kernel_size=3, padding=1, @@ -211,7 +219,7 @@ def hybrid_forward(self, F, x): # Constructor def inception_v3(pretrained=False, ctx=cpu(), - root='~/.mxnet/models', **kwargs): + root='~/.mxnet/models', partial_bn=False, **kwargs): r"""Inception v3 model from `"Rethinking the Inception Architecture for Computer Vision" `_ paper. @@ -225,6 +233,8 @@ def inception_v3(pretrained=False, ctx=cpu(), The context in which to load the pretrained weights. root : str, default $MXNET_HOME/models Location for keeping the model parameters. + partial_bn : bool, default False + Freeze all batch normalization layers during training except the first layer. norm_layer : object Normalization layer used (default: :class:`mxnet.gluon.nn.BatchNorm`) Can be :class:`mxnet.gluon.nn.BatchNorm` or :class:`mxnet.gluon.contrib.nn.SyncBatchNorm`. @@ -232,6 +242,7 @@ def inception_v3(pretrained=False, ctx=cpu(), Additional `norm_layer` arguments, for example `num_devices=4` for :class:`mxnet.gluon.contrib.nn.SyncBatchNorm`. """ + net = Inception3(**kwargs) if pretrained: from .model_store import get_model_file diff --git a/gluoncv/model_zoo/model_zoo.py b/gluoncv/model_zoo/model_zoo.py index 840939dd68..4ca5a209c9 100644 --- a/gluoncv/model_zoo/model_zoo.py +++ b/gluoncv/model_zoo/model_zoo.py @@ -29,6 +29,7 @@ from .ssd import * from .vgg import * from .yolo import * +from .action_recognition import * __all__ = ['get_model', 'get_model_list'] @@ -214,6 +215,8 @@ 'ssd_512_mobilenet1.0_voc_int8': ssd_512_mobilenet1_0_voc_int8, 'ssd_512_resnet50_v1_voc_int8': ssd_512_resnet50_v1_voc_int8, 'ssd_512_vgg16_atrous_voc_int8': ssd_512_vgg16_atrous_voc_int8, + 'vgg16_ucf101': vgg16_ucf101, + 'inceptionv3_ucf101': inceptionv3_ucf101, } diff --git a/gluoncv/nn/block.py b/gluoncv/nn/block.py index 4f5542ff02..6066f6d83b 100644 --- a/gluoncv/nn/block.py +++ b/gluoncv/nn/block.py @@ -4,7 +4,7 @@ from __future__ import absolute_import from mxnet.gluon.nn import BatchNorm, HybridBlock -__all__ = ['BatchNormCudnnOff', 'ReLU6', 'HardSigmoid', 'HardSwish'] +__all__ = ['BatchNormCudnnOff', 'Consensus', 'ReLU6', 'HardSigmoid', 'HardSwish'] class BatchNormCudnnOff(BatchNorm): """Batch normalization layer without CUDNN. It is a temporary solution. @@ -20,6 +20,26 @@ def hybrid_forward(self, F, x, gamma, beta, running_mean, running_var): return F.BatchNorm(x, gamma, beta, running_mean, running_var, name='fwd', cudnn_off=True, **self._kwargs) +class Consensus(HybridBlock): + """Consensus used in temporal segment networks. + + Parameters + ---------- + nclass : number of classses + num_segments : number of segments + kwargs : arguments goes to mxnet.gluon.nn.Consensus + """ + + def __init__(self, nclass, num_segments, **kwargs): + super(Consensus, self).__init__(**kwargs) + self.nclass = nclass + self.num_segments = num_segments + + def hybrid_forward(self, F, x): + reshape_out = x.reshape((-1, self.num_segments, self.nclass)) + consensus_out = reshape_out.mean(axis=1) + return consensus_out + class ReLU6(HybridBlock): """RelU6 used in MobileNetV2 and MobileNetV3. diff --git a/gluoncv/utils/__init__.py b/gluoncv/utils/__init__.py index 46360fc182..5e304d8e8f 100644 --- a/gluoncv/utils/__init__.py +++ b/gluoncv/utils/__init__.py @@ -15,3 +15,4 @@ from .lr_scheduler import LRSequential, LRScheduler from .plot_history import TrainingHistory from .export_helper import export_block +from .sync_loader_helper import split_data, split_and_load diff --git a/gluoncv/utils/sync_loader_helper.py b/gluoncv/utils/sync_loader_helper.py new file mode 100644 index 0000000000..292190b5f6 --- /dev/null +++ b/gluoncv/utils/sync_loader_helper.py @@ -0,0 +1,87 @@ +"""Dataloader helper functions. Synchronize slices for both data and label.""" +__all__ = ['split_data', 'split_and_load'] + +from mxnet import ndarray + +def split_data(data, num_slice, batch_axis=0, even_split=True, multiplier=1): + """Splits an NDArray into `num_slice` slices along `batch_axis`. + Usually used for data parallelism where each slices is sent + to one device (i.e. GPU). + + Parameters + ---------- + data : NDArray + A batch of data. + num_slice : int + Number of desired slices. + batch_axis : int, default 0 + The axis along which to slice. + even_split : bool, default True + Whether to force all slices to have the same number of elements. + If `True`, an error will be raised when `num_slice` does not evenly + divide `data.shape[batch_axis]`. + multiplier : int, default 1 + The batch size has to be the multiples of multiplier + + Returns + ------- + list of NDArray + Return value is a list even if `num_slice` is 1. + """ + size = data.shape[batch_axis] + if even_split and size % num_slice != 0: + raise ValueError( + "data with shape %s cannot be evenly split into %d slices along axis %d. " \ + "Use a batch size that's multiple of %d or set even_split=False to allow " \ + "uneven partitioning of data."%( + str(data.shape), num_slice, batch_axis, num_slice)) + + step = (int(size / multiplier) // num_slice) * multiplier + + # If size < num_slice, make fewer slices + if not even_split and size < num_slice: + step = 1 + num_slice = size + + if batch_axis == 0: + slices = [data[i*step:(i+1)*step] if i < num_slice - 1 else data[i*step:size] + for i in range(num_slice)] + elif even_split: + slices = ndarray.split(data, num_outputs=num_slice, axis=batch_axis) + else: + slices = [ndarray.slice_axis(data, batch_axis, i*step, (i+1)*step) + if i < num_slice - 1 else + ndarray.slice_axis(data, batch_axis, i*step, size) + for i in range(num_slice)] + return slices + + +def split_and_load(data, ctx_list, batch_axis=0, even_split=True, multiplier=1): + """Splits an NDArray into `len(ctx_list)` slices along `batch_axis` and loads + each slice to one context in `ctx_list`. + + Parameters + ---------- + data : NDArray + A batch of data. + ctx_list : list of Context + A list of Contexts. + batch_axis : int, default 0 + The axis along which to slice. + even_split : bool, default True + Whether to force all slices to have the same number of elements. + multiplier : int, default 1 + The batch size has to be the multiples of channel multiplier + + Returns + ------- + list of NDArray + Each corresponds to a context in `ctx_list`. + """ + if not isinstance(data, ndarray.NDArray): + data = ndarray.array(data, ctx=ctx_list[0]) + if len(ctx_list) == 1: + return [data.as_in_context(ctx_list[0])] + + slices = split_data(data, len(ctx_list), batch_axis, even_split, multiplier) + return [i.as_in_context(ctx) for i, ctx in zip(slices, ctx_list)] diff --git a/scripts/action-recognition/test_recognizer.py b/scripts/action-recognition/test_recognizer.py new file mode 100644 index 0000000000..250ca277e2 --- /dev/null +++ b/scripts/action-recognition/test_recognizer.py @@ -0,0 +1,212 @@ +import argparse, time, logging, os, sys, math +import cv2 +import numpy as np +import mxnet as mx +import gluoncv as gcv +from mxnet import gluon, nd, gpu, init, context +from mxnet import autograd as ag +from mxnet.gluon import nn +from mxnet.gluon.data.vision import transforms +from mxboard import SummaryWriter + +from gluoncv.data.transforms import video +from gluoncv.data import ucf101 +from gluoncv.model_zoo import get_model +from gluoncv.utils import makedirs, LRSequential, LRScheduler, split_and_load + +# CLI +def parse_args(): + parser = argparse.ArgumentParser(description='Train a model for action recognition.') + parser.add_argument('--data-dir', type=str, default='~/.mxnet/datasets/ucf101', + help='training and validation pictures to use.') + parser.add_argument('--train-list', type=str, default='~/.mxnet/datasets/ucf101/ucfTrainTestlist/ucf101_train_rgb_split1.txt', + help='the list of training data') + parser.add_argument('--val-list', type=str, default='~/.mxnet/datasets/ucf101/ucfTrainTestlist/ucf101_val_rgb_split1.txt', + help='the list of validation data') + parser.add_argument('--batch-size', type=int, default=32, + help='training batch size per device (CPU/GPU).') + parser.add_argument('--dtype', type=str, default='float32', + help='data type for training. default is float32') + parser.add_argument('--num-gpus', type=int, default=0, + help='number of gpus to use.') + parser.add_argument('-j', '--num-data-workers', dest='num_workers', default=4, type=int, + help='number of preprocessing workers') + parser.add_argument('--num-epochs', type=int, default=3, + help='number of training epochs.') + parser.add_argument('--lr', type=float, default=0.1, + help='learning rate. default is 0.1.') + parser.add_argument('--momentum', type=float, default=0.9, + help='momentum value for optimizer, default is 0.9.') + parser.add_argument('--wd', type=float, default=0.0001, + help='weight decay rate. default is 0.0001.') + parser.add_argument('--lr-mode', type=str, default='step', + help='learning rate scheduler mode. options are step, poly and cosine.') + parser.add_argument('--lr-decay', type=float, default=0.1, + help='decay rate of learning rate. default is 0.1.') + parser.add_argument('--lr-decay-period', type=int, default=0, + help='interval for periodic learning rate decays. default is 0 to disable.') + parser.add_argument('--lr-decay-epoch', type=str, default='40,60', + help='epochs at which learning rate decays. default is 40,60.') + parser.add_argument('--warmup-lr', type=float, default=0.0, + help='starting warmup learning rate. default is 0.0.') + parser.add_argument('--warmup-epochs', type=int, default=0, + help='number of warmup epochs.') + parser.add_argument('--last-gamma', action='store_true', + help='whether to init gamma of the last BN layer in each bottleneck to 0.') + parser.add_argument('--mode', type=str, + help='mode in which to train the model. options are symbolic, imperative, hybrid') + parser.add_argument('--model', type=str, required=True, + help='type of model to use. see vision_model for options.') + parser.add_argument('--input-size', type=int, default=224, + help='size of the input image size. default is 224') + parser.add_argument('--crop-ratio', type=float, default=0.875, + help='Crop ratio during validation. default is 0.875') + parser.add_argument('--use-pretrained', action='store_true', + help='enable using pretrained model from gluon.') + parser.add_argument('--use_se', action='store_true', + help='use SE layers or not in resnext. default is false.') + parser.add_argument('--mixup', action='store_true', + help='whether train the model with mix-up. default is false.') + parser.add_argument('--mixup-alpha', type=float, default=0.2, + help='beta distribution parameter for mixup sampling, default is 0.2.') + parser.add_argument('--mixup-off-epoch', type=int, default=0, + help='how many last epochs to train without mixup, default is 0.') + parser.add_argument('--label-smoothing', action='store_true', + help='use label smoothing or not in training. default is false.') + parser.add_argument('--no-wd', action='store_true', + help='whether to remove weight decay on bias, and beta/gamma for batchnorm layers.') + parser.add_argument('--teacher', type=str, default=None, + help='teacher model for distillation training') + parser.add_argument('--temperature', type=float, default=20, + help='temperature parameter for distillation teacher model') + parser.add_argument('--hard-weight', type=float, default=0.5, + help='weight for the loss of one-hot label for distillation training') + parser.add_argument('--batch-norm', action='store_true', + help='enable batch normalization or not in vgg. default is false.') + parser.add_argument('--save-frequency', type=int, default=10, + help='frequency of model saving.') + parser.add_argument('--save-dir', type=str, default='params', + help='directory of saved models') + parser.add_argument('--resume-epoch', type=int, default=0, + help='epoch to resume training from.') + parser.add_argument('--resume-params', type=str, default='', + help='path of parameters to load from.') + parser.add_argument('--resume-states', type=str, default='', + help='path of trainer state to load from.') + parser.add_argument('--log-interval', type=int, default=50, + help='Number of batches to wait before logging.') + parser.add_argument('--logging-file', type=str, default='train.log', + help='name of training log file') + parser.add_argument('--use-gn', action='store_true', + help='whether to use group norm.') + parser.add_argument('--eval', action='store_true', + help='directly evaluate the model.') + parser.add_argument('--num-segments', type=int, default=1, + help='number of segments to evenly split the video.') + parser.add_argument('--use-tsn', action='store_true', + help='whether to use temporal segment networks.') + parser.add_argument('--new-height', type=int, default=256, + help='new height of the resize image. default is 256') + parser.add_argument('--new-width', type=int, default=340, + help='new width of the resize image. default is 340') + parser.add_argument('--num-classes', type=int, default=101, + help='number of classes.') + parser.add_argument('--ten-crop', action='store_true', + help='whether to use ten crop evaluation.') + opt = parser.parse_args() + return opt + +def batch_fn(batch, ctx): + data = split_and_load(batch[0], ctx_list=ctx, batch_axis=0, even_split=False) + label = split_and_load(batch[1], ctx_list=ctx, batch_axis=0, even_split=False) + return data, label + +def main(): + opt = parse_args() + + # set env + num_gpus = opt.num_gpus + batch_size = opt.batch_size + batch_size *= max(1, num_gpus) + context = [mx.gpu(i) for i in range(num_gpus)] if num_gpus > 0 else [mx.cpu()] + num_workers = opt.num_workers + print('Total batch size is set to %d on %d GPUs' % (batch_size, num_gpus)) + + # get model + classes = opt.num_classes + model_name = opt.model + net = get_model(name=model_name, nclass=classes, pretrained=True, tsn=opt.use_tsn) + net.cast(opt.dtype) + net.collect_params().reset_ctx(context) + if opt.mode == 'hybrid': + net.hybridize(static_alloc=True, static_shape=True) + if opt.resume_params is not '': + net.load_parameters(opt.resume_params, ctx=context) + print('Pre-trained model %s is successfully loaded' % (opt.resume_params)) + + # get data + normalize = video.VideoNormalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) + transform_test = transforms.Compose([ + video.VideoTenCrop(opt.input_size), + video.VideoToTensor(), + normalize + ]) + + val_dataset = ucf101.classification.UCF101(setting=opt.val_list, root=opt.data_dir, train=False, + new_width=opt.new_width, new_height=opt.new_height, + target_width=opt.input_size, target_height=opt.input_size, + test_mode=True, num_segments=opt.num_segments, transform=transform_test) + val_data = gluon.data.DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers) + print('Load %d test samples.' % len(val_dataset)) + + # start evaluation + acc_top1 = mx.metric.Accuracy() + acc_top5 = mx.metric.TopKAccuracy(5) + + """Common practice during evaluation is to evenly sample 25 frames from a single video, and then perform 10-crop data augmentation. + This leads to 250 samples per video (750 channels). If this is too large to fit into one GPU, we can split it into multiple data bacthes. + `num_split_frames` has to be multiples of 3. + """ + num_data_batches = 10 + num_split_frames = int(750 / num_data_batches) + + def test(ctx, val_data): + acc_top1.reset() + acc_top5.reset() + for i, batch in enumerate(val_data): + outputs = [] + for seg_id in range(num_data_batches): + bs = seg_id * num_split_frames + be = (seg_id + 1) * num_split_frames + new_batch = [batch[0][:,bs:be,:,:], batch[1]] + data, label = batch_fn(new_batch, ctx) + for gpu_id, X in enumerate(data): + X_reshaped = X.reshape((-1, 3, opt.input_size, opt.input_size)) + pred = net(X_reshaped.astype(opt.dtype, copy=False)) + if seg_id == 0: + outputs.append(pred) + else: + outputs[gpu_id] = nd.concat(outputs[gpu_id], pred, dim=0) + # Perform the mean operation on 250 samples of each video + for gpu_id, out in enumerate(outputs): + outputs[gpu_id] = nd.expand_dims(out.mean(axis=0), axis=0) + + acc_top1.update(label, outputs) + acc_top5.update(label, outputs) + + if i > 0 and i % opt.log_interval == 0: + print('%04d/%04d is done' % (i, len(val_data))) + + _, top1 = acc_top1.get() + _, top5 = acc_top5.get() + return (top1, top5) + + start_time = time.time() + acc_top1_val, acc_top5_val = test(context, val_data) + end_time = time.time() + + print('Test accuracy: acc-top1=%f acc-top5=%f' % (acc_top1_val*100, acc_top5_val*100)) + print('Total evaluation time is %4.2f minutes' % ((end_time - start_time) / 60)) + +if __name__ == '__main__': + main() diff --git a/scripts/action-recognition/train_recognizer.py b/scripts/action-recognition/train_recognizer.py new file mode 100755 index 0000000000..72ab83602e --- /dev/null +++ b/scripts/action-recognition/train_recognizer.py @@ -0,0 +1,345 @@ +import argparse, time, logging, os, sys, math + +import numpy as np +import mxnet as mx +import gluoncv as gcv +from mxnet import gluon, nd, init, context +from mxnet import autograd as ag +from mxnet.gluon import nn +from mxnet.gluon.data.vision import transforms +from mxboard import SummaryWriter + +from gluoncv.data.transforms import video +from gluoncv.data import ucf101 +from gluoncv.model_zoo import get_model +from gluoncv.utils import makedirs, LRSequential, LRScheduler, split_and_load + +# CLI +def parse_args(): + parser = argparse.ArgumentParser(description='Train a model for action recognition.') + parser.add_argument('--data-dir', type=str, default='~/.mxnet/datasets/ucf101', + help='training and validation pictures to use.') + parser.add_argument('--train-list', type=str, default='~/.mxnet/datasets/ucf101/ucfTrainTestlist/ucf101_train_rgb_split1.txt', + help='the list of training data') + parser.add_argument('--val-list', type=str, default='~/.mxnet/datasets/ucf101/ucfTrainTestlist/ucf101_val_rgb_split1.txt', + help='the list of validation data') + parser.add_argument('--batch-size', type=int, default=32, + help='training batch size per device (CPU/GPU).') + parser.add_argument('--dtype', type=str, default='float32', + help='data type for training. default is float32') + parser.add_argument('--num-gpus', type=int, default=0, + help='number of gpus to use.') + parser.add_argument('-j', '--num-data-workers', dest='num_workers', default=4, type=int, + help='number of preprocessing workers') + parser.add_argument('--num-epochs', type=int, default=3, + help='number of training epochs.') + parser.add_argument('--lr', type=float, default=0.1, + help='learning rate. default is 0.1.') + parser.add_argument('--momentum', type=float, default=0.9, + help='momentum value for optimizer, default is 0.9.') + parser.add_argument('--wd', type=float, default=0.0001, + help='weight decay rate. default is 0.0001.') + parser.add_argument('--lr-mode', type=str, default='step', + help='learning rate scheduler mode. options are step, poly and cosine.') + parser.add_argument('--lr-decay', type=float, default=0.1, + help='decay rate of learning rate. default is 0.1.') + parser.add_argument('--lr-decay-period', type=int, default=0, + help='interval for periodic learning rate decays. default is 0 to disable.') + parser.add_argument('--lr-decay-epoch', type=str, default='40,60', + help='epochs at which learning rate decays. default is 40,60.') + parser.add_argument('--warmup-lr', type=float, default=0.0, + help='starting warmup learning rate. default is 0.0.') + parser.add_argument('--warmup-epochs', type=int, default=0, + help='number of warmup epochs.') + parser.add_argument('--last-gamma', action='store_true', + help='whether to init gamma of the last BN layer in each bottleneck to 0.') + parser.add_argument('--mode', type=str, + help='mode in which to train the model. options are symbolic, imperative, hybrid') + parser.add_argument('--model', type=str, required=True, + help='type of model to use. see vision_model for options.') + parser.add_argument('--input-size', type=int, default=224, + help='size of the input image size. default is 224') + parser.add_argument('--crop-ratio', type=float, default=0.875, + help='Crop ratio during validation. default is 0.875') + parser.add_argument('--use-pretrained', action='store_true', + help='enable using pretrained model from gluon.') + parser.add_argument('--use_se', action='store_true', + help='use SE layers or not in resnext. default is false.') + parser.add_argument('--mixup', action='store_true', + help='whether train the model with mix-up. default is false.') + parser.add_argument('--mixup-alpha', type=float, default=0.2, + help='beta distribution parameter for mixup sampling, default is 0.2.') + parser.add_argument('--mixup-off-epoch', type=int, default=0, + help='how many last epochs to train without mixup, default is 0.') + parser.add_argument('--label-smoothing', action='store_true', + help='use label smoothing or not in training. default is false.') + parser.add_argument('--no-wd', action='store_true', + help='whether to remove weight decay on bias, and beta/gamma for batchnorm layers.') + parser.add_argument('--teacher', type=str, default=None, + help='teacher model for distillation training') + parser.add_argument('--temperature', type=float, default=20, + help='temperature parameter for distillation teacher model') + parser.add_argument('--hard-weight', type=float, default=0.5, + help='weight for the loss of one-hot label for distillation training') + parser.add_argument('--batch-norm', action='store_true', + help='enable batch normalization or not in vgg. default is false.') + parser.add_argument('--save-frequency', type=int, default=10, + help='frequency of model saving.') + parser.add_argument('--save-dir', type=str, default='params', + help='directory of saved models') + parser.add_argument('--resume-epoch', type=int, default=0, + help='epoch to resume training from.') + parser.add_argument('--resume-params', type=str, default='', + help='path of parameters to load from.') + parser.add_argument('--resume-states', type=str, default='', + help='path of trainer state to load from.') + parser.add_argument('--log-interval', type=int, default=50, + help='Number of batches to wait before logging.') + parser.add_argument('--logging-file', type=str, default='train.log', + help='name of training log file') + parser.add_argument('--use-gn', action='store_true', + help='whether to use group norm.') + parser.add_argument('--eval', action='store_true', + help='directly evaluate the model.') + parser.add_argument('--num-segments', type=int, default=1, + help='number of segments to evenly split the video.') + parser.add_argument('--use-tsn', action='store_true', + help='whether to use temporal segment networks.') + parser.add_argument('--new-height', type=int, default=256, + help='new height of the resize image. default is 256') + parser.add_argument('--new-width', type=int, default=340, + help='new width of the resize image. default is 340') + parser.add_argument('--clip-grad', type=int, default=0, + help='clip gradient to a certain threshold. Set the value to be larger than zero to enable gradient clipping.') + parser.add_argument('--partial-bn', action='store_true', + help='whether to freeze bn layers except the first layer.') + parser.add_argument('--num-classes', type=int, default=101, + help='number of classes.') + opt = parser.parse_args() + return opt + +def tsn_mp_batchify_fn(data): + """Collate data into batch. Use shared memory for stacking. + Modify default batchify function for temporal segment networks. + Change `nd.stack` to `nd.concat` since batch dimension already exists. + """ + if isinstance(data[0], nd.NDArray): + return nd.concat(*data, dim=0) + elif isinstance(data[0], tuple): + data = zip(*data) + return [tsn_mp_batchify_fn(i) for i in data] + else: + data = np.asarray(data) + return nd.array(data, dtype=data.dtype, + ctx=context.Context('cpu_shared', 0)) + +def get_data_loader(opt, batch_size, num_workers, logger): + data_dir = opt.data_dir + normalize = video.VideoNormalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) + scale_ratios = [1.0, 0.875, 0.75, 0.66] + input_size = opt.input_size + + def batch_fn(batch, ctx): + if opt.num_segments > 1: + data = split_and_load(batch[0], ctx_list=ctx, batch_axis=0, even_split=False, multiplier=opt.num_segments) + else: + data = split_and_load(batch[0], ctx_list=ctx, batch_axis=0, even_split=False) + label = split_and_load(batch[1], ctx_list=ctx, batch_axis=0, even_split=False) + return data, label + + transform_train = transforms.Compose([ + video.VideoMultiScaleCrop(size=(input_size, input_size), scale_ratios=scale_ratios), + video.VideoRandomHorizontalFlip(), + video.VideoToTensor(), + normalize + ]) + transform_test = transforms.Compose([ + video.VideoCenterCrop(size=input_size), + video.VideoToTensor(), + normalize + ]) + + train_dataset = ucf101.classification.UCF101(setting=opt.train_list, root=data_dir, train=True, + new_width=opt.new_width, new_height=opt.new_height, + target_width=input_size, target_height=input_size, + num_segments=opt.num_segments, transform=transform_train) + val_dataset = ucf101.classification.UCF101(setting=opt.val_list, root=data_dir, train=False, + new_width=opt.new_width, new_height=opt.new_height, + target_width=input_size, target_height=input_size, + num_segments=opt.num_segments, transform=transform_test) + logger.info('Load %d training samples and %d validation samples.' % (len(train_dataset), len(val_dataset))) + + if opt.num_segments > 1: + train_data = gluon.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers, batchify_fn=tsn_mp_batchify_fn) + val_data = gluon.data.DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers, batchify_fn=tsn_mp_batchify_fn) + else: + train_data = gluon.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers) + val_data = gluon.data.DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers) + + return train_data, val_data, batch_fn + +def main(): + opt = parse_args() + + makedirs(opt.save_dir) + + filehandler = logging.FileHandler(os.path.join(opt.save_dir, opt.logging_file)) + streamhandler = logging.StreamHandler() + logger = logging.getLogger('') + logger.setLevel(logging.INFO) + logger.addHandler(filehandler) + logger.addHandler(streamhandler) + logger.info(opt) + + sw = SummaryWriter(logdir=opt.save_dir, flush_secs=5) + + batch_size = opt.batch_size + classes = opt.num_classes + + num_gpus = opt.num_gpus + batch_size *= max(1, num_gpus) + logger.info('Total batch size is set to %d on %d GPUs' % (batch_size, num_gpus)) + context = [mx.gpu(i) for i in range(num_gpus)] if num_gpus > 0 else [mx.cpu()] + num_workers = opt.num_workers + + lr_decay = opt.lr_decay + lr_decay_period = opt.lr_decay_period + if opt.lr_decay_period > 0: + lr_decay_epoch = list(range(lr_decay_period, opt.num_epochs, lr_decay_period)) + else: + lr_decay_epoch = [int(i) for i in opt.lr_decay_epoch.split(',')] + lr_decay_epoch = [e - opt.warmup_epochs for e in lr_decay_epoch] + + optimizer = 'sgd' + if opt.clip_grad > 0: + optimizer_params = {'learning_rate': opt.lr, 'wd': opt.wd, 'momentum': opt.momentum, 'clip_gradient': opt.clip_grad} + else: + optimizer_params = {'learning_rate': opt.lr, 'wd': opt.wd, 'momentum': opt.momentum} + + model_name = opt.model + net = get_model(name=model_name, nclass=classes, pretrained=opt.use_pretrained, + tsn=opt.use_tsn, num_segments=opt.num_segments, partial_bn=opt.partial_bn) + net.cast(opt.dtype) + net.collect_params().reset_ctx(context) + logger.info(net) + + if opt.resume_params is not '': + net.load_parameters(opt.resume_params, ctx=context) + + train_data, val_data, batch_fn = get_data_loader(opt, batch_size, num_workers, logger) + + train_metric = mx.metric.Accuracy() + acc_top1 = mx.metric.Accuracy() + acc_top5 = mx.metric.TopKAccuracy(5) + + def test(ctx, val_data): + acc_top1.reset() + acc_top5.reset() + for i, batch in enumerate(val_data): + data, label = batch_fn(batch, ctx) + outputs = [net(X.astype(opt.dtype, copy=False)) for X in data] + acc_top1.update(label, outputs) + acc_top5.update(label, outputs) + + _, top1 = acc_top1.get() + _, top5 = acc_top5.get() + return (top1, top5) + + def train(ctx): + if isinstance(ctx, mx.Context): + ctx = [ctx] + + if opt.no_wd: + for k, v in net.collect_params('.*beta|.*gamma|.*bias').items(): + v.wd_mult = 0.0 + + if opt.partial_bn: + train_patterns = None + if 'inceptionv3' in opt.model: + train_patterns = '.*weight|.*bias|inception30_batchnorm0_gamma|inception30_batchnorm0_beta|inception30_batchnorm0_running_mean|inception30_batchnorm0_running_var' + else: + logger.info('Current model does not support partial batch normalization.') + trainer = gluon.Trainer(net.collect_params(train_patterns), optimizer, optimizer_params) + else: + trainer = gluon.Trainer(net.collect_params(), optimizer, optimizer_params) + + if opt.resume_states is not '': + trainer.load_states(opt.resume_states) + + L = gluon.loss.SoftmaxCrossEntropyLoss() + + best_val_score = 0 + lr_decay_count = 0 + + for epoch in range(opt.resume_epoch, opt.num_epochs): + tic = time.time() + train_metric.reset() + btic = time.time() + + if epoch == lr_decay_epoch[lr_decay_count]: + trainer.set_learning_rate(trainer.learning_rate * lr_decay) + lr_decay_count += 1 + + for i, batch in enumerate(train_data): + data, label = batch_fn(batch, ctx) + + with ag.record(): + outputs = [net(X.astype(opt.dtype, copy=False)) for X in data] + loss = [L(yhat, y.astype(opt.dtype, copy=False)) for yhat, y in zip(outputs, label)] + + for l in loss: + l.backward() + + trainer.step(batch_size) + train_metric.update(label, outputs) + + if opt.log_interval and not (i+1) % opt.log_interval: + train_metric_name, train_metric_score = train_metric.get() + logger.info('Epoch[%d] Batch [%d]\tSpeed: %f samples/sec\t%s=%f\tlr=%f' % ( + epoch, i, batch_size*opt.log_interval/(time.time()-btic), + train_metric_name, train_metric_score*100, trainer.learning_rate)) + btic = time.time() + + train_metric_name, train_metric_score = train_metric.get() + throughput = int(batch_size * i /(time.time() - tic)) + + acc_top1_val, acc_top5_val = test(ctx, val_data) + + logger.info('[Epoch %d] training: %s=%f'%(epoch, train_metric_name, train_metric_score*100)) + logger.info('[Epoch %d] speed: %d samples/sec\ttime cost: %f'%(epoch, throughput, time.time()-tic)) + logger.info('[Epoch %d] validation: acc-top1=%f acc-top5=%f'%(epoch, acc_top1_val*100, acc_top5_val*100)) + + sw.add_scalar(tag='train_acc', value=train_metric_score*100, global_step=epoch) + sw.add_scalar(tag='valid_acc', value=acc_top1_val*100, global_step=epoch) + + if acc_top1_val > best_val_score: + best_val_score = acc_top1_val + if opt.use_tsn: + net.basenet.save_parameters('%s/%.4f-ucf101-%s-%03d-best.params'%(opt.save_dir, best_val_score, model_name, epoch)) + else: + net.save_parameters('%s/%.4f-ucf101-%s-%03d-best.params'%(opt.save_dir, best_val_score, model_name, epoch)) + trainer.save_states('%s/%.4f-ucf101-%s-%03d-best.states'%(opt.save_dir, best_val_score, model_name, epoch)) + + if opt.save_frequency and opt.save_dir and (epoch + 1) % opt.save_frequency == 0: + if opt.use_tsn: + net.basenet.save_parameters('%s/ucf101-%s-%03d.params'%(opt.save_dir, model_name, epoch)) + else: + net.save_parameters('%s/ucf101-%s-%03d.params'%(opt.save_dir, model_name, epoch)) + trainer.save_states('%s/ucf101-%s-%03d.states'%(opt.save_dir, model_name, epoch)) + + # save the last model + if opt.use_tsn: + net.basenet.save_parameters('%s/ucf101-%s-%03d.params'%(opt.save_dir, model_name, opt.num_epochs-1)) + else: + net.save_parameters('%s/ucf101-%s-%03d.params'%(opt.save_dir, model_name, opt.num_epochs-1)) + trainer.save_states('%s/ucf101-%s-%03d.states'%(opt.save_dir, model_name, opt.num_epochs-1)) + + if opt.mode == 'hybrid': + net.hybridize(static_alloc=True, static_shape=True) + + train(context) + sw.close() + + +if __name__ == '__main__': + main()