Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add create_new_es_index DAGs #3537

Merged
merged 19 commits into from
Jan 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions catalog/dags/common/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@

MediaType = Literal["audio", "image"]

STAGING = "staging"
PRODUCTION = "production"

CONTACT_EMAIL = os.getenv("CONTACT_EMAIL")

DAG_DEFAULT_ARGS = {
Expand Down
29 changes: 28 additions & 1 deletion catalog/dags/common/sensors/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from datetime import datetime

from airflow.decorators import task
from airflow.decorators import task, task_group
from airflow.exceptions import AirflowSensorTimeout
from airflow.models import DagRun
from airflow.sensors.external_task import ExternalTaskSensor
Expand Down Expand Up @@ -39,6 +39,14 @@ def wait_for_external_dag(external_dag_id: str, task_id: str | None = None):
"""
Return a Sensor task which will wait if the given external DAG is
running.

To fully ensure that the waiting DAG and the external DAG do not run
concurrently, the external DAG should have a `prevent_concurrency_with_dag`
task which fails immediately if the waiting DAG is running.

If the external DAG should _not_ fail when the waiting DAG is running,
but instead wait its turn, use the SingleRunExternalDagSensor in both
DAGs to avoid deadlock.
"""
if not task_id:
task_id = f"wait_for_{external_dag_id}"
Expand All @@ -57,6 +65,16 @@ def wait_for_external_dag(external_dag_id: str, task_id: str | None = None):
)


@task_group(group_id="wait_for_external_dags")
def wait_for_external_dags(external_dag_ids: list[str]):
"""
Wait for all DAGs with the given external DAG ids to no longer be
in a running state before continuing.
"""
for dag_id in external_dag_ids:
wait_for_external_dag(dag_id)


@task(retries=0)
def prevent_concurrency_with_dag(external_dag_id: str, **context):
"""
Expand All @@ -73,3 +91,12 @@ def prevent_concurrency_with_dag(external_dag_id: str, **context):
wait_for_dag.execute(context)
except AirflowSensorTimeout:
raise ValueError(f"Concurrency check with {external_dag_id} failed.")


@task_group(group_id="prevent_concurrency")
def prevent_concurrency_with_dags(external_dag_ids: list[str]):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is excellent and so simple!

"""Fail immediately if any of the given external dags are in progress."""
for dag_id in external_dag_ids:
prevent_concurrency_with_dag.override(
task_id=f"prevent_concurrency_with_{dag_id}"
)(dag_id)
22 changes: 15 additions & 7 deletions catalog/dags/data_refresh/create_filtered_index_dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,12 +55,15 @@
from airflow import DAG
from airflow.models.param import Param

from common.constants import DAG_DEFAULT_ARGS
from common.sensors.utils import prevent_concurrency_with_dag
from common.constants import DAG_DEFAULT_ARGS, PRODUCTION
from common.sensors.utils import prevent_concurrency_with_dags
from data_refresh.create_filtered_index import (
create_filtered_index_creation_task_groups,
)
from data_refresh.data_refresh_types import DATA_REFRESH_CONFIGS, DataRefresh
from elasticsearch_cluster.create_new_es_index.create_new_es_index_types import (
CREATE_NEW_INDEX_CONFIGS,
)


# Note: We can't use the TaskFlow `@dag` DAG factory decorator
Expand All @@ -80,7 +83,7 @@ def create_filtered_index_creation_dag(data_refresh: DataRefresh):
media_type = data_refresh.media_type

with DAG(
dag_id=f"create_filtered_{media_type}_index",
dag_id=data_refresh.filtered_index_dag_id,
default_args=DAG_DEFAULT_ARGS,
schedule=None,
start_date=datetime(2023, 4, 1),
Expand Down Expand Up @@ -113,10 +116,15 @@ def create_filtered_index_creation_dag(data_refresh: DataRefresh):
},
render_template_as_native_obj=True,
) as dag:
# Immediately fail if the associated data refresh is running.
prevent_concurrency = prevent_concurrency_with_dag.override(
task_id=f"prevent_concurrency_with_{media_type}_data_refresh"
)(external_dag_id=f"{media_type}_data_refresh")
# Immediately fail if the associated data refresh is running, or the
# create_new_production_es_index DAG is running. This prevents multiple
# DAGs from reindexing from a single production index simultaneously.
prevent_concurrency = prevent_concurrency_with_dags(
external_dag_ids=[
data_refresh.dag_id,
CREATE_NEW_INDEX_CONFIGS[PRODUCTION].dag_id,
]
)

# Once the concurrency check has passed, actually create the filtered
# index.
Expand Down
28 changes: 17 additions & 11 deletions catalog/dags/data_refresh/data_refresh_task_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,13 +53,16 @@
from airflow.utils.trigger_rule import TriggerRule

from common import ingestion_server
from common.constants import XCOM_PULL_TEMPLATE
from common.constants import PRODUCTION, XCOM_PULL_TEMPLATE
from common.sensors.single_run_external_dags_sensor import SingleRunExternalDAGsSensor
from common.sensors.utils import wait_for_external_dag
from common.sensors.utils import wait_for_external_dags
from data_refresh.create_filtered_index import (
create_filtered_index_creation_task_groups,
)
from data_refresh.data_refresh_types import DataRefresh
from elasticsearch_cluster.create_new_es_index.create_new_es_index_types import (
CREATE_NEW_INDEX_CONFIGS,
)


logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -112,16 +115,19 @@ def create_data_refresh_task_group(
pool=DATA_REFRESH_POOL,
)

# If filtered index creation was manually triggered before the data refresh
# started, we need to wait for it to finish or the data refresh could destroy
# the origin index. Realistically the data refresh is too slow to beat the
# filtered index creation process, even if it was triggered immediately after
# filtered index creation. However, it is safer to avoid the possibility
# of the race condition altogether.
wait_for_filtered_index_creation = wait_for_external_dag(
external_dag_id=f"create_filtered_{data_refresh.media_type}_index",
# Wait for other DAGs that operate on the ES cluster. If a new or filtered index
# is being created by one of these DAGs, we need to wait for it to finish or else
# the data refresh might destroy the index being used as the source index.
# Realistically the data refresh is too slow to beat the index creation process,
# even if it was triggered immediately after one of these DAGs; however, it is
# always safer to avoid the possibility of the race condition altogether.
wait_for_es_dags = wait_for_external_dags.override(group_id="wait_for_es_dags")(
external_dag_ids=[
data_refresh.filtered_index_dag_id,
CREATE_NEW_INDEX_CONFIGS[PRODUCTION].dag_id,
]
)
tasks.append([wait_for_data_refresh, wait_for_filtered_index_creation])
tasks.append([wait_for_data_refresh, wait_for_es_dags])

# Get the index currently mapped to our target alias, to delete later.
get_current_index = ingestion_server.get_current_index(target_alias)
Expand Down
2 changes: 2 additions & 0 deletions catalog/dags/data_refresh/data_refresh_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ class DataRefresh:
"""

