Skip to content

Commit

Permalink
concat support
Browse files Browse the repository at this point in the history
Update narwhals/_dask/namespace.py

Co-authored-by: Francesco Bruzzesi <42817048+FBruzzesi@users.noreply.github.com>

import change

inner kwarg
  • Loading branch information
benrutter committed Aug 21, 2024
1 parent 48148b4 commit 436c323
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 7 deletions.
40 changes: 39 additions & 1 deletion narwhals/_dask/namespace.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,21 @@
from __future__ import annotations

from functools import reduce
from itertools import chain
from typing import TYPE_CHECKING
from typing import Any
from typing import Callable
from typing import Iterable
from typing import NoReturn

from narwhals import dtypes
from narwhals._dask.dataframe import DaskLazyFrame
from narwhals._dask.expr import DaskExpr
from narwhals._dask.selectors import DaskSelectorNamespace
from narwhals._expression_parsing import parse_into_exprs
from narwhals.dependencies import get_dask_dataframe

if TYPE_CHECKING:
from narwhals._dask.dataframe import DaskLazyFrame
from narwhals._dask.typing import IntoDaskExpr


Expand Down Expand Up @@ -135,6 +138,41 @@ def any_horizontal(self, *exprs: IntoDaskExpr) -> DaskExpr:
def sum_horizontal(self, *exprs: IntoDaskExpr) -> DaskExpr:
return reduce(lambda x, y: x + y, parse_into_exprs(*exprs, namespace=self))

def concat(
self,
items: Iterable[DaskLazyFrame],
*,
how: str = "vertical",
) -> DaskLazyFrame:
dd = get_dask_dataframe()

if len(list(items)) == 0:
msg = "No items to concatenate"
raise ValueError(msg)
native_frames = [i._native_frame for i in items]
axis: int
if how == "vertical":
all_columns = set(chain(*[i.columns for i in native_frames]))
if not all(set(i.columns) == all_columns for i in native_frames):
msg = "unable to vstack with non-matching columns"
raise TypeError(msg)
axis = 0
elif how == "horizontal":
if len({len(i) for i in native_frames}) != 1:
msg = "cannot vertically concatenate dataframes of different length"
raise ValueError(msg)
axis = 1
else:
msg = (
"Only valid options for concat are 'vertical' and 'horizontal' "
f"({how} not recognised)"
)
raise NotImplementedError(msg)
return DaskLazyFrame(
dd.concat(native_frames, axis=axis, join="inner"),
backend_version=self._backend_version,
)

def _create_expr_from_series(self, _: Any) -> NoReturn:
msg = "`_create_expr_from_series` for DaskNamespace exists only for compatibility"
raise NotImplementedError(msg)
Expand Down
8 changes: 2 additions & 6 deletions tests/frame/concat_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,7 @@
from tests.utils import compare_dicts


def test_concat_horizontal(constructor: Any, request: Any) -> None:
if "dask" in str(constructor):
request.applymarker(pytest.mark.xfail)
def test_concat_horizontal(constructor: Any) -> None:
data = {"a": [1, 3, 2], "b": [4, 4, 6], "z": [7.0, 8, 9]}
df_left = nw.from_native(constructor(data))

Expand All @@ -29,9 +27,7 @@ def test_concat_horizontal(constructor: Any, request: Any) -> None:
nw.concat([])


def test_concat_vertical(constructor: Any, request: Any) -> None:
if "dask" in str(constructor):
request.applymarker(pytest.mark.xfail)
def test_concat_vertical(constructor: Any) -> None:
data = {"a": [1, 3, 2], "b": [4, 4, 6], "z": [7.0, 8, 9]}
df_left = (
nw.from_native(constructor(data)).rename({"a": "c", "b": "d"}).drop("z").lazy()
Expand Down

0 comments on commit 436c323

Please sign in to comment.