Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] 7145 common factory class #7159

Merged
merged 17 commits into from
Oct 27, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
235 changes: 171 additions & 64 deletions monai/networks/layers/factories.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,40 +68,40 @@ def use_factory(fact_args):
import torch.nn as nn

from monai.networks.utils import has_nvfuser_instance_norm
from monai.utils import look_up_option, optional_import
from monai.utils import ComponentStore, look_up_option, optional_import

__all__ = ["LayerFactory", "Dropout", "Norm", "Act", "Conv", "Pool", "Pad", "split_args"]


class LayerFactory:
class LayerFactory(ComponentStore):
"""
Factory object for creating layers, this uses given factory functions to actually produce the types or constructing
callables. These functions are referred to by name and can be added at any time.
"""

def __init__(self) -> None:
self.factories: dict[str, Callable] = {}
def __init__(self, name: str, description: str) -> None:
super().__init__(name, description)
self.__doc__ = (
f"Layer Factory '{name}': {description}\n".strip()
+ "\nPlease see :py:class:`monai.networks.layers.split_args` for additional args parsing."
+ "\n\nThe supported members are:"
)

@property
def names(self) -> tuple[str, ...]:
def add_factory_callable(self, name: str, func: Callable, desc: str | None = None) -> None:
marksgraham marked this conversation as resolved.
Show resolved Hide resolved
"""
Produces all factory names.
Add the factory function to this object under the given name, with optional description.
"""
description: str = desc or func.__doc__ or ""
self.add(name.upper(), description, func)
# append name to the docstring
assert self.__doc__ is not None
self.__doc__ += f"{', ' if len(self.names)>1 else ' '}``{name}``"

return tuple(self.factories)

def add_factory_callable(self, name: str, func: Callable) -> None:
marksgraham marked this conversation as resolved.
Show resolved Hide resolved
def add_factory_class(self, name: str, cls: type, desc: str | None = None) -> None:
marksgraham marked this conversation as resolved.
Show resolved Hide resolved
"""
Add the factory function to this object under the given name.
Adds a factory function which returns the supplied class under the given name, with optional description.
"""

self.factories[name.upper()] = func
self.__doc__ = (
"The supported member"
+ ("s are: " if len(self.names) > 1 else " is: ")
+ ", ".join(f"``{name}``" for name in self.names)
+ ".\nPlease see :py:class:`monai.networks.layers.split_args` for additional args parsing."
)
self.add_factory_callable(name, lambda x=None: cls, desc)

def factory_function(self, name: str) -> Callable:
marksgraham marked this conversation as resolved.
Show resolved Hide resolved
"""
Expand All @@ -126,8 +126,9 @@ def get_constructor(self, factory_name: str, *args) -> Any:
if not isinstance(factory_name, str):
raise TypeError(f"factory_name must a str but is {type(factory_name).__name__}.")

func = look_up_option(factory_name.upper(), self.factories)
return func(*args)
component = look_up_option(factory_name.upper(), self.components)

return component.value(*args)

def __getitem__(self, args) -> Any:
"""
Expand All @@ -153,7 +154,7 @@ def __getattr__(self, key):
as if they were constants, eg. `Fact.FOO` for a factory Fact with factory function foo.
"""

if key in self.factories:
if key in self.components:
return key

return super().__getattribute__(key)
Expand Down Expand Up @@ -194,56 +195,60 @@ def split_args(args):


# Define factories for these layer types

Dropout = LayerFactory()
Norm = LayerFactory()
Act = LayerFactory()
Conv = LayerFactory()
Pool = LayerFactory()
Pad = LayerFactory()
Dropout = LayerFactory(name="Dropout layers", description="Factory for creating dropout layers.")
Norm = LayerFactory(name="Normalization layers", description="Factory for creating normalization layers.")
Act = LayerFactory(name="Activation layers", description="Factory for creating activation layers.")
Conv = LayerFactory(name="Convolution layers", description="Factory for creating convolution layers.")
Pool = LayerFactory(name="Pooling layers", description="Factory for creating pooling layers.")
Pad = LayerFactory(name="Padding layers", description="Factory for creating padding layers.")


@Dropout.factory_function("dropout")
def dropout_factory(dim: int) -> type[nn.Dropout | nn.Dropout2d | nn.Dropout3d]:
"""
Dropout layers in 1,2,3 dimensions.

Args:
dim: desired dimension of the dropout layer

