Skip to content

Commit

Permalink
Conceptual Mappings Differ
Browse files Browse the repository at this point in the history
  • Loading branch information
Kolea Plesco committed Oct 5, 2022
2 parents 456f569 + 41c53c5 commit fe8b7d6
Show file tree
Hide file tree
Showing 13 changed files with 144 additions and 42 deletions.
29 changes: 29 additions & 0 deletions dags/dags_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,35 @@ def push_dag_downstream(key, value):
return context[TASK_INSTANCE].xcom_push(key=str(key), value=value)


def smart_xcom_pull(key: str):
context = get_current_context()
task_id = context[TASK_INSTANCE].task_id
selected_upstream_task_ids = [selected_task_id
for selected_task_id in context[TASK_INSTANCE].xcom_pull(key=task_id,
task_ids=context[
'task'].upstream_task_ids)
if selected_task_id
]
if selected_upstream_task_ids:
return select_first_non_none(context[TASK_INSTANCE].xcom_pull(key=key, task_ids=selected_upstream_task_ids))
return None


def smart_xcom_push(key: str, value: Any, destination_task_id: str = None):
context = get_current_context()
current_task_id = context[TASK_INSTANCE].task_id
task_ids = [destination_task_id] if destination_task_id else context['task'].downstream_task_ids
for task_id in task_ids:
context[TASK_INSTANCE].xcom_push(key=task_id, value=current_task_id)
context[TASK_INSTANCE].xcom_push(key=key, value=value)


def smart_xcom_forward(key: str, destination_task_id: str = None):
value = smart_xcom_pull(key=key)
if value:
smart_xcom_push(key=key, value=value, destination_task_id=destination_task_id)


