Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cache entry point lookups #6124

Merged
merged 18 commits into from
Oct 10, 2023
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions aiida/manage/configuration/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
###########################################################################
"""Module that defines the configuration file of an AiiDA instance and functions to create and load it."""
import codecs
from functools import lru_cache
from functools import cache
from importlib.resources import files
import json
import os
Expand All @@ -28,7 +28,7 @@
SCHEMA_FILE = 'config-v9.schema.json'


@lru_cache(1)
sphuber marked this conversation as resolved.
Show resolved Hide resolved
@cache
def config_schema() -> Dict[str, Any]:
"""Return the configuration schema."""
return json.loads(files(schema_module).joinpath(SCHEMA_FILE).read_text(encoding='utf8'))
Expand Down
1 change: 1 addition & 0 deletions aiida/manage/tests/pytest_fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -862,4 +862,5 @@ def entry_points(monkeypatch) -> EntryPointManager:
warnings.filterwarnings('ignore', category=DeprecationWarning)
eps_copy = copy.deepcopy(plugins.entry_point.eps())
monkeypatch.setattr(plugins.entry_point, 'eps', lambda: eps_copy)
plugins.entry_point.eps_select.cache_clear()
yield EntryPointManager()
16 changes: 11 additions & 5 deletions aiida/plugins/entry_point.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,18 @@
ENTRY_POINT_STRING_SEPARATOR = ':'


@functools.lru_cache(maxsize=1)
@functools.cache
def eps():
return _eps()


@functools.lru_cache(maxsize=100)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

interesting, so do I understand correctly that eps() is cached, nothing is being read from disk here, and processing the few entry points in memory is still taking 25ms?

Can you please document this function to explain the quadruple loop problem here, and why this additional cache is necessary?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I'll write a comment. You can look at the cProfile info on the forum. https://aiida.discourse.group/t/why-is-aiidalab-base-widget-import-so-slow/32/16?u=danielhollas

Together with the fact that we were calling this during each Node class build (in a metaclass), this completely explained why the aiida.orm import was so slow. See #6091

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are we worried about the size of the cache? I think the number of different calls to eps_select should be reasonable, not exceeding on the order of 100. So wouldn't we be better of with lru_cache(maxsize=None) i.e. cache which will be faster. Not sure how much faster it will be compared to an LRU cache of max 100 items. Might be negligible

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am actually worried. I wouldn't be surprised if the number of calls was bigger then 100, especially if plugins are installed, since in some functions we're essentially iterating over all existing entry points when looking for the entry point for a given class. I'll take a closer look and do some more benchmarking.

Copy link
Member

@ltalirz ltalirz Sep 22, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are we doing this iteration or is it importlib internally?

Are we trying to find an entry point for a class without knowing the name of the entry point?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

since in some functions we're essentially iterating over all existing entry points when looking for the entry point for a given class. I'll take a closer look and do some more benchmarking.

These don't call this eps_select function though, do they? The cache here simply applies to the number of combinations of arguments with which it is called. Since it just has group and name, it should just be the list of all (group, name) tuples with which the function is called. This should, reasonably, not be much larger than the entry points that exist.

def eps_select(group, name=None) -> EntryPoints:
danielhollas marked this conversation as resolved.
Show resolved Hide resolved
if name is None:
return eps().select(group=group)
return eps().select(group=group, name=name)


class EntryPointFormat(enum.Enum):
"""
Enum to distinguish between the various possible entry point string formats. An entry point string
Expand Down Expand Up @@ -254,8 +261,7 @@ def get_entry_point_groups() -> Set[str]:

def get_entry_point_names(group: str, sort: bool = True) -> List[str]:
"""Return the entry points within a group."""
all_eps = eps()
group_names = list(all_eps.select(group=group).names)
group_names = list(get_entry_points(group).names)
if sort:
return sorted(group_names)
return group_names
Expand All @@ -268,7 +274,7 @@ def get_entry_points(group: str) -> EntryPoints:
:param group: the entry point group
:return: a list of entry points
"""
return eps().select(group=group)
return eps_select(group=group)


def get_entry_point(group: str, name: str) -> EntryPoint:
Expand All @@ -283,7 +289,7 @@ def get_entry_point(group: str, name: str) -> EntryPoint:
"""
# The next line should be removed for ``aiida-core==3.0`` when the old deprecated entry points are fully removed.
name = convert_potentially_deprecated_entry_point(group, name)
found = eps().select(group=group, name=name)
found = eps_select(group=group, name=name)
if name not in found.names:
raise MissingEntryPointError(f"Entry point '{name}' not found in group '{group}'")
# If multiple entry points are found and they have different values we raise, otherwise if they all
Expand Down
9 changes: 4 additions & 5 deletions tests/plugins/test_entry_point.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,10 +72,10 @@ def select(group, name): # pylint: disable=unused-argument

@pytest.mark.parametrize(
'eps, name, exception', (
((EP(name='ep', group='gr', value='x'),), None, None),
ltalirz marked this conversation as resolved.
Show resolved Hide resolved
((EP(name='ep', group='gr', value='x'),), 'ep', None),
((EP(name='ep', group='gr', value='x'),), 'non-existing', MissingEntryPointError),
((EP(name='ep', group='gr', value='x'), EP(name='ep', group='gr', value='y')), None, MultipleEntryPointError),
((EP(name='ep', group='gr', value='x'), EP(name='ep', group='gr', value='x')), None, None),
((EP(name='ep', group='gr', value='x'), EP(name='ep', group='gr', value='y')), 'ep', MultipleEntryPointError),
((EP(name='ep', group='gr', value='x'), EP(name='ep', group='gr', value='x')), 'ep', None),
),
indirect=['eps']
)
Expand All @@ -91,8 +91,7 @@ def test_get_entry_point(eps, name, exception, monkeypatch):

"""
monkeypatch.setattr(entry_point, 'eps', eps)

name = name or 'ep' # Try to load the entry point with name ``ep`` unless the fixture provides one
entry_point.eps_select.cache_clear()

if exception:
with pytest.raises(exception):
Expand Down