Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] 7145 common factory class #7159

Merged
merged 17 commits into from
Oct 27, 2023
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 41 additions & 56 deletions monai/networks/layers/factories.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,17 +63,17 @@ def use_factory(fact_args):

import warnings
from collections.abc import Callable
from typing import Any
from typing import Any, Iterable

import torch.nn as nn

from monai.networks.utils import has_nvfuser_instance_norm
from monai.utils import look_up_option, optional_import
from monai.utils import Factory, look_up_option, optional_import

__all__ = ["LayerFactory", "Dropout", "Norm", "Act", "Conv", "Pool", "Pad", "split_args"]


class LayerFactory:
class LayerFactory(Factory):
"""
Factory object for creating layers, this uses given factory functions to actually produce the types or constructing
callables. These functions are referred to by name and can be added at any time.
Expand All @@ -82,15 +82,7 @@ class LayerFactory:
def __init__(self) -> None:
self.factories: dict[str, Callable] = {}

@property
def names(self) -> tuple[str, ...]:
"""
Produces all factory names.
"""

return tuple(self.factories)

def add_factory_callable(self, name: str, func: Callable) -> None:
marksgraham marked this conversation as resolved.
Show resolved Hide resolved
def add(self, name: str, func: Callable) -> None:
"""
Add the factory function to this object under the given name.
"""
Expand All @@ -103,17 +95,6 @@ def add_factory_callable(self, name: str, func: Callable) -> None:
+ ".\nPlease see :py:class:`monai.networks.layers.split_args` for additional args parsing."
)

def factory_function(self, name: str) -> Callable:
marksgraham marked this conversation as resolved.
Show resolved Hide resolved
"""
Decorator for adding a factory function with the given name.
"""

def _add(func: Callable) -> Callable:
self.add_factory_callable(name, func)
return func

return _add

def get_constructor(self, factory_name: str, *args) -> Any:
"""
Get the constructor for the given factory name and arguments.
Expand Down Expand Up @@ -158,6 +139,10 @@ def __getattr__(self, key):

return super().__getattribute__(key)

def __iter__(self) -> Iterable:
"""Yields name/component pairs."""
yield from self.factories.items()


def split_args(args):
"""
Expand Down Expand Up @@ -203,50 +188,50 @@ def split_args(args):
Pad = LayerFactory()


@Dropout.factory_function("dropout")
@Dropout.factory_item("dropout")
def dropout_factory(dim: int) -> type[nn.Dropout | nn.Dropout2d | nn.Dropout3d]:
types = (nn.Dropout, nn.Dropout2d, nn.Dropout3d)
return types[dim - 1]


@Dropout.factory_function("alphadropout")
@Dropout.factory_item("alphadropout")
def alpha_dropout_factory(_dim):
return nn.AlphaDropout


@Norm.factory_function("instance")
@Norm.factory_item("instance")
def instance_factory(dim: int) -> type[nn.InstanceNorm1d | nn.InstanceNorm2d | nn.InstanceNorm3d]:
types = (nn.InstanceNorm1d, nn.InstanceNorm2d, nn.InstanceNorm3d)
return types[dim - 1]


@Norm.factory_function("batch")
@Norm.factory_item("batch")
def batch_factory(dim: int) -> type[nn.BatchNorm1d | nn.BatchNorm2d | nn.BatchNorm3d]:
types = (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d)
return types[dim - 1]


@Norm.factory_function("group")
@Norm.factory_item("group")
def group_factory(_dim) -> type[nn.GroupNorm]:
return nn.GroupNorm


@Norm.factory_function("layer")
@Norm.factory_item("layer")
def layer_factory(_dim) -> type[nn.LayerNorm]:
return nn.LayerNorm


@Norm.factory_function("localresponse")
@Norm.factory_item("localresponse")
def local_response_factory(_dim) -> type[nn.LocalResponseNorm]:
return nn.LocalResponseNorm


@Norm.factory_function("syncbatch")
@Norm.factory_item("syncbatch")
def sync_batch_factory(_dim) -> type[nn.SyncBatchNorm]:
return nn.SyncBatchNorm


@Norm.factory_function("instance_nvfuser")
@Norm.factory_item("instance_nvfuser")
def instance_nvfuser_factory(dim):
"""
`InstanceNorm3dNVFuser` is a faster version of InstanceNorm layer and implemented in `apex`.
Expand Down Expand Up @@ -274,91 +259,91 @@ def instance_nvfuser_factory(dim):
return optional_import("apex.normalization", name="InstanceNorm3dNVFuser")[0]