def get_dag_param(key: str, raise_error: bool = False, default_value: Any = None):
"""
Expand Down
3 changes: 2 additions & 1 deletion dags/load_mapping_suite_in_mongodb.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@
BRANCH_OR_TAG_NAME_DAG_PARAM_KEY = "branch_or_tag_name"
GITHUB_REPOSITORY_URL_DAG_PARAM_KEY = "github_repository_url"

TRIGGER_DOCUMENT_PROC_PIPELINE_TASK_ID = "trigger_document_proc_pipeline"
FINISH_LOADING_MAPPING_SUITE_TASK_ID = "finish_loading_mapping_suite"
TRIGGER_DOCUMENT_PROC_PIPELINE_TASK_ID = "trigger_document_proc_pipeline"
CHECK_IF_LOAD_TEST_DATA_TASK_ID = "check_if_load_test_data"


Expand Down Expand Up @@ -58,6 +58,7 @@ def fetch_mapping_suite_package_from_github_into_mongodb(**context_args):
branch_or_tag_name=branch_or_tag_name,
github_repository_url=github_repository_url
)
notice_ids = list(set(notice_ids))
if load_test_data:
push_dag_downstream(key=NOTICE_IDS_KEY, value=notice_ids)
handle_event_message_metadata_dag_context(event_message, context)
Expand Down
4 changes: 3 additions & 1 deletion dags/notice_fetch_by_date_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ def notice_fetch_by_date_workflow():
)
def fetch_by_date_notice_from_ted():
notice_ids = notice_fetcher_by_date_pipeline(date_wild_card=get_dag_param(key=WILD_CARD_DAG_KEY))
if not notice_ids:
raise Exception("No notices has been fetched!")
push_dag_downstream(key=NOTICE_IDS_KEY, value=notice_ids)

trigger_complete_workflow = TriggerNoticeBatchPipelineOperator(task_id=TRIGGER_COMPLETE_WORKFLOW_TASK_ID,
Expand All @@ -47,8 +49,8 @@ def fetch_by_date_notice_from_ted():

def _branch_selector():
trigger_complete_workflow = get_dag_param(key=TRIGGER_COMPLETE_WORKFLOW_DAG_KEY, default_value=False)
push_dag_downstream(key=NOTICE_IDS_KEY, value=pull_dag_upstream(key=NOTICE_IDS_KEY))
if trigger_complete_workflow:
push_dag_downstream(key=NOTICE_IDS_KEY, value=pull_dag_upstream(key=NOTICE_IDS_KEY))
return [TRIGGER_COMPLETE_WORKFLOW_TASK_ID]
return [TRIGGER_PARTIAL_WORKFLOW_TASK_ID]

Expand Down
57 changes: 40 additions & 17 deletions dags/notice_process_workflow.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
from typing import List

from airflow.operators.dummy import DummyOperator
from airflow.operators.python import BranchPythonOperator
from airflow.operators.python import BranchPythonOperator, PythonOperator
from airflow.decorators import dag
from airflow.utils.trigger_rule import TriggerRule

from dags import DEFAULT_DAG_ARGUMENTS
from dags.dags_utils import push_dag_downstream, get_dag_param
from dags.dags_utils import get_dag_param, smart_xcom_push, smart_xcom_forward, smart_xcom_pull
from dags.operators.DagBatchPipelineOperator import NoticeBatchPipelineOperator, NOTICE_IDS_KEY, \
EXECUTE_ONLY_ONE_STEP_KEY, START_WITH_STEP_NAME_KEY
from dags.pipelines.notice_processor_pipelines import notice_normalisation_pipeline, notice_transformation_pipeline, \
Expand All @@ -15,14 +17,31 @@
NOTICE_VALIDATION_PIPELINE_TASK_ID = "notice_validation_pipeline"
NOTICE_PACKAGE_PIPELINE_TASK_ID = "notice_package_pipeline"
NOTICE_PUBLISH_PIPELINE_TASK_ID = "notice_publish_pipeline"
BRANCH_SELECTOR_TASK_ID = 'branch_selector'
STOP_PROCESSING_TASK_ID = "stop_processing"
BRANCH_SELECTOR_TASK_ID = 'branch_selector'
SELECTOR_BRANCH_BEFORE_TRANSFORMATION_TASK_ID = "switch_to_transformation"
SELECTOR_BRANCH_BEFORE_VALIDATION_TASK_ID = "switch_to_validation"
SELECTOR_BRANCH_BEFORE_PACKAGE_TASK_ID = "switch_to_package"
SELECTOR_BRANCH_BEFORE_PUBLISH_TASK_ID = "switch_to_publish"
DAG_NAME = "notice_process_workflow"

BRANCH_SELECTOR_MAP = {NOTICE_NORMALISATION_PIPELINE_TASK_ID: NOTICE_NORMALISATION_PIPELINE_TASK_ID,
NOTICE_TRANSFORMATION_PIPELINE_TASK_ID: SELECTOR_BRANCH_BEFORE_TRANSFORMATION_TASK_ID,
NOTICE_VALIDATION_PIPELINE_TASK_ID: SELECTOR_BRANCH_BEFORE_VALIDATION_TASK_ID,
NOTICE_PACKAGE_PIPELINE_TASK_ID: SELECTOR_BRANCH_BEFORE_PACKAGE_TASK_ID,
NOTICE_PUBLISH_PIPELINE_TASK_ID: SELECTOR_BRANCH_BEFORE_PUBLISH_TASK_ID
}


def branch_selector(result_branch: str, xcom_forward_keys: List[str] = [NOTICE_IDS_KEY]) -> str:
start_with_step_name = get_dag_param(key=START_WITH_STEP_NAME_KEY,
default_value=NOTICE_NORMALISATION_PIPELINE_TASK_ID)
if start_with_step_name != result_branch:
result_branch = STOP_PROCESSING_TASK_ID if get_dag_param(key=EXECUTE_ONLY_ONE_STEP_KEY) else result_branch
for xcom_forward_key in xcom_forward_keys:
smart_xcom_forward(key=xcom_forward_key, destination_task_id=result_branch)
return result_branch


@dag(default_args=DEFAULT_DAG_ARGUMENTS,
schedule_interval=None,
Expand All @@ -38,24 +57,26 @@ def _start_processing():
notice_ids = get_dag_param(key=NOTICE_IDS_KEY, raise_error=True)
start_with_step_name = get_dag_param(key=START_WITH_STEP_NAME_KEY,
default_value=NOTICE_NORMALISATION_PIPELINE_TASK_ID)
push_dag_downstream(key=NOTICE_IDS_KEY, value=notice_ids)
return start_with_step_name
task_id = BRANCH_SELECTOR_MAP[start_with_step_name]
smart_xcom_push(key=NOTICE_IDS_KEY, value=notice_ids, destination_task_id=task_id)
return task_id

def _selector_branch_before_transformation():
return STOP_PROCESSING_TASK_ID if get_dag_param(
key=EXECUTE_ONLY_ONE_STEP_KEY) else NOTICE_TRANSFORMATION_PIPELINE_TASK_ID
return branch_selector(NOTICE_TRANSFORMATION_PIPELINE_TASK_ID)

def _selector_branch_before_validation():
return STOP_PROCESSING_TASK_ID if get_dag_param(
key=EXECUTE_ONLY_ONE_STEP_KEY) else NOTICE_VALIDATION_PIPELINE_TASK_ID
return branch_selector(NOTICE_VALIDATION_PIPELINE_TASK_ID)

def _selector_branch_before_package():
return STOP_PROCESSING_TASK_ID if get_dag_param(
key=EXECUTE_ONLY_ONE_STEP_KEY) else NOTICE_PACKAGE_PIPELINE_TASK_ID
return branch_selector(NOTICE_PACKAGE_PIPELINE_TASK_ID)

def _selector_branch_before_publish():
return STOP_PROCESSING_TASK_ID if get_dag_param(
key=EXECUTE_ONLY_ONE_STEP_KEY) else NOTICE_PUBLISH_PIPELINE_TASK_ID
return branch_selector(NOTICE_PUBLISH_PIPELINE_TASK_ID)

def _stop_processing():
notice_ids = smart_xcom_pull(key=NOTICE_IDS_KEY)
if not notice_ids:
raise Exception(f"No notice has been processed!")

start_processing = BranchPythonOperator(
task_id=BRANCH_SELECTOR_TASK_ID,
Expand All @@ -82,9 +103,10 @@ def _selector_branch_before_publish():
python_callable=_selector_branch_before_publish,
)

stop_processing = DummyOperator(
stop_processing = PythonOperator(
task_id=STOP_PROCESSING_TASK_ID,
trigger_rule=TriggerRule.NONE_FAILED_MIN_ONE_SUCCESS
trigger_rule=TriggerRule.NONE_FAILED_MIN_ONE_SUCCESS,
python_callable=_stop_processing
)

notice_normalisation_step = NoticeBatchPipelineOperator(python_callable=notice_normalisation_pipeline,
Expand All @@ -111,8 +133,9 @@ def _selector_branch_before_publish():
trigger_rule=TriggerRule.NONE_FAILED_MIN_ONE_SUCCESS
)

start_processing >> [notice_normalisation_step, notice_transformation_step, notice_validation_step,
notice_package_step, notice_publish_step]
start_processing >> [notice_normalisation_step, selector_branch_before_transformation,
selector_branch_before_validation,
selector_branch_before_package, selector_branch_before_publish]
[selector_branch_before_transformation, selector_branch_before_validation,
selector_branch_before_package, selector_branch_before_publish, notice_publish_step] >> stop_processing
notice_normalisation_step >> selector_branch_before_transformation >> notice_transformation_step
Expand Down
16 changes: 10 additions & 6 deletions dags/operators/DagBatchPipelineOperator.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
from airflow.operators.trigger_dagrun import TriggerDagRunOperator
from pymongo import MongoClient

from dags.dags_utils import pull_dag_upstream, push_dag_downstream, chunks, get_dag_param
from dags.dags_utils import pull_dag_upstream, push_dag_downstream, chunks, get_dag_param, smart_xcom_pull, \
smart_xcom_push
from dags.pipelines.pipeline_protocols import NoticePipelineCallable
from ted_sws import config
from ted_sws.data_manager.adapters.notice_repository import NoticeRepository
Expand Down Expand Up @@ -40,10 +41,11 @@ def execute(self, context: Any):
This method executes the python_callable for each notice_id in the notice_ids batch.
"""
logger = get_logger()
notice_ids = pull_dag_upstream(key=NOTICE_IDS_KEY)
notice_ids = smart_xcom_pull(key=NOTICE_IDS_KEY)
if not notice_ids:
raise Exception(f"XCOM key [{NOTICE_IDS_KEY}] is not present in context!")
notice_repository = NoticeRepository(mongodb_client=MongoClient(config.MONGO_DB_AUTH_URL))
mongodb_client = MongoClient(config.MONGO_DB_AUTH_URL)
notice_repository = NoticeRepository(mongodb_client=mongodb_client)
processed_notice_ids = []
pipeline_name = self.python_callable.__name__
number_of_notices = len(notice_ids)
Expand All @@ -58,7 +60,7 @@ def execute(self, context: Any):
notice_event = NoticeEventMessage(notice_id=notice_id, domain_action=pipeline_name)
notice_event.start_record()
notice = notice_repository.get(reference=notice_id)
result_notice_pipeline = self.python_callable(notice)
result_notice_pipeline = self.python_callable(notice, mongodb_client)
if result_notice_pipeline.store_result:
notice_repository.update(notice=result_notice_pipeline.notice)
if result_notice_pipeline.processed:
Expand All @@ -67,10 +69,12 @@ def execute(self, context: Any):
logger.info(event_message=notice_event)
except Exception as e:
log_error(message=str(e))

