Skip to content

Commit

Permalink
Support Apache Airflow 2.6
Browse files Browse the repository at this point in the history
Modify the following so the SDK is compatible with the latest version of Apache Airflow:
1. XCom serialization/deserialization - Table and Column
2. Airflow breaking changes which affected `airflow.jobs.scheduler_job.SchedulerJob` and
`airflow.jobs.backfill_job.BackfillJob`

Closes #1590
Closes #1904

Co-authored-by: utkarsh sharma <utkarsharma2@gmail.com>
Co-authored-by: Tatiana Al-Chueyr <tatiana.alchueyr@gmail.com>
  • Loading branch information
3 people authored May 3, 2023
1 parent 95326de commit 1ea5c0c
Show file tree
Hide file tree
Showing 8 changed files with 139 additions and 17 deletions.
10 changes: 5 additions & 5 deletions .github/workflows/ci-python-sdk.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,7 @@ jobs:
key: ${{ runner.os }}-2.5-${{ hashFiles('python-sdk/pyproject.toml') }}-${{ hashFiles('python-sdk/src/astro/__init__.py') }}
- run: sqlite3 /tmp/sqlite_default.db "VACUUM;"
- run: pip3 install nox
- run: nox -s "test-${{ matrix.version }}(airflow='2.5.3')" -- tests/ --cov=src --cov-report=xml --cov-branch
- run: nox -s "test-${{ matrix.version }}(airflow='2.6.0')" -- tests/ --cov=src --cov-report=xml --cov-branch
- name: Upload coverage
uses: actions/upload-artifact@v2
with:
Expand Down Expand Up @@ -319,7 +319,7 @@ jobs:
- run: python -c 'import os; print(os.getenv("GOOGLE_APPLICATION_CREDENTIALS_JSON", "").strip())' > ${{ env.GOOGLE_APPLICATION_CREDENTIALS }}
- run: sqlite3 /tmp/sqlite_default.db "VACUUM;"
- run: pip3 install nox
- run: nox -s "test-3.8(airflow='2.5.3')" -- tests_integration/ -k "test_load_file.py" --splits 3 --group ${{ matrix.group }} --store-durations --durations-path /tmp/durations-${{ matrix.group }} --cov=src --cov-report=xml --cov-branch
- run: nox -s "test-3.8(airflow='2.6.0')" -- tests_integration/ -k "test_load_file.py" --splits 3 --group ${{ matrix.group }} --store-durations --durations-path /tmp/durations-${{ matrix.group }} --cov=src --cov-report=xml --cov-branch
- run: cat /tmp/durations-${{ matrix.group }}
- name: Upload coverage
uses: actions/upload-artifact@v2
Expand Down Expand Up @@ -415,7 +415,7 @@ jobs:
- run: python -c 'import os; print(os.getenv("GOOGLE_APPLICATION_CREDENTIALS_JSON", "").strip())' > ${{ env.GOOGLE_APPLICATION_CREDENTIALS }}
- run: sqlite3 /tmp/sqlite_default.db "VACUUM;"
- run: pip3 install nox
- run: nox -s "test-3.8(airflow='2.5.3')" -- tests_integration/ -k "test_example_dags.py" --splits 3 --group ${{ matrix.group }} --store-durations --durations-path /tmp/durations-${{ matrix.group }} --cov=src --cov-report=xml --cov-branch
- run: nox -s "test-3.8(airflow='2.6.0')" -- tests_integration/ -k "test_example_dags.py" --splits 3 --group ${{ matrix.group }} --store-durations --durations-path /tmp/durations-${{ matrix.group }} --cov=src --cov-report=xml --cov-branch
- run: cat /tmp/durations-${{ matrix.group }}
- name: Upload coverage
uses: actions/upload-artifact@v2
Expand Down Expand Up @@ -511,7 +511,7 @@ jobs:
- run: python -c 'import os; print(os.getenv("GOOGLE_APPLICATION_CREDENTIALS_JSON", "").strip())' > ${{ env.GOOGLE_APPLICATION_CREDENTIALS }}
- run: sqlite3 /tmp/sqlite_default.db "VACUUM;"
- run: pip3 install nox
- run: nox -s "test-3.8(airflow='2.5.3')" -- tests_integration/ -k "not test_load_file.py and not test_example_dags.py" --splits 11 --group ${{ matrix.group }} --store-durations --durations-path /tmp/durations-${{ matrix.group }} --cov=src --cov-report=xml --cov-branch
- run: nox -s "test-3.8(airflow='2.6.0')" -- tests_integration/ -k "not test_load_file.py and not test_example_dags.py" --splits 11 --group ${{ matrix.group }} --store-durations --durations-path /tmp/durations-${{ matrix.group }} --cov=src --cov-report=xml --cov-branch
- run: cat /tmp/durations-${{ matrix.group }}
- name: Upload coverage
uses: actions/upload-artifact@v2
Expand Down Expand Up @@ -627,7 +627,7 @@ jobs:
fail-fast: false
matrix:
python: [ '3.7', '3.8', '3.9', '3.10' ]
airflow: [ '2.2.5', '2.3.4', '2.4.2', '2.5.3' ]
airflow: [ '2.2.5', '2.3.4', '2.4.2', '2.5.3', '2.6.0']
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
Expand Down
4 changes: 2 additions & 2 deletions python-sdk/noxfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def dev(session: nox.Session) -> None:


