Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds copy parameter to __array__ for numpy 2.0 #9393

Merged
merged 3 commits into from
Aug 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,9 @@ Bug fixes
- Fix issue with passing parameters to ZarrStore.open_store when opening
datatree in zarr format (:issue:`9376`, :pull:`9377`).
By `Alfonso Ladino <https://github.com/aladinor>`_
- Fix deprecation warning that was raised when calling ``np.array`` on an ``xr.DataArray``
in NumPy 2.0 (:issue:`9312`, :pull:`9393`)
By `Andrew Scherer <https://github.com/andrew-s28>`_.

Documentation
~~~~~~~~~~~~~
Expand Down
18 changes: 16 additions & 2 deletions xarray/core/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,8 +162,22 @@ def __int__(self: Any) -> int:
def __complex__(self: Any) -> complex:
return complex(self.values)

def __array__(self: Any, dtype: DTypeLike | None = None) -> np.ndarray:
return np.asarray(self.values, dtype=dtype)
def __array__(
self: Any, dtype: DTypeLike | None = None, copy: bool | None = None
) -> np.ndarray:
if not copy:
if np.lib.NumpyVersion(np.__version__) >= "2.0.0":
copy = None
elif np.lib.NumpyVersion(np.__version__) <= "1.28.0":
copy = False
else:
# 2.0.0 dev versions, handle cases where copy may or may not exist
try:
np.array([1]).__array__(copy=None)
copy = None
except TypeError:
copy = False
return np.array(self.values, dtype=dtype, copy=copy)

def __repr__(self) -> str:
return formatting.array_repr(self)
Expand Down
8 changes: 8 additions & 0 deletions xarray/tests/test_dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -7065,6 +7065,14 @@ def test_from_numpy(self) -> None:
np.testing.assert_equal(da.to_numpy(), np.array([1, 2, 3]))
np.testing.assert_equal(da["lat"].to_numpy(), np.array([4, 5, 6]))

def test_to_numpy(self) -> None:
arr = np.array([1, 2, 3])
da = xr.DataArray(arr, dims="x", coords={"lat": ("x", [4, 5, 6])})

with assert_no_warnings():
np.testing.assert_equal(np.asarray(da), arr)
np.testing.assert_equal(np.array(da), arr)

@requires_dask
def test_from_dask(self) -> None:
da = xr.DataArray([1, 2, 3], dims="x", coords={"lat": ("x", [4, 5, 6])})
Expand Down
Loading