batch_event_message.end_record()
logger.info(event_message=batch_event_message)

push_dag_downstream(key=NOTICE_IDS_KEY, value=processed_notice_ids)
if not processed_notice_ids:
raise Exception(f"No notice has been processed!")
smart_xcom_push(key=NOTICE_IDS_KEY, value=processed_notice_ids)


class TriggerNoticeBatchPipelineOperator(BaseOperator):
Expand Down
13 changes: 6 additions & 7 deletions dags/pipelines/notice_processor_pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from ted_sws.notice_validator.services.xpath_coverage_runner import validate_xpath_coverage_notice


def notice_normalisation_pipeline(notice: Notice) -> NoticePipelineOutput:
def notice_normalisation_pipeline(notice: Notice, mongodb_client: MongoClient) -> NoticePipelineOutput:
"""
"""
Expand All @@ -26,11 +26,10 @@ def notice_normalisation_pipeline(notice: Notice) -> NoticePipelineOutput:
return NoticePipelineOutput(notice=normalised_notice)


def notice_transformation_pipeline(notice: Notice) -> NoticePipelineOutput:
def notice_transformation_pipeline(notice: Notice, mongodb_client: MongoClient) -> NoticePipelineOutput:
"""
"""
mongodb_client = MongoClient(config.MONGO_DB_AUTH_URL)
mapping_suite_repository = MappingSuiteRepositoryMongoDB(mongodb_client=mongodb_client)
result = notice_eligibility_checker(notice=notice, mapping_suite_repository=mapping_suite_repository)
if not result:
Expand All @@ -47,12 +46,11 @@ def notice_transformation_pipeline(notice: Notice) -> NoticePipelineOutput:
return NoticePipelineOutput(notice=transformed_notice)


