From 859863f626488048ee75d758057af1e8751d5d02 Mon Sep 17 00:00:00 2001 From: Edoardo Abati <29585319+EdAbati@users.noreply.github.com> Date: Wed, 16 Oct 2024 19:24:53 +0200 Subject: [PATCH] add to_py_scalar --- docs/api-reference/narwhals.md | 1 + narwhals/__init__.py | 2 ++ narwhals/dependencies.py | 5 +++++ narwhals/stable/v1/__init__.py | 2 ++ narwhals/translate.py | 29 ++++++++++++++++++++++++++++ tests/translate/to_py_scalar_test.py | 14 ++++++++++++++ 6 files changed, 53 insertions(+) create mode 100644 tests/translate/to_py_scalar_test.py diff --git a/docs/api-reference/narwhals.md b/docs/api-reference/narwhals.md index 044b20e0a..e13b57018 100644 --- a/docs/api-reference/narwhals.md +++ b/docs/api-reference/narwhals.md @@ -38,4 +38,5 @@ Here are the top-level functions available in Narwhals. - when - show_versions - to_native + - to_py_scalar show_source: false diff --git a/narwhals/__init__.py b/narwhals/__init__.py index 124f10c45..7b624cda4 100644 --- a/narwhals/__init__.py +++ b/narwhals/__init__.py @@ -54,6 +54,7 @@ from narwhals.translate import get_native_namespace from narwhals.translate import narwhalify from narwhals.translate import to_native +from narwhals.translate import to_py_scalar from narwhals.utils import is_ordered_categorical from narwhals.utils import maybe_align_index from narwhals.utils import maybe_convert_dtypes @@ -79,6 +80,7 @@ "maybe_reset_index", "maybe_set_index", "get_native_namespace", + "to_py_scalar", "all", "all_horizontal", "any_horizontal", diff --git a/narwhals/dependencies.py b/narwhals/dependencies.py index 144c57c8a..1f9ae19f5 100644 --- a/narwhals/dependencies.py +++ b/narwhals/dependencies.py @@ -46,6 +46,11 @@ def get_cudf() -> Any: return sys.modules.get("cudf", None) +def get_cupy() -> Any: + """Get cupy module (if already imported - else return None).""" + return sys.modules.get("cupy", None) + + def get_pyarrow() -> Any: # pragma: no cover """Get pyarrow module (if already imported - else return None).""" return sys.modules.get("pyarrow", None) diff --git a/narwhals/stable/v1/__init__.py b/narwhals/stable/v1/__init__.py index e9aac4cf4..e9f313afd 100644 --- a/narwhals/stable/v1/__init__.py +++ b/narwhals/stable/v1/__init__.py @@ -49,6 +49,7 @@ from narwhals.translate import _from_native_impl from narwhals.translate import get_native_namespace as nw_get_native_namespace from narwhals.translate import to_native +from narwhals.translate import to_py_scalar from narwhals.typing import IntoDataFrameT from narwhals.typing import IntoFrameT from narwhals.utils import is_ordered_categorical as nw_is_ordered_categorical @@ -2249,6 +2250,7 @@ def from_dict( "dependencies", "to_native", "from_native", + "to_py_scalar", "is_ordered_categorical", "maybe_align_index", "maybe_convert_dtypes", diff --git a/narwhals/translate.py b/narwhals/translate.py index 4c23f6d91..4566bda32 100644 --- a/narwhals/translate.py +++ b/narwhals/translate.py @@ -9,6 +9,7 @@ from typing import overload from narwhals.dependencies import get_cudf +from narwhals.dependencies import get_cupy from narwhals.dependencies import get_dask from narwhals.dependencies import get_dask_expr from narwhals.dependencies import get_modin @@ -775,8 +776,36 @@ def wrapper(*args: Any, **kwargs: Any) -> Any: return decorator(func) +def to_py_scalar(scalar_like: Any) -> Any: + """If a scalar is not Python native, tries to convert it to Python native. + + Examples: + >>> import narwhals as nw + >>> import pandas as pd + >>> df = nw.from_native(pd.DataFrame({"a": [1, 2, 3]})) + >>> nw.to_py_scalar(df["a"].item(0)) + 1 + >>> import pyarrow as pa + >>> df = nw.from_native(pa.table({"a": [1, 2, 3]})) + >>> nw.to_py_scalar(df["a"].item(0)) + 1 + >>> nw.to_py_scalar(1) + 1 + """ + pa = get_pyarrow() + if pa and isinstance(scalar_like, pa.Scalar): + return scalar_like.as_py() + + cupy = get_cupy() + if cupy and isinstance(scalar_like, cupy.ndarray) and scalar_like.size == 1: + return scalar_like.item() + + return scalar_like + + __all__ = [ "get_native_namespace", "to_native", "narwhalify", + "to_py_scalar", ] diff --git a/tests/translate/to_py_scalar_test.py b/tests/translate/to_py_scalar_test.py new file mode 100644 index 000000000..3d8bdadef --- /dev/null +++ b/tests/translate/to_py_scalar_test.py @@ -0,0 +1,14 @@ +import narwhals as nw +from narwhals.dependencies import get_cupy +from tests.utils import Constructor + + +def test_to_py_scalar(constructor_eager: Constructor) -> None: + df = nw.from_native(constructor_eager({"a": [1, 2, 3]})) + assert nw.to_py_scalar(df["a"].item(0)) == 1 + + +def test_to_py_scalar_cudf_array(constructor_cudf: Constructor) -> None: + df = nw.from_native(constructor_cudf({"a": [1, 2, 3]})) + if cupy := get_cupy(): + assert isinstance(nw.to_py_scalar(df["a"]), cupy.ndarray)