Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cache entry point lookups #6124

Merged
merged 18 commits into from
Oct 10, 2023
Merged
Show file tree
Hide file tree
Changes from 17 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions aiida/manage/configuration/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
###########################################################################
"""Module that defines the configuration file of an AiiDA instance and functions to create and load it."""
import codecs
from functools import lru_cache
from functools import cache
from importlib.resources import files
import json
import os
Expand All @@ -28,7 +28,7 @@
SCHEMA_FILE = 'config-v9.schema.json'


@lru_cache(1)
sphuber marked this conversation as resolved.
Show resolved Hide resolved
@cache
def config_schema() -> Dict[str, Any]:
"""Return the configuration schema."""
return json.loads(files(schema_module).joinpath(SCHEMA_FILE).read_text(encoding='utf8'))
Expand Down
51 changes: 24 additions & 27 deletions aiida/manage/tests/pytest_fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@

import asyncio
import contextlib
import copy
import inspect
import io
import os
Expand All @@ -33,6 +32,7 @@
import uuid
import warnings

from importlib_metadata import EntryPoints
import plumpy
import pytest
import wrapt
Expand Down Expand Up @@ -759,9 +759,16 @@ def suppress_deprecations(wrapped, _, args, kwargs):
class EntryPointManager:
"""Manager to temporarily add or remove entry points."""

@staticmethod
def eps():
return plugins.entry_point.eps()
def __init__(self, entry_points: EntryPoints):
self.entry_points = entry_points

def eps(self) -> EntryPoints:
return self.entry_points

def eps_select(self, group, name=None) -> EntryPoints:
if name is None:
return self.eps().select(group=group)
return self.eps().select(group=group, name=name)

