Skip to content

Commit

Permalink
Expose get_file_list_func to users (#1380)
Browse files Browse the repository at this point in the history
This exposes `get_file_list_func` to users so that they can use it in
`@task` or `@task.virtualenv`

Context:
https://astronomer.slack.com/archives/C02B8SPT93K/p1670408501306539?thread_ts=1669893212.365849&cid=C02B8SPT93K
  • Loading branch information
kaxil authored Dec 8, 2022
1 parent 27b35ea commit bf0d3cd
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 6 deletions.
4 changes: 2 additions & 2 deletions python-sdk/src/astro/files/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@

from airflow.hooks.base import BaseHook

from astro.files.base import File # noqa: F401 # skipcq: PY-W2000
from astro.files.base import resolve_file_path_pattern # noqa: F401 # skipcq: PY-W2000
from astro.files.base import File, resolve_file_path_pattern # noqa: F401 # skipcq: PY-W2000
from astro.files.operators.files import get_file_list_func # noqa: F401 # skipcq: PY-W2000

if TYPE_CHECKING:
from airflow.models.xcom_arg import XComArg
Expand Down
15 changes: 11 additions & 4 deletions python-sdk/src/astro/files/operators/files.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

from typing import Any

from airflow.decorators.base import get_unique_task_id
Expand Down Expand Up @@ -29,8 +31,13 @@ def __init__(self, path: str, conn_id: str, task_id: str = "", **kwargs):
super().__init__(task_id=task_id, **kwargs)

def execute(self, context: Context) -> Any: # skipcq: PYL-W0613
location = create_file_location(self.path, self.conn_id)
files = get_file_list_func(self.path, self.conn_id)
# Get list of files excluding folders
return [
File(path=path, conn_id=location.conn_id) for path in location.paths if not path.endswith("/")
]
return [File(path=file, conn_id=self.conn_id) for file in files]


def get_file_list_func(path: str, conn_id: str) -> list[str]:
"""Function to get list of files from a bucket"""
location = create_file_location(path, conn_id)
# Get list of files excluding folders
return [path for path in location.paths if not path.endswith("/")]

0 comments on commit bf0d3cd

Please sign in to comment.