Skip to content

Commit

Permalink
add to_py_scalar
Browse files Browse the repository at this point in the history
  • Loading branch information
EdAbati committed Oct 16, 2024
1 parent 0e12138 commit 859863f
Show file tree
Hide file tree
Showing 6 changed files with 53 additions and 0 deletions.
1 change: 1 addition & 0 deletions docs/api-reference/narwhals.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,4 +38,5 @@ Here are the top-level functions available in Narwhals.
- when
- show_versions
- to_native
- to_py_scalar
show_source: false
2 changes: 2 additions & 0 deletions narwhals/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
from narwhals.translate import get_native_namespace
from narwhals.translate import narwhalify
from narwhals.translate import to_native
from narwhals.translate import to_py_scalar
from narwhals.utils import is_ordered_categorical
from narwhals.utils import maybe_align_index
from narwhals.utils import maybe_convert_dtypes
Expand All @@ -79,6 +80,7 @@
"maybe_reset_index",
"maybe_set_index",
"get_native_namespace",
"to_py_scalar",
"all",
"all_horizontal",
"any_horizontal",
Expand Down
5 changes: 5 additions & 0 deletions narwhals/dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,11 @@ def get_cudf() -> Any:
return sys.modules.get("cudf", None)


def get_cupy() -> Any:
"""Get cupy module (if already imported - else return None)."""
return sys.modules.get("cupy", None)


def get_pyarrow() -> Any: # pragma: no cover
"""Get pyarrow module (if already imported - else return None)."""
return sys.modules.get("pyarrow", None)
Expand Down
2 changes: 2 additions & 0 deletions narwhals/stable/v1/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
from narwhals.translate import _from_native_impl
from narwhals.translate import get_native_namespace as nw_get_native_namespace
from narwhals.translate import to_native
from narwhals.translate import to_py_scalar
from narwhals.typing import IntoDataFrameT
from narwhals.typing import IntoFrameT
from narwhals.utils import is_ordered_categorical as nw_is_ordered_categorical
Expand Down Expand Up @@ -2249,6 +2250,7 @@ def from_dict(
"dependencies",
"to_native",
"from_native",
"to_py_scalar",
"is_ordered_categorical",
"maybe_align_index",
"maybe_convert_dtypes",
Expand Down
29 changes: 29 additions & 0 deletions narwhals/translate.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from typing import overload

from narwhals.dependencies import get_cudf
from narwhals.dependencies import get_cupy
from narwhals.dependencies import get_dask
from narwhals.dependencies import get_dask_expr
from narwhals.dependencies import get_modin
Expand Down Expand Up @@ -775,8 +776,36 @@ def wrapper(*args: Any, **kwargs: Any) -> Any:
return decorator(func)


def to_py_scalar(scalar_like: Any) -> Any:
"""If a scalar is not Python native, tries to convert it to Python native.
Examples:
>>> import narwhals as nw
>>> import pandas as pd
>>> df = nw.from_native(pd.DataFrame({"a": [1, 2, 3]}))
>>> nw.to_py_scalar(df["a"].item(0))
1
>>> import pyarrow as pa
>>> df = nw.from_native(pa.table({"a": [1, 2, 3]}))
>>> nw.to_py_scalar(df["a"].item(0))
1
>>> nw.to_py_scalar(1)
1
"""
pa = get_pyarrow()
if pa and isinstance(scalar_like, pa.Scalar):
return scalar_like.as_py()

cupy = get_cupy()
if cupy and isinstance(scalar_like, cupy.ndarray) and scalar_like.size == 1:
return scalar_like.item()

return scalar_like


__all__ = [
"get_native_namespace",
"to_native",
"narwhalify",
"to_py_scalar",
]
14 changes: 14 additions & 0 deletions tests/translate/to_py_scalar_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
import narwhals as nw
from narwhals.dependencies import get_cupy
from tests.utils import Constructor


def test_to_py_scalar(constructor_eager: Constructor) -> None:
df = nw.from_native(constructor_eager({"a": [1, 2, 3]}))
assert nw.to_py_scalar(df["a"].item(0)) == 1


def test_to_py_scalar_cudf_array(constructor_cudf: Constructor) -> None:
df = nw.from_native(constructor_cudf({"a": [1, 2, 3]}))
if cupy := get_cupy():
assert isinstance(nw.to_py_scalar(df["a"]), cupy.ndarray)

0 comments on commit 859863f

Please sign in to comment.