def notice_validation_pipeline(notice: Notice) -> NoticePipelineOutput:
def notice_validation_pipeline(notice: Notice, mongodb_client: MongoClient) -> NoticePipelineOutput:
"""
"""
mapping_suite_id = notice.distilled_rdf_manifestation.mapping_suite_id
mongodb_client = MongoClient(config.MONGO_DB_AUTH_URL)
mapping_suite_repository = MappingSuiteRepositoryMongoDB(mongodb_client=mongodb_client)
mapping_suite = mapping_suite_repository.get(reference=mapping_suite_id)
validate_xpath_coverage_notice(notice=notice, mapping_suite=mapping_suite, mongodb_client=mongodb_client)
Expand All @@ -61,7 +59,7 @@ def notice_validation_pipeline(notice: Notice) -> NoticePipelineOutput:
return NoticePipelineOutput(notice=notice)


def notice_package_pipeline(notice: Notice) -> NoticePipelineOutput:
def notice_package_pipeline(notice: Notice, mongodb_client: MongoClient) -> NoticePipelineOutput:
"""
"""
Expand All @@ -71,10 +69,11 @@ def notice_package_pipeline(notice: Notice) -> NoticePipelineOutput:
return NoticePipelineOutput(notice=packaged_notice)


def notice_publish_pipeline(notice: Notice) -> NoticePipelineOutput:
def notice_publish_pipeline(notice: Notice, mongodb_client: MongoClient) -> NoticePipelineOutput:
"""
"""
notice.set_is_eligible_for_publishing(eligibility=True)
result = publish_notice(notice=notice)
if result:
return NoticePipelineOutput(notice=notice)
Expand Down
4 changes: 3 additions & 1 deletion dags/pipelines/pipeline_protocols.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from typing import Protocol

from pymongo import MongoClient

from ted_sws.core.model.notice import Notice