Returns:
Dropout[dim]d
"""
types = (nn.Dropout, nn.Dropout2d, nn.Dropout3d)
return types[dim - 1]


@Dropout.factory_function("alphadropout")
def alpha_dropout_factory(_dim):
return nn.AlphaDropout
Dropout.add_factory_class("alphadropout", nn.AlphaDropout)


@Norm.factory_function("instance")
def instance_factory(dim: int) -> type[nn.InstanceNorm1d | nn.InstanceNorm2d | nn.InstanceNorm3d]:
"""
Instance normalization layers in 1,2,3 dimensions.

Args:
dim: desired dimension of the instance normalization layer

Returns:
InstanceNorm[dim]d
"""
types = (nn.InstanceNorm1d, nn.InstanceNorm2d, nn.InstanceNorm3d)
return types[dim - 1]


@Norm.factory_function("batch")
def batch_factory(dim: int) -> type[nn.BatchNorm1d | nn.BatchNorm2d | nn.BatchNorm3d]:
types = (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d)
return types[dim - 1]


@Norm.factory_function("group")
def group_factory(_dim) -> type[nn.GroupNorm]:
return nn.GroupNorm


@Norm.factory_function("layer")
def layer_factory(_dim) -> type[nn.LayerNorm]:
return nn.LayerNorm


@Norm.factory_function("localresponse")
def local_response_factory(_dim) -> type[nn.LocalResponseNorm]:
return nn.LocalResponseNorm
"""
Batch normalization layers in 1,2,3 dimensions.

Args:
dim: desired dimension of the batch normalization layer

@Norm.factory_function("syncbatch")
def sync_batch_factory(_dim) -> type[nn.SyncBatchNorm]:
return nn.SyncBatchNorm
Returns:
BatchNorm[dim]d
"""
types = (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d)
return types[dim - 1]


@Norm.factory_function("instance_nvfuser")
Expand Down Expand Up @@ -274,91 +279,193 @@ def instance_nvfuser_factory(dim):
return optional_import("apex.normalization", name="InstanceNorm3dNVFuser")[0]


Act.add_factory_callable("elu", lambda: nn.modules.ELU)
Act.add_factory_callable("relu", lambda: nn.modules.ReLU)
Act.add_factory_callable("leakyrelu", lambda: nn.modules.LeakyReLU)
Act.add_factory_callable("prelu", lambda: nn.modules.PReLU)
Act.add_factory_callable("relu6", lambda: nn.modules.ReLU6)
Act.add_factory_callable("selu", lambda: nn.modules.SELU)
Act.add_factory_callable("celu", lambda: nn.modules.CELU)
Act.add_factory_callable("gelu", lambda: nn.modules.GELU)
Act.add_factory_callable("sigmoid", lambda: nn.modules.Sigmoid)
Act.add_factory_callable("tanh", lambda: nn.modules.Tanh)
Act.add_factory_callable("softmax", lambda: nn.modules.Softmax)
Act.add_factory_callable("logsoftmax", lambda: nn.modules.LogSoftmax)
Norm.add_factory_class("group", nn.GroupNorm)
Norm.add_factory_class("layer", nn.LayerNorm)
Norm.add_factory_class("localresponse", nn.LocalResponseNorm)
Norm.add_factory_class("syncbatch", nn.SyncBatchNorm)


Act.add_factory_class("elu", nn.modules.ELU)
Act.add_factory_class("relu", nn.modules.ReLU)
Act.add_factory_class("leakyrelu", nn.modules.LeakyReLU)
Act.add_factory_class("prelu", nn.modules.PReLU)
Act.add_factory_class("relu6", nn.modules.ReLU6)
Act.add_factory_class("selu", nn.modules.SELU)
Act.add_factory_class("celu", nn.modules.CELU)
Act.add_factory_class("gelu", nn.modules.GELU)
Act.add_factory_class("sigmoid", nn.modules.Sigmoid)
Act.add_factory_class("tanh", nn.modules.Tanh)
Act.add_factory_class("softmax", nn.modules.Softmax)
Act.add_factory_class("logsoftmax", nn.modules.LogSoftmax)


@Act.factory_function("swish")
def swish_factory():
"""
Swish activation layer.

Returns:
Swish
"""
from monai.networks.blocks.activation import Swish

return Swish


@Act.factory_function("memswish")
def memswish_factory():
"""
Memory efficient swish activation layer.

Returns:
MemoryEfficientSwish
"""
from monai.networks.blocks.activation import MemoryEfficientSwish

return MemoryEfficientSwish


@Act.factory_function("mish")
def mish_factory():
"""
Mish activation layer.

Returns:
Mish
"""
from monai.networks.blocks.activation import Mish

return Mish


@Act.factory_function("geglu")
def geglu_factory():
"""
GEGLU activation layer.

