Skip to content

Commit

Permalink
Allow users to customise the task name when using @df #126
Browse files Browse the repository at this point in the history
Context

As of Astro 0.5.1, if a user calls multiple times a function decorated with @df (example below), the tasks will be automatically named row_to_gcs_files__1. At the moment, the user cannot override the default task_id name when using @df.

@df
def rows_to_gcs_files()
This behavior is different from the @task decorator, which allows users to specify task_id. We should make them consistent. This works:

@task(task_id = 'load_json')
Acceptance criteria

Users are able to override the default task_id name when using the df decorator, similar to the @task decorator
Ask @mag3141592 (bug reporter) to review the PR
  • Loading branch information
dimberman committed Mar 5, 2022
1 parent 7633f79 commit 48cc11c
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 0 deletions.
2 changes: 2 additions & 0 deletions src/astro/dataframe/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,14 @@ def dataframe(
database: Optional[str] = None,
schema: Optional[str] = None,
warehouse: Optional[str] = None,
task_id: Optional[str] = None,
):
"""
This function allows a user to run python functions in Airflow but with the huge benefit that SQL files
will automatically be turned into dataframes and resulting dataframes can automatically used in astro.sql functions
"""
return task_decorator_factory(
task_id=task_id,
python_callable=python_callable,
multiple_outputs=multiple_outputs,
decorated_operator_class=SqlDataframeOperator, # type: ignore
Expand Down
29 changes: 29 additions & 0 deletions tests/operators/test_dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,35 @@ def my_df_func(df: pandas.DataFrame):
== 200
)

def test_dataframe_from_sql_custom_task_id(self):
@df(task_id="foo")
def my_df_func(df: pandas.DataFrame):
return df.actor_id.count()

with self.dag:
for i in range(5):
# ensure we can create multiple tasks
f = my_df_func(
df=Table(
"actor",
conn_id="postgres_conn",
database="pagila",
schema="public",
)
)

task_ids = [x.task_id for x in self.dag.tasks]
assert task_ids == ["foo", "foo__1", "foo__2", "foo__3", "foo__4"]

test_utils.run_dag(self.dag)

assert (
XCom.get_one(
execution_date=DEFAULT_DATE, key=f.key, task_id=f.operator.task_id
)
== 200
)

def test_dataframe_from_sql_basic_op_arg(self):
@df(conn_id="postgres_conn", database="pagila")
def my_df_func(df: pandas.DataFrame):
Expand Down

0 comments on commit 48cc11c

Please sign in to comment.