From 4e93cb69e7fa881129c32281e8553d6ffc325d04 Mon Sep 17 00:00:00 2001 From: Yoshitomo Matsubara Date: Sun, 23 Jul 2023 00:04:47 -0700 Subject: [PATCH 1/4] Add docstrings --- sc2bench/models/backbone.py | 248 +++++++++++++++++++++++++++++++- sc2bench/models/layer.py | 273 ++++++++++++++++++++++++++++++------ 2 files changed, 478 insertions(+), 43 deletions(-) diff --git a/sc2bench/models/backbone.py b/sc2bench/models/backbone.py index c65d503..ce67f6f 100644 --- a/sc2bench/models/backbone.py +++ b/sc2bench/models/backbone.py @@ -3,12 +3,12 @@ import torch from compressai.models import CompressionModel from timm.models import resnest, regnet, vision_transformer_hybrid +from torchdistill.common.main_util import load_ckpt from torchdistill.datasets.util import build_transform from torchdistill.models.registry import register_model_class, register_model_func from torchvision import models from torchvision.ops import misc as misc_nn_ops -from torchdistill.common.main_util import load_ckpt from .layer import get_layer from ..analysis import AnalyzableModule @@ -17,18 +17,40 @@ def register_backbone_class(cls): + """ + Registers a backbone model (usually a classification model). + + :param cls: backbone model class to be registered + :type cls: class + :return: registered backbone model class + :rtype: class + """ BACKBONE_CLASS_DICT[cls.__name__] = cls register_model_class(cls) return cls def register_backbone_func(func): + """ + Registers a function to build a backbone model (usually a classification model). + + :param func: function to build a backbone to be registered + :type func: typing.Callable + :return: registered function + :rtype: typing.Callable + """ BACKBONE_FUNC_DICT[func.__name__] = func register_model_func(func) return func class UpdatableBackbone(AnalyzableModule): + """ + A base, updatable R-CNN model. + + :param analyzer_configs: list of analysis configurations + :type analyzer_configs: list[dict] or None + """ def __init__(self, analyzer_configs=None): super().__init__(analyzer_configs) self.bottleneck_updated = False @@ -37,17 +59,49 @@ def forward(self, *args, **kwargs): raise NotImplementedError() def update(self, **kwargs): + """ + Updates compression-specific parameters like `CompressAI models do `_. + + This should be overridden by all subclasses. + """ raise NotImplementedError() def get_aux_module(self, **kwargs): + """ + Returns an auxiliary module to compute auxiliary loss if necessary like `CompressAI models do `_. + + This should be overridden by all subclasses. + """ raise NotImplementedError() def check_if_updatable(model): + """ + Checks if the given model is updatable. + + :param model: model + :type model: nn.Module + :return: True if the model is updatable, False otherwise + :rtype: bool + """ return isinstance(model, UpdatableBackbone) class FeatureExtractionBackbone(UpdatableBackbone): + """ + A feature extraction-based backbone model. + + :param model: model + :type model: nn.Module + :param return_layer_dict: mapping from name of module to return its output to a specified key + :type return_layer_dict: dict + :param analyzer_configs: list of analysis configurations + :type analyzer_configs: list[dict] or None + :param analyzes_after_compress: run analysis with `analyzer_configs` if True + :type analyzes_after_compress: bool + :param analyzable_layer_key: key of analyzable layer + :type analyzable_layer_key: str or None + """ # Referred to the IntermediateLayerGetter implementation at https://github.com/pytorch/vision/blob/main/torchvision/models/_utils.py def __init__(self, model, return_layer_dict, analyzer_configs, analyzes_after_compress=False, analyzable_layer_key=None): @@ -89,7 +143,13 @@ def forward(self, x): out[out_name] = x return out - def check_if_updatable(self, strict=True): + def check_if_updatable(self): + """ + Checks if this module is updatable with respect to CompressAI modules. + + :return: True if the model is updatable, False otherwise + :rtype: bool + """ if self.analyzable_layer_key is None or self.analyzable_layer_key not in self._modules \ or not isinstance(self._modules[self.analyzable_layer_key], CompressionModel): return False @@ -113,6 +173,30 @@ def get_aux_module(self, **kwargs): class SplittableResNet(UpdatableBackbone): + """ + ResNet/ResNeSt-based splittable image classification model containing neural encoder, entropy bottleneck, + and decoder. + + - Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun: `"Deep Residual Learning for Image Recognition" `_ @ CVPR 2016 (2016) + - Hang Zhang, Chongruo Wu, Zhongyue Zhang, Yi Zhu, Haibin Lin, Zhi Zhang, Yue Sun, Tong He, Jonas Mueller, R. Manmatha, Mu Li, Alexander Smola: `"ResNeSt: Split-Attention Networks" `_ @ CVPRW 2022 (2022) + - Yoshitomo Matsubara, Ruihan Yang, Marco Levorato, Stephan Mandt: `"Supervised Compression for Resource-Constrained Edge Computing Systems" `_ @ WACV 2022 (2022) + - Yoshitomo Matsubara, Ruihan Yang, Marco Levorato, Stephan Mandt: `"SC2 Benchmark: Supervised Compression for Split Computing" `_ @ TMLR (2023) + + :param bottleneck_layer: high-level bottleneck layer that consists of encoder and decoder + :type bottleneck_layer: nn.Module + :param resnet_model: ResNet model to be used as a base model + :type resnet_model: nn.Module + :param inplanes: ResNet model's inplanes + :type inplanes: int or None + :param skips_avgpool: if True, skips avgpool (average pooling) after layer4 + :type skips_avgpool: bool + :param skips_fc: if True, skips fc (fully-connected layer) after layer4 + :type skips_fc: bool + :param pre_transform_params: pre-transform parameters + :type pre_transform_params: dict or None + :param analysis_config: analysis configuration + :type analysis_config: dict or None + """ # Referred to the ResNet implementation at https://github.com/pytorch/vision/blob/main/torchvision/models/resnet.py def __init__(self, bottleneck_layer, resnet_model, inplanes=None, skips_avgpool=True, skips_fc=True, pre_transform_params=None, analysis_config=None): @@ -161,6 +245,12 @@ def update(self): self.bottleneck_updated = True def load_state_dict(self, state_dict, **kwargs): + """ + Loads parameters for all the sub-modules except bottleneck_layer and then bottleneck_layer. + + :param state_dict: dict containing parameters and persistent buffers + :type state_dict: dict + """ entropy_bottleneck_state_dict = OrderedDict() for key in list(state_dict.keys()): if key.startswith('bottleneck_layer.'): @@ -174,6 +264,25 @@ def get_aux_module(self, **kwargs): class SplittableRegNet(UpdatableBackbone): + """ + RegNet-based splittable image classification model containing neural encoder, entropy bottleneck, and decoder. + + - Ilija Radosavovic, Raj Prateek Kosaraju, Ross Girshick, Kaiming He, Piotr Dollár: `"Designing Network Design Spaces" `_ @ CVPR 2020 (2020) + - Yoshitomo Matsubara, Ruihan Yang, Marco Levorato, Stephan Mandt: `"SC2 Benchmark: Supervised Compression for Split Computing" `_ @ TMLR (2023) + + :param bottleneck_layer: high-level bottleneck layer that consists of encoder and decoder + :type bottleneck_layer: nn.Module + :param regnet_model: RegNet model (`timm`-style) to be used as a base model + :type regnet_model: nn.Module + :param inplanes: mapping from name of module to return its output to a specified key + :type inplanes: int or None + :param skips_head: if True, skips fc (fully-connected layer) after layer4 + :type skips_head: bool + :param pre_transform_params: pre-transform parameters + :type pre_transform_params: dict or None + :param analysis_config: analysis configuration + :type analysis_config: dict or None + """ # Referred to the RegNet implementation at https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/regnet.py def __init__(self, bottleneck_layer, regnet_model, inplanes=None, skips_head=True, pre_transform_params=None, analysis_config=None): @@ -214,6 +323,12 @@ def update(self): self.bottleneck_updated = True def load_state_dict(self, state_dict, **kwargs): + """ + Loads parameters for all the sub-modules except bottleneck_layer and then bottleneck_layer. + + :param state_dict: dict containing parameters and persistent buffers + :type state_dict: dict + """ entropy_bottleneck_state_dict = OrderedDict() for key in list(state_dict.keys()): if key.startswith('bottleneck_layer.'): @@ -227,6 +342,25 @@ def get_aux_module(self, **kwargs): class SplittableHybridViT(UpdatableBackbone): + """ + Hybrid ViT-based splittable image classification model containing neural encoder, entropy bottleneck, and decoder. + + - Andreas Peter Steiner, Alexander Kolesnikov, Xiaohua Zhai, Ross Wightman, Jakob Uszkoreit, Lucas Beyer: `"How to train your ViT? Data, Augmentation, and Regularization in Vision Transformers" `_ @ TMLR (2022) + - Yoshitomo Matsubara, Ruihan Yang, Marco Levorato, Stephan Mandt: `"SC2 Benchmark: Supervised Compression for Split Computing" `_ @ TMLR (2023) + + :param bottleneck_layer: high-level bottleneck layer that consists of encoder and decoder + :type bottleneck_layer: nn.Module + :param hybrid_vit_model: Hybrid Vision Transformer model (`timm`-style) to be used as a base model + :type hybrid_vit_model: nn.Module + :param num_pruned_stages: number of stages in the ResNet backbone of Hybrid ViT to be pruned + :type num_pruned_stages: int + :param skips_head: if True, skips classification head + :type skips_head: bool + :param pre_transform_params: pre-transform parameters + :type pre_transform_params: dict or None + :param analysis_config: analysis configuration + :type analysis_config: dict or None + """ # Referred to Hybrid ViT implementation at https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py def __init__(self, bottleneck_layer, hybrid_vit_model, num_pruned_stages=1, skips_head=True, pre_transform_params=None, analysis_config=None): @@ -281,6 +415,12 @@ def update(self): self.bottleneck_updated = True def load_state_dict(self, state_dict, **kwargs): + """ + Loads parameters for all the sub-modules except bottleneck_layer and then bottleneck_layer. + + :param state_dict: dict containing parameters and persistent buffers + :type state_dict: dict + """ entropy_bottleneck_state_dict = OrderedDict() for key in list(state_dict.keys()): if key.startswith('bottleneck_layer.'): @@ -297,6 +437,30 @@ def get_aux_module(self, **kwargs): def splittable_resnet(bottleneck_config, resnet_name='resnet50', inplanes=None, skips_avgpool=True, skips_fc=True, pre_transform_params=None, analysis_config=None, org_model_ckpt_file_path_or_url=None, org_ckpt_strict=True, **resnet_kwargs): + """ + Builds ResNet-based splittable image classification model containing neural encoder, entropy bottleneck, and decoder. + + :param bottleneck_config: bottleneck configuration + :type bottleneck_config: dict + :param resnet_name: name of ResNet function in `torchvision` + :type resnet_name: str + :param inplanes: ResNet model's inplanes + :type inplanes: int or None + :param skips_avgpool: if True, skips avgpool (average pooling) after layer4 + :type skips_avgpool: bool + :param skips_fc: if True, skips fc (fully-connected layer) after layer4 + :type skips_fc: bool + :param pre_transform_params: pre-transform parameters + :type pre_transform_params: dict or None + :param analysis_config: analysis configuration + :type analysis_config: dict or None + :param org_model_ckpt_file_path_or_url: original ResNet model checkpoint file path or URL + :type org_model_ckpt_file_path_or_url: str or None + :param org_ckpt_strict: whether to strictly enforce that the keys in state_dict match the keys returned by original ResNet model’s `state_dict()` function + :type org_ckpt_strict: bool + :return: splittable ResNet model + :rtype: SplittableResNet + """ bottleneck_layer = get_layer(bottleneck_config['name'], **bottleneck_config['params']) if resnet_kwargs.pop('norm_layer', '') == 'FrozenBatchNorm2d': resnet_model = models.__dict__[resnet_name](norm_layer=misc_nn_ops.FrozenBatchNorm2d, **resnet_kwargs) @@ -313,6 +477,31 @@ def splittable_resnet(bottleneck_config, resnet_name='resnet50', inplanes=None, def splittable_resnest(bottleneck_config, resnest_name='resnest50d', inplanes=None, skips_avgpool=True, skips_fc=True, pre_transform_params=None, analysis_config=None, org_model_ckpt_file_path_or_url=None, org_ckpt_strict=True, **resnest_kwargs): + """ + Builds ResNeSt-based splittable image classification model containing neural encoder, entropy bottleneck, + and decoder. + + :param bottleneck_config: bottleneck configuration + :type bottleneck_config: dict + :param resnest_name: name of ResNeSt function in `timm` + :type resnest_name: str + :param inplanes: ResNeSt model's inplanes + :type inplanes: int or None + :param skips_avgpool: if True, skips avgpool (average pooling) after layer4 + :type skips_avgpool: bool + :param skips_fc: if True, skips fc (fully-connected layer) after layer4 + :type skips_fc: bool + :param pre_transform_params: pre-transform parameters + :type pre_transform_params: dict or None + :param analysis_config: analysis configuration + :type analysis_config: dict or None + :param org_model_ckpt_file_path_or_url: original ResNeSt model checkpoint file path or URL + :type org_model_ckpt_file_path_or_url: str or None + :param org_ckpt_strict: whether to strictly enforce that the keys in state_dict match the keys returned by original ResNeSt model’s `state_dict()` function + :type org_ckpt_strict: bool + :return: splittable ResNeSt model + :rtype: SplittableResNet + """ bottleneck_layer = get_layer(bottleneck_config['name'], **bottleneck_config['params']) resnest_model = resnest.__dict__[resnest_name](**resnest_kwargs) if org_model_ckpt_file_path_or_url is not None: @@ -325,6 +514,28 @@ def splittable_resnest(bottleneck_config, resnest_name='resnest50d', inplanes=No def splittable_regnet(bottleneck_config, regnet_name='regnety_064', inplanes=None, skips_head=True, pre_transform_params=None, analysis_config=None, org_model_ckpt_file_path_or_url=None, org_ckpt_strict=True, **regnet_kwargs): + """ + Builds RegNet-based splittable image classification model containing neural encoder, entropy bottleneck, and decoder. + + :param bottleneck_config: bottleneck configuration + :type bottleneck_config: dict + :param regnet_name: name of RegNet function in `timm` + :type regnet_name: str + :param inplanes: mapping from name of module to return its output to a specified key + :type inplanes: int or None + :param skips_head: if True, skips fc (fully-connected layer) after layer4 + :type skips_head: bool + :param pre_transform_params: pre-transform parameters + :type pre_transform_params: dict or None + :param analysis_config: analysis configuration + :type analysis_config: dict or None + :param org_model_ckpt_file_path_or_url: original RegNet model checkpoint file path or URL + :type org_model_ckpt_file_path_or_url: str or None + :param org_ckpt_strict: whether to strictly enforce that the keys in state_dict match the keys returned by original RegNet model’s `state_dict()` function + :type org_ckpt_strict: bool + :return: splittable RegNet model + :rtype: SplittableRegNet + """ bottleneck_layer = get_layer(bottleneck_config['name'], **bottleneck_config['params']) regnet_model = regnet.__dict__[regnet_name](**regnet_kwargs) if org_model_ckpt_file_path_or_url is not None: @@ -336,6 +547,29 @@ def splittable_regnet(bottleneck_config, regnet_name='regnety_064', inplanes=Non def splittable_hybrid_vit(bottleneck_config, hybrid_vit_name='vit_small_r26_s32_224', num_pruned_stages=1, skips_head=True, pre_transform_params=None, analysis_config=None, org_model_ckpt_file_path_or_url=None, org_ckpt_strict=True, **hybrid_vit_kwargs): + """ + Builds Hybrid ViT-based splittable image classification model containing neural encoder, entropy bottleneck, and decoder. + + + :param bottleneck_config: bottleneck configuration + :type bottleneck_config: dict + :param hybrid_vit_name: name of Hybrid ViT function in `timm` + :type hybrid_vit_name: str + :param num_pruned_stages: number of stages in the ResNet backbone of Hybrid ViT to be pruned + :type num_pruned_stages: int + :param skips_head: if True, skips classification head + :type skips_head: bool + :param pre_transform_params: pre-transform parameters + :type pre_transform_params: dict or None + :param analysis_config: analysis configuration + :type analysis_config: dict or None + :param org_model_ckpt_file_path_or_url: original Hybrid ViT model checkpoint file path or URL + :type org_model_ckpt_file_path_or_url: str or None + :param org_ckpt_strict: whether to strictly enforce that the keys in state_dict match the keys returned by original Hybrid ViT model’s `state_dict()` function + :type org_ckpt_strict: bool + :return: splittable Hybrid ViT model + :rtype: SplittableHybridViT + """ bottleneck_layer = get_layer(bottleneck_config['name'], **bottleneck_config['params']) hybrid_vit_model = vision_transformer_hybrid.__dict__[hybrid_vit_name](**hybrid_vit_kwargs) if org_model_ckpt_file_path_or_url is not None: @@ -345,6 +579,16 @@ def splittable_hybrid_vit(bottleneck_config, hybrid_vit_name='vit_small_r26_s32_ def get_backbone(cls_or_func_name, **kwargs): + """ + Gets a backbone model. + + :param cls_or_func_name: backbone class or function name + :type cls_or_func_name: str + :param kwargs: kwargs for the backbone class or function to build the backbone model + :type kwargs: dict + :return: backbone model + :rtype: nn.Module or None + """ if cls_or_func_name in BACKBONE_CLASS_DICT: return BACKBONE_CLASS_DICT[cls_or_func_name](**kwargs) elif cls_or_func_name in BACKBONE_FUNC_DICT: diff --git a/sc2bench/models/layer.py b/sc2bench/models/layer.py index 8364308..7f163d1 100644 --- a/sc2bench/models/layer.py +++ b/sc2bench/models/layer.py @@ -15,11 +15,12 @@ def register_layer_class(cls): """ - Args: - cls (class): layer module to be registered. + Registers a layer class. - Returns: - cls (class): registered layer module. + :param cls: layer class to be registered + :type cls: class + :return: registered layer class + :rtype: class """ LAYER_CLASS_DICT[cls.__name__] = cls return cls @@ -27,11 +28,12 @@ def register_layer_class(cls): def register_layer_func(func): """ - Args: - func (function): layer module to be registered. + Registers a function to build a layer module. - Returns: - func (function): registered layer module. + :param func: function to build a layer module + :type func: typing.Callable + :return: registered function + :rtype: typing.Callable """ LAYER_FUNC_DICT[func.__name__] = func return func @@ -39,7 +41,18 @@ def register_layer_func(func): class SimpleBottleneck(nn.Module): """ - Simple encoder-decoder layer to treat encoder's output as bottleneck + Simple neural encoder-decoder that treats encoder's output as bottleneck. + + The forward path is encoder -> compressor (if provided) -> decompressor (if provided) -> decoder. + + :param encoder: encoder + :type encoder: nn.Module + :param decoder: decoder + :type decoder: nn.Module + :param encoder: module to compress the encoded data + :type encoder: nn.Module or None + :param decoder: module to decompresse the compressed data + :type decoder: nn.Module or None """ def __init__(self, encoder, decoder, compressor=None, decompressor=None): super().__init__() @@ -49,12 +62,28 @@ def __init__(self, encoder, decoder, compressor=None, decompressor=None): self.decompressor = decompressor def encode(self, x): + """ + Encode the input data. + + :param x: input batch + :type x: torch.Tensor + :return: dict of encoded (and compressed if `compressor` is provided) + :rtype: dict + """ z = self.encoder(x) if self.compressor is not None: z = self.compressor(z) return {'z': z} def decode(self, z): + """ + Decode the encoded data. + + :param z: encoded data + :type z: torch.Tensor + :return: decoded data + :rtype: torch.Tensor + """ if self.decompressor is not None: z = self.decompressor(z) return self.decoder(z) @@ -69,6 +98,11 @@ def forward(self, x): return self.decoder(z) def update(self): + """ + Shows a message that this module has no updatable parameters for entropy coding. + + Dummy function to be compatible with other layers. + """ logger.info('This module has no updatable parameters for entropy coding') @@ -76,7 +110,24 @@ def update(self): def larger_resnet_bottleneck(bottleneck_channel=12, bottleneck_idx=12, output_channel=256, compressor_transform_params=None, decompressor_transform_params=None): """ - "Neural Compression and Filtering for Edge-assisted Real-time Object Detection in Challenged Networks" + Builds a bottleneck layer ResNet-based encoder and decoder (24 layers in total). + + Compatible with ResNet-50, -101, and -152. + + Yoshitomo Matsubara, Marco Levorato: `"Neural Compression and Filtering for Edge-assisted Real-time Object Detection in Challenged Networks" `_ @ ICPR 2020 (2021) + + :param bottleneck_channel: number of channels for the bottleneck point + :type bottleneck_idx: int + :param bottleneck_idx: number of the first layers to be used as an encoder (the remaining layers are for decoder) + :type bottleneck_idx: int + :param output_channel: number of output channels for decoder's output + :type output_channel: int + :param compressor_transform_params: transform parameters for compressor + :type compressor_transform_params: dict or None + :param decompressor_transform_params: transform parameters for decompressor + :type decompressor_transform_params: dict or None + :return: bottleneck layer consisting of encoder and decoder + :rtype: SimpleBottleneck """ modules = [ nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False), @@ -113,9 +164,12 @@ def larger_resnet_bottleneck(bottleneck_channel=12, bottleneck_idx=12, output_ch class EntropyBottleneckLayer(CompressionModel): """ - Entropy bottleneck layer as a simple CompressionModel in compressai - The entropy bottleneck layer is proposed in "Variational Image Compression with a Scale Hyperprior" by - J. Balle, D. Minnen, S. Singh, S.J. Hwang, N. Johnston. + An entropy bottleneck layer as a simple `CompressionModel` in `compressai`. + + Johannes Ballé, David Minnen, Saurabh Singh, Sung Jin Hwang, Nick Johnston: `"Variational Image Compression with a Scale Hyperprior" `_ @ ICLR 2018 (2018) + + :param kwargs: kwargs for `CompressionModel` in `compressai` + :type kwargs: dict """ def __init__(self, **kwargs): super().__init__(**kwargs) @@ -125,33 +179,83 @@ def forward(self, x): return self.entropy_bottleneck(x) def compress(self, x): + """ + Compresses input data. + + :param x: input data + :type x: torch.Tensor + :return: entropy-coded compressed data ('strings' as key) and shape of the input data ('shape' as key) + :rtype: dict + """ strings = self.entropy_bottleneck.compress(x) return {'strings': [strings], 'shape': x.size()[-2:]} def decompress(self, strings, shape): + """ + Dempresses compressed data. + + :param strings: entropy-coded compressed data + :type strings: list[str] + :param shape: shape of the input data + :type shape: list[int] + :return: decompressed data + :rtype: torch.Tensor + """ assert isinstance(strings, list) and len(strings) == 1 return self.entropy_bottleneck.decompress(strings[0], shape) def update(self, force=False): + """ + Updates compression-specific parameters like `CompressAI models do `_. + + :param force: if True, overwrites previous values + :type force: bool + :return: True if one of the EntropyBottlenecks was updated + :rtype: bool + """ self.updated = True return super().update(force=force) class BaseBottleneck(CompressionModel): + """ + An abstract class for entropy bottleneck-based layer. + + :param entropy_bottleneck_channels: number of entropy bottleneck channels + :type entropy_bottleneck_channels: int + """ def __init__(self, entropy_bottleneck_channels): super().__init__(entropy_bottleneck_channels=entropy_bottleneck_channels) self.updated = False def encode(self, *args, **kwargs): + """ + Encodes data. + + This should be overridden by all subclasses. + """ raise NotImplementedError() def decode(self, *args, **kwargs): + """ + Decodes encoded data. + + This should be overridden by all subclasses. + """ raise NotImplementedError() def forward(self, *args): raise NotImplementedError() def update(self, force=False): + """ + Updates compression-specific parameters like `CompressAI models do `_. + + :param force: if True, overwrites previous values + :type force: bool + :return: True if one of the EntropyBottlenecks was updated + :rtype: bool + """ self.updated = True return super().update(force=force) @@ -159,11 +263,22 @@ def update(self, force=False): @register_layer_class class FPBasedResNetBottleneck(BaseBottleneck): """ - Factorized Prior(FP)-based bottleneck for ResNet proposed in - "Supervised Compression for Resource-Constrained Edge Computing Systems" - by Y. Matsubara, R. Yang, M. Levorato, S. Mandt. - Factorized Prior is proposed in "Variational Image Compression with a Scale Hyperprior" by - J. Balle, D. Minnen, S. Singh, S.J. Hwang, N. Johnston. + Factorized Prior(FP)-based encoder-decoder designed to create bottleneck for ResNet and variants. + + - Johannes Ballé, David Minnen, Saurabh Singh, Sung Jin Hwang, Nick Johnston: `"Variational Image Compression with a Scale Hyperprior" `_ @ ICLR 2018 (2018) + - Yoshitomo Matsubara, Ruihan Yang, Marco Levorato, Stephan Mandt: `"Supervised Compression for Resource-Constrained Edge Computing Systems" `_ @ WACV 2022 (2022) + - Yoshitomo Matsubara, Ruihan Yang, Marco Levorato, Stephan Mandt: `"SC2 Benchmark: Supervised Compression for Split Computing" `_ @ TMLR (2023) + + :param num_input_channels: number of input channels + :type num_input_channels: int + :param num_bottleneck_channels: number of bottleneck channels + :type num_bottleneck_channels: int + :param num_target_channels: number of output channels for decoder's output + :type num_target_channels: int + :param encoder_channel_sizes: list of 4 numbers of channels for encoder + :type encoder_channel_sizes: list[int] or None + :param decoder_channel_sizes: list of 4 numbers of channels for decoder + :type decoder_channel_sizes: list[int] or None """ def __init__(self, num_input_channels=3, num_bottleneck_channels=24, num_target_channels=256, encoder_channel_sizes=None, decoder_channel_sizes=None): @@ -198,21 +313,39 @@ def __init__(self, num_input_channels=3, num_bottleneck_channels=24, num_target_ ) def encode(self, x, **kwargs): + """ + Encodes input data. + + :param x: input data + :type x: torch.Tensor + :return: entropy-coded compressed data ('strings' as key) and shape of the input data ('shape' as key) + :rtype: dict + """ latent = self.encoder(x) latent_strings = self.entropy_bottleneck.compress(latent) return {'strings': [latent_strings], 'shape': latent.size()[-2:]} def decode(self, strings, shape): + """ + Decodes encoded data. + + :param strings: entropy-coded compressed data + :type strings: list[str] + :param shape: shape of the input data + :type shape: list[int] + :return: decompressed data + :rtype: torch.Tensor + """ latent_hat = self.entropy_bottleneck.decompress(strings[0], shape) return self.decoder(latent_hat) - def get_means(self, x): + def _get_means(self, x): medians = self.entropy_bottleneck._get_medians().detach() spatial_dims = len(x.size()) - 2 medians = self.entropy_bottleneck._extend_ndims(medians, spatial_dims) return medians.expand(x.size(0), *([-1] * (spatial_dims + 1))) - def forward2train(self, x): + def _forward2train(self, x): encoded_obj = self.encoder(x) y_hat, y_likelihoods = self.entropy_bottleneck(encoded_obj) decoded_obj = self.decoder(y_hat) @@ -229,19 +362,37 @@ def forward(self, x): encoded_output = self.encoder(x) decoder_input =\ self.entropy_bottleneck.dequantize( - self.entropy_bottleneck.quantize(encoded_output, 'dequantize', self.get_means(encoded_output)) + self.entropy_bottleneck.quantize(encoded_output, 'dequantize', self._get_means(encoded_output)) ) decoder_input = decoder_input.detach() return self.decoder(decoder_input) - return self.forward2train(x) + return self._forward2train(x) @register_layer_class class SHPBasedResNetBottleneck(BaseBottleneck): """ - Scale Hyperprior(SHP)-based bottleneck for ResNet. - Scale Hyperprior is proposed in "Variational Image Compression with a Scale Hyperprior" by - J. Balle, D. Minnen, S. Singh, S.J. Hwang, N. Johnston. + Scale Hyperprior(SHP)-based bottleneck for ResNet and variants. + + - Johannes Ballé, David Minnen, Saurabh Singh, Sung Jin Hwang, Nick Johnston: `"Variational Image Compression with a Scale Hyperprior" `_ @ ICLR 2018 (2018) + - Yoshitomo Matsubara, Ruihan Yang, Marco Levorato, Stephan Mandt: `"SC2 Benchmark: Supervised Compression for Split Computing" `_ @ TMLR (2023) + + :param num_input_channels: number of input channels + :type num_input_channels: int + :param num_latent_channels: number of latent channels + :type num_latent_channels: int + :param num_bottleneck_channels: number of bottleneck channels + :type num_bottleneck_channels: int + :param num_target_channels: number of output channels for decoder's output + :type num_target_channels: int + :param h_a: parametric transform :math:`h_a` + :type h_a: nn.Module or None + :param h_s: parametric transform :math:`h_s` + :type h_s: nn.Module or None + :param g_a_channel_sizes: list of 4 numbers of channels for parametric transform :math:`g_a` + :type g_a_channel_sizes: list[int] or None + :param g_s_channel_sizes: list of 4 numbers of channels for parametric transform :math:`g_s` + :type g_s_channel_sizes: list[int] or None """ def __init__(self, num_input_channels=3, num_latent_channels=16, num_bottleneck_channels=24, num_target_channels=256, h_a=None, h_s=None, @@ -297,6 +448,14 @@ def __init__(self, num_input_channels=3, num_latent_channels=16, self.num_bottleneck_channels = num_bottleneck_channels def encode(self, x, **kwargs): + """ + Encodes input data. + + :param x: input data + :type x: torch.Tensor + :return: entropy-coded compressed data ('strings' as key) and shape of the input data ('shape' as key) + :rtype: dict + """ y = self.g_a(x) z = self.h_a(torch.abs(y)) z_shape = z.size()[-2:] @@ -308,6 +467,16 @@ def encode(self, x, **kwargs): return {'strings': [y_strings, z_strings], 'shape': z_shape} def decode(self, strings, shape): + """ + Decodes encoded data. + + :param strings: entropy-coded compressed data + :type strings: list[str] + :param shape: shape of the input data + :type shape: list[int] + :return: decompressed data + :rtype: torch.Tensor + """ assert isinstance(strings, list) and len(strings) == 2 z_hat = self.entropy_bottleneck.decompress(strings[1], shape) scales_hat = self.h_s(z_hat) @@ -315,13 +484,13 @@ def decode(self, strings, shape): y_hat = self.gaussian_conditional.decompress(strings[0], indices, z_hat.dtype) return self.g_s(y_hat) - def get_means(self, x): + def _get_means(self, x): medians = self.entropy_bottleneck._get_medians().detach() spatial_dims = len(x.size()) - 2 medians = self.entropy_bottleneck._extend_ndims(medians, spatial_dims) return medians.expand(x.size(0), *([-1] * (spatial_dims + 1))) - def forward2train(self, x): + def _forward2train(self, x): y = self.g_a(x) z = self.h_a(torch.abs(y)) z_hat, z_likelihoods = self.entropy_bottleneck(z) @@ -339,11 +508,11 @@ def forward(self, x): y = self.g_a(x) y_hat = self.gaussian_conditional.dequantize( - self.gaussian_conditional.quantize(y, 'dequantize', self.get_means(y)) + self.gaussian_conditional.quantize(y, 'dequantize', self._get_means(y)) ) y_hat = y_hat.detach() return self.g_s(y_hat) - return self.forward2train(x) + return self._forward2train(x) def update(self, scale_table=None, force=False): if scale_table is None: @@ -355,6 +524,12 @@ def update(self, scale_table=None, force=False): return updated def load_state_dict(self, state_dict, **kwargs): + """ + Updates registered buffers and loads parameters. + + :param state_dict: dict containing parameters and persistent buffers + :type state_dict: dict + """ update_registered_buffers( self.gaussian_conditional, 'gaussian_conditional', @@ -367,9 +542,23 @@ def load_state_dict(self, state_dict, **kwargs): @register_layer_class class MSHPBasedResNetBottleneck(SHPBasedResNetBottleneck): """ - Mean-Scale Hyperprior(MSHP)-based bottleneck for ResNet. - Mean-Scale Hyperprior is proposed in "Joint Autoregressive and Hierarchical Priors for Learned Image Compression" by - D. Minnen, J. Balle, G.D. Toderici. + Mean-Scale Hyperprior(MSHP)-based bottleneck for ResNet and variants. + + - David Minnen, Johannes Ballé, George Toderici: `"Joint Autoregressive and Hierarchical Priors for Learned Image Compression" `_ @ NeurIPS 2018 (2018) + - Yoshitomo Matsubara, Ruihan Yang, Marco Levorato, Stephan Mandt: `"SC2 Benchmark: Supervised Compression for Split Computing" `_ @ TMLR (2023) + + :param num_input_channels: number of input channels + :type num_input_channels: int + :param num_latent_channels: number of latent channels + :type num_latent_channels: int + :param num_bottleneck_channels: number of bottleneck channels + :type num_bottleneck_channels: int + :param num_target_channels: number of output channels for decoder's output + :type num_target_channels: int + :param g_a_channel_sizes: list of 4 numbers of channels for parametric transform :math:`g_a` + :type g_a_channel_sizes: list[int] or None + :param g_s_channel_sizes: list of 4 numbers of channels for parametric transform :math:`g_s` + :type g_s_channel_sizes: list[int] or None """ def __init__(self, num_input_channels=3, num_latent_channels=16, num_bottleneck_channels=24, num_target_channels=256, @@ -415,7 +604,7 @@ def decode(self, strings, shape): y_hat = self.gaussian_conditional.decompress(strings[0], indices, means=means_hat) return self.g_s(y_hat) - def forward2train(self, x): + def _forward2train(self, x): y = self.g_a(x) z = self.h_a(y) z_hat, z_likelihoods = self.entropy_bottleneck(z) @@ -435,7 +624,7 @@ def forward(self, x): y = self.g_a(x) z = self.h_a(y) z_hat = self.entropy_bottleneck.dequantize( - self.entropy_bottleneck.quantize(z, 'dequantize', self.get_means(z)) + self.entropy_bottleneck.quantize(z, 'dequantize', self._get_means(z)) ) gaussian_params = self.h_s(z_hat) scales_hat, means_hat = gaussian_params.chunk(2, 1) @@ -444,17 +633,19 @@ def forward(self, x): ) y_hat = y_hat.detach() return self.g_s(y_hat) - return self.forward2train(x) + return self._forward2train(x) def get_layer(cls_or_func_name, **kwargs): """ - Args: - cls_or_func_name (str): layer class name. - kwargs (dict): keyword arguments. - - Returns: - nn.Module or None: layer module that is instance of `nn.Module` if found. None otherwise. + Gets a layer module. + + :param cls_or_func_name: layer class or function name + :type cls_or_func_name: str + :param kwargs: kwargs for the layer class or function to build a layer + :type kwargs: dict + :return: layer module + :rtype: nn.Module or None """ if cls_or_func_name in LAYER_CLASS_DICT: return LAYER_CLASS_DICT[cls_or_func_name](**kwargs) From d16a662ebabae615cc63d918b1ff8435206bc5bf Mon Sep 17 00:00:00 2001 From: Yoshitomo Matsubara Date: Sun, 23 Jul 2023 12:11:22 -0700 Subject: [PATCH 2/4] Make minor updates --- sc2bench/models/segmentation/wrapper.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/sc2bench/models/segmentation/wrapper.py b/sc2bench/models/segmentation/wrapper.py index 04f9820..7fde78a 100644 --- a/sc2bench/models/segmentation/wrapper.py +++ b/sc2bench/models/segmentation/wrapper.py @@ -126,9 +126,7 @@ def get_wrapped_segmentation_model(wrapper_model_config, device): :type wrapper_model_config: dict :param device: torch device :type device: torch.device - :return: model: wrapped semantic segmentation model - :rtype: model: nn.Module - :return: semantic segmentation model + :return: wrapped semantic segmentation model :rtype: nn.Module """ wrapper_model_name = wrapper_model_config['name'] From 19e95ee54105aa33bc332be894c318b6b961ba3e Mon Sep 17 00:00:00 2001 From: Yoshitomo Matsubara Date: Sun, 23 Jul 2023 12:11:54 -0700 Subject: [PATCH 3/4] Add docstrings --- sc2bench/models/backbone.py | 3 +- sc2bench/models/registry.py | 65 +++++++++++- sc2bench/models/wrapper.py | 201 ++++++++++++++++++++---------------- 3 files changed, 178 insertions(+), 91 deletions(-) diff --git a/sc2bench/models/backbone.py b/sc2bench/models/backbone.py index ce67f6f..04e2188 100644 --- a/sc2bench/models/backbone.py +++ b/sc2bench/models/backbone.py @@ -565,7 +565,8 @@ def splittable_hybrid_vit(bottleneck_config, hybrid_vit_name='vit_small_r26_s32_ :type analysis_config: dict or None :param org_model_ckpt_file_path_or_url: original Hybrid ViT model checkpoint file path or URL :type org_model_ckpt_file_path_or_url: str or None - :param org_ckpt_strict: whether to strictly enforce that the keys in state_dict match the keys returned by original Hybrid ViT model’s `state_dict()` function + :param org_ckpt_strict: whether to strictly enforce that the keys in state_dict match the keys returned by + original Hybrid ViT model’s `state_dict()` function :type org_ckpt_strict: bool :return: splittable Hybrid ViT model :rtype: SplittableHybridViT diff --git a/sc2bench/models/registry.py b/sc2bench/models/registry.py index 56a54fa..071b843 100644 --- a/sc2bench/models/registry.py +++ b/sc2bench/models/registry.py @@ -16,22 +16,60 @@ COMPRESSION_MODEL_FUNC_DICT = dict() -def register_compressai_model_class(cls_or_func): +def register_compressai_model(cls_or_func): + """ + Registers a compression model class or a function to build a compression model in `compressai`. + + :param cls_or_func: compression model or function to build a compression model to be registered + :type cls_or_func: class or typing.Callable + :return: registered compression model class or function + :rtype: class or typing.Callable + """ COMPRESSAI_DICT[cls_or_func.__name__] = cls_or_func return cls_or_func def register_compression_model_class(cls): + """ + Registers a compression model class. + + :param cls: compression model to be registered + :type cls: class + :return: registered compression model class + :rtype: class + """ COMPRESSION_MODEL_CLASS_DICT[cls.__name__] = cls return cls def register_compression_model_func(func): + """ + Registers a function to build a compression model. + + :param func: function to build a compression model to be registered + :type func: typing.Callable + :return: registered function + :rtype: typing.Callable + """ COMPRESSION_MODEL_FUNC_DICT[func.__name__] = func return func def get_compressai_model(compression_model_name, ckpt_file_path=None, updates=False, **compression_model_kwargs): + """ + Gets a model in `compressai`. + + :param compression_model_name: `compressai` model name + :type compression_model_name: str + :param ckpt_file_path: checkpoint file path + :type ckpt_file_path: str or None + :param updates: if True, updates the parameters for entropy coding + :type updates: bool + :param compression_model_kwargs: kwargs for the model class or function to build the model + :type compression_model_kwargs: dict + :return: `compressai` model + :rtype: nn.Module + """ compression_model = COMPRESSAI_DICT[compression_model_name](**compression_model_kwargs) if ckpt_file_path is not None: load_ckpt(ckpt_file_path, model=compression_model, strict=None) @@ -43,6 +81,16 @@ def get_compressai_model(compression_model_name, ckpt_file_path=None, updates=Fa def get_compression_model(compression_model_config, device): + """ + Gets a compression model. + + :param compression_model_config: compression model configuration + :type compression_model_config: dict + :param device: torch device + :type device: str or torch.device + :return: compression model + :rtype: nn.Module + """ if compression_model_config is None: return None @@ -58,6 +106,21 @@ def get_compression_model(compression_model_config, device): def load_classification_model(model_config, device, distributed, strict=True): + """ + Loads an image classification model. + + :param model_config: image classification model configuration + :type model_config: dict + :param device: torch device + :type device: str or torch.device + :param distributed: whether to use the model in distributed training mode + :type distributed: bool + :param strict: whether to strictly enforce that the keys in state_dict match the keys returned by the model’s + `state_dict()` function + :type strict: bool + :return: image classification model + :rtype: nn.Module + """ model = get_image_classification_model(model_config, distributed) model_name = model_config['name'] if model is None and model_name in timm.models.__dict__: diff --git a/sc2bench/models/wrapper.py b/sc2bench/models/wrapper.py index 9f82d3c..2d51ef8 100644 --- a/sc2bench/models/wrapper.py +++ b/sc2bench/models/wrapper.py @@ -16,11 +16,12 @@ def register_wrapper_class(cls): """ - Args: - cls (class): wrapper module to be registered. + Registers a model wrapper class. - Returns: - cls (class): registered wrapper module. + :param cls: model wrapper to be registered + :type cls: class + :return: registered model wrapper class + :rtype: class """ WRAPPER_CLASS_DICT[cls.__name__] = cls return cls @@ -29,13 +30,18 @@ def register_wrapper_class(cls): @register_wrapper_class class CodecInputCompressionClassifier(AnalyzableModule): """ - Wrapper module for codec input compression model followed by classifier. - Args: - classification_model (nn.Module): classification model - device (torch.device): torch device - codec_params (dict): keyword configurations for transform sequence for codec - post_transform_params (dict): keyword configurations for transform sequence after compression model - analysis_config (dict): configuration for analysis + A wrapper module for codec input compression model followed by a classification model. + + :param classification_model: image classification model + :type classification_model: nn.Module + :param device: torch device + :type device: torch.device or str + :param codec_params: transform sequence configuration for codec + :type codec_params: dict or None + :param post_transform_params: post-transform parameters + :type post_transform_params: dict or None + :param analysis_config: analysis configuration + :type analysis_config: dict or None """ def __init__(self, classification_model, device, codec_params=None, post_transform_params=None, analysis_config=None, **kwargs): @@ -75,13 +81,20 @@ def forward(self, x): @register_wrapper_class class NeuralInputCompressionClassifier(AnalyzableModule): """ - Wrapper module for neural input compression model followed by classifier. - Args: - classification_model (nn.Module): classification model - pre_transform_params (dict): keyword configurations for transform sequence for input data - compression_model (nn.Module): neural input compression model - post_transform_params (dict): keyword configurations for transform sequence after compression model - analysis_config (dict): configuration for analysis + A wrapper module for neural input compression model followed by a classification model. + + :param classification_model: image classification model + :type classification_model: nn.Module + :param pre_transform_params: pre-transform parameters + :type pre_transform_params: dict or None + :param compression_model: compression model + :type compression_model: nn.Module or None + :param uses_cpu4compression_model: whether to use CPU instead of GPU for `comoression_model` + :type uses_cpu4compression_model: bool + :param post_transform_params: post-transform parameters + :type post_transform_params: dict or None + :param analysis_config: analysis configuration + :type analysis_config: dict or None """ def __init__(self, classification_model, pre_transform_params=None, compression_model=None, uses_cpu4compression_model=False, post_transform_params=None, analysis_config=None, **kwargs): @@ -98,18 +111,13 @@ def __init__(self, classification_model, pre_transform_params=None, compression_ self.post_transform = build_transform(post_transform_params) def use_cpu4compression(self): + """ + Changes the device of the compression model to CPU. + """ if self.uses_cpu4compression_model and self.compression_model is not None: self.compression_model = self.compression_model.cpu() def forward(self, x): - """ - Args: - x (list of PIL Images or Tensor): input sample. - - Returns: - Tensor: output tensor from self.classification_model. - """ - if self.pre_transform is not None: x = self.pre_transform(x) if not self.training and self.analyzes_after_pre_transform: @@ -131,16 +139,24 @@ def forward(self, x): @register_wrapper_class class CodecFeatureCompressionClassifier(AnalyzableModule): """ - Wrapper module for codec feature compression model injected to a classifier. - Args: - classification_model (nn.Module): classification model - device (torch.device): torch device - encoder_config (dict): keyword configurations to design an encoder from modules in classification_model - codec_params (dict): keyword configurations for transform sequence for codec - decoder_config (dict): keyword configurations to design a decoder from modules in classification_model - classifier_config (dict): keyword configurations to design a classifier from modules in classification_model - post_transform_params (dict): keyword configurations for transform sequence after compression model - analysis_config (dict): configuration for analysis + A wrapper module for codec feature compression model injected to a classification model. + + :param classification_model: image classification model + :type classification_model: nn.Module + :param device: torch device + :type device: torch.device or str + :param encoder_config: configuration to design an encoder using modules in classification_model + :type encoder_config: dict or None + :param codec_params: transform sequence configuration for codec + :type codec_params: dict or None + :param decoder_config: configuration to design a decoder using modules in classification_model + :type decoder_config: dict or None + :param classifier_config: configuration to design a classifier using modules in classification_model + :type classifier_config: dict or None + :param post_transform_params: post-transform parameters + :type post_transform_params: dict or None + :param analysis_config: analysis configuration + :type analysis_config: dict or None """ def __init__(self, classification_model, device, encoder_config=None, codec_params=None, decoder_config=None, classifier_config=None, post_transform_params=None, analysis_config=None, **kwargs): @@ -159,13 +175,6 @@ def __init__(self, classification_model, device, encoder_config=None, codec_para self.post_transform = build_transform(post_transform_params) def forward(self, x): - """ - Args: - x (Tensor): input sample. - - Returns: - Tensor: output tensor from self.classifier. - """ x = self.encoder(x) tmp_list = list() for sub_x in x: @@ -187,14 +196,22 @@ def forward(self, x): @register_wrapper_class class EntropicClassifier(UpdatableBackbone): """ - Wrapper module for entropic compression model injected to a classifier. - Args: - classification_model (nn.Module): classification model - encoder_config (dict): keyword configurations to design an encoder from modules in classification_model - compression_model_params (dict): keyword configurations for CompressionModel in compressai - decoder_config (dict): keyword configurations to design a decoder from modules in classification_model - classifier_config (dict): keyword configurations to design a classifier from modules in classification_model - analysis_config (dict): configuration for analysis + A wrapper module for entropic compression model injected to a classification model. + + :param classification_model: image classification model + :type classification_model: nn.Module + :param device: torch device + :type device: torch.device or str + :param encoder_config: configuration to design an encoder using modules in classification_model + :type encoder_config: dict + :param compression_model_params: kwargs for `EntropyBottleneckLayer` in `compressai` + :type compression_model_params: dict + :param decoder_config: configuration to design a decoder using modules in classification_model + :type decoder_config: dict + :param classifier_config: configuration to design a classifier using modules in classification_model + :type classifier_config: dict + :param analysis_config: analysis configuration + :type analysis_config: dict or None """ def __init__(self, classification_model, encoder_config, compression_model_params, decoder_config, classifier_config, analysis_config=None, **kwargs): @@ -211,13 +228,6 @@ def __init__(self, classification_model, encoder_config, compression_model_param self.classifier = redesign_model(classification_model, classifier_config, model_label='classification') def forward(self, x): - """ - Args: - x (Tensor): input sample. - - Returns: - Tensor: output tensor from self.classifier. - """ x = self.encoder(x) if self.bottleneck_updated and not self.training: x = self.entropy_bottleneck.compress(x) @@ -236,6 +246,12 @@ def update(self): self.bottleneck_updated = True def load_state_dict(self, state_dict, **kwargs): + """ + Loads parameters for all the sub-modules except entropy_bottleneck and then entropy_bottleneck. + + :param state_dict: dict containing parameters and persistent buffers + :type state_dict: dict + """ entropy_bottleneck_state_dict = OrderedDict() for key in list(state_dict.keys()): if key.startswith('entropy_bottleneck.'): @@ -251,15 +267,22 @@ def get_aux_module(self, **kwargs): @register_wrapper_class class SplitClassifier(UpdatableBackbone): """ - Wrapper module for naively splitting a classifier. - Args: - classification_model (nn.Module): classification model - encoder_config (dict): keyword configurations to design an encoder from modules in classification_model - decoder_config (dict): keyword configurations to design a decoder from modules in classification_model - classifier_config (dict): keyword configurations to design a classifier from modules in classification_model - compressor_transform_params (dict): keyword configurations to build transform for compression - decompressor_transform_params (dict): keyword configurations to build transform for decompression - analysis_config (dict): configuration for analysis + A wrapper module for naively splitting a classification model. + + :param classification_model: image classification model + :type classification_model: nn.Module + :param encoder_config: configuration to design an encoder using modules in classification_model + :type encoder_config: dict or None + :param decoder_config: configuration to design a decoder using modules in classification_model + :type decoder_config: dict or None + :param classifier_config: configuration to design a classifier using modules in classification_model + :type classifier_config: dict or None + :param compressor_transform_params: transform parameters for compressor + :type compressor_transform_params: dict or None + :param decompressor_transform_params: transform parameters for decompressor + :type decompressor_transform_params: dict or None + :param analysis_config: analysis configuration + :type analysis_config: dict or None """ def __init__(self, classification_model, encoder_config, decoder_config, classifier_config, compressor_transform_params=None, decompressor_transform_params=None, @@ -278,13 +301,6 @@ def __init__(self, classification_model, encoder_config, decoder_config, self.classifier = redesign_model(classification_model, classifier_config, model_label='classification') def forward(self, x): - """ - Args: - x (Tensor): input sample. - - Returns: - Tensor: output tensor from self.classifier. - """ x = self.encoder(x) if self.bottleneck_updated and not self.training: x = self.compressor(x) @@ -305,14 +321,18 @@ def get_aux_module(self, **kwargs): def wrap_model(wrapper_model_name, model, compression_model, **kwargs): """ - Args: - wrapper_model_name (str): wrapper model key in wrapper model register. - model (nn.Module): model to be wrapped. - compression_model (nn.Module): compressor to be wrapped. - **kwargs (dict): keyword arguments to instantiate a wrapper object. - - Returns: - nn.Module: a wrapper module. + Wraps a model and a compression model with a wrapper module. + + :param wrapper_model_name: wrapper model name + :type wrapper_model_name: str + :param model: model + :type model: nn.Module + :param compression_model: compression model + :type compression_model: nn.Module + :param kwargs: kwargs for the wrapper class or function to build the wrapper module + :type kwargs: dict + :return: wrapped model + :rtype: nn.Module """ if wrapper_model_name not in WRAPPER_CLASS_DICT: raise ValueError('wrapper_model_name `{}` is not expected'.format(wrapper_model_name)) @@ -321,13 +341,16 @@ def wrap_model(wrapper_model_name, model, compression_model, **kwargs): def get_wrapped_classification_model(wrapper_model_config, device, distributed): """ - Args: - wrapper_model_config (dict): wrapper model configuration. - device (device): torch device. - distributed (bool): uses distributed training model. - - Returns: - nn.Module: a wrapped module. + Gets a wrapped image classification model. + + :param wrapper_model_config: wrapper model configuration + :type wrapper_model_config: dict + :param device: torch device + :type device: torch.device + :param distributed: whether to use the model in distributed training mode + :type distributed: bool + :return: wrapped image classification model + :rtype: nn.Module """ wrapper_model_name = wrapper_model_config['name'] if wrapper_model_name not in WRAPPER_CLASS_DICT: From 5a1bff9606abb6d94860f32460da01a9e2d903f2 Mon Sep 17 00:00:00 2001 From: Yoshitomo Matsubara Date: Sun, 23 Jul 2023 16:27:37 -0700 Subject: [PATCH 4/4] Add docstrings --- sc2bench/transforms/codec.py | 163 ++++++++++++++++++++------------ sc2bench/transforms/collator.py | 27 ++++++ sc2bench/transforms/misc.py | 115 +++++++++++++++------- 3 files changed, 212 insertions(+), 93 deletions(-) diff --git a/sc2bench/transforms/codec.py b/sc2bench/transforms/codec.py index 8471f18..d468e04 100644 --- a/sc2bench/transforms/codec.py +++ b/sc2bench/transforms/codec.py @@ -26,11 +26,12 @@ def register_codec_transform_module(cls): """ - Args: - cls (class): codec transform module to be registered. + Registers a codec transform class. - Returns: - cls (class): registered codec transform module. + :param cls: codec transform class to be registered + :type cls: class + :return: registered codec transform class + :rtype: class """ CODEC_TRANSFORM_MODULE_DICT[cls.__name__] = cls register_transform_class(cls) @@ -40,10 +41,12 @@ def register_codec_transform_module(cls): @register_codec_transform_module class WrappedRandomResizedCrop(RandomResizedCrop): """ - `RandomResizedCrop` in torchvision wrapped to be defined by `interpolation` as a str object - Args: - interpolation (str or None): Desired interpolation mode (`nearest`, `bicubic`, `bilinear`, `box`, `hamming`, `lanczos`) - kwargs (dict): kwargs for `RandomResizedCrop` in torchvision + `RandomResizedCrop` in torchvision wrapped to be defined by `interpolation` as a str object. + + :param interpolation: desired interpolation mode ('nearest', 'bicubic', 'bilinear', 'box', 'hamming', 'lanczos') + :type interpolation: str or None + :param kwargs: kwargs for `RandomResizedCrop` in torchvision + :type kwargs: dict """ def __init__(self, interpolation=None, **kwargs): if interpolation is not None: @@ -54,10 +57,12 @@ def __init__(self, interpolation=None, **kwargs): @register_codec_transform_module class WrappedResize(Resize): """ - `Resize` in torchvision wrapped to be defined by `interpolation` as a str object - Args: - interpolation (str or None): Desired interpolation mode (`nearest`, `bicubic`, `bilinear`, `box`, `hamming`, `lanczos`) - kwargs (dict): kwargs for `Resize` in torchvision + `Resize` in torchvision wrapped to be defined by `interpolation` as a str object. + + :param interpolation: desired interpolation mode ('nearest', 'bicubic', 'bilinear', 'box', 'hamming', 'lanczos') + :type interpolation: str or None + :param kwargs: kwargs for `Resize` in torchvision + :type kwargs: dict """ def __init__(self, interpolation=None, **kwargs): if interpolation is not None: @@ -66,13 +71,16 @@ def __init__(self, interpolation=None, **kwargs): @register_codec_transform_module -class PillowImageModule(nn.Module): +class PILImageModule(nn.Module): """ - Generalized Pillow module to compress (decompress) images e.g., as part of transform pipeline. - Args: - returns_file_size (bool): return file size of compressed object in addition to PIL image if true. - open_kwargs (dict or None): kwargs to be used as part of Image.open(img_buffer, **open_kwargs). - save_kwargs (dict or None): kwargs to be used as part of Image.save(img_buffer, **save_kwargs). + A generalized PIL module to compress (decompress) images e.g., as part of transform pipeline. + + :param returns_file_size: returns file size of compressed object in addition to PIL image if True + :type returns_file_size: bool + :param open_kwargs: kwargs to be used as part of Image.open(img_buffer, **open_kwargs) + :type open_kwargs: dict or None + :param save_kwargs: kwargs to be used as part of Image.save(img_buffer, **save_kwargs) + :type save_kwargs: dict or None """ def __init__(self, returns_file_size=False, open_kwargs=None, **save_kwargs): super().__init__() @@ -82,11 +90,12 @@ def __init__(self, returns_file_size=False, open_kwargs=None, **save_kwargs): def forward(self, pil_img, *args): """ - Args: - pil_img (PIL Image): Image to be transformed. + Saves PIL Image to BytesIO and reopens the image saved in the buffer. - Returns: - PIL Image or a tuple of PIL Image and int: Affine transformed image or with its file size if returns_file_size=True. + :param pil_img: image to be transformed. + :type pil_img: PIL.Image.Image + :return: Affine transformed image or with its file size if returns_file_size=True + :rtype: PIL.Image.Image or (PIL.Image.Image, int) """ img_buffer = BytesIO() pil_img.save(img_buffer, **self.save_kwargs) @@ -103,13 +112,16 @@ def __repr__(self): @register_codec_transform_module -class PillowTensorModule(nn.Module): +class PILTensorModule(nn.Module): """ - Generalized Pillow module to compress (decompress) tensors e.g., as part of transform pipeline. - Args: - returns_file_size (bool): return file size of compressed object in addition to PIL image if true. - open_kwargs (dict or None): kwargs to be used as part of Image.open(img_buffer, **open_kwargs). - save_kwargs (dict or None): kwargs to be used as part of Image.save(img_buffer, **save_kwargs). + A generalized PIL module to compress (decompress) tensors e.g., as part of transform pipeline. + + :param returns_file_size: returns file size of compressed object in addition to PIL image if True + :type returns_file_size: bool + :param open_kwargs: kwargs to be used as part of Image.open(img_buffer, **open_kwargs) + :type open_kwargs: dict or None + :param save_kwargs: kwargs to be used as part of Image.save(img_buffer, **save_kwargs) + :type save_kwargs: dict or None """ def __init__(self, returns_file_size=False, open_kwargs=None, **save_kwargs): super().__init__() @@ -119,11 +131,14 @@ def __init__(self, returns_file_size=False, open_kwargs=None, **save_kwargs): def forward(self, x, *args): """ - Args: - x (torch.Tensor): Tensor (C, H, W) to be transformed - - Returns: - torch.Tensor or a tuple of torch.Tensor and int: Affine transformed image or with its file size if returns_file_size=True. + Splits tensor's channels into sub-tensors (3 or fewer channels each), + normalizes each using its min and max values, saves the normalized sub-tensor to BytesIO, + and reopens the sub-tensor saved in the buffer to reconstruct the input tensor. + + :param x: image tensor (C, H, W) to be transformed. + :type x: torch.Tensor + :return: Affine transformed image tensor or with its file size if returns_file_size=True + :rtype: torch.Tensor or (torch.Tensor, int) """ device = x.device split_features = x.split(3, dim=0) @@ -174,17 +189,32 @@ def __repr__(self): @register_codec_transform_module class BPGModule(nn.Module): """ - BPG module to compress (decompress) images e.g., as part of transform pipeline. + A BPG module to compress (decompress) images e.g., as part of transform pipeline. + Modified https://github.com/InterDigitalInc/CompressAI/blob/master/compressai/utils/bench/codecs.py - Args: - encoder_path (str): file path of BPG encoder you manually installed. - decoder_path (str): file path of BPG decoder you manually installed. - color_mode (str): color mode ("ycbcr" or "rgb"). - encoder (str): encoder type ("x265" or "jctvc"). - subsampling_mode (str or int): subsampling mode (420 or 444). - bit_depth (str or int): bit depth (8 or 10). - quality (int): quality value in range [0, 51]. - returns_file_size (bool): flag to return file size. + + Fabrice Bellard: `"BPG Image format" `_ + + .. warning:: + You need to manually install BPG software beforehand and confirm the encoder and decoder paths. + For Debian machines (e.g., Ubuntu), you can use `this script `_. + + :param encoder_path: file path of BPG encoder you manually installed + :type encoder_path: str + :param decoder_path: file path of BPG decoder you manually installed + :type decoder_path: str + :param color_mode: color mode ('ycbcr' or 'rgb') + :type color_mode: str + :param encoder: encoder type ('x265' or 'jctvc') + :type encoder: str + :param subsampling_mode: subsampling mode ('420' or '444') + :type subsampling_mode: str or int + :param bit_depth: bit depth (8 or 10) + :type bit_depth: str or int + :param quality: quality value in range [0, 51] + :type quality: int + :param returns_file_size: returns file size of compressed object in addition to PIL image if True + :type returns_file_size: bool """ fmt = '.bpg' @@ -247,11 +277,12 @@ def _get_decode_cmd(self, output_file_path, reconst_file_path): def forward(self, pil_img): """ - Args: - pil_img (PIL Image): Image to be transformed. + Compresses and decompresses PIL Image using BPG software. - Returns: - PIL Image or a tuple of PIL Image and float: Affine transformed image or with its file size of BPG compressed data if returns_file_size=True. + :param pil_img: image to be transformed. + :type pil_img: PIL.Image.Image + :return: Affine transformed image or with its file size if returns_file_size=True + :rtype: PIL.Image.Image or (PIL.Image.Image, int) """ fd_i, resized_input_filepath = mkstemp(suffix='.jpg') fd_r, reconst_file_path = mkstemp(suffix='.jpg') @@ -289,15 +320,28 @@ def __repr__(self): @register_codec_transform_module class VTMModule(nn.Module): """ - VTM module to compress (decompress) images e.g., as part of transform pipeline. + A VTM module to compress (decompress) images e.g., as part of transform pipeline. + Modified https://github.com/InterDigitalInc/CompressAI/blob/master/compressai/utils/bench/codecs.py - Args: - encoder_path (str): file path of BPG encoder you manually installed. - decoder_path (str): file path of BPG decoder you manually installed. - config_path (str): VTM configuration file path. - color_mode (str): color mode ("ycbcr" or "rgb"). - quality (int): quality value in range [0, 63]. - returns_file_size (bool): flag to return file size. + + The Joint Video Exploration Team: `"VTM reference software for VVC" `_ + + .. warning:: + You need to manually install VTM software beforehand and confirm the encoder and decoder paths. + For Debian machines (e.g., Ubuntu), you can use `this script `_. + + :param encoder_path: file path of VTM encoder you manually installed + :type encoder_path: str + :param decoder_path: file path of VTM decoder you manually installed + :type decoder_path: str + :param config_path: VTM configuration file path + :type config_path: str + :param color_mode: color mode ('ycbcr' or 'rgb') + :type color_mode: str + :param quality: quality value in range [0, 63] + :type quality: int + :param returns_file_size: returns file size of compressed object in addition to PIL image if True + :type returns_file_size: bool """ fmt = '.bin' @@ -322,11 +366,12 @@ def __init__(self, encoder_path, decoder_path, config_path, color_mode='ycbcr', def forward(self, pil_img): """ - Args: - pil_img (PIL Image): Image to be transformed. + Compresses and decompresses PIL Image using VTM software. - Returns: - PIL Image or a tuple of PIL Image and float: Affine transformed image or with its file size of VTM compressed data if returns_file_size=True. + :param pil_img: image to be transformed. + :type pil_img: PIL.Image.Image + :return: Affine transformed image or with its file size if returns_file_size=True + :rtype: PIL.Image.Image or (PIL.Image.Image, int) """ # Taking 8bit input for now diff --git a/sc2bench/transforms/collator.py b/sc2bench/transforms/collator.py index 0a0944d..2db8716 100644 --- a/sc2bench/transforms/collator.py +++ b/sc2bench/transforms/collator.py @@ -3,6 +3,17 @@ def cat_list(images, fill_value=0): + """ + Concatenates a list of images with the max size for each of heights and widths and + fills empty spaces with a specified value. + + :param images: batch tensor + :type images: torch.Tensor + :param fill_value: value to be filled + :type fill_value: int + :return: backbone model + :rtype: torch.Tensor + """ if len(images) == 1 and not isinstance(images[0], torch.Tensor): return images @@ -16,6 +27,14 @@ def cat_list(images, fill_value=0): @register_collate_func def pascal_seg_collate_fn(batch): + """ + Collates input data for PASCAL VOC 2012 segmentation. + + :param batch: list/tuple of triplets (image, target, supp_dict), where supp_dict can be an empty dict + :type batch: list or tuple + :return: collated images, targets, and supplementary dicts + :rtype: (torch.Tensor, tensor.Tensor, list[dict]) + """ images, targets, supp_dicts = list(zip(*batch)) batched_imgs = cat_list(images, fill_value=0) batched_targets = cat_list(targets, fill_value=255) @@ -24,6 +43,14 @@ def pascal_seg_collate_fn(batch): @register_collate_func def pascal_seg_eval_collate_fn(batch): + """ + Collates input data for PASCAL VOC 2012 segmentation in evaluation + + :param batch: list/tuple of tuples (image, target) + :type batch: list or tuple + :return: collated images and targets + :rtype: (torch.Tensor, tensor.Tensor) + """ images, targets = list(zip(*batch)) batched_imgs = cat_list(images, fill_value=0) batched_targets = cat_list(targets, fill_value=255) diff --git a/sc2bench/transforms/misc.py b/sc2bench/transforms/misc.py index 00b2325..472b6b3 100644 --- a/sc2bench/transforms/misc.py +++ b/sc2bench/transforms/misc.py @@ -4,7 +4,6 @@ import torch from PIL.Image import Image from torch import nn -from torch._six import string_classes from torch.utils.data._utils.collate import np_str_obj_array_pattern, default_collate_err_msg_format from torchdistill.common import tensor_util from torchdistill.datasets.collator import register_collate_func @@ -17,11 +16,12 @@ def register_misc_transform_module(cls): """ - Args: - cls (class): codec transform module to be registered. + Registers a miscellaneous transform class. - Returns: - cls (class): registered codec transform module. + :param cls: miscellaneous transform class to be registered + :type cls: class + :return: registered miscellaneous transform class + :rtype: class """ MISC_TRANSFORM_MODULE_DICT[cls.__name__] = cls register_transform_class(cls) @@ -29,8 +29,13 @@ def register_misc_transform_module(cls): @register_collate_func -def default_collate_w_pillow(batch): - r"""Puts each data field into a tensor or PIL Image with outer dimension batch size""" +def default_collate_w_pil(batch): + """ + Puts each data field into a tensor or PIL Image with outer dimension batch size. + + :param batch: single batch to be collated + :return: collated batch + """ # Extended `default_collate` function in PyTorch elem = batch[0] @@ -51,19 +56,19 @@ def default_collate_w_pillow(batch): if np_str_obj_array_pattern.search(elem.dtype.str) is not None: raise TypeError(default_collate_err_msg_format.format(elem.dtype)) - return default_collate_w_pillow([torch.as_tensor(b) for b in batch]) + return default_collate_w_pil([torch.as_tensor(b) for b in batch]) elif elem.shape == (): # scalars return torch.as_tensor(batch) elif isinstance(elem, float): return torch.tensor(batch, dtype=torch.float64) elif isinstance(elem, int): return torch.tensor(batch) - elif isinstance(elem, string_classes): + elif isinstance(elem, (str, bytes)): return batch elif isinstance(elem, collections.abc.Mapping): - return {key: default_collate_w_pillow([d[key] for d in batch]) for key in elem} + return {key: default_collate_w_pil([d[key] for d in batch]) for key in elem} elif isinstance(elem, tuple) and hasattr(elem, '_fields'): # namedtuple - return elem_type(*(default_collate_w_pillow(samples) for samples in zip(*batch))) + return elem_type(*(default_collate_w_pil(samples) for samples in zip(*batch))) elif isinstance(elem, collections.abc.Sequence): # check to make sure that the elements in batch have consistent size it = iter(batch) @@ -71,7 +76,7 @@ def default_collate_w_pillow(batch): if not all(len(elem) == elem_size for elem in it): raise RuntimeError('each element in list of batch should be of equal size') transposed = zip(*batch) - return [default_collate_w_pillow(samples) for samples in transposed] + return [default_collate_w_pil(samples) for samples in transposed] elif isinstance(elem, Image): return batch @@ -81,18 +86,19 @@ def default_collate_w_pillow(batch): @register_misc_transform_module class ClearTargetTransform(nn.Module): """ - Transform module that replaces target with an empty list. + A transform module that replaces target with an empty list. """ def __init__(self): super().__init__() def forward(self, sample, *args): """ - Args: - sample (PIL Image or Tensor): input sample. + Replaces target data field with an empty list. - Returns: - tuple: a pair of transformed sample and original target. + :param sample: image or image tensor + :type sample: PIL.Image.Image or torch.Tensor + :return: sample and an empty list + :rtype: (PIL.Image.Image or torch.Tensor, list) """ return sample, list() @@ -100,14 +106,19 @@ def forward(self, sample, *args): @register_misc_transform_module class AdaptivePad(nn.Module): """ - Transform module that adaptively determines the size of padded sample. - Args: - fill (int): padded value. - padding_position (str): 'hw' (default) to pad left and right for padding horizontal size // 2 and top and - bottom for padding vertical size // 2; 'right_bottom' to pad bottom and right only. - padding_mode (str): padding mode passed to pad module. - factor (int): factor value for the padded input sample. - returns_org_patch_size (bool): returns original patch size. + A transform module that adaptively determines the size of padded sample. + + :param fill: padded value + :type fill: int + :param padding_position: 'hw' (default) to pad left and right for padding horizontal size // 2 and top and + bottom for padding vertical size // 2; 'right_bottom' to pad bottom and right only + :type padding_position: str + :param padding_mode: padding mode passed to pad module + :type padding_mode: str + :param factor: factor value for the padded input sample + :type factor: int + :param returns_org_patch_size: if True, returns the patch size of the original input + :type returns_org_patch_size: bool """ def __init__(self, fill=0, padding_position='hw', padding_mode='constant', factor=128, returns_org_patch_size=False): @@ -120,12 +131,13 @@ def __init__(self, fill=0, padding_position='hw', padding_mode='constant', def forward(self, x): """ - Args: - x (PIL Image or Tensor): input sample. + Adaptively determines the size of padded image or image tensor. - Returns: - PIL Image or a tuple of PIL Image and int: padded input sample or with its patch size (height, width) - if returns_org_patch_size=True. + :param x: image or image tensor + :type x: PIL.Image.Image or torch.Tensor + :return: padded image or image tensor, and the patch size of the input (height, width) + if returns_org_patch_size=True + :rtype: PIL.Image.Image or torch.Tensor or (PIL.Image.Image or torch.Tensor, list[int, int]) """ height, width = x.shape[-2:] vertical_pad_size = 0 if height % self.factor == 0 else int((height // self.factor + 1) * self.factor - height) @@ -146,10 +158,12 @@ def forward(self, x): @register_misc_transform_module class CustomToTensor(nn.Module): """ - Customized ToTensor module that can be applied to sample and target selectively. - Args: - converts_sample (bool): apply to_tensor to sample if True. - converts_target (bool): apply torch.as_tensor to target if True. + A customized ToTensor module that can be applied to sample and target selectively. + + :param converts_sample: if True, applies to_tensor to sample + :type converts_sample: bool + :param converts_target: if True, applies torch.as_tensor to target + :type converts_target: bool """ def __init__(self, converts_sample=True, converts_target=True): super().__init__() @@ -167,19 +181,52 @@ def __call__(self, image, target): @register_misc_transform_module class SimpleQuantizer(nn.Module): + """ + A module to quantize tensor with its half() function if num_bits=16 (FP16) or + Jacob et al.'s method if num_bits=8 (INT8 + one FP32 scale parameter). + + Benoit Jacob, Skirmantas Kligys, Bo Chen, Menglong Zhu, Matthew Tang, Andrew Howard, Hartwig Adam, Dmitry Kalenichenko: `"Quantization and Training of Neural Networks for Efficient Integer-Arithmetic-Only Inference" `_ @ CVPR 2018 (2018) + + :param num_bits: number of bits for quantization + :type num_bits: int + """ def __init__(self, num_bits): super().__init__() self.num_bits = num_bits def forward(self, z): + """ + Quantizes tensor. + + :param z: tensor + :type z: torch.Tensor + :return: quantized tensor + :rtype: torch.Tensor or torchdistill.common.tensor_util.QuantizedTensor + """ return z.half() if self.num_bits == 16 else tensor_util.quantize_tensor(z, self.num_bits) @register_misc_transform_module class SimpleDequantizer(nn.Module): + """ + A module to dequantize quantized tensor in FP32. If num_bits=8, it uses Jacob et al.'s method. + + Benoit Jacob, Skirmantas Kligys, Bo Chen, Menglong Zhu, Matthew Tang, Andrew Howard, Hartwig Adam, Dmitry Kalenichenko: `"Quantization and Training of Neural Networks for Efficient Integer-Arithmetic-Only Inference" `_ @ CVPR 2018 (2018) + + :param num_bits: number of bits used for quantization + :type num_bits: int + """ def __init__(self, num_bits): super().__init__() self.num_bits = num_bits def forward(self, z): + """ + Dequantizes quantized tensor. + + :param z: quantized tensor + :type z: torch.Tensor or torchdistill.common.tensor_util.QuantizedTensor + :return: dequantized tensor + :rtype: torch.Tensor + """ return z.float() if self.num_bits == 16 else tensor_util.dequantize_tensor(z)