Returns:
GEGLU
"""
from monai.networks.blocks.activation import GEGLU

return GEGLU


@Conv.factory_function("conv")
def conv_factory(dim: int) -> type[nn.Conv1d | nn.Conv2d | nn.Conv3d]:
"""
Convolutional layers in 1,2,3 dimensions.

Args:
dim: desired dimension of the convolutional layer

Returns:
Conv[dim]d
"""
types = (nn.Conv1d, nn.Conv2d, nn.Conv3d)
return types[dim - 1]


@Conv.factory_function("convtrans")
def convtrans_factory(dim: int) -> type[nn.ConvTranspose1d | nn.ConvTranspose2d | nn.ConvTranspose3d]:
"""
Transposed convolutional layers in 1,2,3 dimensions.

Args:
dim: desired dimension of the transposed convolutional layer

Returns:
ConvTranspose[dim]d
"""
types = (nn.ConvTranspose1d, nn.ConvTranspose2d, nn.ConvTranspose3d)
return types[dim - 1]


@Pool.factory_function("max")
def maxpooling_factory(dim: int) -> type[nn.MaxPool1d | nn.MaxPool2d | nn.MaxPool3d]:
"""
Max pooling layers in 1,2,3 dimensions.

Args:
dim: desired dimension of the max pooling layer

Returns:
MaxPool[dim]d
"""
types = (nn.MaxPool1d, nn.MaxPool2d, nn.MaxPool3d)
return types[dim - 1]


@Pool.factory_function("adaptivemax")
def adaptive_maxpooling_factory(dim: int) -> type[nn.AdaptiveMaxPool1d | nn.AdaptiveMaxPool2d | nn.AdaptiveMaxPool3d]:
"""
Adaptive max pooling layers in 1,2,3 dimensions.

Args:
dim: desired dimension of the adaptive max pooling layer

Returns:
AdaptiveMaxPool[dim]d
"""
types = (nn.AdaptiveMaxPool1d, nn.AdaptiveMaxPool2d, nn.AdaptiveMaxPool3d)
return types[dim - 1]


@Pool.factory_function("avg")
def avgpooling_factory(dim: int) -> type[nn.AvgPool1d | nn.AvgPool2d | nn.AvgPool3d]:
"""
Average pooling layers in 1,2,3 dimensions.

Args:
dim: desired dimension of the average pooling layer

Returns:
AvgPool[dim]d
"""
types = (nn.AvgPool1d, nn.AvgPool2d, nn.AvgPool3d)
return types[dim - 1]


@Pool.factory_function("adaptiveavg")
def adaptive_avgpooling_factory(dim: int) -> type[nn.AdaptiveAvgPool1d | nn.AdaptiveAvgPool2d | nn.AdaptiveAvgPool3d]:
"""
Adaptive average pooling layers in 1,2,3 dimensions.

Args:
dim: desired dimension of the adaptive average pooling layer

Returns:
AdaptiveAvgPool[dim]d
"""
types = (nn.AdaptiveAvgPool1d, nn.AdaptiveAvgPool2d, nn.AdaptiveAvgPool3d)
return types[dim - 1]


@Pad.factory_function("replicationpad")
def replication_pad_factory(dim: int) -> type[nn.ReplicationPad1d | nn.ReplicationPad2d | nn.ReplicationPad3d]:
"""
Replication padding layers in 1,2,3 dimensions.

Args:
dim: desired dimension of the replication padding layer

Returns:
ReplicationPad[dim]d
"""
types = (nn.ReplicationPad1d, nn.ReplicationPad2d, nn.ReplicationPad3d)
return types[dim - 1]


@Pad.factory_function("constantpad")
def constant_pad_factory(dim: int) -> type[nn.ConstantPad1d | nn.ConstantPad2d | nn.ConstantPad3d]:
"""
Constant padding layers in 1,2,3 dimensions.

Args:
dim: desired dimension of the constant padding layer

Returns:
ConstantPad[dim]d
"""
types = (nn.ConstantPad1d, nn.ConstantPad2d, nn.ConstantPad3d)
return types[dim - 1]
1 change: 1 addition & 0 deletions monai/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

# have to explicitly bring these in here to resolve circular import issues
from .aliases import alias, resolve_name
from .component_store import ComponentStore
from .decorators import MethodReplacer, RestartGenerator
from .deprecate_utils import DeprecatedError, deprecated, deprecated_arg, deprecated_arg_default
from .dist import RankFilter, evenly_divisible_all_gather, get_dist_device, string_list_all_gather
Expand Down
Loading