Skip to content

Commit

Permalink
Revert to component store and layer factory as they were
Browse files Browse the repository at this point in the history
  • Loading branch information
marksgraham committed Oct 25, 2023
1 parent 99ecb5a commit 8ebd0b6
Show file tree
Hide file tree
Showing 3 changed files with 85 additions and 57 deletions.
98 changes: 56 additions & 42 deletions monai/networks/layers/factories.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,17 +63,17 @@ def use_factory(fact_args):

import warnings
from collections.abc import Callable
from typing import Any, Iterable
from typing import Any

import torch.nn as nn

from monai.networks.utils import has_nvfuser_instance_norm
from monai.utils import Factory, look_up_option, optional_import
from monai.utils import look_up_option, optional_import

__all__ = ["LayerFactory", "Dropout", "Norm", "Act", "Conv", "Pool", "Pad", "split_args"]


class LayerFactory(Factory):
class LayerFactory:
"""
Factory object for creating layers, this uses given factory functions to actually produce the types or constructing
callables. These functions are referred to by name and can be added at any time.
Expand All @@ -82,7 +82,15 @@ class LayerFactory(Factory):
def __init__(self) -> None:
self.factories: dict[str, Callable] = {}

def add(self, name: str, func: Callable) -> None:
@property
def names(self) -> tuple[str, ...]:
"""
Produces all factory names.
"""

return tuple(self.factories)

def add_factory_callable(self, name: str, func: Callable) -> None:
"""
Add the factory function to this object under the given name.
"""
Expand All @@ -95,6 +103,17 @@ def add(self, name: str, func: Callable) -> None:
+ ".\nPlease see :py:class:`monai.networks.layers.split_args` for additional args parsing."
)

def factory_function(self, name: str) -> Callable:
"""
Decorator for adding a factory function with the given name.
"""

def _add(func: Callable) -> Callable:
self.add_factory_callable(name, func)
return func

return _add

def get_constructor(self, factory_name: str, *args) -> Any:
"""
Get the constructor for the given factory name and arguments.
Expand Down Expand Up @@ -139,11 +158,6 @@ def __getattr__(self, key):

return super().__getattribute__(key)

def __iter__(self) -> Iterable:
"""Yields name/component pairs."""
for k, v in self.factories.items():
yield k, v


def split_args(args):
"""
Expand Down Expand Up @@ -189,50 +203,50 @@ def split_args(args):
Pad = LayerFactory()


@Dropout.factory_item("dropout")
@Dropout.factory_function("dropout")
def dropout_factory(dim: int) -> type[nn.Dropout | nn.Dropout2d | nn.Dropout3d]:
types = (nn.Dropout, nn.Dropout2d, nn.Dropout3d)
return types[dim - 1]


@Dropout.factory_item("alphadropout")
@Dropout.factory_function("alphadropout")
def alpha_dropout_factory(_dim):
return nn.AlphaDropout


@Norm.factory_item("instance")
@Norm.factory_function("instance")
def instance_factory(dim: int) -> type[nn.InstanceNorm1d | nn.InstanceNorm2d | nn.InstanceNorm3d]:
types = (nn.InstanceNorm1d, nn.InstanceNorm2d, nn.InstanceNorm3d)
return types[dim - 1]


@Norm.factory_item("batch")
@Norm.factory_function("batch")
def batch_factory(dim: int) -> type[nn.BatchNorm1d | nn.BatchNorm2d | nn.BatchNorm3d]:
types = (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d)
return types[dim - 1]


@Norm.factory_item("group")
@Norm.factory_function("group")
def group_factory(_dim) -> type[nn.GroupNorm]:
return nn.GroupNorm


@Norm.factory_item("layer")
@Norm.factory_function("layer")
def layer_factory(_dim) -> type[nn.LayerNorm]:
return nn.LayerNorm


@Norm.factory_item("localresponse")
@Norm.factory_function("localresponse")
def local_response_factory(_dim) -> type[nn.LocalResponseNorm]:
return nn.LocalResponseNorm


@Norm.factory_item("syncbatch")
@Norm.factory_function("syncbatch")
def sync_batch_factory(_dim) -> type[nn.SyncBatchNorm]:
return nn.SyncBatchNorm


@Norm.factory_item("instance_nvfuser")
@Norm.factory_function("instance_nvfuser")
def instance_nvfuser_factory(dim):
"""
`InstanceNorm3dNVFuser` is a faster version of InstanceNorm layer and implemented in `apex`.
Expand Down Expand Up @@ -260,91 +274,91 @@ def instance_nvfuser_factory(dim):
return optional_import("apex.normalization", name="InstanceNorm3dNVFuser")[0]


