From 436c3233c4cbd1108321b68eae575a513638ea23 Mon Sep 17 00:00:00 2001 From: Ben Rutter Date: Wed, 21 Aug 2024 14:54:40 +0100 Subject: [PATCH] concat support Update narwhals/_dask/namespace.py Co-authored-by: Francesco Bruzzesi <42817048+FBruzzesi@users.noreply.github.com> import change inner kwarg --- narwhals/_dask/namespace.py | 40 ++++++++++++++++++++++++++++++++++++- tests/frame/concat_test.py | 8 ++------ 2 files changed, 41 insertions(+), 7 deletions(-) diff --git a/narwhals/_dask/namespace.py b/narwhals/_dask/namespace.py index 2baf1cf3f..4de29cb47 100644 --- a/narwhals/_dask/namespace.py +++ b/narwhals/_dask/namespace.py @@ -1,18 +1,21 @@ from __future__ import annotations from functools import reduce +from itertools import chain from typing import TYPE_CHECKING from typing import Any from typing import Callable +from typing import Iterable from typing import NoReturn from narwhals import dtypes +from narwhals._dask.dataframe import DaskLazyFrame from narwhals._dask.expr import DaskExpr from narwhals._dask.selectors import DaskSelectorNamespace from narwhals._expression_parsing import parse_into_exprs +from narwhals.dependencies import get_dask_dataframe if TYPE_CHECKING: - from narwhals._dask.dataframe import DaskLazyFrame from narwhals._dask.typing import IntoDaskExpr @@ -135,6 +138,41 @@ def any_horizontal(self, *exprs: IntoDaskExpr) -> DaskExpr: def sum_horizontal(self, *exprs: IntoDaskExpr) -> DaskExpr: return reduce(lambda x, y: x + y, parse_into_exprs(*exprs, namespace=self)) + def concat( + self, + items: Iterable[DaskLazyFrame], + *, + how: str = "vertical", + ) -> DaskLazyFrame: + dd = get_dask_dataframe() + + if len(list(items)) == 0: + msg = "No items to concatenate" + raise ValueError(msg) + native_frames = [i._native_frame for i in items] + axis: int + if how == "vertical": + all_columns = set(chain(*[i.columns for i in native_frames])) + if not all(set(i.columns) == all_columns for i in native_frames): + msg = "unable to vstack with non-matching columns" + raise TypeError(msg) + axis = 0 + elif how == "horizontal": + if len({len(i) for i in native_frames}) != 1: + msg = "cannot vertically concatenate dataframes of different length" + raise ValueError(msg) + axis = 1 + else: + msg = ( + "Only valid options for concat are 'vertical' and 'horizontal' " + f"({how} not recognised)" + ) + raise NotImplementedError(msg) + return DaskLazyFrame( + dd.concat(native_frames, axis=axis, join="inner"), + backend_version=self._backend_version, + ) + def _create_expr_from_series(self, _: Any) -> NoReturn: msg = "`_create_expr_from_series` for DaskNamespace exists only for compatibility" raise NotImplementedError(msg) diff --git a/tests/frame/concat_test.py b/tests/frame/concat_test.py index 970220bf2..94b8d7de1 100644 --- a/tests/frame/concat_test.py +++ b/tests/frame/concat_test.py @@ -6,9 +6,7 @@ from tests.utils import compare_dicts -def test_concat_horizontal(constructor: Any, request: Any) -> None: - if "dask" in str(constructor): - request.applymarker(pytest.mark.xfail) +def test_concat_horizontal(constructor: Any) -> None: data = {"a": [1, 3, 2], "b": [4, 4, 6], "z": [7.0, 8, 9]} df_left = nw.from_native(constructor(data)) @@ -29,9 +27,7 @@ def test_concat_horizontal(constructor: Any, request: Any) -> None: nw.concat([]) -def test_concat_vertical(constructor: Any, request: Any) -> None: - if "dask" in str(constructor): - request.applymarker(pytest.mark.xfail) +def test_concat_vertical(constructor: Any) -> None: data = {"a": [1, 3, 2], "b": [4, 4, 6], "z": [7.0, 8, 9]} df_left = ( nw.from_native(constructor(data)).rename({"a": "c", "b": "d"}).drop("z").lazy()