dag_id: str = field(init=False)
filtered_index_dag_id: str = field(init=False)
media_type: str
start_date: datetime = datetime(2020, 1, 1)
schedule: str | None = "0 0 * * 1" # Mondays 00:00 UTC
Expand All @@ -69,6 +70,7 @@ class DataRefresh:

def __post_init__(self):
self.dag_id = f"{self.media_type}_data_refresh"
self.filtered_index_dag_id = f"create_filtered_{self.media_type}_index"


DATA_REFRESH_CONFIGS = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,6 @@
from airflow.providers.amazon.aws.operators.rds import RdsDeleteDbInstanceOperator
from airflow.providers.amazon.aws.sensors.rds import RdsSnapshotExistenceSensor
from airflow.utils.trigger_rule import TriggerRule
from es.recreate_staging_index.recreate_full_staging_index import (
DAG_ID as RECREATE_STAGING_INDEX_DAG_ID,
)

from common.constants import (
AWS_RDS_CONN_ID,
Expand All @@ -51,6 +48,9 @@
restore_staging_from_snapshot,
skip_restore,
)
from elasticsearch_cluster.recreate_staging_index.recreate_full_staging_index import (
DAG_ID as RECREATE_STAGING_INDEX_DAG_ID,
)


log = logging.getLogger(__name__)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,195 @@
import logging
from datetime import timedelta

from airflow.decorators import task, task_group
from airflow.models.connection import Connection
from airflow.providers.elasticsearch.hooks.elasticsearch import ElasticsearchPythonHook
from airflow.sensors.python import PythonSensor

from common.constants import REFRESH_POKE_INTERVAL
from elasticsearch_cluster.create_new_es_index.utils import merge_configurations


logger = logging.getLogger(__name__)


# Index settings that should not be copied over from the base configuration when
# creating a new index.
EXCLUDED_INDEX_SETTINGS = ["provided_name", "creation_date", "uuid", "version"]

