diff --git a/README.md b/README.md index 033137ab2..a674b18fb 100644 --- a/README.md +++ b/README.md @@ -120,6 +120,8 @@ pip install astro-sdk-python[amazon,google,snowflake,postgres] | json | | ndjson | | parquet | +| xls | +| xlsx | | Database | | :-------- | diff --git a/python-sdk/pyproject.toml b/python-sdk/pyproject.toml index 931cecfe7..8342a50f4 100644 --- a/python-sdk/pyproject.toml +++ b/python-sdk/pyproject.toml @@ -87,7 +87,8 @@ ftp = [ ] openlineage = ["openlineage-airflow>=0.17.0"] -databricks = ["databricks-cli", +databricks = [ + "databricks-cli", "apache-airflow-providers-databricks", # TODO: Remove this once https://github.com/databricks/databricks-sql-python/pull/191 released "databricks-sql-connector<2.9.0", diff --git a/python-sdk/src/astro/constants.py b/python-sdk/src/astro/constants.py index 93c4e677d..a63513447 100644 --- a/python-sdk/src/astro/constants.py +++ b/python-sdk/src/astro/constants.py @@ -39,6 +39,8 @@ class FileType(Enum): JSON = "json" NDJSON = "ndjson" PARQUET = "parquet" + XLS = "xls" + XLSX = "xlsx" # [END filetypes] def __str__(self) -> str: diff --git a/python-sdk/src/astro/dataframes/load_options.py b/python-sdk/src/astro/dataframes/load_options.py index 27f05192f..0d372489d 100644 --- a/python-sdk/src/astro/dataframes/load_options.py +++ b/python-sdk/src/astro/dataframes/load_options.py @@ -16,6 +16,7 @@ class PandasLoadOptions(LoadOptions): 1. CSV file type - https://pandas.pydata.org/docs/reference/api/pandas.read_csv.html 2. NDJSON/JSON file type - https://pandas.pydata.org/docs/reference/api/pandas.read_json.html 3. Parquet file type - https://pandas.pydata.org/docs/reference/api/pandas.read_parquet.html + 4. Excel file type: https://pandas.pydata.org/docs/reference/api/pandas.read_excel.html :param delimiter: Delimiter to use. Defaults to None :param dtype: Data type for data or columns. diff --git a/python-sdk/src/astro/files/base.py b/python-sdk/src/astro/files/base.py index 9c37447be..1b9ba8f6a 100644 --- a/python-sdk/src/astro/files/base.py +++ b/python-sdk/src/astro/files/base.py @@ -92,7 +92,11 @@ def is_binary(self) -> bool: :return: True or False """ - result: bool = self.type.name == constants.FileType.PARQUET + result: bool = self.type.name in ( + constants.FileType.PARQUET, + constants.FileType.XLSX, + constants.FileType.XLS, + ) return result def is_local(self) -> bool: diff --git a/python-sdk/src/astro/files/types/__init__.py b/python-sdk/src/astro/files/types/__init__.py index e0b79124a..2f3e9dacc 100644 --- a/python-sdk/src/astro/files/types/__init__.py +++ b/python-sdk/src/astro/files/types/__init__.py @@ -5,9 +5,12 @@ from astro.constants import FileType as FileTypeConstants from astro.files.types.base import FileType from astro.files.types.csv import CSVFileType +from astro.files.types.excel import ExcelFileType # noqa: F401 # skipcq: PY-W2000 from astro.files.types.json import JSONFileType from astro.files.types.ndjson import NDJSONFileType from astro.files.types.parquet import ParquetFileType +from astro.files.types.xls import XLSFileType +from astro.files.types.xlsx import XLSXFileType from astro.options import LoadOptionsList @@ -23,6 +26,8 @@ def create_file_type( FileTypeConstants.JSON: JSONFileType, FileTypeConstants.NDJSON: NDJSONFileType, FileTypeConstants.PARQUET: ParquetFileType, + FileTypeConstants.XLS: XLSFileType, + FileTypeConstants.XLSX: XLSXFileType, } if not filetype: filetype = get_filetype(path) @@ -49,7 +54,7 @@ def get_filetype(filepath: str | pathlib.PosixPath) -> FileTypeConstants: :param filepath: URI or Path to a file :type filepath: str or pathlib.PosixPath - :return: The filetype (e.g. csv, ndjson, json, parquet) + :return: The filetype (e.g. csv, ndjson, json, parquet, excel) :rtype: astro.constants.FileType """ if isinstance(filepath, pathlib.PosixPath): @@ -67,6 +72,6 @@ def get_filetype(filepath: str | pathlib.PosixPath) -> FileTypeConstants: ) try: - return FileTypeConstants(extension) + return FileTypeConstants(extension.lower()) except ValueError: raise ValueError(f"Unsupported filetype '{extension}' from file '{filepath}'.") diff --git a/python-sdk/src/astro/files/types/excel.py b/python-sdk/src/astro/files/types/excel.py new file mode 100644 index 000000000..1073deaaf --- /dev/null +++ b/python-sdk/src/astro/files/types/excel.py @@ -0,0 +1,46 @@ +from __future__ import annotations + +import io + +import pandas as pd + +from astro.dataframes.load_options import PandasLoadOptions +from astro.dataframes.pandas import PandasDataframe +from astro.files.types.base import FileType +from astro.utils.dataframe import convert_columns_names_capitalization + + +class ExcelFileType(FileType): + """Concrete implementation to handle Excel file type""" + + LOAD_OPTIONS_CLASS_NAME = ("PandasLoadOptions",) + + # We need skipcq because it's a method overloading so we don't want to make it a static method + def export_to_dataframe( + self, + stream, + columns_names_capitalization="original", + **kwargs, + ) -> pd.DataFrame: # skipcq PYL-R0201 + """read Excel file from one of the supported locations and return dataframe + + :param stream: file stream object + :param columns_names_capitalization: determines whether to convert all columns to lowercase/uppercase + in the resulting dataframe + """ + if isinstance(self.load_options, PandasLoadOptions): + kwargs = self.load_options.populate_kwargs(kwargs) + df = pd.read_excel(stream, **kwargs) + df = convert_columns_names_capitalization( + df=df, columns_names_capitalization=columns_names_capitalization + ) + return PandasDataframe.from_pandas_df(df) + + # We need skipcq because it's a method overloading so we don't want to make it a static method + def create_from_dataframe(self, df: pd.DataFrame, stream: io.TextIOWrapper) -> None: # skipcq PYL-R0201 + """Write Excel file to one of the supported locations + + :param df: pandas dataframe + :param stream: file stream object + """ + df.to_excel(stream, index=False) diff --git a/python-sdk/src/astro/files/types/xls.py b/python-sdk/src/astro/files/types/xls.py new file mode 100644 index 000000000..4f54823d3 --- /dev/null +++ b/python-sdk/src/astro/files/types/xls.py @@ -0,0 +1,8 @@ +from astro.constants import FileType as FileTypeConstants +from astro.files.types import ExcelFileType + + +class XLSFileType(ExcelFileType): + @property + def name(self): + return FileTypeConstants.XLS diff --git a/python-sdk/src/astro/files/types/xlsx.py b/python-sdk/src/astro/files/types/xlsx.py new file mode 100644 index 000000000..c743db3ba --- /dev/null +++ b/python-sdk/src/astro/files/types/xlsx.py @@ -0,0 +1,8 @@ +from astro.constants import FileType as FileTypeConstants +from astro.files.types import ExcelFileType + + +class XLSXFileType(ExcelFileType): + @property + def name(self): + return FileTypeConstants.XLSX diff --git a/python-sdk/tests/data/sample.xlsx b/python-sdk/tests/data/sample.xlsx new file mode 100644 index 000000000..4b6ee8fb2 Binary files /dev/null and b/python-sdk/tests/data/sample.xlsx differ diff --git a/python-sdk/tests/files/test_file.py b/python-sdk/tests/files/test_file.py index d86eed362..4fb1a5492 100644 --- a/python-sdk/tests/files/test_file.py +++ b/python-sdk/tests/files/test_file.py @@ -24,6 +24,7 @@ (False, "/tmp/sample.json"), (False, "/tmp/sample.ndjson"), (True, "/tmp/sample.parquet"), + (True, "/tmp/sample.xlsx"), ] @@ -49,12 +50,13 @@ def test_is_binary(filetype): (False, "/tmp/sample.json"), (False, "/tmp/sample.ndjson"), (False, "/tmp/sample.parquet"), + (False, "/tmp/sample.xlsx"), (True, "/tmp/"), (True, "s3://tmp/home_*"), (False, "s3://tmp/.folder/sample.csv"), (True, "s3://tmp/.folder/"), ], - ids=["csv", "json", "ndjson", "parquet", "csv", "json", "csv", "json"], + ids=["csv", "json", "ndjson", "parquet", "xlsx", "csv", "json", "csv", "json"], ) def test_is_pattern(filetype): """Test if the file is a file pattern""" @@ -226,8 +228,9 @@ def test_if_file_object_can_be_pickled(): {"type": "ndjson", "expected_class": PandasLoadOptions}, {"type": "json", "expected_class": PandasLoadOptions}, {"type": "parquet", "expected_class": PandasLoadOptions}, + {"type": "xlsx", "expected_class": PandasLoadOptions}, ], - ids=["csv", "ndjson", "json", "parquet"], + ids=["csv", "ndjson", "json", "parquet", "xlsx"], ) @pytest.mark.parametrize( "file_location", diff --git a/python-sdk/tests/files/type/test_excel.py b/python-sdk/tests/files/type/test_excel.py new file mode 100644 index 000000000..1cea09974 --- /dev/null +++ b/python-sdk/tests/files/type/test_excel.py @@ -0,0 +1,46 @@ +import pathlib +import tempfile +from unittest import mock + +import pandas as pd + +from astro.dataframes.load_options import PandasLoadOptions +from astro.dataframes.pandas import PandasDataframe +from astro.files.types import XLSXFileType + +sample_file = pathlib.Path(pathlib.Path(__file__).parent.parent.parent, "data/sample.xlsx") + + +def test_read_excel_file(): + """Test reading of excel file from local location""" + path = str(sample_file.absolute()) + excel_type = XLSXFileType(path) + with open(path, "rb") as file: + df = excel_type.export_to_dataframe(file) + assert df.shape == (3, 2) + assert isinstance(df, PandasDataframe) + + +@mock.patch("astro.files.types.excel.pd.read_excel") +def test_read_excel_file_with_pandas_opts(mock_read_excel): + """Test pandas option get pass to read_excel""" + path = str(sample_file.absolute()) + excel_type = XLSXFileType(path, load_options=PandasLoadOptions()) + with open(path, "rb") as file: + excel_type.export_to_dataframe(file) + mock_read_excel.assert_called_once_with(file) + + +def test_write_excel_file(): + """Test writing of excel file from local location""" + with tempfile.NamedTemporaryFile() as temp_file: + path = temp_file.name + data = { + "id": [1, 2, 3], + "name": ["First", "Second", "Third with unicode पांचाल"], + } + df = pd.DataFrame(data=data) + + excel_type = XLSXFileType(path) + excel_type.create_from_dataframe(stream=temp_file, df=df) + assert pd.read_excel(path).shape == (3, 2) diff --git a/python-sdk/tests/files/type/test_type_base.py b/python-sdk/tests/files/type/test_type_base.py index 388514b77..224ad878c 100644 --- a/python-sdk/tests/files/type/test_type_base.py +++ b/python-sdk/tests/files/type/test_type_base.py @@ -11,6 +11,7 @@ (FileType.JSON, "sample.json"), (FileType.NDJSON, "sample.ndjson"), (FileType.PARQUET, "sample.parquet"), + (FileType.XLSX, "sample.xlsx"), ] sample_filetypes = [items[0] for items in sample_filepaths_per_filetype] sample_filepaths = [items[1] for items in sample_filepaths_per_filetype] diff --git a/python-sdk/tests/test_constants.py b/python-sdk/tests/test_constants.py index 58e437493..6dc7d74a1 100644 --- a/python-sdk/tests/test_constants.py +++ b/python-sdk/tests/test_constants.py @@ -7,7 +7,7 @@ def test_supported_file_locations(): def test_supported_file_types(): - expected = {"csv", "json", "ndjson", "parquet"} + expected = {"csv", "json", "ndjson", "parquet", "xls", "xlsx"} assert set(SUPPORTED_FILE_TYPES) == expected diff --git a/python-sdk/tests_integration/conftest.py b/python-sdk/tests_integration/conftest.py index 1a5106ff5..63801a86a 100644 --- a/python-sdk/tests_integration/conftest.py +++ b/python-sdk/tests_integration/conftest.py @@ -301,13 +301,19 @@ def method_map_fixture(method, base_path, classes, get): def type_method_map_fixture(request): """Get paths for type's package for methods""" method = request.param["method"] - classes = ["JSONFileType", "CSVFileType", "NDJSONFileType", "ParquetFileType"] + classes = [ + "JSONFileType", + "CSVFileType", + "NDJSONFileType", + "ParquetFileType", + "XLSXFileType", + "XLSFileType", + ] base_path = ("astro.files.types",) - suffix = "FileType" yield method_map_fixture( method=method, classes=classes, base_path=base_path, - get=lambda x: FileType(x.rstrip(suffix).lower()), + get=lambda x: FileType(x[0:-8].lower()), # remove FileType suffix ) diff --git a/python-sdk/tests_integration/sql/operators/test_load_file.py b/python-sdk/tests_integration/sql/operators/test_load_file.py index 15b951e58..8ffaf6d15 100644 --- a/python-sdk/tests_integration/sql/operators/test_load_file.py +++ b/python-sdk/tests_integration/sql/operators/test_load_file.py @@ -834,7 +834,7 @@ def test_load_file_bigquery_error_out(sample_dag, database_table_fixture): indirect=True, ids=["snowflake", "bigquery", "postgresql", "sqlite", "redshift", "duckdb", "mysql"], ) -@pytest.mark.parametrize("file_type", ["parquet", "ndjson", "json", "csv"]) +@pytest.mark.parametrize("file_type", ["parquet", "ndjson", "json", "csv", "xlsx"]) def test_load_file(sample_dag, database_table_fixture, file_type): db, test_table = database_table_fixture diff --git a/python-sdk/tests_integration/sql/operators/utils.py b/python-sdk/tests_integration/sql/operators/utils.py index 191a1c7c5..7c67baa34 100644 --- a/python-sdk/tests_integration/sql/operators/utils.py +++ b/python-sdk/tests_integration/sql/operators/utils.py @@ -75,9 +75,11 @@ def load_to_dataframe(filepath, file_type): "csv": pd.read_csv, "json": pd.read_json, "ndjson": pd.read_json, + "xlsx": pd.read_excel, + "xls": pd.read_excel, } read_params = {"ndjson": {"lines": True}} - mode = {"parquet": "rb"} + mode = {"parquet": "rb", "xls": "rb", "xlsx": "rb"} with open(filepath, mode.get(file_type, "r")) as fp: return read[file_type](fp, **read_params.get(file_type, {}))