diff --git a/dags/dags_utils.py b/dags/dags_utils.py index e49b62eea..86f5bac4d 100644 --- a/dags/dags_utils.py +++ b/dags/dags_utils.py @@ -45,6 +45,35 @@ def push_dag_downstream(key, value): return context[TASK_INSTANCE].xcom_push(key=str(key), value=value) +def smart_xcom_pull(key: str): + context = get_current_context() + task_id = context[TASK_INSTANCE].task_id + selected_upstream_task_ids = [selected_task_id + for selected_task_id in context[TASK_INSTANCE].xcom_pull(key=task_id, + task_ids=context[ + 'task'].upstream_task_ids) + if selected_task_id + ] + if selected_upstream_task_ids: + return select_first_non_none(context[TASK_INSTANCE].xcom_pull(key=key, task_ids=selected_upstream_task_ids)) + return None + + +def smart_xcom_push(key: str, value: Any, destination_task_id: str = None): + context = get_current_context() + current_task_id = context[TASK_INSTANCE].task_id + task_ids = [destination_task_id] if destination_task_id else context['task'].downstream_task_ids + for task_id in task_ids: + context[TASK_INSTANCE].xcom_push(key=task_id, value=current_task_id) + context[TASK_INSTANCE].xcom_push(key=key, value=value) + + +def smart_xcom_forward(key: str, destination_task_id: str = None): + value = smart_xcom_pull(key=key) + if value: + smart_xcom_push(key=key, value=value, destination_task_id=destination_task_id) + + def get_dag_param(key: str, raise_error: bool = False, default_value: Any = None): """ diff --git a/dags/load_mapping_suite_in_mongodb.py b/dags/load_mapping_suite_in_mongodb.py index 36280eb55..021ca543e 100644 --- a/dags/load_mapping_suite_in_mongodb.py +++ b/dags/load_mapping_suite_in_mongodb.py @@ -24,8 +24,8 @@ BRANCH_OR_TAG_NAME_DAG_PARAM_KEY = "branch_or_tag_name" GITHUB_REPOSITORY_URL_DAG_PARAM_KEY = "github_repository_url" -TRIGGER_DOCUMENT_PROC_PIPELINE_TASK_ID = "trigger_document_proc_pipeline" FINISH_LOADING_MAPPING_SUITE_TASK_ID = "finish_loading_mapping_suite" +TRIGGER_DOCUMENT_PROC_PIPELINE_TASK_ID = "trigger_document_proc_pipeline" CHECK_IF_LOAD_TEST_DATA_TASK_ID = "check_if_load_test_data" @@ -58,6 +58,7 @@ def fetch_mapping_suite_package_from_github_into_mongodb(**context_args): branch_or_tag_name=branch_or_tag_name, github_repository_url=github_repository_url ) + notice_ids = list(set(notice_ids)) if load_test_data: push_dag_downstream(key=NOTICE_IDS_KEY, value=notice_ids) handle_event_message_metadata_dag_context(event_message, context) diff --git a/dags/notice_fetch_by_date_workflow.py b/dags/notice_fetch_by_date_workflow.py index d96c98545..b78d95aee 100644 --- a/dags/notice_fetch_by_date_workflow.py +++ b/dags/notice_fetch_by_date_workflow.py @@ -35,6 +35,8 @@ def notice_fetch_by_date_workflow(): ) def fetch_by_date_notice_from_ted(): notice_ids = notice_fetcher_by_date_pipeline(date_wild_card=get_dag_param(key=WILD_CARD_DAG_KEY)) + if not notice_ids: + raise Exception("No notices has been fetched!") push_dag_downstream(key=NOTICE_IDS_KEY, value=notice_ids) trigger_complete_workflow = TriggerNoticeBatchPipelineOperator(task_id=TRIGGER_COMPLETE_WORKFLOW_TASK_ID, @@ -47,8 +49,8 @@ def fetch_by_date_notice_from_ted(): def _branch_selector(): trigger_complete_workflow = get_dag_param(key=TRIGGER_COMPLETE_WORKFLOW_DAG_KEY, default_value=False) + push_dag_downstream(key=NOTICE_IDS_KEY, value=pull_dag_upstream(key=NOTICE_IDS_KEY)) if trigger_complete_workflow: - push_dag_downstream(key=NOTICE_IDS_KEY, value=pull_dag_upstream(key=NOTICE_IDS_KEY)) return [TRIGGER_COMPLETE_WORKFLOW_TASK_ID] return [TRIGGER_PARTIAL_WORKFLOW_TASK_ID] diff --git a/dags/notice_process_workflow.py b/dags/notice_process_workflow.py index dfffd91c4..cc28e312a 100644 --- a/dags/notice_process_workflow.py +++ b/dags/notice_process_workflow.py @@ -1,10 +1,12 @@ +from typing import List + from airflow.operators.dummy import DummyOperator -from airflow.operators.python import BranchPythonOperator +from airflow.operators.python import BranchPythonOperator, PythonOperator from airflow.decorators import dag from airflow.utils.trigger_rule import TriggerRule from dags import DEFAULT_DAG_ARGUMENTS -from dags.dags_utils import push_dag_downstream, get_dag_param +from dags.dags_utils import get_dag_param, smart_xcom_push, smart_xcom_forward, smart_xcom_pull from dags.operators.DagBatchPipelineOperator import NoticeBatchPipelineOperator, NOTICE_IDS_KEY, \ EXECUTE_ONLY_ONE_STEP_KEY, START_WITH_STEP_NAME_KEY from dags.pipelines.notice_processor_pipelines import notice_normalisation_pipeline, notice_transformation_pipeline, \ @@ -15,14 +17,31 @@ NOTICE_VALIDATION_PIPELINE_TASK_ID = "notice_validation_pipeline" NOTICE_PACKAGE_PIPELINE_TASK_ID = "notice_package_pipeline" NOTICE_PUBLISH_PIPELINE_TASK_ID = "notice_publish_pipeline" -BRANCH_SELECTOR_TASK_ID = 'branch_selector' STOP_PROCESSING_TASK_ID = "stop_processing" +BRANCH_SELECTOR_TASK_ID = 'branch_selector' SELECTOR_BRANCH_BEFORE_TRANSFORMATION_TASK_ID = "switch_to_transformation" SELECTOR_BRANCH_BEFORE_VALIDATION_TASK_ID = "switch_to_validation" SELECTOR_BRANCH_BEFORE_PACKAGE_TASK_ID = "switch_to_package" SELECTOR_BRANCH_BEFORE_PUBLISH_TASK_ID = "switch_to_publish" DAG_NAME = "notice_process_workflow" +BRANCH_SELECTOR_MAP = {NOTICE_NORMALISATION_PIPELINE_TASK_ID: NOTICE_NORMALISATION_PIPELINE_TASK_ID, + NOTICE_TRANSFORMATION_PIPELINE_TASK_ID: SELECTOR_BRANCH_BEFORE_TRANSFORMATION_TASK_ID, + NOTICE_VALIDATION_PIPELINE_TASK_ID: SELECTOR_BRANCH_BEFORE_VALIDATION_TASK_ID, + NOTICE_PACKAGE_PIPELINE_TASK_ID: SELECTOR_BRANCH_BEFORE_PACKAGE_TASK_ID, + NOTICE_PUBLISH_PIPELINE_TASK_ID: SELECTOR_BRANCH_BEFORE_PUBLISH_TASK_ID + } + + +def branch_selector(result_branch: str, xcom_forward_keys: List[str] = [NOTICE_IDS_KEY]) -> str: + start_with_step_name = get_dag_param(key=START_WITH_STEP_NAME_KEY, + default_value=NOTICE_NORMALISATION_PIPELINE_TASK_ID) + if start_with_step_name != result_branch: + result_branch = STOP_PROCESSING_TASK_ID if get_dag_param(key=EXECUTE_ONLY_ONE_STEP_KEY) else result_branch + for xcom_forward_key in xcom_forward_keys: + smart_xcom_forward(key=xcom_forward_key, destination_task_id=result_branch) + return result_branch + @dag(default_args=DEFAULT_DAG_ARGUMENTS, schedule_interval=None, @@ -38,24 +57,26 @@ def _start_processing(): notice_ids = get_dag_param(key=NOTICE_IDS_KEY, raise_error=True) start_with_step_name = get_dag_param(key=START_WITH_STEP_NAME_KEY, default_value=NOTICE_NORMALISATION_PIPELINE_TASK_ID) - push_dag_downstream(key=NOTICE_IDS_KEY, value=notice_ids) - return start_with_step_name + task_id = BRANCH_SELECTOR_MAP[start_with_step_name] + smart_xcom_push(key=NOTICE_IDS_KEY, value=notice_ids, destination_task_id=task_id) + return task_id def _selector_branch_before_transformation(): - return STOP_PROCESSING_TASK_ID if get_dag_param( - key=EXECUTE_ONLY_ONE_STEP_KEY) else NOTICE_TRANSFORMATION_PIPELINE_TASK_ID + return branch_selector(NOTICE_TRANSFORMATION_PIPELINE_TASK_ID) def _selector_branch_before_validation(): - return STOP_PROCESSING_TASK_ID if get_dag_param( - key=EXECUTE_ONLY_ONE_STEP_KEY) else NOTICE_VALIDATION_PIPELINE_TASK_ID + return branch_selector(NOTICE_VALIDATION_PIPELINE_TASK_ID) def _selector_branch_before_package(): - return STOP_PROCESSING_TASK_ID if get_dag_param( - key=EXECUTE_ONLY_ONE_STEP_KEY) else NOTICE_PACKAGE_PIPELINE_TASK_ID + return branch_selector(NOTICE_PACKAGE_PIPELINE_TASK_ID) def _selector_branch_before_publish(): - return STOP_PROCESSING_TASK_ID if get_dag_param( - key=EXECUTE_ONLY_ONE_STEP_KEY) else NOTICE_PUBLISH_PIPELINE_TASK_ID + return branch_selector(NOTICE_PUBLISH_PIPELINE_TASK_ID) + + def _stop_processing(): + notice_ids = smart_xcom_pull(key=NOTICE_IDS_KEY) + if not notice_ids: + raise Exception(f"No notice has been processed!") start_processing = BranchPythonOperator( task_id=BRANCH_SELECTOR_TASK_ID, @@ -82,9 +103,10 @@ def _selector_branch_before_publish(): python_callable=_selector_branch_before_publish, ) - stop_processing = DummyOperator( + stop_processing = PythonOperator( task_id=STOP_PROCESSING_TASK_ID, - trigger_rule=TriggerRule.NONE_FAILED_MIN_ONE_SUCCESS + trigger_rule=TriggerRule.NONE_FAILED_MIN_ONE_SUCCESS, + python_callable=_stop_processing ) notice_normalisation_step = NoticeBatchPipelineOperator(python_callable=notice_normalisation_pipeline, @@ -111,8 +133,9 @@ def _selector_branch_before_publish(): trigger_rule=TriggerRule.NONE_FAILED_MIN_ONE_SUCCESS ) - start_processing >> [notice_normalisation_step, notice_transformation_step, notice_validation_step, - notice_package_step, notice_publish_step] + start_processing >> [notice_normalisation_step, selector_branch_before_transformation, + selector_branch_before_validation, + selector_branch_before_package, selector_branch_before_publish] [selector_branch_before_transformation, selector_branch_before_validation, selector_branch_before_package, selector_branch_before_publish, notice_publish_step] >> stop_processing notice_normalisation_step >> selector_branch_before_transformation >> notice_transformation_step diff --git a/dags/operators/DagBatchPipelineOperator.py b/dags/operators/DagBatchPipelineOperator.py index c58c78e42..a0fb5fbe2 100644 --- a/dags/operators/DagBatchPipelineOperator.py +++ b/dags/operators/DagBatchPipelineOperator.py @@ -4,7 +4,8 @@ from airflow.operators.trigger_dagrun import TriggerDagRunOperator from pymongo import MongoClient -from dags.dags_utils import pull_dag_upstream, push_dag_downstream, chunks, get_dag_param +from dags.dags_utils import pull_dag_upstream, push_dag_downstream, chunks, get_dag_param, smart_xcom_pull, \ + smart_xcom_push from dags.pipelines.pipeline_protocols import NoticePipelineCallable from ted_sws import config from ted_sws.data_manager.adapters.notice_repository import NoticeRepository @@ -40,10 +41,11 @@ def execute(self, context: Any): This method executes the python_callable for each notice_id in the notice_ids batch. """ logger = get_logger() - notice_ids = pull_dag_upstream(key=NOTICE_IDS_KEY) + notice_ids = smart_xcom_pull(key=NOTICE_IDS_KEY) if not notice_ids: raise Exception(f"XCOM key [{NOTICE_IDS_KEY}] is not present in context!") - notice_repository = NoticeRepository(mongodb_client=MongoClient(config.MONGO_DB_AUTH_URL)) + mongodb_client = MongoClient(config.MONGO_DB_AUTH_URL) + notice_repository = NoticeRepository(mongodb_client=mongodb_client) processed_notice_ids = [] pipeline_name = self.python_callable.__name__ number_of_notices = len(notice_ids) @@ -58,7 +60,7 @@ def execute(self, context: Any): notice_event = NoticeEventMessage(notice_id=notice_id, domain_action=pipeline_name) notice_event.start_record() notice = notice_repository.get(reference=notice_id) - result_notice_pipeline = self.python_callable(notice) + result_notice_pipeline = self.python_callable(notice, mongodb_client) if result_notice_pipeline.store_result: notice_repository.update(notice=result_notice_pipeline.notice) if result_notice_pipeline.processed: @@ -67,10 +69,12 @@ def execute(self, context: Any): logger.info(event_message=notice_event) except Exception as e: log_error(message=str(e)) + batch_event_message.end_record() logger.info(event_message=batch_event_message) - - push_dag_downstream(key=NOTICE_IDS_KEY, value=processed_notice_ids) + if not processed_notice_ids: + raise Exception(f"No notice has been processed!") + smart_xcom_push(key=NOTICE_IDS_KEY, value=processed_notice_ids) class TriggerNoticeBatchPipelineOperator(BaseOperator): diff --git a/dags/pipelines/notice_processor_pipelines.py b/dags/pipelines/notice_processor_pipelines.py index e4a5ebcc0..9e668d839 100644 --- a/dags/pipelines/notice_processor_pipelines.py +++ b/dags/pipelines/notice_processor_pipelines.py @@ -16,7 +16,7 @@ from ted_sws.notice_validator.services.xpath_coverage_runner import validate_xpath_coverage_notice -def notice_normalisation_pipeline(notice: Notice) -> NoticePipelineOutput: +def notice_normalisation_pipeline(notice: Notice, mongodb_client: MongoClient) -> NoticePipelineOutput: """ """ @@ -26,11 +26,10 @@ def notice_normalisation_pipeline(notice: Notice) -> NoticePipelineOutput: return NoticePipelineOutput(notice=normalised_notice) -def notice_transformation_pipeline(notice: Notice) -> NoticePipelineOutput: +def notice_transformation_pipeline(notice: Notice, mongodb_client: MongoClient) -> NoticePipelineOutput: """ """ - mongodb_client = MongoClient(config.MONGO_DB_AUTH_URL) mapping_suite_repository = MappingSuiteRepositoryMongoDB(mongodb_client=mongodb_client) result = notice_eligibility_checker(notice=notice, mapping_suite_repository=mapping_suite_repository) if not result: @@ -47,12 +46,11 @@ def notice_transformation_pipeline(notice: Notice) -> NoticePipelineOutput: return NoticePipelineOutput(notice=transformed_notice) -def notice_validation_pipeline(notice: Notice) -> NoticePipelineOutput: +def notice_validation_pipeline(notice: Notice, mongodb_client: MongoClient) -> NoticePipelineOutput: """ """ mapping_suite_id = notice.distilled_rdf_manifestation.mapping_suite_id - mongodb_client = MongoClient(config.MONGO_DB_AUTH_URL) mapping_suite_repository = MappingSuiteRepositoryMongoDB(mongodb_client=mongodb_client) mapping_suite = mapping_suite_repository.get(reference=mapping_suite_id) validate_xpath_coverage_notice(notice=notice, mapping_suite=mapping_suite, mongodb_client=mongodb_client) @@ -61,7 +59,7 @@ def notice_validation_pipeline(notice: Notice) -> NoticePipelineOutput: return NoticePipelineOutput(notice=notice) -def notice_package_pipeline(notice: Notice) -> NoticePipelineOutput: +def notice_package_pipeline(notice: Notice, mongodb_client: MongoClient) -> NoticePipelineOutput: """ """ @@ -71,10 +69,11 @@ def notice_package_pipeline(notice: Notice) -> NoticePipelineOutput: return NoticePipelineOutput(notice=packaged_notice) -def notice_publish_pipeline(notice: Notice) -> NoticePipelineOutput: +def notice_publish_pipeline(notice: Notice, mongodb_client: MongoClient) -> NoticePipelineOutput: """ """ + notice.set_is_eligible_for_publishing(eligibility=True) result = publish_notice(notice=notice) if result: return NoticePipelineOutput(notice=notice) diff --git a/dags/pipelines/pipeline_protocols.py b/dags/pipelines/pipeline_protocols.py index c24391ea4..bc16642dc 100644 --- a/dags/pipelines/pipeline_protocols.py +++ b/dags/pipelines/pipeline_protocols.py @@ -1,5 +1,7 @@ from typing import Protocol +from pymongo import MongoClient + from ted_sws.core.model.notice import Notice @@ -13,7 +15,7 @@ def __init__(self, notice: Notice, processed: bool = True, store_result: bool = class NoticePipelineCallable(Protocol): - def __call__(self, notice: Notice) -> NoticePipelineOutput: + def __call__(self, notice: Notice, mongodb_client: MongoClient) -> NoticePipelineOutput: """ """ diff --git a/ted_sws/notice_transformer/services/notice_transformer.py b/ted_sws/notice_transformer/services/notice_transformer.py index 657c9fbda..0afb50606 100644 --- a/ted_sws/notice_transformer/services/notice_transformer.py +++ b/ted_sws/notice_transformer/services/notice_transformer.py @@ -36,7 +36,7 @@ def transform_notice(notice: Notice, mapping_suite: MappingSuite, rml_mapper: RM file.write(notice.xml_manifestation.object_data) rdf_result = rml_mapper.execute(package_path=package_path) notice.set_rdf_manifestation( - rdf_manifestation=RDFManifestation(mapping_suite_id=mapping_suite.identifier, + rdf_manifestation=RDFManifestation(mapping_suite_id=mapping_suite.get_mongodb_id(), object_data=rdf_result)) return notice diff --git a/tests/e2e/conftest.py b/tests/e2e/conftest.py index 931ece3e8..921ccde5a 100644 --- a/tests/e2e/conftest.py +++ b/tests/e2e/conftest.py @@ -1,11 +1,15 @@ +import mongomock +import pymongo import pytest from pymongo import MongoClient from ted_sws import config from ted_sws.data_manager.adapters.triple_store import AllegroGraphTripleStore, FusekiAdapter + from tests import TEST_DATA_PATH + @pytest.fixture def mongodb_client(): uri = config.MONGO_DB_AUTH_URL @@ -38,9 +42,20 @@ def fake_mapping_suite_id() -> str: @pytest.fixture def fuseki_triple_store(): - return FusekiAdapter(host=config.FUSEKI_ADMIN_HOST, user=config.FUSEKI_ADMIN_USER, password=config.FUSEKI_ADMIN_PASSWORD) + return FusekiAdapter(host=config.FUSEKI_ADMIN_HOST, user=config.FUSEKI_ADMIN_USER, + password=config.FUSEKI_ADMIN_PASSWORD) @pytest.fixture def cellar_sparql_endpoint(): return "https://publications.europa.eu/webapi/rdf/sparql" + + +@pytest.fixture(scope="function") +@mongomock.patch(servers=(('server.example.com', 27017),)) +def fake_mongodb_client(): + mongo_client = pymongo.MongoClient('server.example.com') + for database_name in mongo_client.list_database_names(): + mongo_client.drop_database(database_name) + return mongo_client + diff --git a/tests/e2e/dags/pipelines/__init__.py b/tests/e2e/dags/pipelines/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/e2e/dags/pipelines/test_notice_processor_pipelines.py b/tests/e2e/dags/pipelines/test_notice_processor_pipelines.py new file mode 100644 index 000000000..7738bde50 --- /dev/null +++ b/tests/e2e/dags/pipelines/test_notice_processor_pipelines.py @@ -0,0 +1,33 @@ +from dags.pipelines.notice_processor_pipelines import notice_normalisation_pipeline, notice_transformation_pipeline, \ + notice_validation_pipeline, notice_package_pipeline, notice_publish_pipeline +from ted_sws.core.model.notice import NoticeStatus +from ted_sws.data_manager.adapters.notice_repository import NoticeRepository +from ted_sws.mapping_suite_processor.services.conceptual_mapping_processor import \ + mapping_suite_processor_from_github_expand_and_load_package_in_mongo_db + +MAPPING_SUITE_PACKAGE_NAME = "package_F03_test" +MAPPING_SUITE_PACKAGE_ID = f"{MAPPING_SUITE_PACKAGE_NAME}_v2.3.0" +NOTICE_ID = "057215-2021" + + +def test_notice_processor_pipelines(fake_mongodb_client): + mapping_suite_processor_from_github_expand_and_load_package_in_mongo_db( + mapping_suite_package_name=MAPPING_SUITE_PACKAGE_NAME, + mongodb_client=fake_mongodb_client, + load_test_data=True + ) + notice_id = NOTICE_ID + notice_repository = NoticeRepository(mongodb_client=fake_mongodb_client) + notice = notice_repository.get(reference=notice_id) + pipelines = [notice_normalisation_pipeline, notice_transformation_pipeline, notice_validation_pipeline, + notice_package_pipeline, notice_publish_pipeline] + notice_states = [NoticeStatus.RAW, NoticeStatus.NORMALISED_METADATA, NoticeStatus.DISTILLED, + NoticeStatus.VALIDATED, NoticeStatus.PACKAGED, NoticeStatus.PUBLISHED] + for index, pipeline in enumerate(pipelines): + assert notice.status == notice_states[index] + pipeline_output = pipeline(notice=notice, mongodb_client=fake_mongodb_client) + assert pipeline_output.processed, f"{pipeline.__name__} not processed!" + assert pipeline_output.store_result + assert pipeline_output.notice + notice = pipeline_output.notice + assert notice.status == notice_states[index + 1] diff --git a/tests/e2e/mapping_suite_processor/conftest.py b/tests/e2e/mapping_suite_processor/conftest.py index 75bf0eadb..c20c9243c 100644 --- a/tests/e2e/mapping_suite_processor/conftest.py +++ b/tests/e2e/mapping_suite_processor/conftest.py @@ -6,12 +6,6 @@ from tests import TEST_DATA_PATH -@pytest.fixture -@mongomock.patch(servers=(('server.example.com', 27017),)) -def fake_mongodb_client(): - return pymongo.MongoClient('server.example.com') - - @pytest.fixture def file_system_repository_path(): return TEST_DATA_PATH / "notice_transformer" / "mapping_suite_processor_repository" @@ -101,3 +95,4 @@ def mime_type(): @pytest.fixture def github_mapping_suite_id(): return "package_F03" + diff --git a/tests/features/notice_fetcher/conftest.py b/tests/features/notice_fetcher/conftest.py index b7487db29..19088982c 100644 --- a/tests/features/notice_fetcher/conftest.py +++ b/tests/features/notice_fetcher/conftest.py @@ -2,7 +2,6 @@ import pytest -from ted_sws.notice_fetcher.adapters.ted_api import TedAPIAdapter, TedRequestAPI from ted_sws.notice_fetcher.services.notice_fetcher import NoticeFetcher from tests.fakes.fake_ted_api import FakeTedApiAdapter