From ddc9ee599656d22f3a52803008a95581aba962d2 Mon Sep 17 00:00:00 2001 From: Stefaan Lippens Date: Mon, 14 Oct 2024 15:57:47 +0200 Subject: [PATCH] Issue #604/#644 replace lru_cache trick with cleaner cache --- openeo/extra/job_management.py | 12 ++++++------ tests/extra/test_job_management.py | 2 ++ 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/openeo/extra/job_management.py b/openeo/extra/job_management.py index d18dd6a27..d277f058d 100644 --- a/openeo/extra/job_management.py +++ b/openeo/extra/job_management.py @@ -2,7 +2,6 @@ import collections import contextlib import datetime -import functools import json import logging import re @@ -27,7 +26,7 @@ parse_remote_process_definition, ) from openeo.rest import OpenEoApiError -from openeo.util import deep_get, repr_truncate, rfc3339 +from openeo.util import LazyLoadCache, deep_get, repr_truncate, rfc3339 _log = logging.getLogger(__name__) @@ -994,11 +993,15 @@ def __init__( self._namespace = namespace self._parameter_defaults = parameter_defaults or {} self._parameter_column_map = parameter_column_map + self._cache = LazyLoadCache() def _get_process_definition(self, connection: Connection) -> Process: if isinstance(self._namespace, str) and re.match("https?://", self._namespace): # Remote process definition handling - return self._get_remote_process_definition() + return self._cache.get( + key=("remote_process_definition", self._namespace, self._process_id), + load=lambda: parse_remote_process_definition(namespace=self._namespace, process_id=self._process_id), + ) elif self._namespace is None: # Handling of a user-specific UDP udp_raw = connection.user_defined_process(self._process_id).describe() @@ -1008,9 +1011,6 @@ def _get_process_definition(self, connection: Connection) -> Process: f"Unsupported process definition source udp_id={self._process_id!r} namespace={self._namespace!r}" ) - @functools.lru_cache() - def _get_remote_process_definition(self) -> Process: - return parse_remote_process_definition(namespace=self._namespace, process_id=self._process_id) def start_job(self, row: pd.Series, connection: Connection, **_) -> BatchJob: """ diff --git a/tests/extra/test_job_management.py b/tests/extra/test_job_management.py index ea0a13f10..fae8cc2d5 100644 --- a/tests/extra/test_job_management.py +++ b/tests/extra/test_job_management.py @@ -1235,6 +1235,8 @@ def test_with_job_manager_remote_basic( } ) assert set(job_db.read().status) == {"finished"} + + # Verify caching of HTTP request of remote process definition assert remote_process_definitions["increment"].call_count == 1 assert dummy_backend.batch_jobs == {