diff --git a/daft/convert.py b/daft/convert.py index 4bebd7220b..e734d0a176 100644 --- a/daft/convert.py +++ b/daft/convert.py @@ -1,6 +1,6 @@ # isort: dont-add-import: from __future__ import annotations -from typing import TYPE_CHECKING, Any, Dict, List, Union +from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Union from daft.api_annotations import PublicAPI @@ -55,7 +55,7 @@ def from_pydict(data: Dict[str, InputListType]) -> "DataFrame": @PublicAPI -def from_arrow(data: Union["pa.Table", List["pa.Table"]]) -> "DataFrame": +def from_arrow(data: Union["pa.Table", List["pa.Table"], Iterable["pa.Table"]]) -> "DataFrame": """Creates a DataFrame from a pyarrow Table. Example: diff --git a/daft/dataframe/dataframe.py b/daft/dataframe/dataframe.py index c5d9a6ad01..c5a1948bf0 100644 --- a/daft/dataframe/dataframe.py +++ b/daft/dataframe/dataframe.py @@ -312,8 +312,10 @@ def _from_pydict(cls, data: Dict[str, InputListType]) -> "DataFrame": return cls._from_tables(data_vpartition) @classmethod - def _from_arrow(cls, data: Union["pyarrow.Table", List["pyarrow.Table"]]) -> "DataFrame": + def _from_arrow(cls, data: Union["pyarrow.Table", List["pyarrow.Table"], Iterable["pyarrow.Table"]]) -> "DataFrame": """Creates a DataFrame from a `pyarrow Table `__.""" + if isinstance(data, Iterable): + data = list(data) if not isinstance(data, list): data = [data] data_vpartitions = [MicroPartition.from_arrow(table) for table in data] diff --git a/tests/table/test_from_py.py b/tests/table/test_from_py.py index 95ef19b884..216b66dae2 100644 --- a/tests/table/test_from_py.py +++ b/tests/table/test_from_py.py @@ -9,6 +9,7 @@ import pyarrow.compute as pac import pytest +import daft from daft import DataType, TimeUnit from daft.context import get_context from daft.series import Series @@ -649,3 +650,17 @@ def test_nested_struct_dates(levels: int) -> None: assert back_again.to_arrow().type == expected_arrow_type assert back_again.to_pylist() == data + + +def test_from_arrow_iterable() -> None: + class CustomIterable: + def __iter__(self): + yield pa.table({"text": ["foo1", "bar2"]}) + yield pa.table({"text": ["foo2", "bar2"]}) + yield pa.table({"text": ["foo3", "bar3"]}) + + my_iter = CustomIterable() + + table = daft.from_arrow(my_iter) + tbl = table.to_pydict() + assert tbl == {"text": ["foo1", "bar2", "foo2", "bar2", "foo3", "bar3"]}