diff --git a/CHANGELOG.md b/CHANGELOG.md index b1ca2a0f1..97c3d67ac 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] ### Added +- Added option (enabled by default) to automatically validate a process graph before execution. Validation issues just trigger warnings for now. ([#404](https://github.com/Open-EO/openeo-python-client/issues/404)) ### Changed diff --git a/openeo/rest/_testing.py b/openeo/rest/_testing.py index c87c56017..e2c78ec15 100644 --- a/openeo/rest/_testing.py +++ b/openeo/rest/_testing.py @@ -1,3 +1,4 @@ +import json import re from typing import Optional, Union @@ -11,6 +12,15 @@ class DummyBackend: and allows inspection of posted process graphs """ + __slots__ = ( + "connection", + "sync_requests", + "batch_jobs", + "validation_requests", + "next_result", + "next_validation_errors", + ) + # Default result (can serve both as JSON or binary data) DEFAULT_RESULT = b'{"what?": "Result data"}' @@ -18,9 +28,17 @@ def __init__(self, requests_mock, connection: Connection): self.connection = connection self.sync_requests = [] self.batch_jobs = {} + self.validation_requests = [] self.next_result = self.DEFAULT_RESULT - requests_mock.post(connection.build_url("/result"), content=self._handle_post_result) - requests_mock.post(connection.build_url("/jobs"), content=self._handle_post_jobs) + self.next_validation_errors = [] + requests_mock.post( + connection.build_url("/result"), + content=self._handle_post_result, + ) + requests_mock.post( + connection.build_url("/jobs"), + content=self._handle_post_jobs, + ) requests_mock.post( re.compile(connection.build_url(r"/jobs/(job-\d+)/results$")), content=self._handle_post_job_results ) @@ -32,12 +50,16 @@ def __init__(self, requests_mock, connection: Connection): re.compile(connection.build_url("/jobs/(.*?)/results/result.data$")), content=self._handle_get_job_result_asset, ) + requests_mock.post(connection.build_url("/validation"), json=self._handle_post_validation) def _handle_post_result(self, request, context): """handler of `POST /result` (synchronous execute)""" pg = request.json()["process"]["process_graph"] self.sync_requests.append(pg) - return self.next_result + result = self.next_result + if not isinstance(result, bytes): + result = json.dumps(result).encode("utf-8") + return result def _handle_post_jobs(self, request, context): """handler of `POST /jobs` (create batch job)""" @@ -83,6 +105,12 @@ def _handle_get_job_result_asset(self, request, context): assert self.batch_jobs[job_id]["status"] == "finished" return self.next_result + def _handle_post_validation(self, request, context): + """Handler of `POST /validation` (validate process graph).""" + pg = request.json()["process_graph"] + self.validation_requests.append(pg) + return {"errors": self.next_validation_errors} + def get_sync_pg(self) -> dict: """Get one and only synchronous process graph""" assert len(self.sync_requests) == 1 diff --git a/openeo/rest/connection.py b/openeo/rest/connection.py index 02738bf7e..57cf088ad 100644 --- a/openeo/rest/connection.py +++ b/openeo/rest/connection.py @@ -260,6 +260,7 @@ def __init__( refresh_token_store: Optional[RefreshTokenStore] = None, slow_response_threshold: Optional[float] = None, oidc_auth_renewer: Optional[OidcAuthenticator] = None, + auto_validate: bool = True, ): """ Constructor of Connection, authenticates user. @@ -282,6 +283,7 @@ def __init__( self._auth_config = auth_config self._refresh_token_store = refresh_token_store self._oidc_auth_renewer = oidc_auth_renewer + self._auto_validate = auto_validate @classmethod def version_discovery( @@ -1052,8 +1054,8 @@ def validate_process_graph(self, process_graph: dict) -> List[dict]: :param process_graph: (flat) dict representing process graph :return: list of errors (dictionaries with "code" and "message" fields) """ - request = {"process_graph": process_graph} - return self.post(path="/validation", json=request, expected_status=200).json()["errors"] + pg_with_metadata = self._build_request_with_process_graph(process_graph)["process"] + return self.post(path="/validation", json=pg_with_metadata, expected_status=200).json()["errors"] @property def _api_version(self) -> ComparableVersion: @@ -1393,8 +1395,9 @@ def load_url(self, url: str, format: str, options: Optional[dict] = None): def create_service(self, graph: dict, type: str, **kwargs) -> Service: # TODO: type hint for graph: is it a nested or a flat one? - req = self._build_request_with_process_graph(process_graph=graph, type=type, **kwargs) - response = self.post(path="/services", json=req, expected_status=201) + pg_with_metadata = self._build_request_with_process_graph(process_graph=graph, type=type, **kwargs) + self._preflight_validation(pg_with_metadata=pg_with_metadata) + response = self.post(path="/services", json=pg_with_metadata, expected_status=201) service_id = response.headers.get("OpenEO-Identifier") return Service(service_id, self) @@ -1463,23 +1466,55 @@ def upload_file( def _build_request_with_process_graph(self, process_graph: Union[dict, FlatGraphableMixin, Any], **kwargs) -> dict: """ Prepare a json payload with a process graph to submit to /result, /services, /jobs, ... - :param process_graph: flat dict representing a process graph + :param process_graph: flat dict representing a "process graph with metadata" ({"process": {"process_graph": ...}, ...}) """ # TODO: make this a more general helper (like `as_flat_graph`) result = kwargs process_graph = as_flat_graph(process_graph) if "process_graph" not in process_graph: process_graph = {"process_graph": process_graph} - # TODO: also check if `process_graph` already has "process" key (i.e. is a "process graph with metadata already) + # TODO: also check if `process_graph` already has "process" key (i.e. is a "process graph with metadata" already) result["process"] = process_graph return result + def _preflight_validation(self, pg_with_metadata: dict, *, validate: Optional[bool] = None): + """ + Preflight validation of process graph to execute. + + :param pg_with_metadata: flat dict representation of process graph with metadata, + e.g. as produced by `_build_request_with_process_graph` + :param validate: Optional toggle to enable/prevent validation of the process graphs before execution + (overruling the connection's ``auto_validate`` setting). + + :return: + """ + if validate is None: + validate = self._auto_validate + if validate and self.capabilities().supports_endpoint("/validation", "POST"): + # At present, the intention is that a failed validation does not block + # the job from running, it is only reported as a warning. + # Therefor we also want to continue when something *else* goes wrong + # *during* the validation. + try: + resp = self.post(path="/validation", json=pg_with_metadata["process"], expected_status=200) + validation_errors = resp.json()["errors"] + if validation_errors: + _log.warning( + "Preflight process graph validation raised: " + + (" ".join(f"[{e.get('code')}] {e.get('message')}" for e in validation_errors)) + ) + except Exception as e: + _log.error(f"Preflight process graph validation failed: {e}", exc_info=True) + + # TODO: additional validation and sanity checks: e.g. is there a result node, are all process_ids valid, ...? + # TODO: unify `download` and `execute` better: e.g. `download` always writes to disk, `execute` returns result (raw or as JSON decoded dict) def download( self, graph: Union[dict, FlatGraphableMixin, str, Path], outputfile: Union[Path, str, None] = None, timeout: Optional[int] = None, + validate: Optional[bool] = None, ) -> Union[None, bytes]: """ Downloads the result of a process graph synchronously, @@ -1490,11 +1525,14 @@ def download( or as local file path or URL :param outputfile: output file :param timeout: timeout to wait for response + :param validate: Optional toggle to enable/prevent validation of the process graphs before execution + (overruling the connection's ``auto_validate`` setting). """ - request = self._build_request_with_process_graph(process_graph=graph) + pg_with_metadata = self._build_request_with_process_graph(process_graph=graph) + self._preflight_validation(pg_with_metadata=pg_with_metadata, validate=validate) response = self.post( path="/result", - json=request, + json=pg_with_metadata, expected_status=200, stream=True, timeout=timeout or DEFAULT_TIMEOUT_SYNCHRONOUS_EXECUTE, @@ -1511,21 +1549,26 @@ def execute( self, process_graph: Union[dict, str, Path], timeout: Optional[int] = None, + validate: Optional[bool] = None, ): """ Execute a process graph synchronously and return the result (assumed to be JSON). :param process_graph: (flat) dict representing a process graph, or process graph as raw JSON string, or as local file path or URL + :param validate: Optional toggle to enable/prevent validation of the process graphs before execution + (overruling the connection's ``auto_validate`` setting). + :return: parsed JSON response """ - req = self._build_request_with_process_graph(process_graph=process_graph) + pg_with_metadata = self._build_request_with_process_graph(process_graph=process_graph) + self._preflight_validation(pg_with_metadata=pg_with_metadata, validate=validate) return self.post( path="/result", - json=req, + json=pg_with_metadata, expected_status=200, timeout=timeout or DEFAULT_TIMEOUT_SYNCHRONOUS_EXECUTE, - ).json() + ).json() # TODO: only do JSON decoding when mimetype is actually JSON? def create_job( self, @@ -1536,6 +1579,7 @@ def create_job( plan: Optional[str] = None, budget: Optional[float] = None, additional: Optional[dict] = None, + validate: Optional[bool] = None, ) -> BatchJob: """ Create a new job from given process graph on the back-end. @@ -1547,18 +1591,22 @@ def create_job( :param plan: billing plan :param budget: maximum cost the request is allowed to produce :param additional: additional job options to pass to the backend + :param validate: Optional toggle to enable/prevent validation of the process graphs before execution + (overruling the connection's ``auto_validate`` setting). :return: Created job """ # TODO move all this (BatchJob factory) logic to BatchJob? - req = self._build_request_with_process_graph( + + pg_with_metadata = self._build_request_with_process_graph( process_graph=process_graph, **dict_no_none(title=title, description=description, plan=plan, budget=budget) ) if additional: # TODO: get rid of this non-standard field? https://github.com/Open-EO/openeo-api/issues/276 - req["job_options"] = additional + pg_with_metadata["job_options"] = additional - response = self.post("/jobs", json=req, expected_status=201) + self._preflight_validation(pg_with_metadata=pg_with_metadata, validate=validate) + response = self.post("/jobs", json=pg_with_metadata, expected_status=201) job_id = None if "openeo-identifier" in response.headers: @@ -1636,8 +1684,8 @@ def as_curl( cmd += ["-H", "Content-Type: application/json"] if isinstance(self.auth, BearerAuth): cmd += ["-H", f"Authorization: Bearer {'...' if obfuscate_auth else self.auth.bearer}"] - post_data = self._build_request_with_process_graph(data) - post_json = json.dumps(post_data, separators=(',', ':')) + pg_with_metadata = self._build_request_with_process_graph(data) + post_json = json.dumps(pg_with_metadata, separators=(",", ":")) cmd += ["--data", post_json] cmd += [self.build_url(path)] return " ".join(shlex.quote(c) for c in cmd) @@ -1657,17 +1705,20 @@ def version_info(self): def connect( - url: Optional[str] = None, - auth_type: Optional[str] = None, auth_options: Optional[dict] = None, - session: Optional[requests.Session] = None, - default_timeout: Optional[int] = None, + url: Optional[str] = None, + *, + auth_type: Optional[str] = None, + auth_options: Optional[dict] = None, + session: Optional[requests.Session] = None, + default_timeout: Optional[int] = None, + auto_validate: bool = True, ) -> Connection: """ This method is the entry point to OpenEO. You typically create one connection object in your script or application and re-use it for all calls to that backend. - If the backend requires authentication, you can pass authentication data directly to this function + If the backend requires authentication, you can pass authentication data directly to this function, but it could be easier to authenticate as follows: >>> # For basic authentication @@ -1679,7 +1730,10 @@ def connect( :param auth_type: Which authentication to use: None, "basic" or "oidc" (for OpenID Connect) :param auth_options: Options/arguments specific to the authentication type :param default_timeout: default timeout (in seconds) for requests - :rtype: openeo.connections.Connection + :param auto_validate: toggle to automatically validate process graphs before execution + + .. versionadded:: 0.24.0 + added ``auto_validate`` argument """ def _config_log(message): @@ -1704,7 +1758,7 @@ def _config_log(message): if not url: raise OpenEoClientException("No openEO back-end URL given or known to connect to.") - connection = Connection(url, session=session, default_timeout=default_timeout) + connection = Connection(url, session=session, default_timeout=default_timeout, auto_validate=auto_validate) auth_type = auth_type.lower() if isinstance(auth_type, str) else auth_type if auth_type in {None, False, 'null', 'none'}: diff --git a/openeo/rest/datacube.py b/openeo/rest/datacube.py index 35594ce28..906122dd6 100644 --- a/openeo/rest/datacube.py +++ b/openeo/rest/datacube.py @@ -1945,6 +1945,8 @@ def download( outputfile: Optional[Union[str, pathlib.Path]] = None, format: Optional[str] = None, options: Optional[dict] = None, + *, + validate: Optional[bool] = None, ) -> Union[None, bytes]: """ Execute synchronously and download the raster data cube, e.g. as GeoTIFF. @@ -1955,13 +1957,15 @@ def download( :param outputfile: Optional, an output file if the result needs to be stored on disk. :param format: Optional, an output format supported by the backend. :param options: Optional, file format options + :param validate: Optional toggle to enable/prevent validation of the process graphs before execution + (overruling the connection's ``auto_validate`` setting). :return: None if the result is stored to disk, or a bytes object returned by the backend. """ if format is None and outputfile: # TODO #401/#449 don't guess/override format if there is already a save_result with format? format = guess_format(outputfile) cube = self._ensure_save_result(format=format, options=options) - return self._connection.download(cube.flat_graph(), outputfile) + return self._connection.download(cube.flat_graph(), outputfile, validate=validate) def validate(self) -> List[dict]: """ @@ -2062,6 +2066,7 @@ def execute_batch( max_poll_interval: float = 60, connection_retry_interval: float = 30, job_options: Optional[dict] = None, + validate: Optional[bool] = None, # TODO: avoid `format_options` as keyword arguments **format_options, ) -> BatchJob: @@ -2074,6 +2079,8 @@ def execute_batch( :param outputfile: The path of a file to which a result can be written :param out_format: (optional) File format to use for the job result. :param job_options: + :param validate: Optional toggle to enable/prevent validation of the process graphs before execution + (overruling the connection's ``auto_validate`` setting). """ if "format" in format_options and not out_format: out_format = format_options["format"] # align with 'download' call arg name @@ -2081,9 +2088,7 @@ def execute_batch( # TODO #401/#449 don't guess/override format if there is already a save_result with format? out_format = guess_format(outputfile) - job = self.create_job( - out_format=out_format, job_options=job_options, **format_options - ) + job = self.create_job(out_format=out_format, job_options=job_options, validate=validate, **format_options) return job.run_synchronous( outputfile=outputfile, print=print, max_poll_interval=max_poll_interval, connection_retry_interval=connection_retry_interval @@ -2098,6 +2103,7 @@ def create_job( plan: Optional[str] = None, budget: Optional[float] = None, job_options: Optional[dict] = None, + validate: Optional[bool] = None, # TODO: avoid `format_options` as keyword arguments **format_options, ) -> BatchJob: @@ -2115,6 +2121,9 @@ def create_job( :param plan: billing plan :param budget: maximum cost the request is allowed to produce :param job_options: custom job options. + :param validate: Optional toggle to enable/prevent validation of the process graphs before execution + (overruling the connection's ``auto_validate`` setting). + :return: Created job. """ # TODO: add option to also automatically start the job? @@ -2127,6 +2136,7 @@ def create_job( description=description, plan=plan, budget=budget, + validate=validate, additional=job_options, ) @@ -2162,9 +2172,9 @@ def save_user_defined_process( returns=returns, categories=categories, examples=examples, links=links, ) - def execute(self) -> dict: - """Executes the process graph of the imagery. """ - return self._connection.execute(self.flat_graph()) + def execute(self, *, validate: Optional[bool] = None) -> dict: + """Executes the process graph.""" + return self._connection.execute(self.flat_graph(), validate=validate) @staticmethod @deprecated(reason="Use :py:func:`openeo.udf.run_code.execute_local_udf` instead", version="0.7.0") diff --git a/openeo/rest/udp.py b/openeo/rest/udp.py index 25aa7c511..5cac35347 100644 --- a/openeo/rest/udp.py +++ b/openeo/rest/udp.py @@ -94,6 +94,7 @@ def store( # TODO: this "public" flag is not standardized yet EP-3609, https://github.com/Open-EO/openeo-api/issues/310 process["public"] = public + self._connection._preflight_validation(pg_with_metadata=process) self._connection.put( path="/process_graphs/{}".format(self.user_defined_process_id), json=process, expected_status=200 ) diff --git a/openeo/rest/vectorcube.py b/openeo/rest/vectorcube.py index 0f3cc7eac..6f4e7b0f5 100644 --- a/openeo/rest/vectorcube.py +++ b/openeo/rest/vectorcube.py @@ -226,15 +226,17 @@ def _ensure_save_result( cube = self.save_result(format=format or "GeoJSON", options=options) return cube - def execute(self) -> dict: - """Executes the process graph of the imagery.""" - return self._connection.execute(self.flat_graph()) + def execute(self, *, validate: Optional[bool] = None) -> dict: + """Executes the process graph.""" + return self._connection.execute(self.flat_graph(), validate=validate) def download( self, outputfile: Optional[Union[str, pathlib.Path]] = None, format: Optional[str] = None, options: Optional[dict] = None, + *, + validate: Optional[bool] = None, ) -> Union[None, bytes]: """ Execute synchronously and download the vector cube. @@ -245,7 +247,8 @@ def download( :param outputfile: (optional) output file to store the result to :param format: (optional) output format to use. :param options: (optional) additional output format options. - :return: + :param validate: Optional toggle to enable/prevent validation of the process graphs before execution + (overruling the connection's ``auto_validate`` setting). .. versionchanged:: 0.21.0 When not specified explicitly, output format is guessed from output file extension. @@ -256,16 +259,18 @@ def download( if format is None and outputfile: format = guess_format(outputfile) cube = self._ensure_save_result(format=format, options=options) - return self._connection.download(cube.flat_graph(), outputfile) + return self._connection.download(cube.flat_graph(), outputfile=outputfile, validate=validate) def execute_batch( self, outputfile: Optional[Union[str, pathlib.Path]] = None, out_format: Optional[str] = None, + *, print=print, max_poll_interval: float = 60, connection_retry_interval: float = 30, job_options: Optional[dict] = None, + validate: Optional[bool] = None, # TODO: avoid using kwargs as format options **format_options, ) -> BatchJob: @@ -279,6 +284,8 @@ def execute_batch( :param outputfile: The path of a file to which a result can be written :param out_format: (optional) output format to use. :param format_options: (optional) additional output format options + :param validate: Optional toggle to enable/prevent validation of the process graphs before execution + (overruling the connection's ``auto_validate`` setting). .. versionchanged:: 0.21.0 When not specified explicitly, output format is guessed from output file extension. @@ -287,7 +294,7 @@ def execute_batch( # TODO #401/#449 don't guess/override format if there is already a save_result with format? out_format = guess_format(outputfile) - job = self.create_job(out_format, job_options=job_options, **format_options) + job = self.create_job(out_format, job_options=job_options, validate=validate, **format_options) return job.run_synchronous( # TODO #135 support multi file result sets too outputfile=outputfile, @@ -303,6 +310,7 @@ def create_job( plan: Optional[str] = None, budget: Optional[float] = None, job_options: Optional[dict] = None, + validate: Optional[bool] = None, **format_options, ) -> BatchJob: """ @@ -315,6 +323,9 @@ def create_job( :param budget: maximum cost the request is allowed to produce :param job_options: A dictionary containing (custom) job options :param format_options: String Parameters for the job result format + :param validate: Optional toggle to enable/prevent validation of the process graphs before execution + (overruling the connection's ``auto_validate`` setting). + :return: Created job. """ # TODO: avoid using all kwargs as format_options @@ -327,6 +338,7 @@ def create_job( plan=plan, budget=budget, additional=job_options, + validate=validate, ) send_job = legacy_alias(create_job, name="send_job", since="0.10.0") diff --git a/tests/rest/conftest.py b/tests/rest/conftest.py index 3914168ff..e84dd467d 100644 --- a/tests/rest/conftest.py +++ b/tests/rest/conftest.py @@ -6,7 +6,7 @@ import pytest import time_machine -from openeo.rest._testing import DummyBackend +from openeo.rest._testing import DummyBackend, build_capabilities from openeo.rest.connection import Connection API_URL = "https://oeo.test/" @@ -71,15 +71,31 @@ def assert_oidc_device_code_flow(url: str = "https://oidc.test/dc", elapsed: flo @pytest.fixture -def con100(requests_mock): - requests_mock.get(API_URL, json={"api_version": "1.0.0"}) +def api_capabilities() -> dict: + """ + Fixture to be overridden for customizing the capabilities doc used by connection fixtures. + To be used as kwargs for `build_capabilities` + """ + return {} + + +@pytest.fixture +def connection(api_version, requests_mock, api_capabilities) -> Connection: + requests_mock.get(API_URL, json=build_capabilities(api_version=api_version, **api_capabilities)) + con = Connection(API_URL) + return con + + +@pytest.fixture +def con100(requests_mock, api_capabilities): + requests_mock.get(API_URL, json=build_capabilities(api_version="1.0.0", **api_capabilities)) con = Connection(API_URL) return con @pytest.fixture -def con120(requests_mock): - requests_mock.get(API_URL, json={"api_version": "1.2.0"}) +def con120(requests_mock, api_capabilities): + requests_mock.get(API_URL, json=build_capabilities(api_version="1.2.0", **api_capabilities)) con = Connection(API_URL) return con diff --git a/tests/rest/datacube/conftest.py b/tests/rest/datacube/conftest.py index 158807d75..4785690aa 100644 --- a/tests/rest/datacube/conftest.py +++ b/tests/rest/datacube/conftest.py @@ -67,22 +67,18 @@ def setup_collection_metadata(requests_mock, cid: str, bands: List[str]): }) -@pytest.fixture -def support_udp() -> bool: - """Per-test overridable `build_capabilities_kwargs(udp=...)` value for connection fixtures""" - return False - @pytest.fixture -def connection(api_version, requests_mock) -> Connection: +def connection(api_version, requests_mock, api_capabilities) -> Connection: """Connection fixture to a backend of given version with some image collections.""" - return _setup_connection(api_version, requests_mock) + return _setup_connection(api_version, requests_mock, build_capabilities_kwargs=api_capabilities) @pytest.fixture -def con100(requests_mock, support_udp) -> Connection: +def con100(requests_mock, api_capabilities) -> Connection: """Connection fixture to a 1.0.0 backend with some image collections.""" - return _setup_connection("1.0.0", requests_mock, build_capabilities_kwargs={"udp": support_udp}) + return _setup_connection("1.0.0", requests_mock, build_capabilities_kwargs=api_capabilities) + @pytest.fixture diff --git a/tests/rest/datacube/test_datacube.py b/tests/rest/datacube/test_datacube.py index babd72129..c4d31149f 100644 --- a/tests/rest/datacube/test_datacube.py +++ b/tests/rest/datacube/test_datacube.py @@ -14,7 +14,10 @@ import shapely.geometry from openeo.rest import BandMathException +from openeo.rest._testing import build_capabilities +from openeo.rest.connection import Connection from openeo.rest.datacube import DataCube +from openeo.util import dict_no_none from ... import load_json_resource from .. import get_download_graph @@ -807,3 +810,178 @@ def test_save_result_format_options_vs_execute_batch(elf, s2cube, get_create_job }, "result": True, } + + +class TestDataCubeValidation: + """ + Test (auto) validation of datacube execution with `download`, `execute`, ... + """ + + _PG_S2 = { + "loadcollection1": { + "process_id": "load_collection", + "arguments": {"id": "S2", "spatial_extent": None, "temporal_extent": None}, + "result": True, + }, + } + _PG_S2_SAVE = { + "loadcollection1": { + "process_id": "load_collection", + "arguments": {"id": "S2", "spatial_extent": None, "temporal_extent": None}, + }, + "saveresult1": { + "process_id": "save_result", + "arguments": {"data": {"from_node": "loadcollection1"}, "format": "GTiff", "options": {}}, + "result": True, + }, + } + + @pytest.fixture(params=[False, True]) + def auto_validate(self, request) -> bool: + """Fixture to parametrize auto_validate setting.""" + return request.param + + @pytest.fixture + def connection(self, api_version, requests_mock, api_capabilities, auto_validate) -> Connection: + requests_mock.get(API_URL, json=build_capabilities(api_version=api_version, **api_capabilities)) + con = Connection(API_URL, **dict_no_none(auto_validate=auto_validate)) + return con + + @pytest.fixture(autouse=True) + def dummy_backend_setup(self, dummy_backend): + dummy_backend.next_validation_errors = [{"code": "NoAdd", "message": "Don't add numbers"}] + + # Reusable list of (fixture) parameterization + # of ["api_capabilities", "auto_validate", "validate", "validation_expected"] + _VALIDATION_PARAMETER_SETS = [ + # No validation supported by backend: don't attempt to validate + ({}, None, None, False), + ({}, True, True, False), + # Validation supported by backend, default behavior -> validate + ({"validation": True}, None, None, True), + # (Validation supported by backend) no explicit validation enabled: follow auto_validate setting + ({"validation": True}, True, None, True), + ({"validation": True}, False, None, False), + # (Validation supported by backend) follow explicit `validate` toggle regardless of auto_validate + ({"validation": True}, False, True, True), + ({"validation": True}, True, False, False), + ] + + @pytest.mark.parametrize( + ["api_capabilities", "auto_validate", "validate", "validation_expected"], + _VALIDATION_PARAMETER_SETS, + ) + def test_cube_download_validation(self, dummy_backend, connection, validate, validation_expected, caplog, tmp_path): + """The DataCube should pass through request for the validation to the + connection and the validation endpoint should only be called when + validation was requested. + """ + cube = connection.load_collection("S2") + + output = tmp_path / "result.tiff" + cube.download(outputfile=output, **dict_no_none(validate=validate)) + assert output.read_bytes() == b'{"what?": "Result data"}' + assert dummy_backend.get_sync_pg() == self._PG_S2_SAVE + + if validation_expected: + assert dummy_backend.validation_requests == [self._PG_S2_SAVE] + assert caplog.messages == ["Preflight process graph validation raised: [NoAdd] Don't add numbers"] + else: + assert dummy_backend.validation_requests == [] + assert caplog.messages == [] + + @pytest.mark.parametrize("api_capabilities", [{"validation": True}]) + def test_cube_download_validation_broken(self, dummy_backend, connection, requests_mock, caplog, tmp_path): + """Test resilience against broken validation response.""" + requests_mock.post( + connection.build_url("/validation"), status_code=500, json={"code": "Internal", "message": "nope!"} + ) + + cube = connection.load_collection("S2") + + output = tmp_path / "result.tiff" + cube.download(outputfile=output, validate=True) + assert output.read_bytes() == b'{"what?": "Result data"}' + assert dummy_backend.get_sync_pg() == self._PG_S2_SAVE + + assert caplog.messages == ["Preflight process graph validation failed: [500] Internal: nope!"] + + @pytest.mark.parametrize( + ["api_capabilities", "auto_validate", "validate", "validation_expected"], + _VALIDATION_PARAMETER_SETS, + ) + def test_cube_execute_validation(self, dummy_backend, connection, validate, validation_expected, caplog): + """The DataCube should pass through request for the validation to the + connection and the validation endpoint should only be called when + validation was requested. + """ + cube = connection.load_collection("S2") + + res = cube.execute(**dict_no_none(validate=validate)) + assert res == {"what?": "Result data"} + assert dummy_backend.get_sync_pg() == self._PG_S2 + + if validation_expected: + assert dummy_backend.validation_requests == [self._PG_S2] + assert caplog.messages == ["Preflight process graph validation raised: [NoAdd] Don't add numbers"] + else: + assert dummy_backend.validation_requests == [] + assert caplog.messages == [] + + @pytest.mark.parametrize( + ["api_capabilities", "auto_validate", "validate", "validation_expected"], + _VALIDATION_PARAMETER_SETS, + ) + def test_cube_create_job_validation( + self, dummy_backend, connection: Connection, validate, validation_expected, caplog + ): + """The DataCube should pass through request for the validation to the + connection and the validation endpoint should only be called when + validation was requested. + """ + cube = connection.load_collection("S2") + job = cube.create_job(**dict_no_none(validate=validate)) + assert job.job_id == "job-000" + assert dummy_backend.get_batch_pg() == self._PG_S2_SAVE + + if validation_expected: + assert dummy_backend.validation_requests == [self._PG_S2_SAVE] + assert caplog.messages == ["Preflight process graph validation raised: [NoAdd] Don't add numbers"] + else: + assert dummy_backend.validation_requests == [] + assert caplog.messages == [] + + @pytest.mark.parametrize("api_capabilities", [{"validation": True}]) + def test_cube_create_job_validation_broken(self, dummy_backend, connection, requests_mock, caplog, tmp_path): + """Test resilience against broken validation response.""" + requests_mock.post( + connection.build_url("/validation"), status_code=500, json={"code": "Internal", "message": "nope!"} + ) + + cube = connection.load_collection("S2") + job = cube.create_job(validate=True) + assert job.job_id == "job-000" + assert dummy_backend.get_batch_pg() == self._PG_S2_SAVE + + assert caplog.messages == ["Preflight process graph validation failed: [500] Internal: nope!"] + + @pytest.mark.parametrize( + ["api_capabilities", "auto_validate", "validate", "validation_expected"], + _VALIDATION_PARAMETER_SETS, + ) + def test_cube_execute_batch_validation(self, dummy_backend, connection, validate, validation_expected, caplog): + """The DataCube should pass through request for the validation to the + connection and the validation endpoint should only be called when + validation was requested. + """ + cube = connection.load_collection("S2") + job = cube.execute_batch(**dict_no_none(validate=validate)) + assert job.job_id == "job-000" + assert dummy_backend.get_batch_pg() == self._PG_S2_SAVE + + if validation_expected: + assert dummy_backend.validation_requests == [self._PG_S2_SAVE] + assert caplog.messages == ["Preflight process graph validation raised: [NoAdd] Don't add numbers"] + else: + assert dummy_backend.validation_requests == [] + assert caplog.messages == [] diff --git a/tests/rest/datacube/test_datacube100.py b/tests/rest/datacube/test_datacube100.py index e7c65ca60..4835707a6 100644 --- a/tests/rest/datacube/test_datacube100.py +++ b/tests/rest/datacube/test_datacube100.py @@ -27,7 +27,6 @@ from openeo.internal.warnings import UserDeprecationWarning from openeo.processes import ProcessBuilder from openeo.rest import OpenEoClientException -from openeo.rest._testing import build_capabilities from openeo.rest.connection import Connection from openeo.rest.datacube import THIS, UDF, DataCube @@ -1931,9 +1930,9 @@ def test_custom_process_arguments_namespacd(con100: Connection): assert res.flat_graph() == expected -@pytest.mark.parametrize("support_udp", [True]) + +@pytest.mark.parametrize("api_capabilities", [{"udp": True}]) def test_save_user_defined_process(con100, requests_mock): - requests_mock.get(API_URL + "/", json=build_capabilities(udp=True)) requests_mock.get(API_URL + "/processes", json={"processes": [{"id": "add"}]}) expected_body = load_json_resource("data/1.0.0/save_user_defined_process.json") @@ -1955,9 +1954,8 @@ def check_body(request): assert adapter.called -@pytest.mark.parametrize("support_udp", [True]) +@pytest.mark.parametrize("api_capabilities", [{"udp": True}]) def test_save_user_defined_process_public(con100, requests_mock): - requests_mock.get(API_URL + "/", json=build_capabilities(udp=True)) requests_mock.get(API_URL + "/processes", json={"processes": [{"id": "add"}]}) expected_body = load_json_resource("data/1.0.0/save_user_defined_process.json") diff --git a/tests/rest/datacube/test_vectorcube.py b/tests/rest/datacube/test_vectorcube.py index 10d490fd3..fbfbd298c 100644 --- a/tests/rest/datacube/test_vectorcube.py +++ b/tests/rest/datacube/test_vectorcube.py @@ -1,3 +1,4 @@ +import json import re from pathlib import Path @@ -6,9 +7,12 @@ import openeo.processes from openeo.api.process import Parameter -from openeo.rest._testing import DummyBackend +from openeo.rest._testing import DummyBackend, build_capabilities +from openeo.rest.connection import Connection from openeo.rest.vectorcube import VectorCube -from openeo.util import InvalidBBoxException +from openeo.util import InvalidBBoxException, dict_no_none + +API_URL = "https://oeo.test" @pytest.fixture @@ -476,3 +480,167 @@ def test_filter_vector_shapely(vector_cube, dummy_backend, geometries): "result": True, }, } + + +class TestVectorCubeValidation: + """ + Test (auto) validation of vector cube execution with `download`, `execute`, ... + """ + + _PG_GEOJSON = { + "loadgeojson1": { + "process_id": "load_geojson", + "arguments": {"data": {"type": "Point", "coordinates": [1, 2]}, "properties": []}, + "result": True, + }, + } + _PG_GEOJSON_SAVE = { + "loadgeojson1": { + "process_id": "load_geojson", + "arguments": {"data": {"type": "Point", "coordinates": [1, 2]}, "properties": []}, + }, + "saveresult1": { + "process_id": "save_result", + "arguments": {"data": {"from_node": "loadgeojson1"}, "format": "GeoJSON", "options": {}}, + "result": True, + }, + } + + @pytest.fixture(params=[False, True]) + def auto_validate(self, request) -> bool: + """Fixture to parametrize auto_validate setting.""" + return request.param + + @pytest.fixture + def connection(self, api_version, requests_mock, api_capabilities, auto_validate) -> Connection: + requests_mock.get(API_URL, json=build_capabilities(api_version=api_version, **api_capabilities)) + con = Connection(API_URL, **dict_no_none(auto_validate=auto_validate)) + return con + + @pytest.fixture(autouse=True) + def dummy_backend_setup(self, dummy_backend): + dummy_backend.next_result = {"type": "Point", "coordinates": [1, 2]} + dummy_backend.next_validation_errors = [{"code": "NoPoints", "message": "Don't use points."}] + + # Reusable list of (fixture) parameterization + # of ["api_capabilities", "auto_validate", "validate", "validation_expected"] + _VALIDATION_PARAMETER_SETS = [ + # No validation supported by backend: don't attempt to validate + ({}, None, None, False), + ({}, True, True, False), + # Validation supported by backend, default behavior -> validate + ({"validation": True}, None, None, True), + # (Validation supported by backend) no explicit validation enabled: follow auto_validate setting + ({"validation": True}, True, None, True), + ({"validation": True}, False, None, False), + # (Validation supported by backend) follow explicit `validate` toggle regardless of auto_validate + ({"validation": True}, False, True, True), + ({"validation": True}, True, False, False), + ] + + @pytest.mark.parametrize( + ["api_capabilities", "auto_validate", "validate", "validation_expected"], + _VALIDATION_PARAMETER_SETS, + ) + def test_vectorcube_download_validation( + self, dummy_backend, connection, validate, validation_expected, caplog, tmp_path + ): + """The DataCube should pass through request for the validation to the + connection and the validation endpoint should only be called when + validation was requested. + """ + vector_cube = VectorCube.load_geojson(connection=connection, data={"type": "Point", "coordinates": [1, 2]}) + + output = tmp_path / "result.geojson" + vector_cube.download(outputfile=output, **dict_no_none(validate=validate)) + assert json.loads(output.read_text()) == {"type": "Point", "coordinates": [1, 2]} + assert dummy_backend.get_sync_pg() == self._PG_GEOJSON_SAVE + + if validation_expected: + assert dummy_backend.validation_requests == [self._PG_GEOJSON_SAVE] + assert caplog.messages == ["Preflight process graph validation raised: [NoPoints] Don't use points."] + else: + assert dummy_backend.validation_requests == [] + assert caplog.messages == [] + + @pytest.mark.parametrize( + ["api_capabilities", "auto_validate", "validate", "validation_expected"], + _VALIDATION_PARAMETER_SETS, + ) + def test_vectorcube_execute_validation(self, dummy_backend, connection, validate, validation_expected, caplog): + """The DataCube should pass through request for the validation to the + connection and the validation endpoint should only be called when + validation was requested. + """ + vector_cube = VectorCube.load_geojson(connection=connection, data={"type": "Point", "coordinates": [1, 2]}) + + res = vector_cube.execute(**dict_no_none(validate=validate)) + assert res == {"type": "Point", "coordinates": [1, 2]} + assert dummy_backend.get_sync_pg() == self._PG_GEOJSON + + if validation_expected: + assert dummy_backend.validation_requests == [self._PG_GEOJSON] + assert caplog.messages == ["Preflight process graph validation raised: [NoPoints] Don't use points."] + else: + assert dummy_backend.validation_requests == [] + assert caplog.messages == [] + + @pytest.mark.parametrize( + ["api_capabilities", "auto_validate", "validate", "validation_expected"], + _VALIDATION_PARAMETER_SETS, + ) + def test_vectorcube_create_job_validation(self, dummy_backend, connection, validate, validation_expected, caplog): + """The DataCube should pass through request for the validation to the + connection and the validation endpoint should only be called when + validation was requested. + """ + vector_cube = VectorCube.load_geojson(connection=connection, data={"type": "Point", "coordinates": [1, 2]}) + + job = vector_cube.create_job(**dict_no_none(validate=validate)) + assert job.job_id == "job-000" + assert dummy_backend.get_batch_pg() == self._PG_GEOJSON_SAVE + + if validation_expected: + assert dummy_backend.validation_requests == [self._PG_GEOJSON_SAVE] + assert caplog.messages == ["Preflight process graph validation raised: [NoPoints] Don't use points."] + else: + assert dummy_backend.validation_requests == [] + assert caplog.messages == [] + + @pytest.mark.parametrize("api_capabilities", [{"validation": True}]) + def test_vectorcube_create_job_validation_broken(self, dummy_backend, connection, requests_mock, caplog): + """Test resilience against broken validation response.""" + requests_mock.post( + connection.build_url("/validation"), status_code=500, json={"code": "Internal", "message": "nope!"} + ) + vector_cube = VectorCube.load_geojson(connection=connection, data={"type": "Point", "coordinates": [1, 2]}) + + job = vector_cube.create_job(validate=True) + assert job.job_id == "job-000" + assert dummy_backend.get_batch_pg() == self._PG_GEOJSON_SAVE + + assert caplog.messages == ["Preflight process graph validation failed: [500] Internal: nope!"] + + @pytest.mark.parametrize( + ["api_capabilities", "auto_validate", "validate", "validation_expected"], + _VALIDATION_PARAMETER_SETS, + ) + def test_vectorcube_execute_batch_validation( + self, dummy_backend, connection, validate, validation_expected, caplog + ): + """The DataCube should pass through request for the validation to the + connection and the validation endpoint should only be called when + validation was requested. + """ + vector_cube = VectorCube.load_geojson(connection=connection, data={"type": "Point", "coordinates": [1, 2]}) + + job = vector_cube.execute_batch(**dict_no_none(validate=validate)) + assert job.job_id == "job-000" + assert dummy_backend.get_batch_pg() == self._PG_GEOJSON_SAVE + + if validation_expected: + assert dummy_backend.validation_requests == [self._PG_GEOJSON_SAVE] + assert caplog.messages == ["Preflight process graph validation raised: [NoPoints] Don't use points."] + else: + assert dummy_backend.validation_requests == [] + assert caplog.messages == [] diff --git a/tests/rest/test_connection.py b/tests/rest/test_connection.py index 09a97ac25..2cdfb68e3 100644 --- a/tests/rest/test_connection.py +++ b/tests/rest/test_connection.py @@ -32,7 +32,7 @@ paginate, ) from openeo.rest.vectorcube import VectorCube -from openeo.util import ContextTimer +from openeo.util import ContextTimer, dict_no_none from .. import load_json_resource from .auth.test_cli import auth_config, refresh_token_store @@ -3174,126 +3174,237 @@ def test_vectorcube_from_paths(requests_mock): } -class TestExecute: +class TestExecuteFromJsonResources: + """ + Tests for executing process graphs directly from JSON resources (JSON dumps, files, URLs, ...) + """ # Dummy process graphs PG_JSON_1 = '{"add": {"process_id": "add", "arguments": {"x": 3, "y": 5}, "result": true}}' + PG_DICT_1 = {"add": {"process_id": "add", "arguments": {"x": 3, "y": 5}, "result": True}} PG_JSON_2 = '{"process_graph": {"add": {"process_id": "add", "arguments": {"x": 3, "y": 5}, "result": true}}}' - # Dummy `POST /result` handlers - def _post_result_handler_tiff(self, response: requests.Request, context): - pg = response.json()["process"]["process_graph"] - assert pg == {"add": {"process_id": "add", "arguments": {"x": 3, "y": 5}, "result": True}} - return b"TIFF data" - - def _post_result_handler_json(self, response: requests.Request, context): - pg = response.json()["process"]["process_graph"] - assert pg == {"add": {"process_id": "add", "arguments": {"x": 3, "y": 5}, "result": True}} - return {"answer": 8} - - def _post_jobs_handler_json(self, response: requests.Request, context): - pg = response.json()["process"]["process_graph"] - assert pg == {"add": {"process_id": "add", "arguments": {"x": 3, "y": 5}, "result": True}} - context.headers["OpenEO-Identifier"] = "j-123" - return b"" - @pytest.mark.parametrize("pg_json", [PG_JSON_1, PG_JSON_2]) - def test_download_pg_json(self, requests_mock, tmp_path, pg_json: str): - requests_mock.get(API_URL, json={"api_version": "1.0.0"}) - requests_mock.post(API_URL + "result", content=self._post_result_handler_tiff) - - conn = Connection(API_URL) + def test_download_pg_json(self, dummy_backend, connection, tmp_path, pg_json: str): output = tmp_path / "result.tiff" - conn.download(pg_json, outputfile=output) - assert output.read_bytes() == b"TIFF data" + connection.download(pg_json, outputfile=output) + assert output.read_bytes() == b'{"what?": "Result data"}' + assert dummy_backend.get_sync_pg() == self.PG_DICT_1 @pytest.mark.parametrize("pg_json", [PG_JSON_1, PG_JSON_2]) - def test_execute_pg_json(self, requests_mock, pg_json: str): - requests_mock.get(API_URL, json={"api_version": "1.0.0"}) - requests_mock.post(API_URL + "result", json=self._post_result_handler_json) - - conn = Connection(API_URL) - result = conn.execute(pg_json) + def test_execute_pg_json(self, dummy_backend, connection, pg_json: str): + dummy_backend.next_result = {"answer": 8} + result = connection.execute(pg_json) assert result == {"answer": 8} + assert dummy_backend.get_sync_pg() == self.PG_DICT_1 @pytest.mark.parametrize("pg_json", [PG_JSON_1, PG_JSON_2]) - def test_create_job_pg_json(self, requests_mock, pg_json: str): - requests_mock.get(API_URL, json={"api_version": "1.0.0"}) - requests_mock.post(API_URL + "jobs", status_code=201, content=self._post_jobs_handler_json) - - conn = Connection(API_URL) - job = conn.create_job(pg_json) - assert job.job_id == "j-123" + def test_create_job_pg_json(self, dummy_backend, connection, pg_json: str): + job = connection.create_job(pg_json) + assert job.job_id == "job-000" + assert dummy_backend.get_batch_pg() == self.PG_DICT_1 @pytest.mark.parametrize("pg_json", [PG_JSON_1, PG_JSON_2]) @pytest.mark.parametrize("path_factory", [str, Path]) - def test_download_pg_json_file(self, requests_mock, tmp_path, pg_json: str, path_factory): - requests_mock.get(API_URL, json={"api_version": "1.0.0"}) - requests_mock.post(API_URL + "result", content=self._post_result_handler_tiff) + def test_download_pg_json_file(self, dummy_backend, connection, tmp_path, pg_json: str, path_factory): json_file = tmp_path / "input.json" json_file.write_text(pg_json) json_file = path_factory(json_file) - conn = Connection(API_URL) output = tmp_path / "result.tiff" - conn.download(json_file, outputfile=output) - assert output.read_bytes() == b"TIFF data" + connection.download(json_file, outputfile=output) + assert output.read_bytes() == b'{"what?": "Result data"}' + assert dummy_backend.get_sync_pg() == self.PG_DICT_1 @pytest.mark.parametrize("pg_json", [PG_JSON_1, PG_JSON_2]) @pytest.mark.parametrize("path_factory", [str, Path]) - def test_execute_pg_json_file(self, requests_mock, pg_json: str, tmp_path, path_factory): - requests_mock.get(API_URL, json={"api_version": "1.0.0"}) - requests_mock.post(API_URL + "result", json=self._post_result_handler_json) + def test_execute_pg_json_file(self, dummy_backend, connection, pg_json: str, tmp_path, path_factory): + dummy_backend.next_result = {"answer": 8} + json_file = tmp_path / "input.json" json_file.write_text(pg_json) json_file = path_factory(json_file) - conn = Connection(API_URL) - result = conn.execute(json_file) + result = connection.execute(json_file) assert result == {"answer": 8} + assert dummy_backend.get_sync_pg() == self.PG_DICT_1 @pytest.mark.parametrize("pg_json", [PG_JSON_1, PG_JSON_2]) @pytest.mark.parametrize("path_factory", [str, Path]) - def test_create_job_pg_json_file(self, requests_mock, pg_json: str, tmp_path, path_factory): - requests_mock.get(API_URL, json={"api_version": "1.0.0"}) - requests_mock.post(API_URL + "jobs", status_code=201, content=self._post_jobs_handler_json) + def test_create_job_pg_json_file(self, dummy_backend, connection, pg_json: str, tmp_path, path_factory): json_file = tmp_path / "input.json" json_file.write_text(pg_json) json_file = path_factory(json_file) - conn = Connection(API_URL) - job = conn.create_job(json_file) - assert job.job_id == "j-123" + job = connection.create_job(json_file) + assert job.job_id == "job-000" + assert dummy_backend.get_batch_pg() == self.PG_DICT_1 @pytest.mark.parametrize("pg_json", [PG_JSON_1, PG_JSON_2]) - def test_download_pg_json_url(self, requests_mock, tmp_path, pg_json: str): - requests_mock.get(API_URL, json={"api_version": "1.0.0"}) - requests_mock.post(API_URL + "result", content=self._post_result_handler_tiff) + def test_download_pg_json_url(self, dummy_backend, connection, requests_mock, tmp_path, pg_json: str): url = "https://jsonbin.test/pg.json" requests_mock.get(url, text=pg_json) - conn = Connection(API_URL) output = tmp_path / "result.tiff" - conn.download(url, outputfile=output) - assert output.read_bytes() == b"TIFF data" + connection.download(url, outputfile=output) + assert output.read_bytes() == b'{"what?": "Result data"}' + assert dummy_backend.get_sync_pg() == self.PG_DICT_1 @pytest.mark.parametrize("pg_json", [PG_JSON_1, PG_JSON_2]) - def test_execute_pg_json_url(self, requests_mock, pg_json: str): - requests_mock.get(API_URL, json={"api_version": "1.0.0"}) - requests_mock.post(API_URL + "result", json=self._post_result_handler_json) + def test_execute_pg_json_url(self, dummy_backend, connection, requests_mock, pg_json: str): + dummy_backend.next_result = {"answer": 8} + url = "https://jsonbin.test/pg.json" requests_mock.get(url, text=pg_json) - conn = Connection(API_URL) - result = conn.execute(url) + result = connection.execute(url) assert result == {"answer": 8} + assert dummy_backend.get_sync_pg() == self.PG_DICT_1 @pytest.mark.parametrize("pg_json", [PG_JSON_1, PG_JSON_2]) - def test_create_job_pg_json_url(self, requests_mock, pg_json: str): - requests_mock.get(API_URL, json={"api_version": "1.0.0"}) - requests_mock.post(API_URL + "jobs", status_code=201, content=self._post_jobs_handler_json) + def test_create_job_pg_json_url(self, dummy_backend, connection, requests_mock, pg_json: str): url = "https://jsonbin.test/pg.json" requests_mock.get(url, text=pg_json) - conn = Connection(API_URL) - job = conn.create_job(url) - assert job.job_id == "j-123" + job = connection.create_job(url) + assert job.job_id == "job-000" + assert dummy_backend.get_batch_pg() == self.PG_DICT_1 + + +class TestExecuteWithValidation: + # Dummy process graphs + PG_DICT_1 = {"add": {"process_id": "add", "arguments": {"x": 3, "y": 5}, "result": True}} + + @pytest.fixture(params=[False, True]) + def auto_validate(self, request) -> bool: + """Fixture to parametrize auto_validate setting.""" + return request.param + + @pytest.fixture + def connection(self, api_version, requests_mock, api_capabilities, auto_validate) -> Connection: + requests_mock.get(API_URL, json=build_capabilities(api_version=api_version, **api_capabilities)) + con = Connection(API_URL, **dict_no_none(auto_validate=auto_validate)) + return con + + # Reusable list of (fixture) parameterization + # of ["api_capabilities", "auto_validate", "validate", "validation_expected"] + _VALIDATION_PARAMETER_SETS = [ + # No validation supported by backend: don't attempt to validate + ({}, None, None, False), + ({}, True, True, False), + # Validation supported by backend, default behavior -> validate + ({"validation": True}, None, None, True), + # (Validation supported by backend) no explicit validation enabled: follow auto_validate setting + ({"validation": True}, True, None, True), + ({"validation": True}, False, None, False), + # (Validation supported by backend) follow explicit `validate` toggle regardless of auto_validate + ({"validation": True}, False, True, True), + ({"validation": True}, True, False, False), + ] + + @pytest.mark.parametrize( + ["api_capabilities", "auto_validate", "validate", "validation_expected"], + _VALIDATION_PARAMETER_SETS, + ) + def test_download_validation( + self, + dummy_backend, + connection, + tmp_path, + caplog, + api_capabilities, + validate, + validation_expected, + ): + caplog.set_level(logging.WARNING) + dummy_backend.next_result = b"TIFF data" + dummy_backend.next_validation_errors = [ + {"code": "OddSupport", "message": "Odd values are not supported."}, + {"code": "ComplexityOverflow", "message": "Too complex."}, + ] + + output = tmp_path / "result.tiff" + connection.download(self.PG_DICT_1, outputfile=output, **dict_no_none(validate=validate)) + assert output.read_bytes() == b"TIFF data" + assert dummy_backend.get_sync_pg() == self.PG_DICT_1 + + if validation_expected: + assert caplog.messages == [ + "Preflight process graph validation raised: [OddSupport] Odd values are not supported. [ComplexityOverflow] Too complex." + ] + assert dummy_backend.validation_requests == [self.PG_DICT_1] + else: + assert caplog.messages == [] + assert dummy_backend.validation_requests == [] + + @pytest.mark.parametrize( + ["api_capabilities", "auto_validate"], + [ + ({"validation": True}, True), + ], + ) + def test_download_validation_broken( + self, dummy_backend, connection, requests_mock, tmp_path, caplog, api_capabilities + ): + """ + Verify the job won't be blocked if errors occur during validation that are + not related to the validity of the graph, e.g. the HTTP request itself fails, etc. + Because we don't want to break the existing workflows. + """ + caplog.set_level(logging.WARNING) + dummy_backend.next_result = b"TIFF data" + + # Simulate server side error during the request. + m = requests_mock.post(API_URL + "validation", json={"code": "Internal", "message": "Nope!"}, status_code=500) + + output = tmp_path / "result.tiff" + connection.download(self.PG_DICT_1, outputfile=output, validate=True) + assert output.read_bytes() == b"TIFF data" + + # We still want to see those warnings in the logs though: + assert caplog.messages == ["Preflight process graph validation failed: [500] Internal: Nope!"] + assert m.call_count == 1 + + @pytest.mark.parametrize( + ["api_capabilities", "auto_validate", "validate", "validation_expected"], + _VALIDATION_PARAMETER_SETS, + ) + def test_execute_validation( + self, dummy_backend, connection, caplog, api_capabilities, validate, validation_expected + ): + caplog.set_level(logging.WARNING) + dummy_backend.next_result = {"answer": 8} + dummy_backend.next_validation_errors = [{"code": "OddSupport", "message": "Odd values are not supported."}] + + result = connection.execute(self.PG_DICT_1, **dict_no_none(validate=validate)) + assert result == {"answer": 8} + if validation_expected: + assert caplog.messages == [ + "Preflight process graph validation raised: [OddSupport] Odd values are not supported." + ] + assert dummy_backend.validation_requests == [self.PG_DICT_1] + else: + assert caplog.messages == [] + assert dummy_backend.validation_requests == [] + + @pytest.mark.parametrize( + ["api_capabilities", "auto_validate", "validate", "validation_expected"], + _VALIDATION_PARAMETER_SETS, + ) + def test_create_job_validation( + self, dummy_backend, connection, caplog, api_capabilities, validate, validation_expected + ): + caplog.set_level(logging.WARNING) + dummy_backend.next_validation_errors = [{"code": "OddSupport", "message": "Odd values are not supported."}] + + job = connection.create_job(self.PG_DICT_1, **dict_no_none(validate=validate)) + assert job.job_id == "job-000" + assert dummy_backend.get_batch_pg() == self.PG_DICT_1 + + if validation_expected: + assert caplog.messages == [ + "Preflight process graph validation raised: [OddSupport] Odd values are not supported." + ] + assert dummy_backend.validation_requests == [self.PG_DICT_1] + else: + assert caplog.messages == [] + assert dummy_backend.validation_requests == []