Skip to content

Commit

Permalink
refactor: Allow passing multiple extensions to load_extensions inst…
Browse files Browse the repository at this point in the history
…ead of a sequence

Issue-268: #268
  • Loading branch information
pawamoy committed Jun 8, 2024
1 parent c4e3bf2 commit fadb72b
Show file tree
Hide file tree
Showing 5 changed files with 35 additions and 16 deletions.
4 changes: 2 additions & 2 deletions benchmark.sh
Original file line number Diff line number Diff line change
Expand Up @@ -39,12 +39,12 @@ from griffe.loader import GriffeLoader
from griffe.extensions import load_extensions
from griffe.extensions import Extension
stdlib_packages = sorted([m for m in sys.stdlib_module_names if not m.startswith("_")])
# extensions = load_extensions([
# extensions = load_extensions(
# Extension, Extension, Extension, Extension,
# Extension, Extension, Extension, Extension,
# Extension, Extension, Extension, Extension,
# Extension, Extension, Extension, Extension,
# ])
# )
extensions = None
loader = GriffeLoader(allow_inspection=False, extensions=extensions)
for package in stdlib_packages:
Expand Down
12 changes: 5 additions & 7 deletions docs/extensions.md
Original file line number Diff line number Diff line change
Expand Up @@ -87,13 +87,11 @@ import griffe
from mypackage.extensions import ThisExtension, ThisOtherExtension

extensions = griffe.load_extensions(
[
{"pydantic": {"schema": true}},
{"scripts/exts.py:DynamicDocstrings": {"paths": ["mypkg.mymod.myobj"]}},
"griffe_attrs",
ThisExtension(option="value"),
ThisOtherExtension,
]
{"pydantic": {"schema": true}},
{"scripts/exts.py:DynamicDocstrings": {"paths": ["mypkg.mymod.myobj"]}},
"griffe_attrs",
ThisExtension(option="value"),
ThisOtherExtension,
)

data = griffe.load("mypackage", extensions=extensions)
Expand Down
4 changes: 2 additions & 2 deletions src/griffe/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,7 +376,7 @@ def dump(
search_paths.extend(sys.path)

try:
loaded_extensions = load_extensions(extensions or ())
loaded_extensions = load_extensions(*(extensions or ()))
except ExtensionError as error:
logger.exception(str(error)) # noqa: TRY401
return 1
Expand Down Expand Up @@ -464,7 +464,7 @@ def check(
repository = get_repo_root(against_path)

try:
loaded_extensions = load_extensions(extensions or ())
loaded_extensions = load_extensions(*(extensions or ()))
except ExtensionError as error:
logger.exception(str(error)) # noqa: TRY401
return 1
Expand Down
27 changes: 24 additions & 3 deletions src/griffe/extensions/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from importlib.util import module_from_spec, spec_from_file_location
from inspect import isclass
from pathlib import Path
from typing import TYPE_CHECKING, Any, Sequence, Union
from typing import TYPE_CHECKING, Any, Dict, Sequence, Type, Union

from griffe.agents.nodes import ast_children, ast_kind
from griffe.enumerations import When
Expand Down Expand Up @@ -236,6 +236,7 @@ def on_package_loaded(self, *, pkg: Module) -> None:


ExtensionType = Union[VisitorExtension, InspectorExtension, Extension]
"""All the types that can be passed to `Extensions.add`."""


class Extensions:
Expand Down Expand Up @@ -455,8 +456,13 @@ def _load_extension(
return [ext(**options) for ext in extensions]


LoadableExtension = Union[str, Dict[str, Any], ExtensionType, Type[ExtensionType]]
"""All the types that can be passed to `load_extensions`."""


def load_extensions(
exts: Sequence[str | dict[str, Any] | ExtensionType | type[ExtensionType]] | None = None,
# TODO: Only accept LoadableExtension at some point.
*exts: LoadableExtension | Sequence[LoadableExtension],
) -> Extensions:
"""Load configured extensions.
Expand All @@ -467,7 +473,22 @@ def load_extensions(
An extensions container.
"""
extensions = Extensions()
for extension in exts or ():

# TODO: Remove at some point.
all_exts: list[LoadableExtension] = []
for ext in exts:
if isinstance(ext, (list, tuple)):
warnings.warn(
"Passing multiple extensions as a single list or tuple is deprecated. "
"Please pass them as separate arguments instead.",
DeprecationWarning,
stacklevel=2,
)
all_exts.extend(ext)
else:
all_exts.append(ext) # type: ignore[arg-type]

for extension in all_exts:
ext = _load_extension(extension)
if isinstance(ext, list):
extensions.add(*ext)
Expand Down
4 changes: 2 additions & 2 deletions tests/test_extensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def test_loading_extensions(extension: str | dict[str, dict[str, Any]] | Extensi
Parameters:
extension: Extension specification (parametrized).
"""
extensions = load_extensions([extension])
extensions = load_extensions(extension)
loaded: ExtensionTest = extensions._extensions[0] # type: ignore[assignment]
# We cannot use isinstance here,
# because loading from a filepath drops the parent `tests` package,
Expand All @@ -115,7 +115,7 @@ class Class:
cattr = 1
def method(self): ...
""",
extensions=load_extensions([extension]),
extensions=load_extensions(extension),
):
pass
events = [
Expand Down

0 comments on commit fadb72b

Please sign in to comment.