Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make minor updates #127

Merged
merged 1 commit into from
Jul 23, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion sc2bench/common/config_util.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
def overwrite_config(org_config, sub_config):
"""
Overwrites a configuration (dict).
Overwrites a configuration.

:param org_config: (nested) dictionary of configuration to be updated.
:type org_config: dict
Expand Down
14 changes: 7 additions & 7 deletions sc2bench/models/detection/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,11 @@

def register_detection_model_class(cls):
"""
Registers an object detection model
Registers an object detection model class.

:param cls: object detection model to be registered
:param cls: object detection model class to be registered
:type cls: class
:return: object detection model
:return: registered object detection model class
:rtype: class
"""
DETECTION_MODEL_CLASS_DICT[cls.__name__] = cls
Expand All @@ -25,11 +25,11 @@ def register_detection_model_class(cls):

def register_detection_model_func(func):
"""
Registers a function to build an object detection model
Registers a function to build an object detection model.

:param func: function to build an object detection model to be registered
:type func: typing.Callable
:return: function to build an object detection model
:return: registered function
:rtype: typing.Callable
"""
DETECTION_MODEL_FUNC_DICT[func.__name__] = func
Expand All @@ -39,7 +39,7 @@ def register_detection_model_func(func):

def get_detection_model(cls_or_func_name, **kwargs):
"""
Gets an object detection model
Gets an object detection model.

:param cls_or_func_name: model class or function name
:type cls_or_func_name: str
Expand All @@ -55,7 +55,7 @@ def get_detection_model(cls_or_func_name, **kwargs):

def load_detection_model(model_config, device, strict=True):
"""
Loads an object detection model
Loads an object detection model.

:param model_config: model configuration
:type model_config: dict
Expand Down
2 changes: 1 addition & 1 deletion sc2bench/models/detection/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

class RCNNTransformWithCompression(GeneralizedRCNNTransform, AnalyzableModule):
"""
An R-CNN Transform with codec-based or model-based compression
An R-CNN Transform with codec-based or model-based compression.

:param transform: performs the data transformation from the inputs to feed into the model
:type transform: nn.Module
Expand Down
2 changes: 1 addition & 1 deletion sc2bench/models/detection/wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def clear_analysis(self):

def get_wrapped_detection_model(wrapper_model_config, device):
"""
Gets a wrapped object detection model
Gets a wrapped object detection model.

:param wrapper_model_config: wrapper model configuration
:type wrapper_model_config: dict
Expand Down
12 changes: 6 additions & 6 deletions sc2bench/models/segmentation/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,11 @@

def register_segmentation_model_class(cls):
"""
Registers a semantic segmentation model
Registers a semantic segmentation model class.

:param cls: semantic segmentation model to be registered
:type cls: class
:return: semantic segmentation model
:return: registered semantic segmentation model class
:rtype: class
"""
SEGMENTATION_MODEL_CLASS_DICT[cls.__name__] = cls
Expand All @@ -25,11 +25,11 @@ def register_segmentation_model_class(cls):

def register_segmentation_model_func(func):
"""
Registers a function to build a semantic segmentation model
Registers a function to build a semantic segmentation model.

:param func: function to build a semantic segmentation model to be registered
:type func: typing.Callable
:return: function to build a semantic segmentation model
:return: registered function
:rtype: typing.Callable
"""
SEGMENTATION_MODEL_FUNC_DICT[func.__name__] = func
Expand All @@ -39,7 +39,7 @@ def register_segmentation_model_func(func):

def get_segmentation_model(cls_or_func_name, **kwargs):
"""
Gets a semantic segmentation model
Gets a semantic segmentation model.

:param cls_or_func_name: model class or function name
:type cls_or_func_name: str
Expand All @@ -55,7 +55,7 @@ def get_segmentation_model(cls_or_func_name, **kwargs):

def load_segmentation_model(model_config, device, strict=True):
"""
Loads a semantic segmentation model
Loads a semantic segmentation model.

:param model_config: model configuration
:type model_config: dict
Expand Down
2 changes: 1 addition & 1 deletion sc2bench/models/segmentation/wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ def forward(self, x):

def get_wrapped_segmentation_model(wrapper_model_config, device):
"""
Gets a wrapped semantic segmentation model
Gets a wrapped semantic segmentation model.

:param wrapper_model_config: wrapper model configuration
:type wrapper_model_config: dict
Expand Down