GET_FINAL_INDEX_CONFIG_TASK_NAME = "get_final_index_configuration"
GET_CURRENT_INDEX_CONFIG_TASK_NAME = "get_current_index_configuration"


@task
def get_es_host(environment: str):
conn = Connection.get_connection_from_secrets(f"elasticsearch_http_{environment}")
return conn.host


@task
def get_index_name(media_type: str, index_suffix: str):
return f"{media_type}-{index_suffix}".lower()


@task.branch
def check_override_config(override):
if override:
# Skip the steps to fetch the current index configuration
# and merge changes in.
return GET_FINAL_INDEX_CONFIG_TASK_NAME

return GET_CURRENT_INDEX_CONFIG_TASK_NAME


@task
def get_current_index_configuration(
source_index: str,
es_host: str,
):
"""
Return the configuration for the current index, identified by the
`source_index` param. `source_index` may be either an index name
or an alias, but must uniquely identify one existing index or an
error will be raised.
"""
es_conn = ElasticsearchPythonHook(hosts=[es_host]).get_conn

response = es_conn.indices.get(
index=source_index,
# Return empty dict instead of throwing error if no index can be
# found. We raise our own error instead.
ignore_unavailable=True,
)

if len(response) != 1:
raise ValueError(f"Index {source_index} could not be uniquely identified.")

# The response has the form:
# { index_name: index_configuration }
# However, since `source_index` can be an alias rather than the index name,
# we do not necessarily know the index_name so we cannot access the configuration
# directly by key. We instead get the first value from the dict, knowing that we
# have already ensured in a previous check that there is exactly one value in the
# response.
config = next(iter(response.values()))
return config


@task
def merge_index_configurations(new_index_config, current_index_config):
"""
Merge the `new_index_config` into the `current_index_config`, and
return an index configuration in the appropriate format for being
passed to the `create_index` API.
"""
# Do not automatically apply any aliases to the new index
current_index_config.pop("aliases")

# Remove fields from the current_index_config that should not be copied
# over into the new index (such as uuid)
for setting in EXCLUDED_INDEX_SETTINGS:
current_index_config.get("settings", {}).get("index", {}).pop(setting)

# Merge the new configuration values into the current configuration
return merge_configurations(current_index_config, new_index_config)


@task
def get_final_index_configuration(
override_config: bool,
index_config,
merged_config,
index_name: str,
):
"""
Resolve the final index configuration to be used in the `create_index`
task.

Required arguments:

override_config: Whether the index_config should be used instead of
the merged_config
index_config: The new index configuration which was passed in via
DAG params
merged_config: The result of merging the index_config with the current
index configuration. This may be None if the merge
tasks were skipped using the override param.
index_name: Name of the index to update.
"""
config = index_config if override_config else merged_config

# Apply the desired index name
config["index"] = index_name
return config


@task
def create_index(index_config, es_host: str):
es_conn = ElasticsearchPythonHook(hosts=[es_host]).get_conn

new_index = es_conn.indices.create(**index_config)

return new_index


@task_group(group_id="trigger_and_wait_for_reindex")
def trigger_and_wait_for_reindex(
index_name: str,
source_index: str,
query: dict,
timeout: timedelta,
requests_per_second: int,
es_host: str,
):
@task
def trigger_reindex(
index_name: str,
source_index: str,
query: dict,
requests_per_second: int,
es_host: str,
):
es_conn = ElasticsearchPythonHook(hosts=[es_host]).get_conn

source = {"index": source_index}
# An empty query is not accepted; only pass it
# if a query was actually supplied
if query:
source["query"] = query

response = es_conn.reindex(
source=source,
dest={"index": index_name},
# Parallelize indexing
slices="auto",
# Do not hold the slot while awaiting completion
wait_for_completion=False,
# Immediately refresh the index after completion to make
# the data available for search
refresh=True,
# Throttle
requests_per_second=requests_per_second,
)

return response["task"]

def _wait_for_reindex(task_id: str, es_host: str):
es_conn = ElasticsearchPythonHook(hosts=[es_host]).get_conn

response = es_conn.tasks.get(task_id=task_id)
return response.get("completed")

trigger_reindex_task = trigger_reindex(
index_name, source_index, query, requests_per_second, es_host
)

wait_for_reindex = PythonSensor(
task_id="wait_for_reindex",
python_callable=_wait_for_reindex,
timeout=timeout,
poke_interval=REFRESH_POKE_INTERVAL,
op_kwargs={"task_id": trigger_reindex_task, "es_host": es_host},
)

trigger_reindex_task >> wait_for_reindex
Loading