Act.add_factory_callable("elu", lambda: nn.modules.ELU)
Act.add_factory_callable("relu", lambda: nn.modules.ReLU)
Act.add_factory_callable("leakyrelu", lambda: nn.modules.LeakyReLU)
Act.add_factory_callable("prelu", lambda: nn.modules.PReLU)
Act.add_factory_callable("relu6", lambda: nn.modules.ReLU6)
Act.add_factory_callable("selu", lambda: nn.modules.SELU)
Act.add_factory_callable("celu", lambda: nn.modules.CELU)
Act.add_factory_callable("gelu", lambda: nn.modules.GELU)
Act.add_factory_callable("sigmoid", lambda: nn.modules.Sigmoid)
Act.add_factory_callable("tanh", lambda: nn.modules.Tanh)
Act.add_factory_callable("softmax", lambda: nn.modules.Softmax)
Act.add_factory_callable("logsoftmax", lambda: nn.modules.LogSoftmax)
Act.add("elu", lambda: nn.modules.ELU)
wyli marked this conversation as resolved.
Show resolved Hide resolved
Act.add("relu", lambda: nn.modules.ReLU)
Act.add("leakyrelu", lambda: nn.modules.LeakyReLU)
Act.add("prelu", lambda: nn.modules.PReLU)
Act.add("relu6", lambda: nn.modules.ReLU6)
Act.add("selu", lambda: nn.modules.SELU)
Act.add("celu", lambda: nn.modules.CELU)
Act.add("gelu", lambda: nn.modules.GELU)
Act.add("sigmoid", lambda: nn.modules.Sigmoid)
Act.add("tanh", lambda: nn.modules.Tanh)
Act.add("softmax", lambda: nn.modules.Softmax)
Act.add("logsoftmax", lambda: nn.modules.LogSoftmax)


@Act.factory_function("swish")
@Act.factory_item("swish")
def swish_factory():
from monai.networks.blocks.activation import Swish

return Swish


@Act.factory_function("memswish")
@Act.factory_item("memswish")
def memswish_factory():
from monai.networks.blocks.activation import MemoryEfficientSwish

return MemoryEfficientSwish


@Act.factory_function("mish")
@Act.factory_item("mish")
def mish_factory():
from monai.networks.blocks.activation import Mish

return Mish


@Act.factory_function("geglu")
@Act.factory_item("geglu")
def geglu_factory():
from monai.networks.blocks.activation import GEGLU

return GEGLU


@Conv.factory_function("conv")
@Conv.factory_item("conv")
def conv_factory(dim: int) -> type[nn.Conv1d | nn.Conv2d | nn.Conv3d]:
types = (nn.Conv1d, nn.Conv2d, nn.Conv3d)
return types[dim - 1]


@Conv.factory_function("convtrans")
@Conv.factory_item("convtrans")
def convtrans_factory(dim: int) -> type[nn.ConvTranspose1d | nn.ConvTranspose2d | nn.ConvTranspose3d]:
types = (nn.ConvTranspose1d, nn.ConvTranspose2d, nn.ConvTranspose3d)
return types[dim - 1]


@Pool.factory_function("max")
@Pool.factory_item("max")
def maxpooling_factory(dim: int) -> type[nn.MaxPool1d | nn.MaxPool2d | nn.MaxPool3d]:
types = (nn.MaxPool1d, nn.MaxPool2d, nn.MaxPool3d)
return types[dim - 1]


@Pool.factory_function("adaptivemax")
@Pool.factory_item("adaptivemax")
def adaptive_maxpooling_factory(dim: int) -> type[nn.AdaptiveMaxPool1d | nn.AdaptiveMaxPool2d | nn.AdaptiveMaxPool3d]:
types = (nn.AdaptiveMaxPool1d, nn.AdaptiveMaxPool2d, nn.AdaptiveMaxPool3d)
return types[dim - 1]


@Pool.factory_function("avg")
@Pool.factory_item("avg")
def avgpooling_factory(dim: int) -> type[nn.AvgPool1d | nn.AvgPool2d | nn.AvgPool3d]:
types = (nn.AvgPool1d, nn.AvgPool2d, nn.AvgPool3d)
return types[dim - 1]


