diff --git a/importlib_metadata/__init__.py b/importlib_metadata/__init__.py index 95087b55..67c925d4 100644 --- a/importlib_metadata/__init__.py +++ b/importlib_metadata/__init__.py @@ -600,9 +600,15 @@ class FastPath: children. """ - def __init__(self, root): + @functools.lru_cache() # type: ignore + def __new__(cls, root): + self = object().__new__(cls) self.root = str(root) self.base = os.path.basename(self.root).lower() + self.last_mtime = -1 + self.infos = {} + self.eggs = {} + return self def joinpath(self, child): return pathlib.Path(self.root, child) @@ -618,15 +624,47 @@ def zip_children(self): zip_path = zipp.Path(self.root) names = zip_path.root.namelist() self.joinpath = zip_path.joinpath - return dict.fromkeys(child.split(posixpath.sep, 1)[0] for child in names) - def search(self, name): - return ( - self.joinpath(child) - for child in self.children() - if name.matches(child, self.base) - ) + def update_cache(self): + root = self.root or "." + try: + mtime = os.stat(root).st_mtime + except OSError: + self.infos.clear() + self.eggs.clear() + self.last_mtime = -1 + return + if mtime == self.last_mtime: + return + self.infos.clear() + self.eggs.clear() + base_is_egg = self.base.endswith(".egg") + for child in self.children(): + low = child.lower() + if low.endswith((".dist-info", ".egg-info")): + # rpartition is faster than splitext and suitable for this purpose. + name = low.rpartition(".")[0].partition("-")[0] + normalized = Prepared.normalize(name) + self.infos.setdefault(normalized, []).append(child) + elif base_is_egg and low == "egg-info": + name = self.base.rpartition(".")[0].partition("-")[0] + legacy_normalized = Prepared.legacy_normalize(name) + self.eggs.setdefault(legacy_normalized, []).append(child) + self.last_mtime = mtime + + def search(self, prepared): + self.update_cache() + if prepared.name: + infos = self.infos.get(prepared.normalized, []) + yield from map(self.joinpath, infos) + eggs = self.eggs.get(prepared.legacy_normalized, []) + yield from map(self.joinpath, eggs) + else: + for infos in self.infos.values(): + yield from map(self.joinpath, infos) + for eggs in self.eggs.values(): + yield from map(self.joinpath, eggs) class Prepared: @@ -635,22 +673,14 @@ class Prepared: """ normalized = None - suffixes = 'dist-info', 'egg-info' - exact_matches = [''][:0] - egg_prefix = '' - versionless_egg_name = '' + legacy_normalized = None def __init__(self, name): self.name = name if name is None: return self.normalized = self.normalize(name) - self.exact_matches = [ - self.normalized + '.' + suffix for suffix in self.suffixes - ] - legacy_normalized = self.legacy_normalize(self.name) - self.egg_prefix = legacy_normalized + '-' - self.versionless_egg_name = legacy_normalized + '.egg' + self.legacy_normalized = self.legacy_normalize(name) @staticmethod def normalize(name): @@ -667,27 +697,6 @@ def legacy_normalize(name): """ return name.lower().replace('-', '_') - def matches(self, cand, base): - low = cand.lower() - # rpartition is faster than splitext and suitable for this purpose. - pre, _, ext = low.rpartition('.') - name, _, rest = pre.partition('-') - return ( - low in self.exact_matches - or ext in self.suffixes - and (not self.normalized or name.replace('.', '_') == self.normalized) - # legacy case: - or self.is_egg(base) - and low == 'egg-info' - ) - - def is_egg(self, base): - return ( - base == self.versionless_egg_name - or base.startswith(self.egg_prefix) - and base.endswith('.egg') - ) - @install class MetadataPathFinder(NullFinder, DistributionFinder): @@ -717,6 +726,9 @@ def _search_paths(cls, name, paths): path.search(prepared) for path in map(FastPath, paths) ) + def invalidate_caches(cls): + FastPath.__new__.cache_clear() + class PathDistribution(Distribution): def __init__(self, path): diff --git a/tests/test_api.py b/tests/test_api.py index 8c8d9abb..fef99033 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -2,6 +2,7 @@ import textwrap import unittest import warnings +import importlib from . import fixtures from importlib_metadata import ( @@ -275,3 +276,9 @@ def test_distribution_at_str(self): dist_info_path = self.site_dir / 'distinfo_pkg-1.0.0.dist-info' dist = Distribution.at(str(dist_info_path)) assert dist.version == '1.0.0' + + +class InvalidateCache(unittest.TestCase): + def test_invalidate_cache(self): + # No externally observable behavior, but ensures test coverage... + importlib.invalidate_caches()