Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: dask namespace concat method #840

Merged
merged 11 commits into from
Aug 28, 2024
44 changes: 43 additions & 1 deletion narwhals/_dask/namespace.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,12 @@
from typing import TYPE_CHECKING
from typing import Any
from typing import Callable
from typing import Iterable
from typing import NoReturn
from typing import cast

from narwhals import dtypes
from narwhals._dask.dataframe import DaskLazyFrame
from narwhals._dask.expr import DaskExpr
from narwhals._dask.selectors import DaskSelectorNamespace
from narwhals._dask.utils import validate_comparand
Expand All @@ -16,7 +18,6 @@
if TYPE_CHECKING:
import dask_expr

from narwhals._dask.dataframe import DaskLazyFrame
from narwhals._dask.typing import IntoDaskExpr


Expand Down Expand Up @@ -142,6 +143,47 @@ def sum_horizontal(self, *exprs: IntoDaskExpr) -> DaskExpr:
[expr.fill_null(0) for expr in parse_into_exprs(*exprs, namespace=self)],
)

def concat(
self,
items: Iterable[DaskLazyFrame],
*,
how: str = "vertical",
) -> DaskLazyFrame:
import dask.dataframe as dd # ignore-banned-import

if len(list(items)) == 0:
msg = "No items to concatenate" # pragma: no cover
raise AssertionError(msg)
native_frames = [i._native_frame for i in items]
if how == "vertical":
if not all(
tuple(i.columns) == tuple(native_frames[0].columns) for i in native_frames
): # pragma: no cover
msg = "unable to vstack with non-matching columns"
raise AssertionError(msg)
return DaskLazyFrame(
dd.concat(native_frames, axis=0, join="inner"),
backend_version=self._backend_version,
)
if how == "horizontal":
all_column_names: list[str] = [
column for frame in native_frames for column in frame.columns
]
if len(all_column_names) != len(set(all_column_names)): # pragma: no cover
duplicates = [
i for i in all_column_names if all_column_names.count(i) > 1
]
msg = (
f"Columns with name(s): {', '.join(duplicates)} "
"have more than one occurrence"
)
raise AssertionError(msg)
return DaskLazyFrame(
dd.concat(native_frames, axis=1, join="outer"),
backend_version=self._backend_version,
)
raise NotImplementedError

def mean_horizontal(self, *exprs: IntoDaskExpr) -> IntoDaskExpr:
dask_exprs = parse_into_exprs(*exprs, namespace=self)
total = reduce(lambda x, y: x + y, (e.fill_null(0.0) for e in dask_exprs))
Expand Down
8 changes: 2 additions & 6 deletions tests/frame/concat_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,7 @@
from tests.utils import compare_dicts


def test_concat_horizontal(constructor: Any, request: Any) -> None:
if "dask" in str(constructor):
request.applymarker(pytest.mark.xfail)
def test_concat_horizontal(constructor: Any) -> None:
data = {"a": [1, 3, 2], "b": [4, 4, 6], "z": [7.0, 8, 9]}
df_left = nw.from_native(constructor(data)).lazy()

Expand All @@ -29,9 +27,7 @@ def test_concat_horizontal(constructor: Any, request: Any) -> None:
nw.concat([])


def test_concat_vertical(constructor: Any, request: Any) -> None:
if "dask" in str(constructor):
request.applymarker(pytest.mark.xfail)
def test_concat_vertical(constructor: Any) -> None:
data = {"a": [1, 3, 2], "b": [4, 4, 6], "z": [7.0, 8, 9]}
df_left = (
nw.from_native(constructor(data)).lazy().rename({"a": "c", "b": "d"}).drop("z")
Expand Down
Loading