Skip to content

Commit

Permalink
Renaming and moving some functionality into base Factory class
Browse files Browse the repository at this point in the history
  • Loading branch information
marksgraham committed Oct 24, 2023
1 parent dfaa6a0 commit 99ecb5a
Show file tree
Hide file tree
Showing 4 changed files with 88 additions and 163 deletions.
94 changes: 40 additions & 54 deletions monai/networks/layers/factories.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def use_factory(fact_args):

import warnings
from collections.abc import Callable
from typing import Any
from typing import Any, Iterable

import torch.nn as nn

Expand All @@ -82,15 +82,7 @@ class LayerFactory(Factory):
def __init__(self) -> None:
self.factories: dict[str, Callable] = {}

@property
def names(self) -> tuple[str, ...]:
"""
Produces all factory names.
"""

return tuple(self.factories)

def add_factory_callable(self, name: str, func: Callable) -> None:
def add(self, name: str, func: Callable) -> None:
"""
Add the factory function to this object under the given name.
"""
Expand All @@ -103,17 +95,6 @@ def add_factory_callable(self, name: str, func: Callable) -> None:
+ ".\nPlease see :py:class:`monai.networks.layers.split_args` for additional args parsing."
)

def factory_function(self, name: str) -> Callable:
"""
Decorator for adding a factory function with the given name.
"""

def _add(func: Callable) -> Callable:
self.add_factory_callable(name, func)
return func

return _add

def get_constructor(self, factory_name: str, *args) -> Any:
"""
Get the constructor for the given factory name and arguments.
Expand Down Expand Up @@ -158,6 +139,11 @@ def __getattr__(self, key):

return super().__getattribute__(key)

def __iter__(self) -> Iterable:
"""Yields name/component pairs."""
for k, v in self.factories.items():
yield k, v


def split_args(args):
"""
Expand Down Expand Up @@ -203,50 +189,50 @@ def split_args(args):
Pad = LayerFactory()


@Dropout.factory_function("dropout")
@Dropout.factory_item("dropout")
def dropout_factory(dim: int) -> type[nn.Dropout | nn.Dropout2d | nn.Dropout3d]:
types = (nn.Dropout, nn.Dropout2d, nn.Dropout3d)
return types[dim - 1]


@Dropout.factory_function("alphadropout")
@Dropout.factory_item("alphadropout")
def alpha_dropout_factory(_dim):
return nn.AlphaDropout


@Norm.factory_function("instance")
@Norm.factory_item("instance")
def instance_factory(dim: int) -> type[nn.InstanceNorm1d | nn.InstanceNorm2d | nn.InstanceNorm3d]:
types = (nn.InstanceNorm1d, nn.InstanceNorm2d, nn.InstanceNorm3d)
return types[dim - 1]


@Norm.factory_function("batch")
@Norm.factory_item("batch")
def batch_factory(dim: int) -> type[nn.BatchNorm1d | nn.BatchNorm2d | nn.BatchNorm3d]:
types = (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d)
return types[dim - 1]


@Norm.factory_function("group")
@Norm.factory_item("group")
def group_factory(_dim) -> type[nn.GroupNorm]:
return nn.GroupNorm


@Norm.factory_function("layer")
@Norm.factory_item("layer")
def layer_factory(_dim) -> type[nn.LayerNorm]:
return nn.LayerNorm


@Norm.factory_function("localresponse")
@Norm.factory_item("localresponse")
def local_response_factory(_dim) -> type[nn.LocalResponseNorm]:
return nn.LocalResponseNorm


@Norm.factory_function("syncbatch")
@Norm.factory_item("syncbatch")
def sync_batch_factory(_dim) -> type[nn.SyncBatchNorm]:
return nn.SyncBatchNorm


