Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add docstrings to transforms package #128

Merged
merged 4 commits into from
Jul 23, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
249 changes: 247 additions & 2 deletions sc2bench/models/backbone.py

Large diffs are not rendered by default.

273 changes: 232 additions & 41 deletions sc2bench/models/layer.py

Large diffs are not rendered by default.

65 changes: 64 additions & 1 deletion sc2bench/models/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,22 +16,60 @@
COMPRESSION_MODEL_FUNC_DICT = dict()


def register_compressai_model_class(cls_or_func):
def register_compressai_model(cls_or_func):
"""
Registers a compression model class or a function to build a compression model in `compressai`.

:param cls_or_func: compression model or function to build a compression model to be registered
:type cls_or_func: class or typing.Callable
:return: registered compression model class or function
:rtype: class or typing.Callable
"""
COMPRESSAI_DICT[cls_or_func.__name__] = cls_or_func
return cls_or_func


def register_compression_model_class(cls):
"""
Registers a compression model class.

:param cls: compression model to be registered
:type cls: class
:return: registered compression model class
:rtype: class
"""
COMPRESSION_MODEL_CLASS_DICT[cls.__name__] = cls
return cls


def register_compression_model_func(func):
"""
Registers a function to build a compression model.

:param func: function to build a compression model to be registered
:type func: typing.Callable
:return: registered function
:rtype: typing.Callable
"""
COMPRESSION_MODEL_FUNC_DICT[func.__name__] = func
return func


def get_compressai_model(compression_model_name, ckpt_file_path=None, updates=False, **compression_model_kwargs):
"""
Gets a model in `compressai`.

:param compression_model_name: `compressai` model name
:type compression_model_name: str
:param ckpt_file_path: checkpoint file path
:type ckpt_file_path: str or None
:param updates: if True, updates the parameters for entropy coding
:type updates: bool
:param compression_model_kwargs: kwargs for the model class or function to build the model
:type compression_model_kwargs: dict
:return: `compressai` model
:rtype: nn.Module
"""
compression_model = COMPRESSAI_DICT[compression_model_name](**compression_model_kwargs)
if ckpt_file_path is not None:
load_ckpt(ckpt_file_path, model=compression_model, strict=None)
Expand All @@ -43,6 +81,16 @@ def get_compressai_model(compression_model_name, ckpt_file_path=None, updates=Fa


def get_compression_model(compression_model_config, device):
"""
Gets a compression model.

:param compression_model_config: compression model configuration
:type compression_model_config: dict
:param device: torch device
:type device: str or torch.device
:return: compression model
:rtype: nn.Module
"""
if compression_model_config is None:
return None

Expand All @@ -58,6 +106,21 @@ def get_compression_model(compression_model_config, device):


def load_classification_model(model_config, device, distributed, strict=True):
"""
Loads an image classification model.

:param model_config: image classification model configuration
:type model_config: dict
:param device: torch device
:type device: str or torch.device
:param distributed: whether to use the model in distributed training mode
:type distributed: bool
:param strict: whether to strictly enforce that the keys in state_dict match the keys returned by the model’s
`state_dict()` function
:type strict: bool
:return: image classification model
:rtype: nn.Module
"""
model = get_image_classification_model(model_config, distributed)
model_name = model_config['name']
if model is None and model_name in timm.models.__dict__:
Expand Down
4 changes: 1 addition & 3 deletions sc2bench/models/segmentation/wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,9 +126,7 @@ def get_wrapped_segmentation_model(wrapper_model_config, device):
:type wrapper_model_config: dict
:param device: torch device
:type device: torch.device
:return: model: wrapped semantic segmentation model
:rtype: model: nn.Module
:return: semantic segmentation model
:return: wrapped semantic segmentation model
:rtype: nn.Module
"""
wrapper_model_name = wrapper_model_config['name']
Expand Down
Loading