Skip to content

Commit

Permalink
fix: ensure that __copy__ and __deepcopy__ are enabled. (#1695)
Browse files Browse the repository at this point in the history
* fix: ensure that  and  are enabled.

* Implemented all the (missing) copy/deepcopy methods.

* refactor: use private parameters where possible

* fix: use `_length` not `_zeros_length` in `RegularArray.copy()``

NB. I did think about introducing a `zeros_length` attribute, but I am concerned that we would need to take a great deal of care in handling this parameter for no good cause. Instead, we should make it such that the `length` parameter is the only externally visible one.

* refactor: replace `Content.__deepcopy__` with abstract method

This will ensure we *have* to implement this routine. Although we could use a dynamic deepcopy, it opens us up to more bugs (explicit vs implicit) and hides some of the copying logic.

* refactor: remove unneeded optional in `__deepcopy__` signature.

* fix: use `copy.deepcopy` to dispatch copying

This avoids handling deep copying in two places

* fix: `deepcopy(behavior)` in `Array/Record.__deepcopy__`

* fix: use `copy.deepcopy` in tests

Co-authored-by: Angus Hollands <goosey15@gmail.com>
  • Loading branch information
jpivarski and agoose77 authored Sep 14, 2022
1 parent 8a338cb commit b5545a8
Show file tree
Hide file tree
Showing 19 changed files with 198 additions and 71 deletions.
11 changes: 11 additions & 0 deletions src/awkward/_v2/contents/bitmaskedarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,17 @@ def copy(
self._nplike if nplike is unset else nplike,
)

def __copy__(self):
return self.copy()

def __deepcopy__(self, memo):
return self.copy(
mask=copy.deepcopy(self._mask, memo),
content=copy.deepcopy(self._content, memo),
identifier=copy.deepcopy(self._identifier, memo),
parameters=copy.deepcopy(self._parameters, memo),
)

def __init__(
self,
mask,
Expand Down
11 changes: 11 additions & 0 deletions src/awkward/_v2/contents/bytemaskedarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,17 @@ def copy(
self._nplike if nplike is unset else nplike,
)

def __copy__(self):
return self.copy()

def __deepcopy__(self, memo):
return self.copy(
mask=copy.deepcopy(self._mask, memo),
content=copy.deepcopy(self._content, memo),
identifier=copy.deepcopy(self._identifier, memo),
parameters=copy.deepcopy(self._parameters, memo),
)

def __init__(
self, mask, content, valid_when, identifier=None, parameters=None, nplike=None
):
Expand Down
15 changes: 5 additions & 10 deletions src/awkward/_v2/contents/content.py
Original file line number Diff line number Diff line change
Expand Up @@ -1559,16 +1559,11 @@ def with_parameter(self, key, value):

return out

def __deepcopy__(self, memo=None):
cls = self.__class__
new_instance = cls.__new__(cls)
memo = {id(self): new_instance}
for k, v in self.__dict__.items():
if k == "_nplike":
new_instance._nplike = v
else:
setattr(new_instance, k, copy.deepcopy(v, memo))
return new_instance
def __copy__(self):
raise ak._v2._util.error(NotImplementedError)

def __deepcopy__(self, memo):
raise ak._v2._util.error(NotImplementedError)

def _jax_flatten(self):
from awkward._v2._connect.jax import _find_numpyarray_nodes, AuxData
Expand Down
11 changes: 11 additions & 0 deletions src/awkward/_v2/contents/emptyarray.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
# BSD 3-Clause License; see https://github.com/scikit-hep/awkward-1.0/blob/main/LICENSE

import copy

import awkward as ak
from awkward._v2.contents.content import Content, unset
from awkward._v2.forms.emptyform import EmptyForm
Expand All @@ -24,6 +26,15 @@ def copy(
self._nplike if nplike is unset else nplike,
)

def __copy__(self):
return self.copy()

def __deepcopy__(self, memo):
return self.copy(
identifier=copy.deepcopy(self._identifier, memo),
parameters=copy.deepcopy(self._parameters, memo),
)

def __init__(self, identifier=None, parameters=None, nplike=None):
if nplike is None:
nplike = numpy
Expand Down
11 changes: 11 additions & 0 deletions src/awkward/_v2/contents/indexedarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,17 @@ def copy(
self._nplike if nplike is unset else nplike,
)

def __copy__(self):
return self.copy()

def __deepcopy__(self, memo):
return self.copy(
index=copy.deepcopy(self._index, memo),
content=copy.deepcopy(self._content, memo),
identifier=copy.deepcopy(self._identifier, memo),
parameters=copy.deepcopy(self._parameters, memo),
)

def __init__(self, index, content, identifier=None, parameters=None, nplike=None):
if not (
isinstance(index, Index)
Expand Down
11 changes: 11 additions & 0 deletions src/awkward/_v2/contents/indexedoptionarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,17 @@ def copy(
self._nplike if nplike is unset else nplike,
)

def __copy__(self):
return self.copy()

def __deepcopy__(self, memo):
return self.copy(
index=copy.deepcopy(self._index, memo),
content=copy.deepcopy(self._content, memo),
identifier=copy.deepcopy(self._identifier, memo),
parameters=copy.deepcopy(self._parameters, memo),
)

def __init__(self, index, content, identifier=None, parameters=None, nplike=None):
if not (
isinstance(index, Index)
Expand Down
12 changes: 12 additions & 0 deletions src/awkward/_v2/contents/listarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,18 @@ def copy(
self._nplike if nplike is unset else nplike,
)

def __copy__(self):
return self.copy()

def __deepcopy__(self, memo):
return self.copy(
starts=copy.deepcopy(self._starts, memo),
stops=copy.deepcopy(self._stops, memo),
content=copy.deepcopy(self._content, memo),
identifier=copy.deepcopy(self._identifier, memo),
parameters=copy.deepcopy(self._parameters, memo),
)

def __init__(
self, starts, stops, content, identifier=None, parameters=None, nplike=None
):
Expand Down
11 changes: 11 additions & 0 deletions src/awkward/_v2/contents/listoffsetarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,17 @@ def copy(
self._nplike if nplike is unset else nplike,
)

def __copy__(self):
return self.copy()

def __deepcopy__(self, memo):
return self.copy(
offsets=copy.deepcopy(self._offsets, memo),
content=copy.deepcopy(self._content, memo),
identifier=copy.deepcopy(self._identifier, memo),
parameters=copy.deepcopy(self._parameters, memo),
)

def __init__(self, offsets, content, identifier=None, parameters=None, nplike=None):
if not isinstance(offsets, Index) and offsets.dtype in (
np.dtype(np.int32),
Expand Down
19 changes: 11 additions & 8 deletions src/awkward/_v2/contents/numpyarray.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# BSD 3-Clause License; see https://github.com/scikit-hep/awkward-1.0/blob/main/LICENSE

import copy

import awkward as ak
from awkward._v2.contents.content import Content, unset
from awkward._v2.forms.numpyform import NumpyForm
Expand All @@ -27,6 +28,16 @@ def copy(
self._nplike if nplike is unset else nplike,
)

def __copy__(self):
return self.copy()

def __deepcopy__(self, memo):
return self.copy(
data=copy.deepcopy(self._data, memo),
identifier=copy.deepcopy(self._identifier, memo),
parameters=copy.deepcopy(self._parameters, memo),
)

def __init__(self, data, identifier=None, parameters=None, nplike=None):
if nplike is None:
nplike = ak.nplike.of(data)
Expand Down Expand Up @@ -1364,14 +1375,6 @@ def _to_nplike(self, nplike):
nplike=nplike,
)

def __deepcopy__(self, memo=None):
return ak._v2.contents.NumpyArray(
copy.deepcopy(self._data),
copy.deepcopy(self._identifier),
copy.deepcopy(self._parameters),
self._nplike,
)

def _layout_equal(self, other, index_dtype=True, numpyarray=True):
if numpyarray:
return (
Expand Down
11 changes: 11 additions & 0 deletions src/awkward/_v2/contents/recordarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,17 @@ def copy(
self._nplike if nplike is unset else nplike,
)

def __copy__(self):
return self.copy()

def __deepcopy__(self, memo):
return self.copy(
contents=[copy.deepcopy(x, memo) for x in self._contents],
fields=copy.deepcopy(self._fields, memo),
identifier=copy.deepcopy(self._identifier, memo),
parameters=copy.deepcopy(self._parameters, memo),
)

def __init__(
self,
contents,
Expand Down
Loading

0 comments on commit b5545a8

Please sign in to comment.