Skip to content

Commit

Permalink
Merge pull request #128 from yoshitomo-matsubara/dev
Browse files Browse the repository at this point in the history
Add docstrings to transforms package
  • Loading branch information
yoshitomo-matsubara authored Jul 23, 2023
2 parents a4de421 + 5a1bff9 commit 7669b8f
Show file tree
Hide file tree
Showing 8 changed files with 868 additions and 229 deletions.
249 changes: 247 additions & 2 deletions sc2bench/models/backbone.py

Large diffs are not rendered by default.

273 changes: 232 additions & 41 deletions sc2bench/models/layer.py

Large diffs are not rendered by default.

65 changes: 64 additions & 1 deletion sc2bench/models/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,22 +16,60 @@
COMPRESSION_MODEL_FUNC_DICT = dict()


def register_compressai_model_class(cls_or_func):
def register_compressai_model(cls_or_func):
"""
Registers a compression model class or a function to build a compression model in `compressai`.
:param cls_or_func: compression model or function to build a compression model to be registered
:type cls_or_func: class or typing.Callable
:return: registered compression model class or function
:rtype: class or typing.Callable
"""
COMPRESSAI_DICT[cls_or_func.__name__] = cls_or_func
return cls_or_func


def register_compression_model_class(cls):
"""
Registers a compression model class.
:param cls: compression model to be registered
:type cls: class
:return: registered compression model class
:rtype: class
"""
COMPRESSION_MODEL_CLASS_DICT[cls.__name__] = cls
return cls


def register_compression_model_func(func):
"""
Registers a function to build a compression model.
:param func: function to build a compression model to be registered
:type func: typing.Callable
:return: registered function
:rtype: typing.Callable
"""
COMPRESSION_MODEL_FUNC_DICT[func.__name__] = func
return func


def get_compressai_model(compression_model_name, ckpt_file_path=None, updates=False, **compression_model_kwargs):
"""
Gets a model in `compressai`.
:param compression_model_name: `compressai` model name
:type compression_model_name: str
:param ckpt_file_path: checkpoint file path
:type ckpt_file_path: str or None
:param updates: if True, updates the parameters for entropy coding
:type updates: bool
:param compression_model_kwargs: kwargs for the model class or function to build the model
:type compression_model_kwargs: dict
:return: `compressai` model
:rtype: nn.Module
"""
compression_model = COMPRESSAI_DICT[compression_model_name](**compression_model_kwargs)
if ckpt_file_path is not None:
load_ckpt(ckpt_file_path, model=compression_model, strict=None)
Expand All @@ -43,6 +81,16 @@ def get_compressai_model(compression_model_name, ckpt_file_path=None, updates=Fa


def get_compression_model(compression_model_config, device):
"""
Gets a compression model.
:param compression_model_config: compression model configuration
:type compression_model_config: dict
:param device: torch device
:type device: str or torch.device
:return: compression model
:rtype: nn.Module
"""
if compression_model_config is None:
return None

Expand All @@ -58,6 +106,21 @@ def get_compression_model(compression_model_config, device):


def load_classification_model(model_config, device, distributed, strict=True):
"""
Loads an image classification model.
:param model_config: image classification model configuration
:type model_config: dict
:param device: torch device
:type device: str or torch.device
:param distributed: whether to use the model in distributed training mode
:type distributed: bool
:param strict: whether to strictly enforce that the keys in state_dict match the keys returned by the model’s
`state_dict()` function
:type strict: bool
:return: image classification model
:rtype: nn.Module
"""
model = get_image_classification_model(model_config, distributed)
model_name = model_config['name']
if model is None and model_name in timm.models.__dict__:
Expand Down
4 changes: 1 addition & 3 deletions sc2bench/models/segmentation/wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,9 +126,7 @@ def get_wrapped_segmentation_model(wrapper_model_config, device):
:type wrapper_model_config: dict
:param device: torch device
:type device: torch.device
:return: model: wrapped semantic segmentation model
:rtype: model: nn.Module
:return: semantic segmentation model
:return: wrapped semantic segmentation model
:rtype: nn.Module
"""
wrapper_model_name = wrapper_model_config['name']
Expand Down
Loading

0 comments on commit 7669b8f

Please sign in to comment.