@nox.session(python=["3.7", "3.8", "3.9", "3.10"])
@nox.parametrize("airflow", ["2.2.5", "2.4", "2.5.3"])
@nox.parametrize("airflow", ["2.2.5", "2.4", "2.5.3", "2.6.0"])
def test(session: nox.Session, airflow) -> None:
"""Run both unit and integration tests."""
env = {
Expand Down Expand Up @@ -151,7 +151,7 @@ def build_docs(session: nox.Session) -> None:


@nox.session(python=["3.7", "3.8", "3.9", "3.10"])
@nox.parametrize("airflow", ["2.2.5", "2.3.4", "2.4.2", "2.5.3"])
@nox.parametrize("airflow", ["2.2.5", "2.3.4", "2.4.2", "2.5.3", "2.6.0"])
def generate_constraints(session: nox.Session, airflow) -> None:
"""Generate constraints file"""
session.install("wheel")
Expand Down
11 changes: 10 additions & 1 deletion python-sdk/src/astro/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

# typing.Literal was only introduced in Python 3.8, and we support Python 3.7
if sys.version_info >= (3, 8):
from typing import Literal
from typing import Any, Literal
else:
from typing_extensions import Literal

Expand Down Expand Up @@ -44,6 +44,15 @@ class FileType(Enum):
def __str__(self) -> str:
return self.value

def serialize(self) -> dict[str, Any]:
return {
"value": self.value,
}

@staticmethod
def deserialize(data: dict[str, Any], _: int):
return FileType(data["value"])


class Database(Enum):
# [START database]
Expand Down
9 changes: 7 additions & 2 deletions python-sdk/src/astro/sql/operators/cleanup.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,12 @@
from datetime import timedelta
from typing import Any

from airflow import __version__ as airflow_version
from airflow.decorators.base import get_unique_task_id
from airflow.exceptions import AirflowException
from airflow.models.baseoperator import BaseOperator
from airflow.models.dagrun import DagRun
from packaging import version

try:
# Airflow >= 2.3
Expand Down Expand Up @@ -215,11 +217,14 @@ def _is_single_worker_mode(cls, current_dagrun: DagRun) -> bool:

@staticmethod
def _get_executor_from_job_id(job_id: int) -> str | None:
from airflow.jobs.base_job import BaseJob
if version.parse(airflow_version) >= version.parse("2.6"):
from airflow.jobs.job import Job
else:
from airflow.jobs.base_job import BaseJob as Job
from airflow.utils.session import create_session

with create_session() as session:
job = session.get(BaseJob, job_id)
job = session.get(Job, job_id)
return job.executor_class if job else None

def get_all_task_outputs(self, context: Context) -> list[BaseTable]:
Expand Down
22 changes: 21 additions & 1 deletion python-sdk/src/astro/table.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import random
import string
from typing import Any
from typing import Any, ClassVar

from attr import define, field, fields_dict
from sqlalchemy import Column, MetaData
Expand Down Expand Up @@ -50,6 +50,7 @@ class BaseTable:
"""

template_fields = ("name",)
version: ClassVar[int] = 1

# TODO: discuss alternative names to this class, since it contains metadata as opposed to be the
# SQL table itself
Expand Down Expand Up @@ -178,6 +179,25 @@ def openlineage_emit_temp_table_event(self):
isinstance(self, TempTable) and OPENLINEAGE_EMIT_TEMP_TABLE_EVENT
)

def serialize(self) -> dict[str, Any]:
return {
"name": self.name,
"temp": self.temp,
"conn_id": self.conn_id,
"metadata": {"schema": self.metadata.schema, "database": self.metadata.database},
}

@staticmethod
def deserialize(data: dict[str, Any], version: int):
if version > 1:
raise TypeError(f"version > {BaseTable.version}")
return Table(
name=data["name"],
temp=data["temp"],
conn_id=data["conn_id"],
metadata=Metadata(**data["metadata"]),
)


@define(slots=False)
class TempTable(BaseTable):
Expand Down
92 changes: 86 additions & 6 deletions python-sdk/tests/sql/operators/test_cleanup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,16 @@
from unittest import mock

import pytest
from airflow import DAG, AirflowException
from airflow import DAG, AirflowException, __version__ as airflow_version
from airflow.executors.local_executor import LocalExecutor
from airflow.executors.sequential_executor import SequentialExecutor
from airflow.jobs.backfill_job import BackfillJob
from airflow.jobs.scheduler_job import SchedulerJob
from airflow.models.dagrun import DagRun
from airflow.models.taskinstance import TaskInstance
from airflow.operators.bash import BashOperator
from airflow.settings import Session
from airflow.utils.state import State
from airflow.utils.timezone import datetime
from packaging import version

from astro.constants import Database
from astro.files import File
Expand Down Expand Up @@ -101,6 +100,10 @@ def test_error_raised_with_blocking_op_executors(
cleanup_task.execute({"dag_run": dr})


@pytest.mark.skipif(
version.parse(airflow_version) < version.parse("2.6.0"),
reason="BackfillJobRunner and Job classes are only available in airflow >= 2.6",
)
@pytest.mark.parametrize(
"executor_in_job,executor_in_cfg,expected_val",
[
Expand All @@ -112,6 +115,42 @@ def test_error_raised_with_blocking_op_executors(
)
def test_single_worker_mode_backfill(executor_in_job, executor_in_cfg, expected_val):
"""Test that if we run Backfill Job it should be marked as single worker node"""
from airflow.jobs.backfill_job_runner import BackfillJobRunner
from airflow.jobs.job import Job

dag = DAG("test_single_worker_mode_backfill", start_date=datetime(2022, 1, 1))
dr = DagRun(dag_id=dag.dag_id)

with mock.patch.dict(os.environ, {"AIRFLOW__CORE__EXECUTOR": executor_in_cfg}):
job = Job(executor=executor_in_job)
session = Session()
session.add(job)
session.flush()
BackfillJobRunner(job=job, dag=dag)

dr.creating_job_id = job.id
assert CleanupOperator._is_single_worker_mode(dr) == expected_val

session.rollback()


@pytest.mark.skipif(
version.parse(airflow_version) >= version.parse("2.6.0"),
reason="BackfillJob class is not available in airflow < 2.6",
)
@pytest.mark.parametrize(
"executor_in_job,executor_in_cfg,expected_val",
[
(SequentialExecutor(), "LocalExecutor", True),
(LocalExecutor(), "LocalExecutor", False),
(None, "LocalExecutor", False),
(None, "SequentialExecutor", True),
],
)
def test_single_worker_mode_backfill_airflow_2_5(executor_in_job, executor_in_cfg, expected_val):
"""Test that if we run Backfill Job it should be marked as single worker node"""
from airflow.jobs.backfill_job import BackfillJob

dag = DAG("test_single_worker_mode_backfill", start_date=datetime(2022, 1, 1))
dr = DagRun(dag_id=dag.dag_id)

Expand All @@ -127,20 +166,61 @@ def test_single_worker_mode_backfill(executor_in_job, executor_in_cfg, expected_
session.rollback()


@pytest.mark.skipif(
version.parse(airflow_version) < version.parse("2.6.0"),
reason="SchedulerJobRunner and Job classes are only available in airflow >= 2.6.0",
)
@pytest.mark.parametrize(
"executor_in_job,executor_in_cfg,expected_val",
[
(SequentialExecutor(), "LocalExecutor", True),
(LocalExecutor(), "LocalExecutor", False),
(None, "LocalExecutor", False),
(None, "SequentialExecutor", True),
],
)
def test_single_worker_mode_scheduler_job(executor_in_job, executor_in_cfg, expected_val):
"""Test that if we run Scheduler Job it should be marked as single worker node"""
from airflow.jobs.job import Job
from airflow.jobs.scheduler_job_runner import SchedulerJobRunner

dag = DAG("test_single_worker_mode_scheduler_job", start_date=datetime(2022, 1, 1))
dr = DagRun(dag_id=dag.dag_id)

with mock.patch.dict(os.environ, {"AIRFLOW__CORE__EXECUTOR": executor_in_cfg}):
# Scheduler Job in Airflow sets executor from airflow.cfg
job = Job(executor=executor_in_job)
session = Session()
session.add(job)
session.flush()
SchedulerJobRunner(job=job)

dr.creating_job_id = job.id
assert CleanupOperator._is_single_worker_mode(dr) == expected_val

session.rollback()


@pytest.mark.skipif(
version.parse(airflow_version) >= version.parse("2.6.0"),
reason="SchedulerJob class is not available in airflow < 2.6",
)
@pytest.mark.parametrize(
"executor_in_cfg,expected_val",
"executor_in_job,expected_val",
[
("LocalExecutor", False),
("SequentialExecutor", True),
("CeleryExecutor", False),
],
)
def test_single_worker_mode_scheduler_job(executor_in_cfg, expected_val):
def test_single_worker_mode_scheduler_job_airflow_2_5(executor_in_job, expected_val):
"""Test that if we run Scheduler Job it should be marked as single worker node"""
from airflow.jobs.scheduler_job import SchedulerJob

dag = DAG("test_single_worker_mode_scheduler_job", start_date=datetime(2022, 1, 1))
dr = DagRun(dag_id=dag.dag_id)

with mock.patch.dict(os.environ, {"AIRFLOW__CORE__EXECUTOR": executor_in_cfg}):
with mock.patch.dict(os.environ, {"AIRFLOW__CORE__EXECUTOR": executor_in_job}):
# Scheduler Job in Airflow sets executor from airflow.cfg
job = SchedulerJob()
session = Session()
Expand Down
1 change: 1 addition & 0 deletions python-sdk/tests/sql/operators/test_dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
DEFAULT_DATE = timezone.datetime(2016, 1, 1)
CWD = pathlib.Path(__file__).parent

# Trying out need to replace.
test_df = pandas.DataFrame({"numbers": [1, 2, 3], "Colors": ["red", "white", "blue"]})
test_df_2 = pandas.DataFrame({"Numbers": [1, 2, 3], "Colors": ["red", "white", "blue"]})

Expand Down
7 changes: 7 additions & 0 deletions python-sdk/tests/sql/test_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,3 +159,10 @@ def test_openlineage_emit_temp_table_event():
with mock.patch("astro.table.OPENLINEAGE_EMIT_TEMP_TABLE_EVENT", new=False):
tb = TempTable(name="_tmp_xyz")
assert tb.openlineage_emit_temp_table_event() is False


def test_serialization__deserialization():
table = Table(
conn_id="postgres_conn",
)
Table.deserialize(table.serialize(), 1)

0 comments on commit 1ea5c0c

Please sign in to comment.