@Pool.factory_function("adaptiveavg")
@Pool.factory_item("adaptiveavg")
def adaptive_avgpooling_factory(dim: int) -> type[nn.AdaptiveAvgPool1d | nn.AdaptiveAvgPool2d | nn.AdaptiveAvgPool3d]:
types = (nn.AdaptiveAvgPool1d, nn.AdaptiveAvgPool2d, nn.AdaptiveAvgPool3d)
return types[dim - 1]


@Pad.factory_function("replicationpad")
@Pad.factory_item("replicationpad")
def replication_pad_factory(dim: int) -> type[nn.ReplicationPad1d | nn.ReplicationPad2d | nn.ReplicationPad3d]:
types = (nn.ReplicationPad1d, nn.ReplicationPad2d, nn.ReplicationPad3d)
return types[dim - 1]


@Pad.factory_function("constantpad")
@Pad.factory_item("constantpad")
def constant_pad_factory(dim: int) -> type[nn.ConstantPad1d | nn.ConstantPad2d | nn.ConstantPad3d]:
types = (nn.ConstantPad1d, nn.ConstantPad2d, nn.ConstantPad3d)
return types[dim - 1]
2 changes: 2 additions & 0 deletions monai/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

# have to explicitly bring these in here to resolve circular import issues
from .aliases import alias, resolve_name
from .component_store import ComponentStore
from .decorators import MethodReplacer, RestartGenerator
from .deprecate_utils import DeprecatedError, deprecated, deprecated_arg, deprecated_arg_default
from .dist import RankFilter, evenly_divisible_all_gather, get_dist_device, string_list_all_gather
Expand Down Expand Up @@ -61,6 +62,7 @@
Weight,
WSIPatchKeys,
)
from .factory import Factory
from .jupyter_utils import StatusMembers, ThreadContainer
from .misc import (
MAX_SEED,
Expand Down
103 changes: 103 additions & 0 deletions monai/utils/component_store.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

from collections import namedtuple
from keyword import iskeyword
from textwrap import dedent, indent
from typing import Any, Iterable, TypeVar

from monai.utils.factory import Factory

T = TypeVar("T")


def is_variable(name):
"""Returns True if `name` is a valid Python variable name and also not a keyword."""
return name.isidentifier() and not iskeyword(name)


class ComponentStore(Factory):
"""
Represents a storage object for other objects (specifically functions) keyed to a name with a description.

These objects act as global named places for storing components for objects parameterised by component names.
Typically this is functions although other objects can be added. Printing a component store will produce a
list of members along with their docstring information if present.

Example:

.. code-block:: python

TestStore = ComponentStore("Test Store", "A test store for demo purposes")

@TestStore.add_def("my_func_name", "Some description of your function")
def _my_func(a, b):
'''A description of your function here.'''
return a * b

print(TestStore) # will print out name, description, and 'my_func_name' with the docstring

func = TestStore["my_func_name"]
result = func(7, 6)

"""

_Component = namedtuple("Component", ("description", "value")) # internal value pair

def __init__(self, name: str, description: str) -> None:
self.factories: dict[str, self._Component] = {}
self.name: str = name
self.description: str = description

self.__doc__ = f"Component Store '{name}': {description}\n{self.__doc__ or ''}".strip()

def add(self, name: str, desc: str, value: T) -> T:
"""Store the object `value` under the name `name` with description `desc`."""

if not is_variable(name):
raise ValueError("Name of component must be valid Python identifier")

self.factories[name] = self._Component(desc, value)
return value

def __iter__(self) -> Iterable:
"""Yields name/component pairs."""
for k, v in self.factories.items():
yield k, v.value

def __str__(self):
result = f"Component Store '{self.name}': {self.description}\nAvailable components:"
for k, v in self.factories.items():
result += f"\n* {k}:"

if hasattr(v.value, "__doc__"):
doc = indent(dedent(v.value.__doc__.lstrip("\n").rstrip()), " ")
result += f"\n{doc}\n"
else:
result += f" {v.description}"

return result

def __getattr__(self, name: str) -> Any:
"""Returns the stored object under the given name."""
if name in self.factories:
return self.factories[name].value
else:
return self.__getattribute__(name)

def __getitem__(self, name: str) -> Any:
"""Returns the stored object under the given name."""
if name in self.factories:
return self.factories[name].value
else:
raise ValueError(f"Component '{name}' not found")
Loading
Loading