From bf0d3cde42a8053a9389790c58c66df06425fb32 Mon Sep 17 00:00:00 2001 From: Kaxil Naik Date: Thu, 8 Dec 2022 10:37:04 +0000 Subject: [PATCH] Expose `get_file_list_func` to users (#1380) This exposes `get_file_list_func` to users so that they can use it in `@task` or `@task.virtualenv` Context: https://astronomer.slack.com/archives/C02B8SPT93K/p1670408501306539?thread_ts=1669893212.365849&cid=C02B8SPT93K --- python-sdk/src/astro/files/__init__.py | 4 ++-- python-sdk/src/astro/files/operators/files.py | 15 +++++++++++---- 2 files changed, 13 insertions(+), 6 deletions(-) diff --git a/python-sdk/src/astro/files/__init__.py b/python-sdk/src/astro/files/__init__.py index 41f555e5f..accd4ed4e 100644 --- a/python-sdk/src/astro/files/__init__.py +++ b/python-sdk/src/astro/files/__init__.py @@ -2,8 +2,8 @@ from airflow.hooks.base import BaseHook -from astro.files.base import File # noqa: F401 # skipcq: PY-W2000 -from astro.files.base import resolve_file_path_pattern # noqa: F401 # skipcq: PY-W2000 +from astro.files.base import File, resolve_file_path_pattern # noqa: F401 # skipcq: PY-W2000 +from astro.files.operators.files import get_file_list_func # noqa: F401 # skipcq: PY-W2000 if TYPE_CHECKING: from airflow.models.xcom_arg import XComArg diff --git a/python-sdk/src/astro/files/operators/files.py b/python-sdk/src/astro/files/operators/files.py index dc3fd4fae..c962103fd 100644 --- a/python-sdk/src/astro/files/operators/files.py +++ b/python-sdk/src/astro/files/operators/files.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from typing import Any from airflow.decorators.base import get_unique_task_id @@ -29,8 +31,13 @@ def __init__(self, path: str, conn_id: str, task_id: str = "", **kwargs): super().__init__(task_id=task_id, **kwargs) def execute(self, context: Context) -> Any: # skipcq: PYL-W0613 - location = create_file_location(self.path, self.conn_id) + files = get_file_list_func(self.path, self.conn_id) # Get list of files excluding folders - return [ - File(path=path, conn_id=location.conn_id) for path in location.paths if not path.endswith("/") - ] + return [File(path=file, conn_id=self.conn_id) for file in files] + + +def get_file_list_func(path: str, conn_id: str) -> list[str]: + """Function to get list of files from a bucket""" + location = create_file_location(path, conn_id) + # Get list of files excluding folders + return [path for path in location.paths if not path.endswith("/")]