Expand All @@ -13,7 +15,7 @@ def __init__(self, notice: Notice, processed: bool = True, store_result: bool =

class NoticePipelineCallable(Protocol):

def __call__(self, notice: Notice) -> NoticePipelineOutput:
def __call__(self, notice: Notice, mongodb_client: MongoClient) -> NoticePipelineOutput:
"""
"""
2 changes: 1 addition & 1 deletion ted_sws/notice_transformer/services/notice_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def transform_notice(notice: Notice, mapping_suite: MappingSuite, rml_mapper: RM
file.write(notice.xml_manifestation.object_data)
rdf_result = rml_mapper.execute(package_path=package_path)
notice.set_rdf_manifestation(
rdf_manifestation=RDFManifestation(mapping_suite_id=mapping_suite.identifier,
rdf_manifestation=RDFManifestation(mapping_suite_id=mapping_suite.get_mongodb_id(),
object_data=rdf_result))
return notice

Expand Down
17 changes: 16 additions & 1 deletion tests/e2e/conftest.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
import mongomock
import pymongo
import pytest
from pymongo import MongoClient

from ted_sws import config
from ted_sws.data_manager.adapters.triple_store import AllegroGraphTripleStore, FusekiAdapter

from tests import TEST_DATA_PATH



@pytest.fixture
def mongodb_client():
uri = config.MONGO_DB_AUTH_URL
Expand Down Expand Up @@ -38,9 +42,20 @@ def fake_mapping_suite_id() -> str:

@pytest.fixture
def fuseki_triple_store():
return FusekiAdapter(host=config.FUSEKI_ADMIN_HOST, user=config.FUSEKI_ADMIN_USER, password=config.FUSEKI_ADMIN_PASSWORD)
return FusekiAdapter(host=config.FUSEKI_ADMIN_HOST, user=config.FUSEKI_ADMIN_USER,
password=config.FUSEKI_ADMIN_PASSWORD)


@pytest.fixture
def cellar_sparql_endpoint():
return "https://publications.europa.eu/webapi/rdf/sparql"


@pytest.fixture(scope="function")
@mongomock.patch(servers=(('server.example.com', 27017),))
def fake_mongodb_client():
mongo_client = pymongo.MongoClient('server.example.com')
for database_name in mongo_client.list_database_names():
mongo_client.drop_database(database_name)
return mongo_client

Empty file.
33 changes: 33 additions & 0 deletions tests/e2e/dags/pipelines/test_notice_processor_pipelines.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
from dags.pipelines.notice_processor_pipelines import notice_normalisation_pipeline, notice_transformation_pipeline, \
notice_validation_pipeline, notice_package_pipeline, notice_publish_pipeline
from ted_sws.core.model.notice import NoticeStatus
from ted_sws.data_manager.adapters.notice_repository import NoticeRepository
from ted_sws.mapping_suite_processor.services.conceptual_mapping_processor import \
mapping_suite_processor_from_github_expand_and_load_package_in_mongo_db

MAPPING_SUITE_PACKAGE_NAME = "package_F03_test"
MAPPING_SUITE_PACKAGE_ID = f"{MAPPING_SUITE_PACKAGE_NAME}_v2.3.0"
NOTICE_ID = "057215-2021"


def test_notice_processor_pipelines(fake_mongodb_client):
mapping_suite_processor_from_github_expand_and_load_package_in_mongo_db(
mapping_suite_package_name=MAPPING_SUITE_PACKAGE_NAME,
mongodb_client=fake_mongodb_client,
load_test_data=True
)
notice_id = NOTICE_ID
notice_repository = NoticeRepository(mongodb_client=fake_mongodb_client)
notice = notice_repository.get(reference=notice_id)
pipelines = [notice_normalisation_pipeline, notice_transformation_pipeline, notice_validation_pipeline,
notice_package_pipeline, notice_publish_pipeline]
notice_states = [NoticeStatus.RAW, NoticeStatus.NORMALISED_METADATA, NoticeStatus.DISTILLED,
NoticeStatus.VALIDATED, NoticeStatus.PACKAGED, NoticeStatus.PUBLISHED]
for index, pipeline in enumerate(pipelines):
assert notice.status == notice_states[index]
pipeline_output = pipeline(notice=notice, mongodb_client=fake_mongodb_client)
assert pipeline_output.processed, f"{pipeline.__name__} not processed!"
assert pipeline_output.store_result
assert pipeline_output.notice
notice = pipeline_output.notice
assert notice.status == notice_states[index + 1]
Loading

0 comments on commit fe8b7d6

Please sign in to comment.