Act.add("elu", lambda: nn.modules.ELU)
Act.add("relu", lambda: nn.modules.ReLU)
Act.add("leakyrelu", lambda: nn.modules.LeakyReLU)
Act.add("prelu", lambda: nn.modules.PReLU)
Act.add("relu6", lambda: nn.modules.ReLU6)
Act.add("selu", lambda: nn.modules.SELU)
Act.add("celu", lambda: nn.modules.CELU)
Act.add("gelu", lambda: nn.modules.GELU)
Act.add("sigmoid", lambda: nn.modules.Sigmoid)
Act.add("tanh", lambda: nn.modules.Tanh)
Act.add("softmax", lambda: nn.modules.Softmax)
Act.add("logsoftmax", lambda: nn.modules.LogSoftmax)
Act.add_factory_callable("elu", lambda: nn.modules.ELU)
Act.add_factory_callable("relu", lambda: nn.modules.ReLU)
Act.add_factory_callable("leakyrelu", lambda: nn.modules.LeakyReLU)
Act.add_factory_callable("prelu", lambda: nn.modules.PReLU)
Act.add_factory_callable("relu6", lambda: nn.modules.ReLU6)
Act.add_factory_callable("selu", lambda: nn.modules.SELU)
Act.add_factory_callable("celu", lambda: nn.modules.CELU)
Act.add_factory_callable("gelu", lambda: nn.modules.GELU)
Act.add_factory_callable("sigmoid", lambda: nn.modules.Sigmoid)
Act.add_factory_callable("tanh", lambda: nn.modules.Tanh)
Act.add_factory_callable("softmax", lambda: nn.modules.Softmax)
Act.add_factory_callable("logsoftmax", lambda: nn.modules.LogSoftmax)


@Act.factory_item("swish")
@Act.factory_function("swish")
def swish_factory():
from monai.networks.blocks.activation import Swish

return Swish


@Act.factory_item("memswish")
@Act.factory_function("memswish")
def memswish_factory():
from monai.networks.blocks.activation import MemoryEfficientSwish

return MemoryEfficientSwish


@Act.factory_item("mish")
@Act.factory_function("mish")
def mish_factory():
from monai.networks.blocks.activation import Mish

return Mish


@Act.factory_item("geglu")
@Act.factory_function("geglu")
def geglu_factory():
from monai.networks.blocks.activation import GEGLU

return GEGLU


@Conv.factory_item("conv")
@Conv.factory_function("conv")
def conv_factory(dim: int) -> type[nn.Conv1d | nn.Conv2d | nn.Conv3d]:
types = (nn.Conv1d, nn.Conv2d, nn.Conv3d)
return types[dim - 1]


@Conv.factory_item("convtrans")
@Conv.factory_function("convtrans")
def convtrans_factory(dim: int) -> type[nn.ConvTranspose1d | nn.ConvTranspose2d | nn.ConvTranspose3d]:
types = (nn.ConvTranspose1d, nn.ConvTranspose2d, nn.ConvTranspose3d)
return types[dim - 1]


@Pool.factory_item("max")
@Pool.factory_function("max")
def maxpooling_factory(dim: int) -> type[nn.MaxPool1d | nn.MaxPool2d | nn.MaxPool3d]:
types = (nn.MaxPool1d, nn.MaxPool2d, nn.MaxPool3d)
return types[dim - 1]


@Pool.factory_item("adaptivemax")
@Pool.factory_function("adaptivemax")
def adaptive_maxpooling_factory(dim: int) -> type[nn.AdaptiveMaxPool1d | nn.AdaptiveMaxPool2d | nn.AdaptiveMaxPool3d]:
types = (nn.AdaptiveMaxPool1d, nn.AdaptiveMaxPool2d, nn.AdaptiveMaxPool3d)
return types[dim - 1]


@Pool.factory_item("avg")
@Pool.factory_function("avg")
def avgpooling_factory(dim: int) -> type[nn.AvgPool1d | nn.AvgPool2d | nn.AvgPool3d]:
types = (nn.AvgPool1d, nn.AvgPool2d, nn.AvgPool3d)
return types[dim - 1]


