From 748a52eacc5021aa09ee0f8074a7020681c4f2c7 Mon Sep 17 00:00:00 2001 From: Johan Schreurs Date: Tue, 3 Oct 2023 09:48:34 +0200 Subject: [PATCH] Issu #404 Add automatic process graph validation, when backend supports it --- openeo/rest/connection.py | 41 +++++++++++- tests/rest/conftest.py | 34 +++++++++- tests/rest/test_connection.py | 113 ++++++++++++++++++++++++++++++++++ 3 files changed, 186 insertions(+), 2 deletions(-) diff --git a/openeo/rest/connection.py b/openeo/rest/connection.py index 02738bf7e..43007fb0d 100644 --- a/openeo/rest/connection.py +++ b/openeo/rest/connection.py @@ -72,6 +72,11 @@ DEFAULT_TIMEOUT_SYNCHRONOUS_EXECUTE = 30 * 60 +# TODO: remove temporary constant that is intended for refactoring +# constant for refactoring to switch default validation of process graph on or off. +VALIDATE_PROCESS_GRAPH_BY_DEFAULT = True + + class RestApiConnection: """Base connection class implementing generic REST API request functionality""" @@ -1052,7 +1057,14 @@ def validate_process_graph(self, process_graph: dict) -> List[dict]: :param process_graph: (flat) dict representing process graph :return: list of errors (dictionaries with "code" and "message" fields) """ - request = {"process_graph": process_graph} + # TODO: sometimes process_graph is already in the graph. Should we really *always* add it? + # Was getting errors in some new unit tests because of the double process_graph but + # perhaps the error is really not here but somewhere else that adds process_graph + # when it should not? Still needs to be confirmed. + if "process_graph" not in process_graph: + request = {"process_graph": process_graph} + else: + request = process_graph return self.post(path="/validation", json=request, expected_status=200).json()["errors"] @property @@ -1474,12 +1486,27 @@ def _build_request_with_process_graph(self, process_graph: Union[dict, FlatGraph result["process"] = process_graph return result + def _warn_if_process_graph_invalid(self, process_graph: Union[dict, FlatGraphableMixin, str, Path]): + if not self.capabilities().supports_endpoint("/validation", "POST"): + return + + graph = as_flat_graph(process_graph) + if "process_graph" not in graph: + graph = {"process_graph": graph} + + validation_errors = self.validate_process_graph(process_graph=graph) + if validation_errors: + _log.warning( + "Process graph is not valid. Validation errors:\n" + "\n".join(e["message"] for e in validation_errors) + ) + # TODO: unify `download` and `execute` better: e.g. `download` always writes to disk, `execute` returns result (raw or as JSON decoded dict) def download( self, graph: Union[dict, FlatGraphableMixin, str, Path], outputfile: Union[Path, str, None] = None, timeout: Optional[int] = None, + validate: bool = VALIDATE_PROCESS_GRAPH_BY_DEFAULT, ) -> Union[None, bytes]: """ Downloads the result of a process graph synchronously, @@ -1491,6 +1518,9 @@ def download( :param outputfile: output file :param timeout: timeout to wait for response """ + if validate: + self._warn_if_process_graph_invalid(process_graph=graph) + request = self._build_request_with_process_graph(process_graph=graph) response = self.post( path="/result", @@ -1511,6 +1541,7 @@ def execute( self, process_graph: Union[dict, str, Path], timeout: Optional[int] = None, + validate: bool = VALIDATE_PROCESS_GRAPH_BY_DEFAULT, ): """ Execute a process graph synchronously and return the result (assumed to be JSON). @@ -1519,6 +1550,9 @@ def execute( or as local file path or URL :return: parsed JSON response """ + if validate: + self._warn_if_process_graph_invalid(process_graph=process_graph) + req = self._build_request_with_process_graph(process_graph=process_graph) return self.post( path="/result", @@ -1536,6 +1570,7 @@ def create_job( plan: Optional[str] = None, budget: Optional[float] = None, additional: Optional[dict] = None, + validate: bool = VALIDATE_PROCESS_GRAPH_BY_DEFAULT, ) -> BatchJob: """ Create a new job from given process graph on the back-end. @@ -1550,6 +1585,10 @@ def create_job( :return: Created job """ # TODO move all this (BatchJob factory) logic to BatchJob? + + if validate: + self._warn_if_process_graph_invalid(process_graph=process_graph) + req = self._build_request_with_process_graph( process_graph=process_graph, **dict_no_none(title=title, description=description, plan=plan, budget=budget) diff --git a/tests/rest/conftest.py b/tests/rest/conftest.py index 3914168ff..842ea9b40 100644 --- a/tests/rest/conftest.py +++ b/tests/rest/conftest.py @@ -1,12 +1,14 @@ import contextlib import re import typing +from typing import List, Optional from unittest import mock import pytest import time_machine -from openeo.rest._testing import DummyBackend +import openeo +from openeo.rest._testing import DummyBackend, build_capabilities from openeo.rest.connection import Connection API_URL = "https://oeo.test/" @@ -87,3 +89,33 @@ def con120(requests_mock): @pytest.fixture def dummy_backend(requests_mock, con100) -> DummyBackend: yield DummyBackend(requests_mock=requests_mock, connection=con100) + + +def _setup_connection(api_version, requests_mock, build_capabilities_kwargs: Optional[dict] = None) -> Connection: + # TODO: make this more reusable? + requests_mock.get(API_URL, json=build_capabilities(api_version=api_version, **(build_capabilities_kwargs or {}))) + requests_mock.get( + API_URL + "file_formats", + json={ + "output": { + "GTiff": {"gis_data_types": ["raster"]}, + "netCDF": {"gis_data_types": ["raster"]}, + "csv": {"gis_data_types": ["table"]}, + } + }, + ) + requests_mock.get( + API_URL + "udf_runtimes", + json={ + "Python": {"type": "language", "default": "3", "versions": {"3": {"libraries": {}}}}, + "R": {"type": "language", "default": "4", "versions": {"4": {"libraries": {}}}}, + }, + ) + + return openeo.connect(API_URL) + + +@pytest.fixture +def connection_with_pgvalidation(api_version, requests_mock) -> Connection: + """Connection fixture to a backend of given version with some image collections.""" + return _setup_connection(api_version, requests_mock, build_capabilities_kwargs={"validation": True}) diff --git a/tests/rest/test_connection.py b/tests/rest/test_connection.py index 09a97ac25..589518738 100644 --- a/tests/rest/test_connection.py +++ b/tests/rest/test_connection.py @@ -30,6 +30,7 @@ RestApiConnection, connect, paginate, + VALIDATE_PROCESS_GRAPH_BY_DEFAULT, ) from openeo.rest.vectorcube import VectorCube from openeo.util import ContextTimer @@ -3179,23 +3180,50 @@ class TestExecute: PG_JSON_1 = '{"add": {"process_id": "add", "arguments": {"x": 3, "y": 5}, "result": true}}' PG_JSON_2 = '{"process_graph": {"add": {"process_id": "add", "arguments": {"x": 3, "y": 5}, "result": true}}}' + PG_INVALID_DICT_INNER = { + "loadcollection1": { + "process_id": "load_collection", + "arguments": {"id": "S2", "spatial_extent": None, "temporal_extent": None}, + "result": True, + } + } + PG_INVALID_DICT_OUTER = {"process_graph": PG_INVALID_DICT_INNER} + PG_INVALID_INNER = json.dumps(PG_INVALID_DICT_INNER) + PG_INVALID_OUTER = json.dumps(PG_INVALID_DICT_OUTER) + # Dummy `POST /result` handlers def _post_result_handler_tiff(self, response: requests.Request, context): pg = response.json()["process"]["process_graph"] assert pg == {"add": {"process_id": "add", "arguments": {"x": 3, "y": 5}, "result": True}} return b"TIFF data" + def _post_result_handler_tiff_invalid_pg(self, response: requests.Request, context): + pg = response.json()["process"]["process_graph"] + assert pg == self.PG_INVALID_DICT_INNER + return b"TIFF data" + def _post_result_handler_json(self, response: requests.Request, context): pg = response.json()["process"]["process_graph"] assert pg == {"add": {"process_id": "add", "arguments": {"x": 3, "y": 5}, "result": True}} return {"answer": 8} + def _post_result_handler_json_invalid_pg(self, response: requests.Request, context): + pg = response.json()["process"]["process_graph"] + assert pg == self.PG_INVALID_DICT_INNER + return {"answer": 8} + def _post_jobs_handler_json(self, response: requests.Request, context): pg = response.json()["process"]["process_graph"] assert pg == {"add": {"process_id": "add", "arguments": {"x": 3, "y": 5}, "result": True}} context.headers["OpenEO-Identifier"] = "j-123" return b"" + def _post_jobs_handler_json_invalid_pg(self, response: requests.Request, context): + pg = response.json()["process"]["process_graph"] + assert pg == self.PG_INVALID_DICT_INNER + context.headers["OpenEO-Identifier"] = "j-123" + return b"" + @pytest.mark.parametrize("pg_json", [PG_JSON_1, PG_JSON_2]) def test_download_pg_json(self, requests_mock, tmp_path, pg_json: str): requests_mock.get(API_URL, json={"api_version": "1.0.0"}) @@ -3206,6 +3234,28 @@ def test_download_pg_json(self, requests_mock, tmp_path, pg_json: str): conn.download(pg_json, outputfile=output) assert output.read_bytes() == b"TIFF data" + @pytest.mark.parametrize("pg_json", [PG_INVALID_INNER, PG_INVALID_OUTER]) + def test_download_pg_json_with_invalid_pg( + self, requests_mock, connection_with_pgvalidation, tmp_path, pg_json: str, caplog + ): + caplog.set_level(logging.WARNING) + requests_mock.post(API_URL + "result", content=self._post_result_handler_tiff_invalid_pg) + + validation_errors = [{"code": "Invalid", "message": "Invalid process graph"}] + + def validation(request, context): + assert request.json() == self.PG_INVALID_DICT_OUTER + return {"errors": validation_errors} + + m = requests_mock.post(API_URL + "validation", json=validation) + + output = tmp_path / "result.tiff" + connection_with_pgvalidation.download(pg_json, outputfile=output, validate=True) + + assert output.read_bytes() == b"TIFF data" + assert caplog.messages == ["Process graph is not valid. Validation errors:\nInvalid process graph"] + assert m.call_count == 1 + @pytest.mark.parametrize("pg_json", [PG_JSON_1, PG_JSON_2]) def test_execute_pg_json(self, requests_mock, pg_json: str): requests_mock.get(API_URL, json={"api_version": "1.0.0"}) @@ -3215,6 +3265,24 @@ def test_execute_pg_json(self, requests_mock, pg_json: str): result = conn.execute(pg_json) assert result == {"answer": 8} + @pytest.mark.parametrize("pg_json", [PG_INVALID_INNER, PG_INVALID_OUTER]) + def test_execute_pg_json_with_invalid_pg(self, requests_mock, connection_with_pgvalidation, pg_json: str, caplog): + caplog.set_level(logging.WARNING) + requests_mock.post(API_URL + "result", json=self._post_result_handler_json_invalid_pg) + + validation_errors = [{"code": "Invalid", "message": "Invalid process graph"}] + + def validation(request, context): + assert request.json() == self.PG_INVALID_DICT_OUTER + return {"errors": validation_errors} + + m = requests_mock.post(API_URL + "validation", json=validation) + + result = connection_with_pgvalidation.execute(pg_json, validate=True) + assert result == {"answer": 8} + assert caplog.messages == ["Process graph is not valid. Validation errors:\nInvalid process graph"] + assert m.call_count == 1 + @pytest.mark.parametrize("pg_json", [PG_JSON_1, PG_JSON_2]) def test_create_job_pg_json(self, requests_mock, pg_json: str): requests_mock.get(API_URL, json={"api_version": "1.0.0"}) @@ -3224,6 +3292,26 @@ def test_create_job_pg_json(self, requests_mock, pg_json: str): job = conn.create_job(pg_json) assert job.job_id == "j-123" + @pytest.mark.parametrize("pg_json", [PG_INVALID_INNER, PG_INVALID_OUTER]) + def test_create_job_pg_json_with_invalid_pg( + self, requests_mock, connection_with_pgvalidation, pg_json: str, caplog + ): + caplog.set_level(logging.WARNING) + requests_mock.post(API_URL + "jobs", status_code=201, content=self._post_jobs_handler_json_invalid_pg) + + validation_errors = [{"code": "Invalid", "message": "Invalid process graph"}] + + def validation(request, context): + assert request.json() == self.PG_INVALID_DICT_OUTER + return {"errors": validation_errors} + + m = requests_mock.post(API_URL + "validation", json=validation) + + job = connection_with_pgvalidation.create_job(pg_json, validate=True) + assert job.job_id == "j-123" + assert caplog.messages == ["Process graph is not valid. Validation errors:\nInvalid process graph"] + assert m.call_count == 1 + @pytest.mark.parametrize("pg_json", [PG_JSON_1, PG_JSON_2]) @pytest.mark.parametrize("path_factory", [str, Path]) def test_download_pg_json_file(self, requests_mock, tmp_path, pg_json: str, path_factory): @@ -3238,6 +3326,31 @@ def test_download_pg_json_file(self, requests_mock, tmp_path, pg_json: str, path conn.download(json_file, outputfile=output) assert output.read_bytes() == b"TIFF data" + @pytest.mark.parametrize("pg_json", [PG_INVALID_INNER, PG_INVALID_OUTER]) + @pytest.mark.parametrize("path_factory", [str, Path]) + def test_download_pg_json_file_with_invalid_pg( + self, requests_mock, connection_with_pgvalidation, tmp_path, pg_json: str, path_factory, caplog + ): + caplog.set_level(logging.WARNING) + requests_mock.post(API_URL + "result", content=self._post_result_handler_tiff_invalid_pg) + + validation_errors = [{"code": "Invalid", "message": "Invalid process graph"}] + + def validation(request, context): + assert request.json() == self.PG_INVALID_DICT_OUTER + return {"errors": validation_errors} + + m = requests_mock.post(API_URL + "validation", json=validation) + + json_file = tmp_path / "input.json" + json_file.write_text(pg_json) + json_file = path_factory(json_file) + + output = tmp_path / "result.tiff" + connection_with_pgvalidation.download(json_file, outputfile=output, validate=True) + assert caplog.messages == ["Process graph is not valid. Validation errors:\nInvalid process graph"] + assert m.call_count == 1 + @pytest.mark.parametrize("pg_json", [PG_JSON_1, PG_JSON_2]) @pytest.mark.parametrize("path_factory", [str, Path]) def test_execute_pg_json_file(self, requests_mock, pg_json: str, tmp_path, path_factory):