@Norm.factory_function("instance_nvfuser")
@Norm.factory_item("instance_nvfuser")
def instance_nvfuser_factory(dim):
"""
`InstanceNorm3dNVFuser` is a faster version of InstanceNorm layer and implemented in `apex`.
Expand Down Expand Up @@ -274,91 +260,91 @@ def instance_nvfuser_factory(dim):
return optional_import("apex.normalization", name="InstanceNorm3dNVFuser")[0]


Act.add_factory_callable("elu", lambda: nn.modules.ELU)
Act.add_factory_callable("relu", lambda: nn.modules.ReLU)
Act.add_factory_callable("leakyrelu", lambda: nn.modules.LeakyReLU)
Act.add_factory_callable("prelu", lambda: nn.modules.PReLU)
Act.add_factory_callable("relu6", lambda: nn.modules.ReLU6)
Act.add_factory_callable("selu", lambda: nn.modules.SELU)
Act.add_factory_callable("celu", lambda: nn.modules.CELU)
Act.add_factory_callable("gelu", lambda: nn.modules.GELU)
Act.add_factory_callable("sigmoid", lambda: nn.modules.Sigmoid)
Act.add_factory_callable("tanh", lambda: nn.modules.Tanh)
Act.add_factory_callable("softmax", lambda: nn.modules.Softmax)
Act.add_factory_callable("logsoftmax", lambda: nn.modules.LogSoftmax)
Act.add("elu", lambda: nn.modules.ELU)
Act.add("relu", lambda: nn.modules.ReLU)
Act.add("leakyrelu", lambda: nn.modules.LeakyReLU)
Act.add("prelu", lambda: nn.modules.PReLU)
Act.add("relu6", lambda: nn.modules.ReLU6)
Act.add("selu", lambda: nn.modules.SELU)
Act.add("celu", lambda: nn.modules.CELU)
Act.add("gelu", lambda: nn.modules.GELU)
Act.add("sigmoid", lambda: nn.modules.Sigmoid)
Act.add("tanh", lambda: nn.modules.Tanh)
Act.add("softmax", lambda: nn.modules.Softmax)
Act.add("logsoftmax", lambda: nn.modules.LogSoftmax)


@Act.factory_function("swish")
@Act.factory_item("swish")
def swish_factory():
from monai.networks.blocks.activation import Swish

return Swish


@Act.factory_function("memswish")
@Act.factory_item("memswish")
def memswish_factory():
from monai.networks.blocks.activation import MemoryEfficientSwish

return MemoryEfficientSwish


@Act.factory_function("mish")
@Act.factory_item("mish")
def mish_factory():
from monai.networks.blocks.activation import Mish

return Mish


@Act.factory_function("geglu")
@Act.factory_item("geglu")
def geglu_factory():
from monai.networks.blocks.activation import GEGLU

return GEGLU


@Conv.factory_function("conv")
@Conv.factory_item("conv")
def conv_factory(dim: int) -> type[nn.Conv1d | nn.Conv2d | nn.Conv3d]:
types = (nn.Conv1d, nn.Conv2d, nn.Conv3d)
return types[dim - 1]


@Conv.factory_function("convtrans")
@Conv.factory_item("convtrans")
def convtrans_factory(dim: int) -> type[nn.ConvTranspose1d | nn.ConvTranspose2d | nn.ConvTranspose3d]:
types = (nn.ConvTranspose1d, nn.ConvTranspose2d, nn.ConvTranspose3d)
return types[dim - 1]


@Pool.factory_function("max")
@Pool.factory_item("max")
def maxpooling_factory(dim: int) -> type[nn.MaxPool1d | nn.MaxPool2d | nn.MaxPool3d]:
types = (nn.MaxPool1d, nn.MaxPool2d, nn.MaxPool3d)
return types[dim - 1]


@Pool.factory_function("adaptivemax")
@Pool.factory_item("adaptivemax")
def adaptive_maxpooling_factory(dim: int) -> type[nn.AdaptiveMaxPool1d | nn.AdaptiveMaxPool2d | nn.AdaptiveMaxPool3d]:
types = (nn.AdaptiveMaxPool1d, nn.AdaptiveMaxPool2d, nn.AdaptiveMaxPool3d)
return types[dim - 1]


@Pool.factory_function("avg")
@Pool.factory_item("avg")
def avgpooling_factory(dim: int) -> type[nn.AvgPool1d | nn.AvgPool2d | nn.AvgPool3d]:
types = (nn.AvgPool1d, nn.AvgPool2d, nn.AvgPool3d)
return types[dim - 1]


@Pool.factory_function("adaptiveavg")
@Pool.factory_item("adaptiveavg")
def adaptive_avgpooling_factory(dim: int) -> type[nn.AdaptiveAvgPool1d | nn.AdaptiveAvgPool2d | nn.AdaptiveAvgPool3d]:
types = (nn.AdaptiveAvgPool1d, nn.AdaptiveAvgPool2d, nn.AdaptiveAvgPool3d)
return types[dim - 1]


@Pad.factory_function("replicationpad")
@Pad.factory_item("replicationpad")
def replication_pad_factory(dim: int) -> type[nn.ReplicationPad1d | nn.ReplicationPad2d | nn.ReplicationPad3d]:
types = (nn.ReplicationPad1d, nn.ReplicationPad2d, nn.ReplicationPad3d)
return types[dim - 1]


@Pad.factory_function("constantpad")
@Pad.factory_item("constantpad")
def constant_pad_factory(dim: int) -> type[nn.ConstantPad1d | nn.ConstantPad2d | nn.ConstantPad3d]:
types = (nn.ConstantPad1d, nn.ConstantPad2d, nn.ConstantPad3d)
return types[dim - 1]
39 changes: 12 additions & 27 deletions monai/utils/component_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,12 @@
from collections import namedtuple
from keyword import iskeyword
from textwrap import dedent, indent
from typing import Any, Callable, Iterable, TypeVar
from typing import Any, Iterable, TypeVar

T = TypeVar("T")
from monai.utils.factory import Factory

T = TypeVar("T")


def is_variable(name):
"""Returns True if `name` is a valid Python variable name and also not a keyword."""
Expand Down Expand Up @@ -54,45 +55,29 @@ def _my_func(a, b):
_Component = namedtuple("Component", ("description", "value")) # internal value pair

def __init__(self, name: str, description: str) -> None:
self.components: dict[str, self._Component] = {}
self.factories: dict[str, self._Component] = {}
self.name: str = name
self.description: str = description

self.__doc__ = f"Component Store '{name}': {description}\n{self.__doc__ or ''}".strip()

def add(self, name: str, desc: str, value: T) -> T:
"""Store the object `value` under the name `name` with description `desc`."""

if not is_variable(name):
raise ValueError("Name of component must be valid Python identifier")

self.components[name] = self._Component(desc, value)
self.factories[name] = self._Component(desc, value)
return value

def add_def(self, name: str, desc: str) -> Callable:
"""Returns a decorator which stores the decorated function under `name` with description `desc`."""

def deco(func):
"""Decorator to add a function to a store."""
return self.add(name, desc, func)

return deco

def __contains__(self, name: str) -> bool:
"""Returns True if the given name is stored."""
return name in self.components

def __len__(self) -> int:
"""Returns the number of stored components."""
return len(self.components)

def __iter__(self) -> Iterable:
"""Yields name/component pairs."""
for k, v in self.components.items():
for k, v in self.factories.items():
yield k, v.value

def __str__(self):
result = f"Component Store '{self.name}': {self.description}\nAvailable components:"
for k, v in self.components.items():
for k, v in self.factories.items():
result += f"\n* {k}:"

if hasattr(v.value, "__doc__"):
Expand All @@ -105,14 +90,14 @@ def __str__(self):

def __getattr__(self, name: str) -> Any:
"""Returns the stored object under the given name."""
if name in self.components:
return self.components[name].value
if name in self.factories:
return self.factories[name].value
else:
return self.__getattribute__(name)

def __getitem__(self, name: str) -> Any:
"""Returns the stored object under the given name."""
if name in self.components:
return self.components[name].value
if name in self.factories:
return self.factories[name].value
else:
raise ValueError(f"Component '{name}' not found")
Loading

0 comments on commit 99ecb5a

Please sign in to comment.