@staticmethod
def _validate_entry_point(entry_point_string: str | None, group: str | None, name: str | None) -> tuple[str, str]:
Expand Down Expand Up @@ -791,7 +798,6 @@ def _validate_entry_point(entry_point_string: str | None, group: str | None, nam

return group, name

@suppress_deprecations
def add(
self,
value: type | str,
Expand All @@ -817,9 +823,8 @@ def add(

group, name = self._validate_entry_point(entry_point_string, group, name)
entry_point = plugins.entry_point.EntryPoint(name, value, group)
self.eps()[group].append(entry_point)
self.entry_points = EntryPoints(self.entry_points + (entry_point,))

@suppress_deprecations
def remove(
self, entry_point_string: str | None = None, *, name: str | None = None, group: str | None = None
) -> None:
Expand All @@ -835,31 +840,23 @@ def remove(
:raises ValueError: If `entry_point_string` is not a complete entry point string with group and name.
"""
group, name = self._validate_entry_point(entry_point_string, group, name)

for entry_point in self.eps()[group]:
if entry_point.name == name:
self.eps()[group].remove(entry_point)
break
else:
try:
self.entry_points[name]
except KeyError:
raise KeyError(f'entry point `{name}` does not exist in group `{group}`.')
self.entry_points = EntryPoints((ep for ep in self.entry_points if not (ep.name == name and ep.group == group)))


@pytest.fixture
def entry_points(monkeypatch) -> EntryPointManager:
"""Return an instance of the ``EntryPointManager`` which allows to temporarily add or remove entry points.

This fixture creates a deep copy of the entry point cache returned by the :func:`aiida.plugins.entry_point.eps`
method and then monkey patches that function to return the deepcopy. This ensures that the changes on the entry
point cache performed during the test through the manager are undone at the end of the function scope.

.. note:: This fixture does not use the ``suppress_deprecations`` decorator on purpose, but instead adds it manually
inside the fixture's body. The reason is that otherwise all deprecations would be suppressed for the entire
scope of the fixture, including those raised by the code run in the test using the fixture, which is not
desirable.

This fixture monkey patches the entry point caches returned by
the :func:`aiida.plugins.entry_point.eps` and :func:`aiida.plugins.entry_point.eps_select` functions
to class methods of the ``EntryPointManager`` so that we can dynamically add / remove entry points.
Note that we do not need a deepcopy here as ``eps()`` returns an immutable ``EntryPoints`` tuple type.
"""
with warnings.catch_warnings():
warnings.filterwarnings('ignore', category=DeprecationWarning)
eps_copy = copy.deepcopy(plugins.entry_point.eps())
monkeypatch.setattr(plugins.entry_point, 'eps', lambda: eps_copy)
yield EntryPointManager()
epm = EntryPointManager(plugins.entry_point.eps())
monkeypatch.setattr(plugins.entry_point, 'eps', epm.eps)
monkeypatch.setattr(plugins.entry_point, 'eps_select', epm.eps_select)
yield epm
2 changes: 1 addition & 1 deletion aiida/orm/autogroup.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def validate(strings: list[str] | None):
"""Validate the list of strings passed to set_include and set_exclude."""
if strings is None:
return
valid_prefixes = set(['aiida.node', 'aiida.calculations', 'aiida.workflows', 'aiida.data'])
danielhollas marked this conversation as resolved.
Show resolved Hide resolved
valid_prefixes = {'aiida.node', 'aiida.calculations', 'aiida.workflows', 'aiida.data'}
for string in strings:
pieces = string.split(':')
if len(pieces) != 2:
Expand Down
47 changes: 34 additions & 13 deletions aiida/plugins/entry_point.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
# For further information please visit http://www.aiida.net #
###########################################################################
"""Module to manage loading entrypoints."""
from __future__ import annotations

import enum
import functools
import traceback
Expand All @@ -30,9 +32,30 @@
ENTRY_POINT_STRING_SEPARATOR = ':'


@functools.lru_cache(maxsize=1)
def eps():
return _eps()
@functools.cache
def eps() -> EntryPoints:
"""Cache around entry_points()

This call takes around 50ms!
NOTE: For faster lookups, we sort the ``EntryPoints`` alphabetically
by the group name so that 'aiida.' groups come up first.
Unfortunately, this does not help with the entry_points.select() filter,
which will always iterate over all entry points since it looks for
possible duplicate entries.
"""
entry_points = _eps()
return EntryPoints(sorted(entry_points, key=lambda x: x.group))


@functools.lru_cache(maxsize=100)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

interesting, so do I understand correctly that eps() is cached, nothing is being read from disk here, and processing the few entry points in memory is still taking 25ms?

Can you please document this function to explain the quadruple loop problem here, and why this additional cache is necessary?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I'll write a comment. You can look at the cProfile info on the forum. https://aiida.discourse.group/t/why-is-aiidalab-base-widget-import-so-slow/32/16?u=danielhollas

Together with the fact that we were calling this during each Node class build (in a metaclass), this completely explained why the aiida.orm import was so slow. See #6091

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are we worried about the size of the cache? I think the number of different calls to eps_select should be reasonable, not exceeding on the order of 100. So wouldn't we be better of with lru_cache(maxsize=None) i.e. cache which will be faster. Not sure how much faster it will be compared to an LRU cache of max 100 items. Might be negligible

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am actually worried. I wouldn't be surprised if the number of calls was bigger then 100, especially if plugins are installed, since in some functions we're essentially iterating over all existing entry points when looking for the entry point for a given class. I'll take a closer look and do some more benchmarking.

Copy link
Member

@ltalirz ltalirz Sep 22, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are we doing this iteration or is it importlib internally?

Are we trying to find an entry point for a class without knowing the name of the entry point?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

since in some functions we're essentially iterating over all existing entry points when looking for the entry point for a given class. I'll take a closer look and do some more benchmarking.

These don't call this eps_select function though, do they? The cache here simply applies to the number of combinations of arguments with which it is called. Since it just has group and name, it should just be the list of all (group, name) tuples with which the function is called. This should, reasonably, not be much larger than the entry points that exist.

def eps_select(group: str, name: str | None = None) -> EntryPoints:
"""
A thin wrapper around entry_points.select() calls, which are
expensive so we want to cache them.
"""
if name is None:
return eps().select(group=group)
return eps().select(group=group, name=name)


class EntryPointFormat(enum.Enum):
Expand Down Expand Up @@ -254,8 +277,7 @@ def get_entry_point_groups() -> Set[str]:

def get_entry_point_names(group: str, sort: bool = True) -> List[str]:
"""Return the entry points within a group."""
all_eps = eps()
group_names = list(all_eps.select(group=group).names)
group_names = list(get_entry_points(group).names)
if sort:
return sorted(group_names)
return group_names
Expand All @@ -268,7 +290,7 @@ def get_entry_points(group: str) -> EntryPoints:
:param group: the entry point group
:return: a list of entry points
"""
return eps().select(group=group)
return eps_select(group=group)


def get_entry_point(group: str, name: str) -> EntryPoint:
Expand All @@ -283,7 +305,7 @@ def get_entry_point(group: str, name: str) -> EntryPoint:
"""
# The next line should be removed for ``aiida-core==3.0`` when the old deprecated entry points are fully removed.
name = convert_potentially_deprecated_entry_point(group, name)
found = eps().select(group=group, name=name)
found = eps_select(group=group, name=name)
if name not in found.names:
raise MissingEntryPointError(f"Entry point '{name}' not found in group '{group}'")
# If multiple entry points are found and they have different values we raise, otherwise if they all
Expand Down Expand Up @@ -326,14 +348,13 @@ def get_entry_point_from_class(class_module: str, class_name: str) -> Tuple[Opti
:param class_name: name of the class
:return: a tuple of the corresponding group and entry point or None if not found
"""
for group in get_entry_point_groups():
for entry_point in get_entry_points(group):
for entry_point in eps():

if entry_point.module != class_module:
continue
if entry_point.module != class_module:
continue

if entry_point.attr == class_name:
return group, entry_point
if entry_point.attr == class_name:
return entry_point.group, entry_point

return None, None

Expand Down
2 changes: 1 addition & 1 deletion environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ dependencies:
- jinja2~=3.0
- jsonschema~=3.0
- kiwipy[rmq]~=0.7.7
- importlib-metadata~=4.13
- importlib-metadata~=6.0
danielhollas marked this conversation as resolved.
Show resolved Hide resolved
- numpy~=1.21
- paramiko>=2.7.2,~=2.7
- plumpy~=0.21.6
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ dependencies = [
"jinja2~=3.0",
"jsonschema~=3.0",
"kiwipy[rmq]~=0.7.7",
"importlib-metadata~=4.13",
"importlib-metadata~=6.0",
"numpy~=1.21",
"paramiko~=2.7,>=2.7.2",
"plumpy~=0.21.6",
Expand Down
2 changes: 1 addition & 1 deletion requirements/requirements-py-3.10.txt
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ graphviz==0.20.1
greenlet==2.0.2
idna==3.4
imagesize==1.4.1
importlib-metadata==4.13.0
importlib-metadata==6.8.0
iniconfig==2.0.0
ipykernel==6.23.2
ipython==8.14.0
Expand Down
2 changes: 1 addition & 1 deletion requirements/requirements-py-3.11.txt
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ graphviz==0.20.1
greenlet==2.0.2
idna==3.4
imagesize==1.4.1
importlib-metadata==4.13.0
importlib-metadata==6.8.0
iniconfig==2.0.0
ipykernel==6.23.2
ipython==8.14.0
Expand Down
2 changes: 1 addition & 1 deletion requirements/requirements-py-3.9.txt
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ graphviz==0.20.1
greenlet==2.0.2
idna==3.4
imagesize==1.4.1
importlib-metadata==4.13.0
importlib-metadata==6.8.0
importlib-resources==5.12.0
iniconfig==2.0.0
ipykernel==6.23.2
Expand Down
9 changes: 4 additions & 5 deletions tests/plugins/test_entry_point.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,10 +72,10 @@ def select(group, name): # pylint: disable=unused-argument

@pytest.mark.parametrize(
'eps, name, exception', (
((EP(name='ep', group='gr', value='x'),), None, None),
ltalirz marked this conversation as resolved.
Show resolved Hide resolved
((EP(name='ep', group='gr', value='x'),), 'ep', None),
((EP(name='ep', group='gr', value='x'),), 'non-existing', MissingEntryPointError),
((EP(name='ep', group='gr', value='x'), EP(name='ep', group='gr', value='y')), None, MultipleEntryPointError),
((EP(name='ep', group='gr', value='x'), EP(name='ep', group='gr', value='x')), None, None),
((EP(name='ep', group='gr', value='x'), EP(name='ep', group='gr', value='y')), 'ep', MultipleEntryPointError),
((EP(name='ep', group='gr', value='x'), EP(name='ep', group='gr', value='x')), 'ep', None),
),
indirect=['eps']
)
Expand All @@ -91,8 +91,7 @@ def test_get_entry_point(eps, name, exception, monkeypatch):

"""
monkeypatch.setattr(entry_point, 'eps', eps)

name = name or 'ep' # Try to load the entry point with name ``ep`` unless the fixture provides one
entry_point.eps_select.cache_clear()

if exception:
with pytest.raises(exception):
Expand Down
Loading