@Pool.factory_item("adaptiveavg")
@Pool.factory_function("adaptiveavg")
def adaptive_avgpooling_factory(dim: int) -> type[nn.AdaptiveAvgPool1d | nn.AdaptiveAvgPool2d | nn.AdaptiveAvgPool3d]:
types = (nn.AdaptiveAvgPool1d, nn.AdaptiveAvgPool2d, nn.AdaptiveAvgPool3d)
return types[dim - 1]


@Pad.factory_item("replicationpad")
@Pad.factory_function("replicationpad")
def replication_pad_factory(dim: int) -> type[nn.ReplicationPad1d | nn.ReplicationPad2d | nn.ReplicationPad3d]:
types = (nn.ReplicationPad1d, nn.ReplicationPad2d, nn.ReplicationPad3d)
return types[dim - 1]


@Pad.factory_item("constantpad")
@Pad.factory_function("constantpad")
def constant_pad_factory(dim: int) -> type[nn.ConstantPad1d | nn.ConstantPad2d | nn.ConstantPad3d]:
types = (nn.ConstantPad1d, nn.ConstantPad2d, nn.ConstantPad3d)
return types[dim - 1]
40 changes: 27 additions & 13 deletions monai/utils/component_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,7 @@
from collections import namedtuple
from keyword import iskeyword
from textwrap import dedent, indent
from typing import Any, Iterable, TypeVar

from monai.utils.factory import Factory
from typing import Any, Callable, Iterable, TypeVar

T = TypeVar("T")

Expand All @@ -26,7 +24,7 @@ def is_variable(name):
return name.isidentifier() and not iskeyword(name)


class ComponentStore(Factory):
class ComponentStore:
"""
Represents a storage object for other objects (specifically functions) keyed to a name with a description.
Expand Down Expand Up @@ -55,29 +53,45 @@ def _my_func(a, b):
_Component = namedtuple("Component", ("description", "value")) # internal value pair

def __init__(self, name: str, description: str) -> None:
self.factories: dict[str, self._Component] = {}
self.components: dict[str, self._Component] = {}
self.name: str = name
self.description: str = description

self.__doc__ = f"Component Store '{name}': {description}\n{self.__doc__ or ''}".strip()

def add(self, name: str, desc: str, value: T) -> T:
"""Store the object `value` under the name `name` with description `desc`."""

if not is_variable(name):
raise ValueError("Name of component must be valid Python identifier")

self.factories[name] = self._Component(desc, value)
self.components[name] = self._Component(desc, value)
return value

def add_def(self, name: str, desc: str) -> Callable:
"""Returns a decorator which stores the decorated function under `name` with description `desc`."""

def deco(func):
"""Decorator to add a function to a store."""
return self.add(name, desc, func)

return deco

def __contains__(self, name: str) -> bool:
"""Returns True if the given name is stored."""
return name in self.components

def __len__(self) -> int:
"""Returns the number of stored components."""
return len(self.components)

def __iter__(self) -> Iterable:
"""Yields name/component pairs."""
for k, v in self.factories.items():
for k, v in self.components.items():
yield k, v.value

def __str__(self):
result = f"Component Store '{self.name}': {self.description}\nAvailable components:"
for k, v in self.factories.items():
for k, v in self.components.items():
result += f"\n* {k}:"

if hasattr(v.value, "__doc__"):
Expand All @@ -90,14 +104,14 @@ def __str__(self):

def __getattr__(self, name: str) -> Any:
"""Returns the stored object under the given name."""
if name in self.factories:
return self.factories[name].value
if name in self.components:
return self.components[name].value
else:
return self.__getattribute__(name)

def __getitem__(self, name: str) -> Any:
"""Returns the stored object under the given name."""
if name in self.factories:
return self.factories[name].value
if name in self.components:
return self.components[name].value
else:
raise ValueError(f"Component '{name}' not found")
4 changes: 2 additions & 2 deletions tests/test_component_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def test_add2(self):
def test_add_def(self):
self.assertFalse("test_func" in self.cs)

@self.cs.factory_item("test_func", "Test function")
@self.cs.add_def("test_func", "Test function")
def test_func():
return 123

Expand All @@ -66,7 +66,7 @@ def test_func():
self.assertEqual(self.cs["test_func"], test_func)

# try adding the same function again
self.cs.factory_item("test_func", "Test function but with new description")(test_func)
self.cs.add_def("test_func", "Test function but with new description")(test_func)

self.assertEqual(len(self.cs), 1)
self.assertEqual(self.cs.test_func, test_func)

0 comments on commit 8ebd0b6

Please sign in to comment.